Source code for tinygrad.lazy

from __future__ import annotations
import sys, math
from typing import Callable, Optional, Tuple, Union, List, Dict, Any, cast, Mapping, Set
from weakref import ref, WeakSet, WeakValueDictionary

import numpy as np
from tinygrad.helpers import (
    prod,
    getenv,
    DType,
    dtypes,
    flatten,
    dedup,
    merge_dicts,
    all_int,
    ImageDType,
    DEBUG,
)
from tinygrad.ops import (
    ScheduleItem,
    UnaryOps,
    BinaryOps,
    TernaryOps,
    ReduceOps,
    MovementOps,
    LoadOps,
    OpType,
    LazyOp,
    MemBuffer,
    ConstBuffer,
    BufferOps,
    get_lazyop_info,
)
from tinygrad.shape.shapetracker import ShapeTracker, get_contraction
from tinygrad.shape.symbolic import Variable, sint
from tinygrad.device import Buffer

"""
**sys.setrecursionlimit(10000)**: Increases the recursion limit in Python to 10,000. This allows for deeper recursion in the program.

**OPT = getenv("OPT", 2)**: Gets the value of the environment variable "OPT". If it doesn't exist, sets OPT to 2.

**LAZYCACHE = getenv("LAZYCACHE", 1)**: Gets the value of the environment variable "LAZYCACHE". If it doesn't exist, sets LAZYCACHE to 1.

**(REMOVE_MOVEMENT_NOPS, MERGE_ELEMENTWISE_INTO_REDUCE, SHUFFLE_MOVEMENT_OPS, MERGE_ELEMENTWISE_OPS) = (OPT >= 1, OPT >= 1, OPT >= 1, OPT >= 1)**: Sets boolean values for various optimizations based on the value of OPT. If OPT is 1 or greater, sets these to True; otherwise, False.

**MERGE_ONE_REDUCE_INTO_ELEMENTWISE, SHUFFLE_PAD_OPS = OPT >= 2, OPT >= 2**: Sets boolean values for more specific optimizations based on the value of OPT. If OPT is 2 or greater, sets these to True; otherwise, False.

**PUSH_PERMUTES, PUSH_CONTIGUOUS = OPT >= 3, OPT >= 3**: Sets boolean values for further optimizations based on the value of OPT. If OPT is 3 or greater, sets these to True; otherwise, False.

**PUSH_RESHAPES = OPT >= 4**: Sets a boolean value for an additional optimization based on the value of OPT. If OPT is 4 or greater, sets this to True; otherwise, False.
"""
sys.setrecursionlimit(10000)

OPT = getenv("OPT", 2)
LAZYCACHE = getenv("LAZYCACHE", 1)

# TODO: movement ops that only change shape are really nops. treat them as such
(
    REMOVE_MOVEMENT_NOPS,
    MERGE_ELEMENTWISE_INTO_REDUCE,
    SHUFFLE_MOVEMENT_OPS,
    MERGE_ELEMENTWISE_OPS,
) = (OPT >= 1, OPT >= 1, OPT >= 1, OPT >= 1)
MERGE_ONE_REDUCE_INTO_ELEMENTWISE, SHUFFLE_PAD_OPS = OPT >= 2, OPT >= 2
PUSH_PERMUTES, PUSH_CONTIGUOUS = OPT >= 3, OPT >= 3
PUSH_RESHAPES = OPT >= 4

# **** ast fixing functions ****


def _ast_reduceops(op: LazyOp) -> LazyOp:
    """
    Reduce operations in Abstract Syntax Trees (AST).

    This function is designed to optimize the AST by reducing operations. It can also realize a binary op after the reduce, not just before. The function takes a single argument `op` of type `LazyOp`.

    Attributes:
        src (LazyOp): The source from which to retrieve the operation.

    Returns:
        LazyOp: The optimized `LazyOp` object after performing the reduction operation.
    """
    src = op.src[0]
    if not src.realized:
        assert isinstance(
            src.op, LazyOp
        ), "if not src.realized, then src.op must be a LazyOp"
        if (
            MERGE_ELEMENTWISE_INTO_REDUCE
            and src.optype is BinaryOps
            and len(src.children) <= 1
        ):
            src = src.op
    return LazyOp(op.op, (src,), op.arg)


# this supports late merging an upstream Reduce op and even an Elementwise op above that
def _ast_binaryops(op: LazyOp, shape: Tuple[sint, ...]) -> LazyOp:
    """
    This function supports late merging an upstream Reduce op and even an Elementwise op above that.

    Attributes:
        op (LazyOp): The input lazy operation.
        shape (Tuple[sint, ...]): The output shape of the operation.

    Returns:
        LazyOp: The transformed lazy operation with the late merging applied.
    """
    real_srcs: Dict[LazyBuffer, Optional[Union[LazyOp, LazyBuffer]]] = {
        x: None for x in op.buffers
    }
    # NOTE: contiguous does not always mean the same size with SHRINK. this is still mergeable but requires more thought how
    # TODO: this can also support late fusion of BinaryOps, required for test_fold_conv_sgd
    psrcs = [
        (buf, root)
        for buf in op.buffers
        if len(buf.children) <= 1
        and (root := get_movementroot_contiguous(buf)).optype == ReduceOps
        and not root.realized
        and prod(root.shape) == prod(buf.shape)
        and len(root.children) <= 1
    ]
    intermediate_shape = shape
    if MERGE_ONE_REDUCE_INTO_ELEMENTWISE and psrcs:
        # NOTE: right now we can't handle multiple, as we'd have to check for loop
        buf, root = psrcs[0]
        top = _ast_reduceops(root.op)
        real_srcs[buf] = top
        real_srcs.update(
            {x: x for x in top.buffers}
        )  # the reduce op buffers are not modified

        # if the ReduceOp is followed by a reshape, we push this reshape before all the ElementwiseOp inputs
        if buf.shape != root.shape:
            intermediate_shape = root.shape
            assert buf.shape == shape, f"shape mismatch {buf.shape} != {shape}"

    # reshape all the late ops into the output shape
    # NOTE: these RESHAPEs will return self if they don't change the shape
    for buf, src in real_srcs.items():
        if src is None:
            real_srcs[buf] = buf.reshape(intermediate_shape)
    # NOTE: cast the type to remove the Optional
    ast = op.map_buffers(cast(Dict[LazyBuffer, Union[LazyOp, LazyBuffer]], real_srcs))
    return (
        LazyOp(MovementOps.RESHAPE, (ast,), shape)
        if intermediate_shape != shape
        else ast
    )


