from __future__ import annotations
import numpy as np
from collections import defaultdict
from typing import TYPE_CHECKING, Union, Any, List, Optional, Dict, Callable
import importlib, inspect, functools, pathlib, time, re, ctypes
from tinygrad.helpers import (
ansilen,
DEBUG,
getenv,
GlobalCounters,
colored,
BEAM,
NOOPT,
all_int,
to_function_name,
DType,
from_mv,
dtypes,
flat_mv,
ImageDType,
round_up,
)
from tinygrad.shape.symbolic import Variable, sym_infer, sint
from tinygrad.ops import (
LazyOp,
TernaryOps,
get_lazyop_info,
ReduceOps,
BufferOps,
BinaryOps,
UnaryOps,
Op,
)
if TYPE_CHECKING:
from tinygrad.codegen.linearizer import Linearizer
from tinygrad.codegen.kernel import LinearizerOptions
class _Device:
"""
This class represents a device, which can be a CPU, GPU or any other hardware
accelerator that TinyGrad supports. It contains methods for getting the default
device and canonicalizing a device string.
The class uses functools.lru_cache to cache the results of certain methods for
better performance.
"""
def __init__(self) -> None:
"""
Initializes the Device class by populating the list of available
buffers. It looks for files in the 'runtime' directory that start with
'ops_' and adds their uppercase counterparts to the list.
"""
self._buffers: List[str] = [
x.stem[len("ops_") :].upper()
for x in (pathlib.Path(__file__).parent / "runtime").iterdir()
if x.stem.startswith("ops_")
]
def canonicalize(self, device: Optional[str]) -> str:
"""
Takes a device string and returns its canonicalized form by converting
it to uppercase and removing any trailing ":0" characters. If the device
is None, it returns the default device.
Args:
device (Optional[str]): The device string to canonicalize.
Returns:
str: The canonicalized device string.
"""
return (
(
device.split(":", 1)[0].upper()
+ ((":" + device.split(":", 1)[1]) if ":" in device else "")
).replace(":0", "")
if device is not None
else self.DEFAULT
)
def __getitem__(self, ix: str) -> Union[Interpreted, Compiled]:
"""
Gets the canonicalized item for a given device string.
Args:
ix (str): The device string.
Returns:
Union[Interpreted, Compiled]: The canonicalized item.
"""
return self.__get_canonicalized_item(self.canonicalize(ix))
@functools.lru_cache(
maxsize=None
) # this class is a singleton, pylint: disable=method-cache-max-size-none
def __get_canonicalized_item(self, ix: str) -> Union[Interpreted, Compiled]:
"""
Gets the canonicalized item for a given device string. This method is
cached to improve performance.
Args:
ix (str): The device string.
Returns:
Union[Interpreted, Compiled]: The canonicalized item.
"""
x = ix.split(":")[0].upper()
ret = [
cls
for cname, cls in inspect.getmembers(
importlib.import_module(f"tinygrad.runtime.ops_{x.lower()}")
)
if (cname.lower() == x.lower() + "device") and x in self._buffers
][0]
if isinstance(ret, type):
ret = ret(ix)
return ret
@functools.cached_property
def DEFAULT(self) -> str:
"""
Gets the default device by checking the TINYGRAD_DEVICE environment
variable and then looking for available devices in a specific order:
METAL, CUDA, GPU. If none of these are found, it defaults to "CPU".
Returns:
str: The default device string.
"""
device_from_env: Optional[str] = functools.reduce(lambda val, ele: ele if getenv(ele) == 1 else val, self._buffers, None) # type: ignore
if device_from_env:
return device_from_env
for device in ["METAL", "CUDA", "GPU"]:
try:
if self[device]:
return device
except Exception:
pass
return "CPU"
Device = _Device()
[docs]
class JITRunner:
"""
This class defines a Just-In-Time (JIT) runner that is responsible for executing operations and caching the results.
The primary method of interest here is the 'exec' method, which takes in a list of Buffer objects and an optional dictionary
of variable values. It returns an estimated time of execution.
"""
def __init__(self):
"""
Initializes a JITRunner object with default operation and memory estimates set to 0.
"""
self.op_estimate, self.mem_estimate = 0, 0
[docs]
def exec(
self, rawbufs: List[Buffer], var_vals: Optional[Dict[Variable, int]] = None
) -> Optional[float]:
"""
This method is responsible for executing the operations associated with a given list of Buffer objects and an optional
dictionary of variable values. It first checks if 'var_vals' is not None; if it is, an empty dictionary is created. Then,
it imports CacheCollector and adds the current JITRunner object along with the buffer and variable dictionaries to the cache.
:param rawbufs: A list of Buffer objects to be executed.
:param var_vals: An optional dictionary containing variable values. Default is None.
:return: An estimated time of execution as a float value or None if not available.
"""
var_vals = var_vals if var_vals is not None else {}
from tinygrad.jit import CacheCollector
et = self(rawbufs, var_vals)
CacheCollector.add(self, rawbufs, var_vals)
return et
def __call__(
self,
rawbufs: List[Buffer],
var_vals: Dict[Variable, int],
wait=False,
jit=False,
) -> Optional[float]:
"""
This method is not implemented and raises a NotImplementedError when called. It should be overridden by a subclass to
provide the actual implementation of operation execution.
:param rawbufs: A list of Buffer objects to be executed.
:param var_vals: A dictionary containing variable values.
:param wait: An optional boolean flag indicating whether to wait for the operations to complete before returning. Default is False.
:param jit: An optional boolean flag indicating whether to use Just-In-Time (JIT) compilation. Default is False.
:return: An estimated time of execution as a float value or None if not available.
"""
raise NotImplementedError("override this")
[docs]
def update_stats(
name: str,
op_estimate: sint,
mem_estimate: sint,
var_vals: Optional[Dict[Variable, int]],
et: Optional[float],
buf_count,
jit=False,
num_kernels=1,
lra: Optional[Dict] = None,
):
"""
This function updates the global counters for operations and memory usage, as well as prints debugging information if the DEBUG level is 2 or higher.
:param name: The name of the operation being executed.
:param op_estimate: An estimated number of operations to be executed.
:param mem_estimate: An estimated amount of memory to be used.
:param var_vals: An optional dictionary containing variable values. Default is None.
:param et: An optional estimated time of execution in seconds. Default is None.
:param buf_count: The number of buffers (i.e., the argument count) associated with the operation.
:param jit: An optional boolean flag indicating whether to use Just-In-Time (JIT) compilation. Default is False.
:param num_kernels: The number of kernels used in the operation. Default is 1.
:param lra: An optional dictionary containing local and global size information for the operation. Default is None.
"""
if var_vals is None:
var_vals = {}
op_estimate, mem_estimate = sym_infer(op_estimate, var_vals), sym_infer(
mem_estimate, var_vals
)
GlobalCounters.kernel_count += num_kernels
GlobalCounters.global_ops += op_estimate
GlobalCounters.global_mem += mem_estimate
if et is not None:
GlobalCounters.time_sum_s += et
if DEBUG >= 2:
print(
f"{colored(f'*** {GlobalCounters.kernel_count:4d}', ('magenta' if num_kernels == 1 else 'CYAN') if jit else None)} {name+' '*(37-ansilen(name))} arg {buf_count:3d} sz {str(lra.get('global_size', '') if lra else ''):18s} {str(lra.get('local_size', '') if lra else ''):12s} OPs {int(op_estimate/1e6):6d}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB "
+ (
str()
if et is None
else f"tm {et*1e6:9.2f}us/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({op_estimate/((et or 1e-20)*1e9):8.2f} GFLOPS, {mem_estimate/((et or 1e-20)*1e9):7.2f} GB/s)"
)
)
[docs]
class Buffer:
"""
This class represents a buffer that can be used to store data of a certain size and type. It also provides methods for copying data into and out of the buffer, as well as converting the buffer to a numpy array. The buffer is allocated using an allocator specific to the device it will be used with.
"""
def __init__(self, device: str, size: int, dtype: DType, opaque: Any = None):
"""
Initialize a new buffer.
:param device: The name of the device that this buffer will be used with.
:param size: The number of items to store in the buffer.
:param dtype: The type of data to store in the buffer.
:param opaque: An optional parameter for providing an existing buffer to use instead of allocating a new one.
"""
assert isinstance(dtype, DType)
self.device, self.size, self.dtype = device, size, dtype
self.allocator = Device[self.device].allocator
# TODO: image hack shouldn't be here. where should it be?
if isinstance(dtype, ImageDType) and hasattr(self.allocator, "_cast_image"):
assert opaque is None
row_pitch_items = round_up(dtype.shape[1], 256) * 4
self.size = (
row_pitch_items * dtype.shape[0]
) # adjust the size to include the image padding
self._real_buf = self.allocator.alloc(self.size * dtype.itemsize)
self._buf = self.allocator._cast_image(
self._real_buf, dtype, row_pitch_items * dtype.itemsize
)
else:
self._buf = (
opaque
if opaque is not None
else self.allocator.alloc(size * dtype.itemsize)
)
# TODO: mem_used for all devices
if self.device == Device.DEFAULT:
GlobalCounters.mem_used += self.size * self.dtype.itemsize
def __del__(self):
"""
Clean up the buffer when it is no longer needed, freeing its resources.
"""
if self.device == Device.DEFAULT:
GlobalCounters.mem_used -= self.size * self.dtype.itemsize
if isinstance(self.dtype, ImageDType):
self.allocator._free(self._buf)
self.allocator.free(self._real_buf, self.size * self.dtype.itemsize)
else:
self.allocator.free(self._buf, self.size * self.dtype.itemsize)
def __repr__(self):
"""
Return a string representation of the buffer.
:return: A string containing information about the buffer's device, size and data type.
"""
return f"<buf device:{self.device} size:{self.size} dtype:{self.dtype}>"
[docs]
def copyin(self, mv: memoryview):
"""
Copy data from a memory view into the buffer.
:param mv: The memory view to copy data from.
:return: The buffer.
"""
mv = flat_mv(mv)
assert (
len(mv) == self.size * self.dtype.itemsize
), f"size mismatch, {len(mv)=} != {self.dtype=} {self.size=}"
self.allocator.copyin(self._buf, mv)
return self
[docs]
@staticmethod
def fromCPU(device: str, x: np.ndarray):
"""
Create a new buffer and copy data from a numpy array into it.
:param device: The name of the device that the buffer will be used with.
:param x: The numpy array to copy data from.
:return: The newly created buffer.
"""
return Buffer(device, x.size, dtypes.from_np(x.dtype)).copyin(x.data)
[docs]
def toCPU(self) -> np.ndarray:
"""
Converts the data from GPU to CPU.
Checks if the allocator has a method 'as_buffer'. If it does, then the buffer is copied into a NumPy array with zero copy by using the 'frombuffer' function.
Otherwise, an empty NumPy array of size same as self.size and datatype same as self.dtype is created. The data from GPU buffer is then copied to this new NumPy array.
:return: A Numpy array containing the data which was previously on GPU.
"""
# zero copy with as_buffer
if hasattr(self.allocator, "as_buffer"):
return np.frombuffer(self.allocator.as_buffer(self._buf), dtype=np.dtype(self.dtype.np, metadata={"backing": self._buf})) # type: ignore
ret = np.empty(self.size, self.dtype.np)
if self.size > 0:
self.allocator.copyout(flat_mv(ret.data), self._buf)
return ret
class _BufferCopy(JITRunner):
# TODO: make wait work
def __call__(
self,
rawbufs: List[Buffer],
var_vals: Dict[Variable, int],
wait=False,
jit=False,
):
dest, src = rawbufs
assert (
dest.size == src.size and dest.dtype == src.dtype
), "buffer copy size/dtype mismatch"
if DEBUG >= 2:
print(
f"*** copy {dest.device} <- {src.device} size {dest.size:<16d} dtype {dest.dtype}"
)
if hasattr(dest.allocator, "transfer") and type(dest.allocator) is type(
src.allocator
):
# fast path, used on HIP between GPUs
# NOTE: it's important we use the dest device here to ensure the transfer is ready
dest.allocator.transfer(
dest._buf, src._buf, dest.size * dest.dtype.itemsize
)
return
if (
getenv("FROM_BUFFER")
and hasattr(dest.allocator, "from_buffer")
and hasattr(dest.allocator, "transfer")
and hasattr(src.allocator, "as_buffer")
):
# fast path, used on Metal in OS X Sonoma
# NOTE: this is *only* faster if the pages from disk are already loaded into memory
fb = dest.allocator.from_buffer(src.allocator.as_buffer(src._buf))
if fb:
dest.allocator.transfer(dest._buf, fb, dest.size * dest.dtype.itemsize)
return
if hasattr(dest.allocator, "as_buffer"):
# fast(ish) path, uses readinto in diskbuffers
src.allocator.copyout(dest.allocator.as_buffer(dest._buf), src._buf)
elif hasattr(src.allocator, "as_buffer"):
dest.allocator.copyin(dest._buf, src.allocator.as_buffer(src._buf))
else:
# slow path, allocates a CPU buffer
dest.copyin(src.toCPU().data)
BufferCopy = _BufferCopy()
# TODO: size, dest, src are the same type. can we enforce this?
[docs]
class Allocator:
"""
This is a class for an Allocator. It has methods to allocate, free and copy memory.
Attributes:
None
"""
[docs]
def alloc(self, size: int):
"""
The alloc method checks if the given size is positive. If it's not, an assertion error will be raised.
Then, it calls the private _alloc method to perform the actual allocation.
Args:
size (int): The size of memory to allocate in bytes. It must be a positive integer.
Returns:
The result from the private _alloc method.
Raises:
AssertionError: If the given size is not positive.
"""
assert size > 0, f"alloc size must be positve, getting {size}"
return self._alloc(size)
def _alloc(self, size: int):
"""
The _alloc method is a placeholder that raises a NotImplementedError. This should be overridden by a child class to provide the actual allocation functionality.
Args:
size (int): The size of memory to allocate in bytes. It must be a positive integer.
Returns:
None
Raises:
NotImplementedError: This method is not implemented and should be overridden by a child class.
"""
raise NotImplementedError("need alloc")
[docs]
def free(self, opaque, size: int):
"""
The free method calls the private _free method to perform the actual freeing of memory. In some cases, if you are returning a Python object, you don't need a free method and it can be a no-op.
Args:
opaque: Some data used for identifying the memory block to be freed. Its type or content depends on the implementation.
size (int): The size of memory to be freed in bytes. It must be a positive integer.
Returns:
None
"""
self._free(
opaque
) # if you are returning a Python object, you don't need a free
def _free(self, opaque):
"""
The _free method is a placeholder that does nothing. This should be overridden by a child class to provide the actual freeing functionality.
Args:
opaque: Some data used for identifying the memory block to be freed. Its type or content depends on the implementation.
Returns:
None
"""
pass
[docs]
def copyin(self, dest, src: memoryview):
"""
The copyin method copies data from a memoryview object (src) to an allocated memory block (dest). It raises a NotImplementedError because it should be implemented by a child class.
Args:
dest: The destination allocated memory block where the data will be copied to.
src (memoryview): The source memoryview object from where the data will be copied from.
Returns:
None
Raises:
NotImplementedError: This method is not implemented and should be overridden by a child class.
"""
raise NotImplementedError("need copyin")
[docs]
def copyout(self, dest: memoryview, src):
"""
The copyout method copies data from an allocated memory block (src) to a memoryview object (dest). It raises a NotImplementedError because it should be implemented by a child class.
Args:
dest (memoryview): The destination memoryview object where the data will be copied to.
src: The source allocated memory block from where the data will be copied from.
Returns:
None
Raises:
NotImplementedError: This method is not implemented and should be overridden by a child class.
"""
raise NotImplementedError("need copyout")
[docs]
class LRUAllocator(Allocator): # pylint: disable=abstract-method
"""
This class defines an Allocator that uses the Least Recently Used (LRU)
strategy to manage memory allocation and deallocation. It is a subclass of
the Allocator parent class.
"""
def __init__(self):
"""
Initializes an instance of LRUAllocator. Here, we initialize an empty
defaultdict for self.cache to store allocated memory blocks, with keys
being the sizes of these blocks and values being lists of corresponding
memory opaque objects returned by alloc().
"""
self.cache: Dict[int, Any] = defaultdict(list)
[docs]
def alloc(self, size: int):
"""
This method is used to allocate a block of memory of the specified
'size'. If there are any available opaque objects in the cache for this
size, we return one from there. Otherwise, we try to allocate memory by
calling super().alloc(size). If that raises a MemoryError exception,
we free up some memory by calling self.free_cache() and then retry
allocating memory using super().alloc(size) again.
"""
if len(c := self.cache[size]):
return c.pop()
try:
return super().alloc(size)
except MemoryError:
self.free_cache()
return super().alloc(size)
[docs]
def free_cache(self):
"""
This method is used to free up some memory by calling the _free() method
for each opaque object in the cache, then clearing the cache.
"""
for opaques in self.cache.values():
for opaque in opaques:
self._free(opaque)
opaques.clear()
[docs]
def free(self, opaque: Any, size: int):
"""
This method is used to deallocate a block of memory specified by the
'size' and corresponding 'opaque' object. If the LRU environment variable
"LRU" is set to 1 (default), we append the opaque object to the cache for
future use; otherwise, we directly free up this memory block by calling
self._free(opaque).
"""
if getenv("LRU", 1):
self.cache[size].append(opaque)
else:
self._free(opaque)
class _MallocAllocator(LRUAllocator):
"""
The _MallocAllocator class is a subclass of the LRUAllocator class. It provides methods for allocation, copying data in and out, and presenting as a buffer.
Attributes:
None
"""
def _alloc(self, size: int):
"""
Allocates memory based on the input size.
Args:
size (int): The amount of memory to be allocated in bytes.
Returns:
ctypes.c_uint8: An array of unsigned 8-bit integers representing the allocated memory.
"""
return (ctypes.c_uint8 * size)()
def as_buffer(self, src) -> memoryview:
"""
Converts the input to a memoryview object which provides a view on the underlying buffer.
Args:
src: The source data to be converted to a memoryview.
Returns:
memoryview: A view on the underlying buffer of the source data.
"""
return flat_mv(memoryview(src))
def copyin(self, dest, src: memoryview):
"""
Copies data from a memoryview object to allocated memory.
Args:
dest: The destination where the data will be copied.
src (memoryview): The source data to be copied.
Returns:
None
"""
ctypes.memmove(dest, from_mv(src), len(src))
def copyout(self, dest: memoryview, src):
"""
Copies data from allocated memory to a memoryview object.
Args:
dest (memoryview): The destination where the data will be copied.
src: The source data to be copied.
Returns:
None
"""
ctypes.memmove(from_mv(dest), src, len(dest))
MallocAllocator = _MallocAllocator()
[docs]
class InterpretedASTRunner(JITRunner):
"""
This class is used to run interpreted Abstract Syntax Trees (ASTs). It inherits from JITRunner.
Attributes:
fxn: A callable function
op_estimate: Estimated number of floating point operations required for the operation
mem_estimate: Estimated memory requirement for the operation
"""
def __init__(self, ast: LazyOp, fxn: Callable):
"""
Initializes an instance of the InterpretedASTRunner class.
Parameters:
ast: A lazy operation (LazyOp object)
fxn: A callable function
Returns:
None
"""
super().__init__()
self.fxn = fxn
info = get_lazyop_info(ast)
self.op_estimate, self.mem_estimate = info.flops, info.mem_estimate
def __call__(
self,
rawbufs: List[Buffer],
var_vals: Dict[Variable, int],
wait=False,
jit=False,
) -> float:
"""
This method performs the actual operation by executing the function on the given buffers with the specified variable values.
Parameters:
rawbufs: A list of Buffer objects
var_vals: A dictionary containing Variable objects as keys and integers as values
wait: An optional boolean parameter indicating whether to wait for completion (default is False)
jit: An optional boolean parameter indicating whether to use Just-In-Time compilation (default is False)
Returns:
The execution time of the operation in seconds
"""
st = time.perf_counter()
rawbufs[0]._buf = self.fxn([x._buf for x in rawbufs], var_vals)
et = time.perf_counter() - st
update_stats(
f"<interpreted {rawbufs[0].size}>",
self.op_estimate,
self.mem_estimate,
var_vals,
et,
len(rawbufs),
jit,
)
return et
[docs]
class Interpreted:
"""
The main class of the interpreter. This class is responsible for handling the allocation and
execution of operations. It uses an allocator to manage memory and a dictionary of callable
functions to perform various operations.
Attributes:
allocator (Allocator): Object used to allocate and deallocate memory.
fxn_for_op (Dict[Op, Callable]): A dictionary mapping operation types to their corresponding callable function.
synchronize (function): A placeholder function that does nothing. It is intended to be replaced with a proper synchronization mechanism in the future.
codegen (None): Placeholder for code generation functionality. Currently set to None.
graph (None): Placeholder for computation graph functionality. Currently set to None.
"""
def __init__(self, allocator: Allocator, fxn_for_op: Dict[Op, Callable]):
"""
Initializes an instance of the Interpreted class.
Args:
allocator (Allocator): Object used to allocate and deallocate memory.
fxn_for_op (Dict[Op, Callable]): A dictionary mapping operation types to their corresponding callable function.
"""
self.allocator, self.fxn_for_op = allocator, fxn_for_op
self.synchronize, self.codegen, self.graph = lambda: None, None, None
[docs]
@functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
def get_runner(self, ast: LazyOp) -> InterpretedASTRunner:
"""
Retrieves a runner function for the given abstract syntax tree (AST). The function is retrieved from the fxn_for_op dictionary.
Args:
ast (LazyOp): The abstract syntax tree to be executed.
Returns:
InterpretedASTRunner: A callable function that can execute the given AST.
"""
return _get_interpreted_fxn(self.fxn_for_op, ast)
def _get_interpreted_fxn(
fxn_for_op: Dict[Op, Callable], ast: LazyOp
) -> InterpretedASTRunner:
if DEBUG >= 3:
from tinygrad.graph import print_tree
print_tree(ast)
tglob: Dict[str, Any] = {"Variable": Variable}
@functools.lru_cache(None)
def gstr(x: Any, nm=None) -> str:
if "Variable" in (str_arg := repr(x)) or "NumNode" in str_arg:
str_arg = re.sub(
r"Variable\(.*?\)", lambda m: f"var_vals[{str(m.group(0))}]", str_arg
)
# TODO: (Variable - Variable) might create NumNode. can we remove it?
return re.sub(r"NumNode\((.*?)\)", r"\1", str_arg)
ret = str(nm).replace(".", "_") if nm else f"m{len(tglob):04d}"
tglob[ret] = x
return ret
lines: List[str] = []
@functools.lru_cache(None)
def _interpret_ast(ast: LazyOp) -> str:
# TODO: shortcutted store won't work with strides
if ast.op == BufferOps.STORE:
return _interpret_ast(ast.src[0])
if (
TernaryOps.MULACC in fxn_for_op
and ast.op == ReduceOps.SUM
and isinstance(ast.src[0], LazyOp)
and ast.src[0].op == BinaryOps.MUL
):
ast = LazyOp(TernaryOps.MULACC, ast.src[0].src, ast.arg)
if ast.op in BufferOps:
if ast.op == ast.op == BufferOps.CONST:
tmp = f"{gstr(fxn_for_op[ast.op], ast.op)}({gstr(ast.arg.val)}, {gstr(ast.arg.dtype)})"
else:
tmp = f"{gstr(fxn_for_op[UnaryOps.CAST], UnaryOps.CAST)}(inputs[{ast.arg.idx}], ({gstr(ast.arg.dtype)}, True))"
for mop, arg in ast.arg.st.to_movement_ops():
tmp = f"{gstr(fxn_for_op[mop], mop)}({tmp}, {gstr(arg)})"
else:
tmp = f"{gstr(fxn_for_op[ast.op], ast.op)}({', '.join([_interpret_ast(src) for src in ast.src] + ([gstr(ast.arg)] if ast.arg else []))})"
ret = f"a{len(lines)}"
lines.append(f" {ret} = {tmp}")
return ret
ret = _interpret_ast(ast)
src = "\n".join(["def run(inputs, var_vals):"] + lines + [f" return {ret}"])
if DEBUG >= 4:
print(
functools.reduce(
lambda x, y: (x.replace(y[0], str(y[1])) if y[0][0:2] == "m0" else x),
tglob.items(),
src,
)
)
exec(compile(src, "<ast>", "exec"), tglob) # pylint: disable=exec-used
return InterpretedASTRunner(ast, tglob["run"])
[docs]
class CompiledASTRunner(JITRunner):
"""
A class for running compiled code generated from an Abstract Syntax Tree (AST). Inherits from JITRunner.
"""
def __init__(
self,
ast: Optional[LazyOp],
name: str,
prg: str,
global_size: Optional[List[int]] = None,
local_size: Optional[List[int]] = None,
runtime_args: Optional[dict] = None,
):
"""
Initializes an instance of CompiledASTRunner.
:param ast: The Abstract Syntax Tree (AST) to be compiled and run.
:type ast: Optional[LazyOp]
:param name: Name of the function or kernel.
:type name: str
:param prg: Source code for the function or kernel.
:type prg: str
:param global_size: Global size for the kernel execution, defaults to None
:type global_size: Optional[List[int]], optional
:param local_size: Local size for the kernel execution, defaults to None
:type local_size: Optional[List[int]], optional
:param runtime_args: Additional arguments for the runtime, defaults to None
:type runtime_args: Optional[dict], optional
"""
super().__init__()
if DEBUG >= 4:
print(prg)
if global_size is not None:
global_size = global_size + [1] * (3 - len(global_size))
if local_size is not None:
local_size = local_size + [1] * (3 - len(local_size))
(
self.name,
self.display_name,
self.prg,
self.global_size,
self.local_size,
self.runtime_args,
) = (
to_function_name(name),
name,
prg,
global_size,
local_size,
runtime_args if runtime_args is not None else {},
)
self.vars: List[Variable] = []
if ast:
info = get_lazyop_info(ast)
self.op_estimate, self.mem_estimate = info.flops, info.mem_estimate
from tinygrad.lazy import vars_from_ast
self.vars = vars_from_ast(ast)
assert all(
v._val is None for v in self.vars
), f"ASTRunner contains bound Variable {self.vars}"
[docs]
def build(self, compiler, runtime):
"""
Builds the kernel from source code.
:param compiler: The compiler used to compile the kernel.
:type compiler: Function
:param runtime: The runtime used to execute the kernel.
:type runtime: Function
:return: Returns an instance of CompiledASTRunner.
:rtype: CompiledASTRunner
"""
self.lib = (
compiler.__wrapped__(self.prg)
if getenv("DISABLE_COMPILER_CACHE")
else compiler(self.prg)
)
self.clprg = runtime(self.name, self.lib)
return self
[docs]
def launch_dims(self, var_vals):
"""
Computes the launch dimensions for the kernel execution.
:param var_vals: Values of the variables used in the kernel.
:type var_vals: Dict[Variable, int]
:return: Returns the global and local sizes for the kernel execution.
:rtype: Tuple[List[int], List[int]]
"""
global_size = (
[sym_infer(sz, var_vals) for sz in self.global_size]
if self.global_size is not None
else self.global_size
)
local_size = (
[sym_infer(sz, var_vals) for sz in self.local_size]
if self.local_size is not None
else self.local_size
)
return global_size, local_size
def __call__(
self,
rawbufs: List[Buffer],
var_vals: Dict[Variable, int],
wait=False,
jit=False,
) -> Optional[float]:
"""The call method for the class.
This method is responsible for executing the OpenCL program and retrieving its execution time.
Args:
rawbufs (List[Buffer]): A list of Buffer objects to be passed as arguments to the OpenCL program.
var_vals (Dict[Variable, int]): A dictionary mapping variables to their integer values.
wait (bool, optional): Whether to wait for the execution to finish before returning. Defaults to False.
jit (bool, optional): If True, just-in-time compilation is used. Defaults to False.
Returns:
Optional[float]: The execution time of the OpenCL program in seconds, or None if wait is set to False.
"""
global_size, local_size = self.launch_dims(var_vals)
if global_size is not None and local_size is None and all_int(self.global_size): # type: ignore[arg-type]
# TODO: this is copied from get_program
from tinygrad.features.search import optimize_local_size
local_size = self.local_size = optimize_local_size(
self.clprg, global_size, rawbufs
)
global_size = self.global_size = [
g // l if g % l == 0 else g / l for g, l in zip(global_size, local_size)
]
lra = self.runtime_args.copy()
if global_size:
lra["global_size"] = global_size
if local_size and "local_size" not in lra:
lra["local_size"] = local_size
et = self.clprg(
*[x._buf for x in rawbufs],
**lra,
vals=tuple(var_vals[k] for k in self.vars),
wait=wait or DEBUG >= 2,
)
update_stats(
self.display_name,
self.op_estimate,
self.mem_estimate,
var_vals,
et,
len(rawbufs),
jit,
lra=lra,
)
return et
[docs]
class Compiled:
"""
The Compiled class is responsible for compiling and executing the given AST (Abstract Syntax Tree).
It takes in an allocator, linearizer options, renderer, compiler, runtime, and optionally a graph.
:param Allocator allocator: Memory allocator for the compiled code.
:param LinearizerOptions linearizer_opts: Options for the linearizer.
:param renderer: Renderer object to convert the AST into an executable format.
:param compiler: Compiler used to compile the rendered code.
:param runtime: Runtime environment for executing the compiled code.
:param graph: (Optional) Graph to be compiled. Default is None.
"""
def __init__(
self,
allocator: Allocator,
linearizer_opts: LinearizerOptions,
renderer,
compiler,
runtime,
graph=None,
):
(
self.allocator,
self.linearizer_opts,
self.renderer,
self.compiler,
self.runtime,
self.graph,
) = (allocator, linearizer_opts, renderer, compiler, runtime, graph)
[docs]
def synchronize(self):
"""
This function is a placeholder for device-specific synchronization code.
It should be overridden in the derived class with specific implementation for the desired device.
"""
pass # override this in your device
[docs]
def to_program(self, k: Linearizer) -> CompiledASTRunner:
"""
Converts a linearized AST into an executable format using the renderer and then compiles it
using the compiler and runtime environment.
:param Linearizer k: The linearized AST to be converted.
:return: An instance of CompiledASTRunner with the compiled code.
"""
k.linearize()
src, runtime_args = self.renderer(to_function_name(k.name), k.uops)
return CompiledASTRunner(
k.ast, k.name, src, k.global_size, k.local_size, runtime_args
).build(self.compiler, self.runtime)
[docs]
def get_linearizer(self, ast: LazyOp) -> Linearizer:
"""
Optimizes the given AST using a series of optimization techniques and returns the optimized linearized version.
:param LazyOp ast: The abstract syntax tree to be optimized.
:return: An instance of Linearizer with the optimized code.
"""
if DEBUG >= 3:
from tinygrad.graph import print_tree
print_tree(ast)
from tinygrad.codegen.linearizer import Linearizer
k = Linearizer(ast, self.linearizer_opts)
if not NOOPT:
if not (used_tensor_cores := k.apply_tensor_cores(getenv("TC", 1))):
k.hand_coded_optimizations()
if BEAM >= 1:
lins = [(("tc" if used_tensor_cores else "hc"), k)]
if used_tensor_cores:
lins.append(("hc", Linearizer(ast, self.linearizer_opts)))
lins[-1][1].hand_coded_optimizations()
kb = Linearizer(ast, self.linearizer_opts)
from tinygrad.features.search import (
beam_search,
time_linearizer,
bufs_from_lin,
)
# TODO: this shouldn't use Device.DEFAULT, it should get the device from the LinearizerOptions
test_rawbuffers = bufs_from_lin(
kb
) # allocate scratch buffers for optimization
lins.append(
(
f"beam{BEAM.value}",
beam_search(
kb,
test_rawbuffers,
BEAM.value,
bool(getenv("BEAM_ESTIMATE", 1)),
),
)
)
timed = sorted(
[
(
nm,
tk,
time_linearizer(
tk,
test_rawbuffers,
allow_test_size=False,
clear_l2=True,
),
)
for nm, tk in lins
],
key=lambda x: x[2],
)
if DEBUG >= 1:
print(
" < ".join(
f"{nm:6s} : {lin.colored_shape(30, dense=True)} : {tm*1e6:8.2f} us"
for nm, lin, tm in timed
)
)
k = timed[0][1]
return k
[docs]
@functools.lru_cache(None) # pylint: disable=method-cache-max-size-none
def get_runner(self, ast: LazyOp) -> CompiledASTRunner:
"""
A cached version of the to_program function that takes an AST and returns a runner for it.
:param LazyOp ast: The abstract syntax tree to be executed.
:return: An instance of CompiledASTRunner with the compiled code.
"""
return self.to_program(self.get_linearizer(ast))