tinygrad jit

Note

You likely want the upstream tinygrad, not tinygrab. Tinygrab contains AI generated docstrings for a tinygrad snapshot. Upstream: https://tinygrad.org

exception tinygrad.jit.GraphException[source]

Bases: Exception

Custom Exception class for graph related operations.

This class is a part of the exception handling mechanism in Python which allows you to create custom exceptions by extending the built-in Exception class. The main purpose of this class is to provide an identity and functionality for Graph related exceptions.

message

Description of the error. This string is used as a part of the error message that is displayed when the exception is raised.

Type:

str

class tinygrad.jit.JitItem(prg: JITRunner, rawbufs: List[Buffer | None])[source]

Bases: object

This data class is used to represent a just-in-time (JIT) item.

prg

A JITRunner object or a graph executor like MetalGraph. This represents the program or model that will be executed just-in-time.

Type:

tinygrad.device.JITRunner

rawbufs

A list of optional Buffer objects. These represent the raw buffers that are used as inputs to the just-in-time execution process.

Type:

List[tinygrad.device.Buffer | None]

prg: JITRunner
rawbufs: List[Buffer | None]
class tinygrad.jit.PlaceHolder(buf: Buffer)[source]

Bases: object

alloc_if_needed(buffer_cache: Dict[PlaceHolder, Buffer]) Buffer[source]
to_tuple()[source]
class tinygrad.jit.ReturnType

Type variable for return type.

This is a TypeVar that can be used to specify the expected return type of a function or method. This allows for more flexibility in type checking and can help catch errors early on.

Example

def get_value() -> ReturnType:

alias of TypeVar(‘ReturnType’)

class tinygrad.jit.TinyJit(fxn: Callable[[...], ReturnType])[source]

Bases: Generic[ReturnType]

This class is used to create a just-in-time (JIT) compiler for a given function. The JIT compiler can cache the results of previous function calls, which can lead to performance improvements if the function is called multiple times with the same arguments.

fxn

The original function that will be compiled using JIT.

jit_cache

A list used for storing previously computed results of the function calls.

input_replace

A dictionary used for mapping tuples of input indices to a single index in the cache.

cnt

An integer counter, which is incremented each time the function is called.

ret

The return value of the last call to the function.

expected_vals

A tuple of variables that are expected as arguments for the function.

expected_name_sts_dtype

A tuple containing information about the expected argument names, shapes and data types.

reset()[source]

Resets the internal state of the TinyJit instance, including the cache, input replacement mapping, counter, return value and expected values/name-shape-data type information.

tinygrad.jit.get_input_replace(jit_cache: List[JitItem], input_rawbuffers: List[Buffer]) Dict[Tuple[int, int], int][source]

This function takes in two parameters - a list of JIT items and a list of raw buffers. It returns a dictionary where the keys are tuples (j, i), and the values are integers. The purpose is to map the position of certain input tensors in the JIT cache to their corresponding indices in the input_rawbuffers list.

Parameters:
  • jit_cache (List[JitItem]) – A list of JitItem objects that represent the JIT cache.

  • input_rawbuffers (List[Buffer]) – A list of Buffer objects that represent the raw input buffers.

Returns:

A dictionary where keys are tuples (j, i) representing the position of an input tensor in the JIT cache, and values are integers representing the index of that tensor in the input_rawbuffers list.

Return type:

Dict[Tuple[int, int], int]

tinygrad.jit.get_jc_idxs_with_updatable_launch_dims(jit_cache: List[JitItem]) List[int][source]

This function returns a list of indices in the jit_cache where the program’s global and/or local sizes are not all integers.

Parameters:

jit_cache (List[JitItem]) – A list of JitItems, each representing information about a just-in-time compiled code object.

Returns:

A list of indices where the program’s global and/or local sizes are not all integers.

Return type:

List[int]

tinygrad.jit.get_jc_idxs_with_updatable_var_vals(jit_cache: List[JitItem]) List[int][source]

This function returns a list of indices in the input jit_cache list which have an associated CompiledASTRunner object with vars. The purpose of this function is to identify which elements within the jit_cache list need to be updated based on their internal state.

Parameters:

jit_cache (List[JitItem]) – A list of JitItem objects, where each JitItem represents an item in the just-in-time execution cache.

Returns:

A list of integer indices that correspond to elements within the input jit_cache which have an associated CompiledASTRunner object with vars.

Return type:

List[int]

tinygrad.jit.get_jit_stats(jit_cache: List[JitItem]) Tuple[Node, Node][source]

Calculate the operation and memory estimates for a given list of JitItems.

This function takes a list of JitItems as input and returns a tuple containing two nodes. The first node represents the total operation estimate and the second node represents the total memory estimate. These values are calculated by reducing the list of JitItem’s prg.op_estimate attributes and prg.mem_estimate attributes respectively, using the add operator.

Parameters:

jit_cache (List[JitItem]) – A list of JitItems for which the operation and memory estimates will be calculated.

Returns:

A tuple containing two nodes; the first represents total operation estimate, and the second represents total memory estimate.

Return type:

Tuple[Node, Node]