def _replace_bufferops(op: LazyOp) -> Tuple[LazyOp, List[LazyBuffer]]:
    """
    Replace buffer operations in a lazy op with new ones.

    This function takes a lazy operation (op) as input and replaces its buffer operations
    with new ones based on certain conditions. It returns the updated lazy operation and a list of base buffers.

    Args:
        op (LazyOp): The input lazy operation.

    Returns:
        Tuple[LazyOp, List[LazyBuffer]]: A tuple containing the updated lazy operation and
                                          a list of base buffers.

    Raises:
        NotImplementedError: If a certain buffer is not handled by the function.
    """
    replacements: Dict[LazyBuffer, LazyOp] = {}
    base_bufs = dedup([x.base for x in op.buffers if not x.is_unrealized_const()])
    for x in op.buffers:
        st = x.st.simplify().unbind()
        if x.base in base_bufs:
            replacements[x] = LazyOp(
                BufferOps.LOAD, (), MemBuffer(base_bufs.index(x.base) + 1, x.dtype, st)
            )
        elif not x.realized and x.base.op.op == LoadOps.CONST:
            replacements[x] = LazyOp(
                BufferOps.CONST, (), ConstBuffer(float(x.base.op.arg), x.dtype, st)
            )
        else:
            raise NotImplementedError(f"not handled {x}")
    return (
        op.src[0] if op.op in {MovementOps.RESHAPE, LoadOps.CONTIGUOUS} else op
    ).map_buffers(replacements), base_bufs


# **** lazy operations ****


