from __future__ import annotations
from typing import Callable, List, Tuple, Dict, cast, Union, Optional, TypeVar, Generic
import functools, itertools, operator
from tinygrad.helpers import DEBUG, DType, merge_dicts, getenv, all_int
from tinygrad.device import Device, JITRunner, CompiledASTRunner, Buffer
from tinygrad.tensor import Tensor
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.symbolic import Variable, NumNode, Node
from weakref import ref, WeakKeyDictionary
from dataclasses import dataclass
[docs]
@dataclass(frozen=True)
class JitItem:
"""
This data class is used to represent a just-in-time (JIT) item.
Attributes:
prg: A JITRunner object or a graph executor like MetalGraph. This represents the program or model that will
be executed just-in-time.
rawbufs: A list of optional Buffer objects. These represent the raw buffers that are used as inputs to the
just-in-time execution process.
"""
prg: JITRunner # or a graph executor like MetalGraph
rawbufs: List[Optional[Buffer]]
[docs]
def get_jit_stats(jit_cache: List[JitItem]) -> Tuple[Node, Node]:
"""
Calculate the operation and memory estimates for a given list of JitItems.
This function takes a list of JitItems as input and returns a tuple containing two nodes.
The first node represents the total operation estimate and the second node represents the total memory estimate.
These values are calculated by reducing the list of JitItem's prg.op_estimate attributes and
prg.mem_estimate attributes respectively, using the add operator.
:param jit_cache: A list of JitItems for which the operation and memory estimates will be calculated.
:type jit_cache: List[JitItem]
:return: A tuple containing two nodes; the first represents total operation estimate,
and the second represents total memory estimate.
:rtype: Tuple[Node, Node]
"""
return functools.reduce(
operator.__add__, [ji.prg.op_estimate for ji in jit_cache], NumNode(0)
), functools.reduce(
operator.__add__, [ji.prg.mem_estimate for ji in jit_cache], NumNode(0)
)
[docs]
def get_jc_idxs_with_updatable_launch_dims(jit_cache: List[JitItem]) -> List[int]:
"""
This function returns a list of indices in the jit_cache where the program's global and/or local sizes are not all integers.
Parameters:
jit_cache (List[JitItem]): A list of JitItems, each representing information about a just-in-time compiled code object.
Returns:
List[int]: A list of indices where the program's global and/or local sizes are not all integers.
"""
return [
j
for j, ji in enumerate(jit_cache)
if isinstance(ji.prg, CompiledASTRunner)
and (
(ji.prg.global_size and not all_int(tuple(ji.prg.global_size)))
or (ji.prg.local_size and not all_int(tuple(ji.prg.local_size)))
)
]
[docs]
def get_jc_idxs_with_updatable_var_vals(jit_cache: List[JitItem]) -> List[int]:
"""
This function returns a list of indices in the input jit_cache list which have an associated CompiledASTRunner
object with vars. The purpose of this function is to identify which elements within the jit_cache list need
to be updated based on their internal state.
:param jit_cache: A list of JitItem objects, where each JitItem represents an item in the just-in-time execution
cache.
:type jit_cache: List[JitItem]
:return: A list of integer indices that correspond to elements within the input jit_cache which have an associated
CompiledASTRunner object with vars.
:rtype: List[int]
"""
return [
j
for j, ji in enumerate(jit_cache)
if isinstance(ji.prg, CompiledASTRunner) and ji.prg.vars
]
[docs]
class GraphException(Exception):
"""
Custom Exception class for graph related operations.
This class is a part of the exception handling mechanism in Python which allows you to create custom exceptions
by extending the built-in Exception class. The main purpose of this class is to provide an identity and
functionality for Graph related exceptions.
Attributes:
message (str): Description of the error. This string is used as a part of the error message that is displayed
when the exception is raised.
"""
pass
ReturnType = TypeVar("ReturnType")
"""
Type variable for return type.
This is a TypeVar that can be used to specify the expected return type of a function or method. This allows for more flexibility in
type checking and can help catch errors early on.
Example:
def get_value() -> ReturnType:
...
"""
[docs]
class TinyJit(Generic[ReturnType]):
"""
This class is used to create a just-in-time (JIT) compiler for a given function. The JIT compiler can cache the
results of previous function calls, which can lead to performance improvements if the function is called multiple
times with the same arguments.
Attributes:
fxn: The original function that will be compiled using JIT.
jit_cache: A list used for storing previously computed results of the function calls.
input_replace: A dictionary used for mapping tuples of input indices to a single index in the cache.
cnt: An integer counter, which is incremented each time the function is called.
ret: The return value of the last call to the function.
expected_vals: A tuple of variables that are expected as arguments for the function.
expected_name_sts_dtype: A tuple containing information about the expected argument names, shapes and data types.
"""
def __init__(self, fxn: Callable[..., ReturnType]):
"""
Initializes a new instance of the TinyJit class with the specified function.
Args:
fxn: The function to be compiled using JIT.
"""
self.fxn = fxn
self.reset()
[docs]
def reset(self):
"""
Resets the internal state of the TinyJit instance, including the cache, input replacement mapping, counter,
return value and expected values/name-shape-data type information.
"""
self.jit_cache: List[JitItem] = []
self.input_replace: Dict[Tuple[int, int], int] = {}
self.cnt: int = 0
self.ret: Optional[ReturnType] = None
self.expected_vals: Optional[Tuple[Variable, ...]] = None
self.expected_name_sts_dtype: Optional[
Tuple[Tuple[Union[int, str], ShapeTracker, DType], ...]
] = None
# add support for instance methods
def __get__(self, obj, objtype):
"""
This method is used to bind the TinyJit instance as a method of an object. It returns a partial function that
has 'obj' as its first argument and self.__call__ as its second argument.
Args:
obj: The object that the TinyJit instance should be bound to.
objtype: The type of the object.
Returns:
A partial function with 'obj' as its first argument and self.__call__ as its second argument.
"""
return functools.partial(self.__call__, obj)
def __call__(self, *args, **kwargs) -> ReturnType:
# all inputs are realized
input_tensors: Dict[Union[int, str], Tensor] = {
cast(Union[int, str], k): v.realize()
for k, v in itertools.chain(enumerate(args), kwargs.items())
if v.__class__ is Tensor
}
expected_name_sts_dtype = tuple(
[(k, v.lazydata.st.unbind(), v.dtype) for k, v in input_tensors.items()]
)
# get rawbuffers
input_rawbuffers: List[Buffer] = [
cast(Buffer, v.lazydata.realized) for v in input_tensors.values()
]
assert len(set(input_rawbuffers)) == len(
input_rawbuffers
), "duplicate inputs to JIT"
# get variables: they can either be in Tensors or passed in as arguments, and all must be bound. these are all global
var_vals: Dict[Variable, int] = merge_dicts(
[arg.lazydata.st.var_vals for arg in input_tensors.values()]
+ [
dict(
x.unbind()
for x in itertools.chain(args, kwargs.values())
if isinstance(x, Variable)
)
]
)
expected_vals = tuple(var_vals.keys())
if self.cnt >= 2:
# jit exec
assert self.expected_vals == expected_vals, "mismatch of var_vals"
assert (
self.expected_name_sts_dtype == expected_name_sts_dtype
), f"mismatch of sts, expected {self.expected_name_sts_dtype} got {expected_name_sts_dtype}"
for (j, i), input_idx in self.input_replace.items():
self.jit_cache[j].rawbufs[i] = input_rawbuffers[input_idx]
for ji in self.jit_cache:
ji.prg(
cast(List[Buffer], ji.rawbufs), var_vals, wait=DEBUG >= 2, jit=True
)
elif self.cnt == 1:
# jit capture
self.expected_vals, self.expected_name_sts_dtype = (
expected_vals,
expected_name_sts_dtype,
)
CacheCollector.start(var_vals)
self.ret = self.fxn(*args, **kwargs)
self.jit_cache = CacheCollector.finish()
assert len(self.jit_cache) != 0, "didn't JIT anything!"
# if your Device supports it, condense the items into a graph executor
if (make_graph := Device[Device.DEFAULT].graph) and getenv("JIT") != 2:
try:
if DEBUG >= 1:
print(
f"JIT GRAPHing {len(self.jit_cache)} kernels with {len(input_rawbuffers)} inputs"
)
self.jit_cache = [
JitItem(
make_graph(self.jit_cache, input_rawbuffers, var_vals),
cast(List[Optional[Buffer]], input_rawbuffers),
)
]
except GraphException as e:
if DEBUG >= 1:
print(f"graph create failed {e}")
else:
if DEBUG >= 1:
print(
f"JIT captured {len(self.jit_cache)} kernels with {len(input_rawbuffers)} inputs"
)
self.input_replace = get_input_replace(self.jit_cache, input_rawbuffers)
elif self.cnt == 0:
# jit ignore
self.ret = self.fxn(*args, **kwargs)
# clear jit inputs
for j, i in self.input_replace.keys():
self.jit_cache[j].rawbufs[i] = None
self.cnt += 1
return cast(ReturnType, self.ret)
[docs]
class PlaceHolder:
def __init__(self, buf: Buffer):
self.size, self.dtype, self.device, self.ref, self.bufid = (
buf.size,
buf.dtype,
buf.device,
ref(buf),
id(buf._buf),
)
[docs]
def to_tuple(self):
return (self.size, self.dtype, self.device, self.bufid)
def __hash__(self):
return hash(self.to_tuple())
def __eq__(self, x):
return isinstance(x, PlaceHolder) and self.to_tuple() == x.to_tuple()
[docs]
def alloc_if_needed(self, buffer_cache: Dict[PlaceHolder, Buffer]) -> Buffer:
ret = self.ref()
if ret:
return ret
if self not in buffer_cache:
buffer_cache[self] = Buffer(self.device, self.size, self.dtype)
return buffer_cache[self]
class _CacheCollector:
def __init__(self):
self.cache: Optional[
List[Tuple[JITRunner, List[Union[Buffer, PlaceHolder]]]]
] = None
def start(self, var_vals: Optional[Dict[Variable, int]] = None):
self.cache = []
self.placeholders: WeakKeyDictionary[Buffer, PlaceHolder] = WeakKeyDictionary()
self.var_vals = var_vals if var_vals is not None else {}
def add(self, prg, rawbufs, var_vals):
if self.cache is None:
return
for k, v in var_vals.items():
assert (
k in self.var_vals and self.var_vals[k] == v
), f"var_vals {k} mismatch {v} != {self.var_vals.get(k)}"
self.placeholders[rawbufs[0]] = PlaceHolder(
rawbufs[0]
) # NOTE: this is making an assumption that 0 is special
self.cache.append(
(
prg,
[
self.placeholders.get(x, x) if isinstance(x, Buffer) else x
for x in rawbufs
],
)
)
def finish(self) -> List[JitItem]:
if self.cache is None:
return []
buffer_cache: Dict[PlaceHolder, Buffer] = {}
saved_cache, self.cache = self.cache, None
return [
JitItem(
prg,
[
x.alloc_if_needed(buffer_cache) if isinstance(x, PlaceHolder) else x
for x in pl
],
)
for prg, pl in saved_cache
]
CacheCollector = _CacheCollector()