[docs] def get_movementroot(root: LazyBuffer, allow_contiguous=False) -> LazyBuffer: """ Recursively retrieve the root of a movement operation or contiguous data. This function is used to locate the origin of a series of operations by traversing back through the chain of operations until it reaches the original source buffer. It will continue this process as long as the current root is not realized and its operation type matches specific criteria. :param root: The current root node in the operation tree. :type root: LazyBuffer :param allow_contiguous: A flag indicating whether to include operations that result in contiguous data, defaults to False. :type allow_contiguous: bool, optional :return: The original root node of the operation tree or the current root if it is realized or does not meet the specified criteria. :rtype: LazyBuffer """ return ( get_movementroot(cast(LazyBuffer, root.op.src[0]), allow_contiguous) if not root.realized and ( root.optype == MovementOps or ( root.op.op == LoadOps.CONTIGUOUS and allow_contiguous and root.op.src[0].st.contiguous ) ) else root )
[docs] def get_movementroot_contiguous(x: LazyBuffer) -> LazyBuffer: """ Recursively obtain the root of movement for a contiguous operation. This function is used to identify and return the root of a series of operations that lead up to a contiguous operation on a lazy buffer. It does this by checking if the current operation is a contiguous operation, and if it is not, recursively calling itself on the source operation until it finds the root of the contiguous operation. :param x: The input lazy buffer to check for the root of its movement. :type x: LazyBuffer :return: The root of the movement for the contiguous operation. :rtype: LazyBuffer """ return ( get_movementroot_contiguous(cast(LazyBuffer, x.op.src[0])) if not x.realized and x.op.op == LoadOps.CONTIGUOUS else ( get_movementroot(x, True) if x.optype == MovementOps and x.st.contiguous else x ) )
# NOTE: this is the canonical order
[docs] def vars_from_ast(ast: LazyOp) -> List[Variable]: """ Retrieve variables from abstract syntax tree (AST). This function extracts all unique variables from the AST by leveraging the `LazyOp` operations that belong to `BufferOps`. The resulting set of variables is then sorted based on their expression string representation. Attributes: ast (LazyOp): Abstract syntax tree object to extract variables from. Returns: List[Variable]: Sorted list of unique variables extracted from AST. """ return sorted( set.union( *[x.arg.st.vars() for x in ast.get_lazyops() if x.op in BufferOps], set() ), key=lambda x: str(x.expr), )
lazycache: WeakValueDictionary = WeakValueDictionary()
[docs] def create_lazybuffer( device: str, st: ShapeTracker, optype: OpType, op: LazyOp, dtype: DType, base: Optional[LazyBuffer] = None, ): """ Create a lazy buffer for the given device, shape tracker, operation type, operation, data type, and optional base. Args: device (str): The device to create the lazy buffer on. st (ShapeTracker): The shape tracker to use. optype (OpType): The operation type. op (LazyOp): The lazy operation. dtype (DType): The data type of the lazy buffer. base (Optional[LazyBuffer]): The optional base for the lazy buffer. Default is None. Returns: LazyBuffer: The created lazy buffer. """ # rewrite 0 size into a CONST if 0 in st.shape: return LazyBuffer( device, ShapeTracker.from_shape(st.shape), LoadOps, LazyOp(LoadOps.CONST, tuple(), 0.0), dtype, ) # fromcpu aren't cached if not LAZYCACHE or ( optype is LoadOps and op.op in {LoadOps.EMPTY, LoadOps.CUSTOM, LoadOps.CONST} ): return LazyBuffer(device, st, optype, op, dtype, base=base) # wop is the deduping key. i feel this used to compare more deeply wop = (device, dtype, optype, ref(op), ref(base) if base else None) if wop in lazycache: for x in op.buffers: x.children.add(lazycache[wop]) return lazycache[wop] lazycache[wop] = ret = LazyBuffer(device, st, optype, op, dtype, base=base) return ret
[docs] class LazyBuffer: """ LazyBuffer class for lazy operations on buffers. Attributes: __deletable__ (tuple): Tuple containing the attribute 'op' which can be deleted. device (str): Device where this buffer is located. st (ShapeTracker): Shape tracker object. optype (OpType): Operation type. op (Optional[LazyOp]): Lazy operation object. dtype (DType): Data type of the buffer. src (Optional[Buffer]): Source buffer, if any. Default is None. base (Optional[LazyBuffer]): Base lazy buffer, if any. Default is None. output_buffer (Optional[Buffer]): Output buffer. Default is None. children (WeakSet[LazyBuffer]): Weak set of child lazy buffers. views (WeakSet[LazyBuffer]): Weak set of view lazy buffers. """ __deletable__ = ("op",) def __init__( self, device: str, st: ShapeTracker, optype: OpType, op: Optional[LazyOp], dtype: DType, src: Optional[Buffer] = None, base: Optional[LazyBuffer] = None, ): """ Initializes a new instance of the LazyBuffer class. Args: device (str): Device where this buffer is located. st (ShapeTracker): Shape tracker object. optype (OpType): Operation type. op (Optional[LazyOp]): Lazy operation object. dtype (DType): Data type of the buffer. src (Optional[Buffer]): Source buffer, if any. Default is None. base (Optional[LazyBuffer]): Base lazy buffer, if any. Default is None. """ self.device, self.st, self.shape, self.optype, self._dtype, self._realized = ( device, st, st.shape, optype, dtype, src, ) self.output_buffer: Optional[ Buffer ] = None # TODO: do we really need this? or can we just use realized # TODO: does children have to be a ref count instead of a set? can a Buffer be a double child? self.children: WeakSet[LazyBuffer] = WeakSet() self.views: WeakSet[LazyBuffer] = WeakSet() # NOTE: op should be read only after construction of LazyBuffer. it is now with schedule if op is not None: self.op = op for x in op.buffers: x.children.add(self) assert optype != MovementOps or ( base is not None and base.optype != MovementOps ), "MovementOps must be based" self._base = base if base: base.views.add(self) else: assert st.contiguous, "unbased LazyBuffers must be contiguous" @property def base(self): """ Return the base of this LazyBuffer. Returns: The base of this LazyBuffer if it is not None, else self. """ return self._base if self._base is not None else self
[docs] def is_unrealized_const(self): """ Check whether this LazyBuffer is an unrealized constant. Returns: True if the buffer is unrealized and its base operation is a CONST LoadOp, False otherwise. """ return not self.realized and self.base.op.op == LoadOps.CONST
[docs] def is_unrealized_contiguous_const(self): """ Check whether this LazyBuffer is an unrealized contiguous constant. Returns: True if the buffer is both unrealized and contiguous, False otherwise. """ return self.is_unrealized_const() and self.st.contiguous
@property def realized(self): """ Return whether this LazyBuffer is realized. Returns: True if the buffer is realized, False otherwise. """ return self.base._realized @realized.setter def realized(self, val: Buffer): """ Set the realization of this LazyBuffer. Args: val (Buffer): The buffer to set as the realization of this LazyBuffer. Raises: AssertionError: If _base is not None when trying to set the realized value. """ assert self._base is None, "no setting realized of based LazyBuffers" self._realized = val @property def dtype(self): """ Get the data type of this LazyBuffer. Returns: The data type of this LazyBuffer. """ return self.base._dtype @dtype.setter def dtype(self, val: DType): """ Set the data type for this LazyBuffer. Args: val (DType): The data type to set for this LazyBuffer. Raises: AssertionError: If attempting to set the dtype of a based LazyBuffer. """ assert self._base is None, "no setting dtype of based LazyBuffers" self._dtype = val def __repr__(self): """ Get a string representation of this LazyBuffer. Returns: A string containing the shape, data type, operation, and storage type of this LazyBuffer. """ return f"<LB {self.shape} {self.dtype} op={self.op.op if hasattr(self, 'op') else self._realized} st={self.st}>" def _device_extra_args(self) -> Dict[str, str]: """ Get extra arguments for the device based on its representation. Returns: A dictionary containing any extra arguments necessary for the device. """ return {"device": self.device.split(":", 1)[1]} if ":" in self.device else {} @property def buffers(self) -> Tuple[LazyBuffer, ...]: """ Return a tuple containing the instance of `LazyBuffer`. Returns: Tuple[LazyBuffer]: A tuple containing the instance of `LazyBuffer`. """ return (self,)
[docs] def map_buffers(self, real_srcs: Mapping[Any, Union[LazyBuffer, LazyOp]]): """ Retrieve the corresponding `LazyBuffer` or `LazyOp` object from the mapping of sources. Args: real_srcs (Mapping[Any, Union[LazyBuffer, LazyOp]]): A mapping of objects, where keys are any hashable objects and values are either `LazyBuffer` or `LazyOp`. Returns: Union[LazyBuffer, LazyOp]: The corresponding `LazyBuffer` or `LazyOp` object. If the instance is not found in the mapping, it returns itself. """ return real_srcs.get(self, self)
[docs] def get_lazyops(self) -> List[LazyOp]: """ Return an empty list of `LazyOp` objects. This method is a placeholder and always returns an empty list. Subclasses may override this method to provide specific behavior. Returns: List[LazyOp]: An empty list. """ return []
# *** scheduling ***
[docs] def schedule(self, seen: Optional[Set[LazyBuffer]] = None) -> List[ScheduleItem]: """ Schedules the computation of this lazy buffer. Args: seen (Optional[Set[LazyBuffer]]): Set of already scheduled buffers. Defaults to None. Returns: List[ScheduleItem]: A list of schedule items for this buffer's computations. Attributes: seen (Optional[Set[LazyBuffer]]): Set of already scheduled buffers. Defaults to an empty set if not provided. ret (List[ScheduleItem]): List of schedule items for this buffer's computations. var_vals (Dict): Merged dictionary of variable values from this buffer and its operand buffers. op (ASTNode): The abstract syntax tree node representing the operation to be performed on this buffer. base_bufs (List[Buffer]): List of base buffers for this buffer's computation. """ if seen is None: seen = set() if self in seen or self.realized or self.is_unrealized_const(): return [] seen.add(self) if self.base is not self: return self.base.schedule(seen) op = self.op if self.optype is BinaryOps: op = _ast_binaryops(op, self.shape) elif self.optype is ReduceOps: op = _ast_reduceops(op) # schedule the past ret: List[ScheduleItem] = [] for x in op.buffers: ret += x.schedule(seen) var_vals = merge_dicts( [self.st.var_vals] + [buf.st.var_vals for buf in op.buffers] ) op, base_bufs = _replace_bufferops(op) # check if we can reuse the output buffer # if it's aliased, don't use it # TODO: this is pretty wrong actually, who knows where else this buffer is used? # TODO: what if an assign is required? this silently is wrong # NOTE: this has been moved to schedule, as this is only an issue if buffers are already realized if self.output_buffer is not None: for i, a in enumerate(base_bufs): # TODO: if this is contiguous it's fine if a.realized == self.output_buffer: if any( not x.arg.st.contiguous for x in op.get_lazyops() if x.op == BufferOps.LOAD and x.arg.idx == i + 1 ): self.output_buffer = None break if op.op not in LoadOps: # add the store info = get_lazyop_info(op) assert info.dtype == self.dtype or isinstance( self.dtype, ImageDType ), f"dtype mismatch {info.dtype=} != {self.dtype=}" if isinstance(self.dtype, ImageDType) and ( prod(self.shape) != prod(self.dtype.shape) or not any(self.shape[x] % 4 == 0 for x in self.st.unit_stride_axes()) ): if DEBUG >= 3: print(f"forcing image {self.dtype} to float32") self.dtype = ( dtypes.float32 ) # NOTE; this is what makes the dtype above not match op = LazyOp(UnaryOps.CAST, (op,), (dtypes.float32, False)) # TODO: why doesn't this match? # assert info.shape == self.shape, f"shape mismatch {info.shape=} != {self.shape=}" op = LazyOp( BufferOps.STORE, (op,), MemBuffer(0, self.dtype, ShapeTracker.from_shape(info.shape)), ) else: # check loadop validity of bufferops for i, s in enumerate(op.src): assert ( isinstance(s, LazyOp) and s.op == BufferOps.LOAD and s.arg.idx == i + 1 and s.arg.st.contiguous ), f"bad LoadOps src {i}: {s}" return ret + [ ScheduleItem( op, self, tuple(base_bufs), {k: var_vals[k] for k in vars_from_ast(op)} ) ]
# *** creation/special ops ***
[docs] @staticmethod def loadop( op, shape: Tuple[sint, ...], dtype: DType, device: str, arg=None, src: Optional[LazyBuffer] = None, ) -> LazyBuffer: """ Load operation factory method. Creates and returns a new `LazyBuffer` object based on the given parameters. This is a static method and does not require an instance of the class. Attributes: op: The operation to be performed. shape (Tuple[sint, ...]): The shape of the data to be loaded. dtype: The data type of the data to be loaded. device (str): The device where the data will be loaded onto. arg: An optional argument for the operation. Default is None. src (Optional[LazyBuffer]): An optional source `LazyBuffer` object. Default is None. Returns: LazyBuffer: A new `LazyBuffer` object created with the given parameters. """ return create_lazybuffer( device, ShapeTracker.from_shape(shape), LoadOps, LazyOp(op, tuple() if src is None else (src,), arg), dtype, )
# create a constant with the shape and dtype of self
[docs] def const(self, val: Union[float, int]) -> LazyBuffer: """ Creates a new constant `LazyBuffer` object based on the shape and data type of this instance. Returns: LazyBuffer: A new constant `LazyBuffer` object with the same shape and data type as this instance. """ # NOTE: dtypes.from_np(self.dtype.np) to deal with image types return ( LazyBuffer.loadop( LoadOps.CONST, tuple(), dtypes.from_np(self.dtype.np), self.device, arg=val, ) .reshape((1,) * len(self.shape)) .expand(self.shape) )
[docs] def copy_to_device(self, device: str) -> LazyBuffer: """ This method is used to copy a lazy buffer to a specified device. It will check if the buffer is already on the target device and return it directly if true. Otherwise, it will create a new lazy buffer on the target device. Attributes: self (LazyBuffer): The source lazy buffer. device (str): The target device to copy the buffer to. Returns: LazyBuffer: The copied lazy buffer on the target device. """ # back off a FROM if it's a double FROM if ( not self.realized and self.op.op == LoadOps.FROM and cast(LazyBuffer, self.op.src[0]).device == device ): return cast(LazyBuffer, self.op.src[0]) return LazyBuffer.loadop( LoadOps.FROM, self.shape, self.dtype, device, src=self.contiguous() )
[docs] def contiguous(self: LazyBuffer) -> LazyBuffer: """ This method is used to ensure a lazy buffer is stored in a contiguous manner. If the source buffer is already contiguous, it will return itself directly. Otherwise, it will create a new contiguous lazy buffer. Attributes: self (LazyBuffer): The source lazy buffer. Returns: LazyBuffer: A contiguous version of the source lazy buffer. """ if not self.realized and self.op.op in LoadOps and self.op.op != LoadOps.CONST: return self # all LoadOps are already contiguous (except CONST) if ( self.st.contiguous and self.st.size() == self.base.st.size() and not self.is_unrealized_const() ): # this will turn into nothing, it's based and a copy # TODO: based lazybuffers shouldn't take dtype or var_vals, same issue in movementops return create_lazybuffer( self.device, ShapeTracker.from_shape(tuple(self.shape)), LoadOps, LazyOp(LoadOps.CONTIGUOUS, (self,), None), self.dtype, base=self.base, ) return LazyBuffer.loadop( LoadOps.CONTIGUOUS, self.shape, self.dtype, self.device, src=self )
[docs] @staticmethod def fromCPU(x: np.ndarray) -> LazyBuffer: """ Create a new `LazyBuffer` object from a numpy array. Attributes: x (np.ndarray): The numpy array to be used for creating the `LazyBuffer`. Returns: LazyBuffer: A new `LazyBuffer` object created from the input numpy array. """ return LazyBuffer( "CPU", ShapeTracker.from_shape(x.shape), LoadOps, None, dtypes.from_np(x.dtype), Buffer("CPU", prod(x.shape), dtypes.from_np(x.dtype), x.flatten()), )
[docs] def cast(self, dtype: DType, bitcast: bool = False): """ Cast the elements of this buffer to a new data type. Attributes: dtype (DType): The desired data type for the elements of this buffer. bitcast (bool): Whether to allow a bit-level cast, which is faster but may result in unspecified behavior if used incorrectly. Defaults to `False`. Returns: self: This method modifies the `LazyBuffer` object in-place and returns it for chaining purposes. """ return self.e(UnaryOps.CAST, arg=(dtype, bitcast))
# *** elementwise ops ***
[docs] def e( self: LazyBuffer, op: Union[UnaryOps, BinaryOps, TernaryOps], *srcs: LazyBuffer, arg: Optional[Any] = None, ) -> LazyBuffer: """ This method performs an operation on the input buffers and returns a new LazyBuffer. :param self: The instance of the LazyBuffer class. :type self: LazyBuffer :param op: The operation to be performed (UnaryOps, BinaryOps, TernaryOps). :type op: Union[UnaryOps, BinaryOps, TernaryOps] :param srcs: The input buffers for the operation. :type srcs: LazyBuffer :param arg: An optional argument for certain operations, defaults to None. :type arg: Optional[Any], optional :return: A new LazyBuffer with the result of the operation. :rtype: LazyBuffer Attributes: srcs (LazyBuffer): The input buffers for the operation. Includes self. out_device (str): The output device. out_shape (Tuple[int, ...]): The output shape. out_dtype (DType): The output data type. """ # srcs includes self srcs = (self,) + srcs # if we are separated from other binary ops by movement ops, we push those movement ops above those binaryops if SHUFFLE_MOVEMENT_OPS: srcs = _push_movement_ops(srcs) # get outputs now out_device, out_shape, out_dtype = ( srcs[0].device, srcs[0].shape, max([x.dtype for x in srcs]) if op != UnaryOps.CAST else cast(Tuple[DType, bool], arg)[0], ) # push all contiguous to the end of BinaryOps if PUSH_CONTIGUOUS and any( not x.realized and x.op.op == LoadOps.CONTIGUOUS and len(x.op.src[0].children) <= 1 for x in srcs ): new_srcs: List[LazyBuffer] = [] for x in srcs: if ( not x.realized and x.op.op == LoadOps.CONTIGUOUS and len(x.op.src[0].children) <= 1 ): x.op.src[0].children.discard(x) x = cast(LazyBuffer, x.op.src[0]) new_srcs.append(x) return new_srcs[0].e(op, *new_srcs[1:], arg=arg).contiguous() if MERGE_ELEMENTWISE_OPS: # remove the buffers from any (childless) BinaryOps that feed into this _srcs = tuple( [ x.op if x.optype == BinaryOps and not x.children and not x.realized else x for x in srcs ] ) # TODO: needs general merge limiting if ( out_device != "WEBGPU" or len( dedup( [ x.base for _src in _srcs for x in _src.buffers if not x.is_unrealized_const() ] ) ) < 7 ): srcs = _srcs # type: ignore return create_lazybuffer( out_device, ShapeTracker.from_shape(out_shape), BinaryOps, LazyOp(op, srcs, arg), out_dtype, )
# *** reduce ops *** def _reduce_op( self: LazyBuffer, op: ReduceOps, new_shape: Tuple[sint, ...] ) -> LazyBuffer: """ Create a new LazyBuffer with reduced dimensions. This method is used to reduce the dimensions of the current `LazyBuffer` object by applying a reduction operation specified by `op`. The reduced shape is given by `new_shape`. If the current shape and `new_shape` are equal, this method returns the original LazyBuffer. Attributes: self (LazyBuffer): The `LazyBuffer` object on which the method is called. op (ReduceOps): The reduction operation to be applied. new_shape (Tuple[sint, ...]): The desired shape of the reduced LazyBuffer. Returns: LazyBuffer: A new `LazyBuffer` object with reduced dimensions. """ if self.shape == tuple(new_shape): return self srcs = _push_movement_ops((self,)) if SHUFFLE_MOVEMENT_OPS else (self,) unbound_new_shape = tuple( s.unbind()[0] if not isinstance(s, int) else s for s in new_shape ) return create_lazybuffer( self.device, ShapeTracker.from_shape(new_shape), ReduceOps, LazyOp(op, srcs, unbound_new_shape), self.dtype, )
[docs] def r(self: LazyBuffer, op: ReduceOps, new_shape: Tuple[sint, ...]) -> LazyBuffer: """ Alias for `_reduce_op`. This method is an alias for `_reduce_op` and provides another way to call it. It takes the same arguments as `_reduce_op` and returns a new `LazyBuffer` with reduced dimensions. Attributes: self (LazyBuffer): The `LazyBuffer` object on which the method is called. op (ReduceOps): The reduction operation to be applied. new_shape (Tuple[sint, ...]): The desired shape of the reduced LazyBuffer. Returns: LazyBuffer: A new `LazyBuffer` object with reduced dimensions. """ # TODO: can we split symbolic shape if the reduce axis is not symbolic? if ( not all_int(self.shape) or (0 in self.shape) or prod(self.shape) // prod(new_shape) < getenv("REDUCEOP_SPLIT_THRESHOLD", 32768) ): return self._reduce_op(op, new_shape) heuristic, divisor, dim_to_split = max(((divisor := math.gcd(256, old)) / (stride or math.inf), divisor, i) for i, (old, new, stride) in enumerate(zip(self.shape, new_shape, self.st.real_strides())) if old != new) # type: ignore if divisor < 16 or heuristic < 0.1: return self._reduce_op(op, new_shape) # choose largest divisor (>=16) to split on, penalize large strides def splitted_shape(dim_aft_div): return ( self.shape[:dim_to_split] + (self.shape[dim_to_split] // divisor,) + dim_aft_div + self.shape[dim_to_split + 1 :] ) return ( self.reshape(splitted_shape((divisor,))) ._reduce_op(op, splitted_shape((1,))) .reshape(splitted_shape(())) ._reduce_op(op, new_shape) )
# *** movement ops ***
[docs] def reshape(self: LazyBuffer, arg: Tuple[sint, ...]) -> LazyBuffer: """ Reshapes the buffer. Attributes: self (LazyBuffer): The current lazy buffer. arg (Tuple[sint, ...]): The new shape for the buffer. Returns: LazyBuffer: The reshaped lazy buffer. """ if self.shape == arg: return self if not self.realized and self.op.op == MovementOps.RESHAPE: assert isinstance(self.op.src[0], LazyBuffer) self.op.src[0].children.discard( self ) # NOTE: this is only required in reshape and when pushing permutes, why?? return self.op.src[0].reshape(arg) return self._movement_op(self.st.reshape(arg), MovementOps.RESHAPE, arg)
[docs] def pad(self: LazyBuffer, arg: Tuple[Tuple[int, int], ...]) -> LazyBuffer: """ Pad the buffer object. This method pads the current buffer with specified start and end indices. If all padding values are 0, it simply returns the original buffer. If the buffer is not realized and the last operation was a pad operation, the pad operation is combined with the new one. Otherwise, a new movement operation is created with the pad argument. :param arg: A tuple of tuples containing start and end padding values. :type arg: Tuple[Tuple[int, int], ...] :return: The padded buffer object. :rtype: LazyBuffer """ if all(b == 0 and e == 0 for b, e in arg): return self if not self.realized and self.op.op == MovementOps.PAD: return self.op.src[0].pad( tuple( [(b1 + b2, e1 + e2) for (b1, e1), (b2, e2) in zip(self.op.arg, arg)] ) ) return self._movement_op(self.st.pad(arg), MovementOps.PAD, arg)
[docs] def expand(self: LazyBuffer, arg: Tuple[sint, ...]) -> LazyBuffer: """ Expand the current LazyBuffer based on the given argument. This function checks if the shape of the current LazyBuffer is equal to the provided argument. If so, it returns the LazyBuffer itself. If not, and if the LazyBuffer hasn't been realized yet and its operation is an expansion operation, it returns the source of the operation with index 0, also expanded using the provided argument. Otherwise, it creates a new LazyBuffer by calling the movement_op function with the result of expanding the current state, the MovementOps.EXPAND operation, and the provided argument. :param self: The current LazyBuffer object. :type self: LazyBuffer :param arg: A tuple containing sint objects representing the new shape of the LazyBuffer. :type arg: Tuple[sint, ...] :return: A new LazyBuffer with the expanded shape. :rtype: LazyBuffer """ if self.shape == arg: return self if not self.realized and self.op.op == MovementOps.EXPAND: return self.op.src[0].expand(arg) return self._movement_op(self.st.expand(arg), MovementOps.EXPAND, arg)
[docs] def permute(self: LazyBuffer, arg: Tuple[int, ...]) -> LazyBuffer: """ Permute the current LazyBuffer based on the given argument. This function is not yet implemented. :param self: The current LazyBuffer object. :type self: LazyBuffer :param arg: A tuple containing int objects representing the permutation of the dimensions. :type arg: Tuple[int, ...] :return: A new LazyBuffer with the permuted shape. :rtype: LazyBuffer """ if arg == tuple(range(len(self.shape))): return self if not self.realized and self.op.op == MovementOps.PERMUTE: return self.op.src[0].permute(tuple([self.op.arg[i] for i in arg])) if SHUFFLE_MOVEMENT_OPS and not self.realized: if PUSH_PERMUTES and self.optype == ReduceOps: # reduceops have one buffer input, permute it narg = tuple([self.op.arg[a] for a in arg]) src, rop = self.op.src[0], self.op.op src.children.discard(self) del self # TODO: why doesn't this delete remove it from the children return src.permute(arg).r(cast(ReduceOps, rop), narg) # move permutes before expands (always, this is safe) if self.op.op == MovementOps.EXPAND: return ( self.op.src[0] .permute(arg) .expand(tuple([self.op.arg[a] for a in arg])) ) # move permutes before reshapes if we can if ( PUSH_PERMUTES and self.op.op == MovementOps.RESHAPE and isinstance(self.op.src[0], LazyBuffer) ): if shape_idx_groups := get_contraction( self.op.src[0].shape, self.shape ): self.op.src[0].children.discard( self ) # NOTE: this is only required in reshape and when pushing permutes, why?? return ( self.op.src[0] .permute(tuple(flatten(shape_idx_groups[i] for i in arg))) .reshape(self.st.permute(arg).shape) ) return self._movement_op(self.st.permute(arg), MovementOps.PERMUTE, arg)
[docs] def shrink(self: LazyBuffer, arg: Tuple[Tuple[sint, sint], ...]) -> LazyBuffer: """ Shrinks the buffer based on the provided arguments. Attributes: self (LazyBuffer): The lazy buffer to be shrunk. arg (Tuple[Tuple[sint, sint], ...]): A tuple of tuples containing the start and end indices for each dimension of the buffer. Returns: LazyBuffer: The shrunken lazy buffer. """ if all(b - a == s for s, (a, b) in zip(self.shape, arg)): return self if not self.realized and self.op.op == MovementOps.SHRINK: return self.op.src[0].shrink( tuple( [(b1 + b2, b1 + e2) for (b1, _), (b2, e2) in zip(self.op.arg, arg)] ) ) return self._movement_op(self.st.shrink(arg), MovementOps.SHRINK, arg)
[docs] def stride(self: LazyBuffer, arg: Tuple[int, ...]) -> LazyBuffer: """ Applies a stride operation to the buffer based on the provided arguments. Attributes: self (LazyBuffer): The lazy buffer to have the stride operation applied to. arg (Tuple[int, ...]): A tuple of integers representing the strides for each dimension of the buffer. Returns: LazyBuffer: The lazy buffer after the stride operation has been applied. """ if all(a == 1 for a in arg): return self if not self.realized and self.op.op == MovementOps.STRIDE: return self.op.src[0].stride( tuple(a1 * a2 for a1, a2 in zip(arg, self.op.arg)) ) return self._movement_op(self.st.stride(arg), MovementOps.STRIDE, arg)
def _movement_op( self, st: ShapeTracker, op: MovementOps, arg: Union[Tuple[sint, ...], Tuple[Tuple[sint, sint], ...]], ) -> LazyBuffer: """ Perform movement operation on the current instance. This function checks certain conditions and based on them either replaces the current instance with a new one created using movement operations or creates a new `LazyBuffer` object. Parameters: st (ShapeTracker): Shape tracker for the operation. op (MovementOps): Movement operation to be performed. arg (Union[Tuple[sint, ...], Tuple[Tuple[sint, sint], ...]]): Arguments for the movement operation. Returns: LazyBuffer: The result of the movement operation. Attributes: SHUFFLE_MOVEMENT_OPS (bool): If True, shuffle movement operations are performed. self.realized (bool): Indicates if the object is realized or not. self.optype (BinaryOps): The type of operation to be performed. self.children (list): List of children for the current instance. MovementOps.SHRINK, MovementOps.STRIDE, MovementOps.PERMUTE (MovementOps): Enumeration values for shrink, stride and permute operations. self.op.op (UnaryOps): The type of unary operation to be performed. PUSH_RESHAPES (bool): If True, reshapes are pushed. Attributes: REMOVE_MOVEMENT_NOPS (bool): If True, no-operation movement are removed. get_movementroot (function): Function to get the root of movement operations. self.st.contiguous (bool): Indicates if the shape tracker is contiguous or not. prod (function): Function to calculate the product of a tuple of integers. Attributes: self.device (str): Device on which the operation will be performed. st (ShapeTracker): Shape tracker for the operation. MovementOps (Enum): Enum class for movement operations. LazyOp (class): Class representing a lazy operation. self.dtype (data-type): Data type of the elements in the buffer. self.base (object): Base object for the current instance. """ if ( SHUFFLE_MOVEMENT_OPS and not self.realized and self.optype == BinaryOps and not self.children ): if op in {MovementOps.SHRINK, MovementOps.STRIDE, MovementOps.PERMUTE} or ( op == MovementOps.RESHAPE and (self.op.op in UnaryOps or PUSH_RESHAPES) ): return self.op.replace_with_movement_ops([(op, arg)]) if REMOVE_MOVEMENT_NOPS and not self.realized and st.contiguous: # MovementOps aren't stacked any more, they each have one parent, find the root if ( (root := get_movementroot(self)) != self and root.st.contiguous and prod(st.shape) == prod(root.shape) ): return root.reshape(st.shape) return create_lazybuffer( self.device, st, MovementOps, LazyOp(op, (self,), arg), self.dtype, base=self.base, )
[docs] def replace_with_movement_ops( self: LazyBuffer, ops: List[Tuple[MovementOps, Any]] ) -> LazyBuffer: """ This method takes a list of tuples as an argument where each tuple contains a movement operation and its corresponding argument. Attributes: self (LazyBuffer): The lazy buffer instance on which the operations are to be performed. ops (List[Tuple[MovementOps, Any]]): A list of tuples where each tuple contains a MovementOps enum member and its corresponding argument. Returns: LazyBuffer: The updated lazy buffer instance after performing all the movement operations in sequence. This method iterates over the list of tuples and for each tuple, it retrieves the corresponding function from the MOVEMENT_OPS_DISPATCHER dictionary using the MovementOps enum member as the key. The function is then called with the lazy buffer instance and its argument to perform the movement operation. The result of this operation is then stored back in the lazy buffer instance for subsequent operations. Once all the operations have been performed, the final updated lazy buffer instance is returned. """ y = self for op, arg in ops: y = MOVEMENT_OPS_DISPATCHER[op](y, arg) return y
""" Constants and configurations. Attributes: UNSAFE_PAD_OPS (set): A set of unsafe padding operations. These include division, comparison less than, base-2 logarithm, base-2 exponentiation, and reciprocal (1/x). """ UNSAFE_PAD_OPS = { BinaryOps.DIV, BinaryOps.CMPLT, UnaryOps.LOG2, UnaryOps.EXP2, UnaryOps.RECIP, } def _push_movement_ops(srcs: Tuple[LazyBuffer, ...]) -> Tuple[LazyBuffer, ...]: """ This function pushes movement operations to the sources of a lazy buffer. Attributes: srcs (Tuple[LazyBuffer, ...]): A tuple of LazyBuffer objects. Returns: Tuple[LazyBuffer, ...]: A tuple of updated LazyBuffer objects with movement operations pushed to their sources. """ new_srcs = [] for x in srcs: mops: List[Tuple[MovementOps, Any]] = [] bx = x # backwalk all the movement ops. don't push PAD or EXPAND while ( not bx.realized and bx.optype is MovementOps and bx.op.op is not MovementOps.EXPAND and (SHUFFLE_PAD_OPS or bx.op.op is not MovementOps.PAD) and len(bx.children) <= 1 ): assert isinstance(bx.op.op, MovementOps) and isinstance( bx.op.src[0], LazyBuffer ) mops.append((bx.op.op, bx.op.arg)) bx = bx.op.src[0] # NOTE: can't push pads past anything where f(0, 0) != 0 or f(0) != 0 if ( mops and not bx.realized and bx.optype is BinaryOps and len(bx.children) <= 1 and ( all(y[0] is not MovementOps.PAD for y in mops) or all(y.op not in UNSAFE_PAD_OPS for y in bx.op.get_lazyops()) ) ): x = bx.op.replace_with_movement_ops(mops[::-1]) new_srcs.append(x) return tuple(new_srcs) MOVEMENT_OPS_DISPATCHER: Dict[MovementOps, Callable] = { MovementOps.RESHAPE: LazyBuffer.reshape, MovementOps.EXPAND: LazyBuffer.expand, MovementOps.SHRINK: LazyBuffer.shrink, MovementOps.PERMUTE: LazyBuffer.permute, MovementOps.PAD: LazyBuffer.pad, MovementOps.STRIDE: LazyBuffer.stride, } """ This dictionary acts as a dispatcher for various movement operations on `LazyBuffer` objects. It maps each operation to its corresponding function in the `LazyBuffer` class. Attributes: MovementOps (Enum): An enumeration of different movement operations like reshape, expand, shrink, permute, pad, and stride. LazyBuffer (Class): The `LazyBuffer` class which contains the methods corresponding to each operation in this dispatcher. Dict[MovementOps, Callable] (Type Hinting): Type hint indicating that keys are movement operations and values are their corresponding functions in `LazyBuffer`. The dictionary keys: MovementOps.RESHAPE: Corresponds to the reshape operation, mapped to `LazyBuffer.reshape` method. MovementOps.EXPAND: Corresponds to the expand operation, mapped to `LazyBuffer.expand` method. MovementOps.SHRINK: Corresponds to the shrink operation, mapped to `LazyBuffer.shrink` method. MovementOps.PERMUTE: Corresponds to the permute operation, mapped to `LazyBuffer.permute` method. MovementOps.PAD: Corresponds to the pad operation, mapped to `LazyBuffer.pad` method. MovementOps.STRIDE: Corresponds to the stride operation, mapped to `LazyBuffer.stride` method. """