from __future__ import annotations
import sys, math
from typing import Callable, Optional, Tuple, Union, List, Dict, Any, cast, Mapping, Set
from weakref import ref, WeakSet, WeakValueDictionary
import numpy as np
from tinygrad.helpers import (
prod,
getenv,
DType,
dtypes,
flatten,
dedup,
merge_dicts,
all_int,
ImageDType,
DEBUG,
)
from tinygrad.ops import (
ScheduleItem,
UnaryOps,
BinaryOps,
TernaryOps,
ReduceOps,
MovementOps,
LoadOps,
OpType,
LazyOp,
MemBuffer,
ConstBuffer,
BufferOps,
get_lazyop_info,
)
from tinygrad.shape.shapetracker import ShapeTracker, get_contraction
from tinygrad.shape.symbolic import Variable, sint
from tinygrad.device import Buffer
"""
**sys.setrecursionlimit(10000)**: Increases the recursion limit in Python to 10,000. This allows for deeper recursion in the program.
**OPT = getenv("OPT", 2)**: Gets the value of the environment variable "OPT". If it doesn't exist, sets OPT to 2.
**LAZYCACHE = getenv("LAZYCACHE", 1)**: Gets the value of the environment variable "LAZYCACHE". If it doesn't exist, sets LAZYCACHE to 1.
**(REMOVE_MOVEMENT_NOPS, MERGE_ELEMENTWISE_INTO_REDUCE, SHUFFLE_MOVEMENT_OPS, MERGE_ELEMENTWISE_OPS) = (OPT >= 1, OPT >= 1, OPT >= 1, OPT >= 1)**: Sets boolean values for various optimizations based on the value of OPT. If OPT is 1 or greater, sets these to True; otherwise, False.
**MERGE_ONE_REDUCE_INTO_ELEMENTWISE, SHUFFLE_PAD_OPS = OPT >= 2, OPT >= 2**: Sets boolean values for more specific optimizations based on the value of OPT. If OPT is 2 or greater, sets these to True; otherwise, False.
**PUSH_PERMUTES, PUSH_CONTIGUOUS = OPT >= 3, OPT >= 3**: Sets boolean values for further optimizations based on the value of OPT. If OPT is 3 or greater, sets these to True; otherwise, False.
**PUSH_RESHAPES = OPT >= 4**: Sets a boolean value for an additional optimization based on the value of OPT. If OPT is 4 or greater, sets this to True; otherwise, False.
"""
sys.setrecursionlimit(10000)
OPT = getenv("OPT", 2)
LAZYCACHE = getenv("LAZYCACHE", 1)
# TODO: movement ops that only change shape are really nops. treat them as such
(
REMOVE_MOVEMENT_NOPS,
MERGE_ELEMENTWISE_INTO_REDUCE,
SHUFFLE_MOVEMENT_OPS,
MERGE_ELEMENTWISE_OPS,
) = (OPT >= 1, OPT >= 1, OPT >= 1, OPT >= 1)
MERGE_ONE_REDUCE_INTO_ELEMENTWISE, SHUFFLE_PAD_OPS = OPT >= 2, OPT >= 2
PUSH_PERMUTES, PUSH_CONTIGUOUS = OPT >= 3, OPT >= 3
PUSH_RESHAPES = OPT >= 4
# **** ast fixing functions ****
def _ast_reduceops(op: LazyOp) -> LazyOp:
"""
Reduce operations in Abstract Syntax Trees (AST).
This function is designed to optimize the AST by reducing operations. It can also realize a binary op after the reduce, not just before. The function takes a single argument `op` of type `LazyOp`.
Attributes:
src (LazyOp): The source from which to retrieve the operation.
Returns:
LazyOp: The optimized `LazyOp` object after performing the reduction operation.
"""
src = op.src[0]
if not src.realized:
assert isinstance(
src.op, LazyOp
), "if not src.realized, then src.op must be a LazyOp"
if (
MERGE_ELEMENTWISE_INTO_REDUCE
and src.optype is BinaryOps
and len(src.children) <= 1
):
src = src.op
return LazyOp(op.op, (src,), op.arg)
# this supports late merging an upstream Reduce op and even an Elementwise op above that
def _ast_binaryops(op: LazyOp, shape: Tuple[sint, ...]) -> LazyOp:
"""
This function supports late merging an upstream Reduce op and even an Elementwise op above that.
Attributes:
op (LazyOp): The input lazy operation.
shape (Tuple[sint, ...]): The output shape of the operation.
Returns:
LazyOp: The transformed lazy operation with the late merging applied.
"""
real_srcs: Dict[LazyBuffer, Optional[Union[LazyOp, LazyBuffer]]] = {
x: None for x in op.buffers
}
# NOTE: contiguous does not always mean the same size with SHRINK. this is still mergeable but requires more thought how
# TODO: this can also support late fusion of BinaryOps, required for test_fold_conv_sgd
psrcs = [
(buf, root)
for buf in op.buffers
if len(buf.children) <= 1
and (root := get_movementroot_contiguous(buf)).optype == ReduceOps
and not root.realized
and prod(root.shape) == prod(buf.shape)
and len(root.children) <= 1
]
intermediate_shape = shape
if MERGE_ONE_REDUCE_INTO_ELEMENTWISE and psrcs:
# NOTE: right now we can't handle multiple, as we'd have to check for loop
buf, root = psrcs[0]
top = _ast_reduceops(root.op)
real_srcs[buf] = top
real_srcs.update(
{x: x for x in top.buffers}
) # the reduce op buffers are not modified
# if the ReduceOp is followed by a reshape, we push this reshape before all the ElementwiseOp inputs
if buf.shape != root.shape:
intermediate_shape = root.shape
assert buf.shape == shape, f"shape mismatch {buf.shape} != {shape}"
# reshape all the late ops into the output shape
# NOTE: these RESHAPEs will return self if they don't change the shape
for buf, src in real_srcs.items():
if src is None:
real_srcs[buf] = buf.reshape(intermediate_shape)
# NOTE: cast the type to remove the Optional
ast = op.map_buffers(cast(Dict[LazyBuffer, Union[LazyOp, LazyBuffer]], real_srcs))
return (
LazyOp(MovementOps.RESHAPE, (ast,), shape)
if intermediate_shape != shape
else ast
)
def _replace_bufferops(op: LazyOp) -> Tuple[LazyOp, List[LazyBuffer]]:
"""
Replace buffer operations in a lazy op with new ones.
This function takes a lazy operation (op) as input and replaces its buffer operations
with new ones based on certain conditions. It returns the updated lazy operation and a list of base buffers.
Args:
op (LazyOp): The input lazy operation.
Returns:
Tuple[LazyOp, List[LazyBuffer]]: A tuple containing the updated lazy operation and
a list of base buffers.
Raises:
NotImplementedError: If a certain buffer is not handled by the function.
"""
replacements: Dict[LazyBuffer, LazyOp] = {}
base_bufs = dedup([x.base for x in op.buffers if not x.is_unrealized_const()])
for x in op.buffers:
st = x.st.simplify().unbind()
if x.base in base_bufs:
replacements[x] = LazyOp(
BufferOps.LOAD, (), MemBuffer(base_bufs.index(x.base) + 1, x.dtype, st)
)
elif not x.realized and x.base.op.op == LoadOps.CONST:
replacements[x] = LazyOp(
BufferOps.CONST, (), ConstBuffer(float(x.base.op.arg), x.dtype, st)
)
else:
raise NotImplementedError(f"not handled {x}")
return (
op.src[0] if op.op in {MovementOps.RESHAPE, LoadOps.CONTIGUOUS} else op
).map_buffers(replacements), base_bufs
# **** lazy operations ****
[docs]
def get_movementroot(root: LazyBuffer, allow_contiguous=False) -> LazyBuffer:
"""
Recursively retrieve the root of a movement operation or contiguous data.
This function is used to locate the origin of a series of operations by
traversing back through the chain of operations until it reaches the original
source buffer. It will continue this process as long as the current root is
not realized and its operation type matches specific criteria.
:param root: The current root node in the operation tree.
:type root: LazyBuffer
:param allow_contiguous: A flag indicating whether to include operations
that result in contiguous data, defaults to False.
:type allow_contiguous: bool, optional
:return: The original root node of the operation tree or the current root if
it is realized or does not meet the specified criteria.
:rtype: LazyBuffer
"""
return (
get_movementroot(cast(LazyBuffer, root.op.src[0]), allow_contiguous)
if not root.realized
and (
root.optype == MovementOps
or (
root.op.op == LoadOps.CONTIGUOUS
and allow_contiguous
and root.op.src[0].st.contiguous
)
)
else root
)
[docs]
def get_movementroot_contiguous(x: LazyBuffer) -> LazyBuffer:
"""
Recursively obtain the root of movement for a contiguous operation.
This function is used to identify and return the root of a series of operations
that lead up to a contiguous operation on a lazy buffer. It does this by checking
if the current operation is a contiguous operation, and if it is not, recursively
calling itself on the source operation until it finds the root of the contiguous operation.
:param x: The input lazy buffer to check for the root of its movement.
:type x: LazyBuffer
:return: The root of the movement for the contiguous operation.
:rtype: LazyBuffer
"""
return (
get_movementroot_contiguous(cast(LazyBuffer, x.op.src[0]))
if not x.realized and x.op.op == LoadOps.CONTIGUOUS
else (
get_movementroot(x, True)
if x.optype == MovementOps and x.st.contiguous
else x
)
)
# NOTE: this is the canonical order
[docs]
def vars_from_ast(ast: LazyOp) -> List[Variable]:
"""
Retrieve variables from abstract syntax tree (AST).
This function extracts all unique variables from the AST by leveraging
the `LazyOp` operations that belong to `BufferOps`. The resulting set of
variables is then sorted based on their expression string representation.
Attributes:
ast (LazyOp): Abstract syntax tree object to extract variables from.
Returns:
List[Variable]: Sorted list of unique variables extracted from AST.
"""
return sorted(
set.union(
*[x.arg.st.vars() for x in ast.get_lazyops() if x.op in BufferOps], set()
),
key=lambda x: str(x.expr),
)
lazycache: WeakValueDictionary = WeakValueDictionary()
[docs]
def create_lazybuffer(
device: str,
st: ShapeTracker,
optype: OpType,
op: LazyOp,
dtype: DType,
base: Optional[LazyBuffer] = None,
):
"""
Create a lazy buffer for the given device, shape tracker, operation type, operation, data type, and optional base.
Args:
device (str): The device to create the lazy buffer on.
st (ShapeTracker): The shape tracker to use.
optype (OpType): The operation type.
op (LazyOp): The lazy operation.
dtype (DType): The data type of the lazy buffer.
base (Optional[LazyBuffer]): The optional base for the lazy buffer. Default is None.
Returns:
LazyBuffer: The created lazy buffer.
"""
# rewrite 0 size into a CONST
if 0 in st.shape:
return LazyBuffer(
device,
ShapeTracker.from_shape(st.shape),
LoadOps,
LazyOp(LoadOps.CONST, tuple(), 0.0),
dtype,
)
# fromcpu aren't cached
if not LAZYCACHE or (
optype is LoadOps and op.op in {LoadOps.EMPTY, LoadOps.CUSTOM, LoadOps.CONST}
):
return LazyBuffer(device, st, optype, op, dtype, base=base)
# wop is the deduping key. i feel this used to compare more deeply
wop = (device, dtype, optype, ref(op), ref(base) if base else None)
if wop in lazycache:
for x in op.buffers:
x.children.add(lazycache[wop])
return lazycache[wop]
lazycache[wop] = ret = LazyBuffer(device, st, optype, op, dtype, base=base)
return ret
[docs]
class LazyBuffer:
"""
LazyBuffer class for lazy operations on buffers.
Attributes:
__deletable__ (tuple): Tuple containing the attribute 'op' which can be deleted.
device (str): Device where this buffer is located.
st (ShapeTracker): Shape tracker object.
optype (OpType): Operation type.
op (Optional[LazyOp]): Lazy operation object.
dtype (DType): Data type of the buffer.
src (Optional[Buffer]): Source buffer, if any. Default is None.
base (Optional[LazyBuffer]): Base lazy buffer, if any. Default is None.
output_buffer (Optional[Buffer]): Output buffer. Default is None.
children (WeakSet[LazyBuffer]): Weak set of child lazy buffers.
views (WeakSet[LazyBuffer]): Weak set of view lazy buffers.
"""
__deletable__ = ("op",)
def __init__(
self,
device: str,
st: ShapeTracker,
optype: OpType,
op: Optional[LazyOp],
dtype: DType,
src: Optional[Buffer] = None,
base: Optional[LazyBuffer] = None,
):
"""
Initializes a new instance of the LazyBuffer class.
Args:
device (str): Device where this buffer is located.
st (ShapeTracker): Shape tracker object.
optype (OpType): Operation type.
op (Optional[LazyOp]): Lazy operation object.
dtype (DType): Data type of the buffer.
src (Optional[Buffer]): Source buffer, if any. Default is None.
base (Optional[LazyBuffer]): Base lazy buffer, if any. Default is None.
"""
self.device, self.st, self.shape, self.optype, self._dtype, self._realized = (
device,
st,
st.shape,
optype,
dtype,
src,
)
self.output_buffer: Optional[
Buffer
] = None # TODO: do we really need this? or can we just use realized
# TODO: does children have to be a ref count instead of a set? can a Buffer be a double child?
self.children: WeakSet[LazyBuffer] = WeakSet()
self.views: WeakSet[LazyBuffer] = WeakSet()
# NOTE: op should be read only after construction of LazyBuffer. it is now with schedule
if op is not None:
self.op = op
for x in op.buffers:
x.children.add(self)
assert optype != MovementOps or (
base is not None and base.optype != MovementOps
), "MovementOps must be based"
self._base = base
if base:
base.views.add(self)
else:
assert st.contiguous, "unbased LazyBuffers must be contiguous"
@property
def base(self):
"""
Return the base of this LazyBuffer.
Returns:
The base of this LazyBuffer if it is not None, else self.
"""
return self._base if self._base is not None else self
[docs]
def is_unrealized_const(self):
"""
Check whether this LazyBuffer is an unrealized constant.
Returns:
True if the buffer is unrealized and its base operation is a CONST LoadOp, False otherwise.
"""
return not self.realized and self.base.op.op == LoadOps.CONST
[docs]
def is_unrealized_contiguous_const(self):
"""
Check whether this LazyBuffer is an unrealized contiguous constant.
Returns:
True if the buffer is both unrealized and contiguous, False otherwise.
"""
return self.is_unrealized_const() and self.st.contiguous
@property
def realized(self):
"""
Return whether this LazyBuffer is realized.
Returns:
True if the buffer is realized, False otherwise.
"""
return self.base._realized
@realized.setter
def realized(self, val: Buffer):
"""
Set the realization of this LazyBuffer.
Args:
val (Buffer): The buffer to set as the realization of this LazyBuffer.
Raises:
AssertionError: If _base is not None when trying to set the realized value.
"""
assert self._base is None, "no setting realized of based LazyBuffers"
self._realized = val
@property
def dtype(self):
"""
Get the data type of this LazyBuffer.
Returns:
The data type of this LazyBuffer.
"""
return self.base._dtype
@dtype.setter
def dtype(self, val: DType):
"""
Set the data type for this LazyBuffer.
Args:
val (DType): The data type to set for this LazyBuffer.
Raises:
AssertionError: If attempting to set the dtype of a based LazyBuffer.
"""
assert self._base is None, "no setting dtype of based LazyBuffers"
self._dtype = val
def __repr__(self):
"""
Get a string representation of this LazyBuffer.
Returns:
A string containing the shape, data type, operation, and storage type of this LazyBuffer.
"""
return f"<LB {self.shape} {self.dtype} op={self.op.op if hasattr(self, 'op') else self._realized} st={self.st}>"
def _device_extra_args(self) -> Dict[str, str]:
"""
Get extra arguments for the device based on its representation.
Returns:
A dictionary containing any extra arguments necessary for the device.
"""
return {"device": self.device.split(":", 1)[1]} if ":" in self.device else {}
@property
def buffers(self) -> Tuple[LazyBuffer, ...]:
"""
Return a tuple containing the instance of `LazyBuffer`.
Returns:
Tuple[LazyBuffer]: A tuple containing the instance of `LazyBuffer`.
"""
return (self,)
[docs]
def map_buffers(self, real_srcs: Mapping[Any, Union[LazyBuffer, LazyOp]]):
"""
Retrieve the corresponding `LazyBuffer` or `LazyOp` object from the mapping of sources.
Args:
real_srcs (Mapping[Any, Union[LazyBuffer, LazyOp]]): A mapping of objects, where keys are any hashable objects and values are either `LazyBuffer` or `LazyOp`.
Returns:
Union[LazyBuffer, LazyOp]: The corresponding `LazyBuffer` or `LazyOp` object. If the instance is not found in the mapping, it returns itself.
"""
return real_srcs.get(self, self)
[docs]
def get_lazyops(self) -> List[LazyOp]:
"""
Return an empty list of `LazyOp` objects.
This method is a placeholder and always returns an empty list. Subclasses may override this method to provide specific behavior.
Returns:
List[LazyOp]: An empty list.
"""
return []
# *** scheduling ***
[docs]
def schedule(self, seen: Optional[Set[LazyBuffer]] = None) -> List[ScheduleItem]:
"""
Schedules the computation of this lazy buffer.
Args:
seen (Optional[Set[LazyBuffer]]): Set of already scheduled buffers. Defaults to None.
Returns:
List[ScheduleItem]: A list of schedule items for this buffer's computations.
Attributes:
seen (Optional[Set[LazyBuffer]]): Set of already scheduled buffers. Defaults to an empty set if not provided.
ret (List[ScheduleItem]): List of schedule items for this buffer's computations.
var_vals (Dict): Merged dictionary of variable values from this buffer and its operand buffers.
op (ASTNode): The abstract syntax tree node representing the operation to be performed on this buffer.
base_bufs (List[Buffer]): List of base buffers for this buffer's computation.
"""
if seen is None:
seen = set()
if self in seen or self.realized or self.is_unrealized_const():
return []
seen.add(self)
if self.base is not self:
return self.base.schedule(seen)
op = self.op
if self.optype is BinaryOps:
op = _ast_binaryops(op, self.shape)
elif self.optype is ReduceOps:
op = _ast_reduceops(op)
# schedule the past
ret: List[ScheduleItem] = []
for x in op.buffers:
ret += x.schedule(seen)
var_vals = merge_dicts(
[self.st.var_vals] + [buf.st.var_vals for buf in op.buffers]
)
op, base_bufs = _replace_bufferops(op)
# check if we can reuse the output buffer
# if it's aliased, don't use it
# TODO: this is pretty wrong actually, who knows where else this buffer is used?
# TODO: what if an assign is required? this silently is wrong
# NOTE: this has been moved to schedule, as this is only an issue if buffers are already realized
if self.output_buffer is not None:
for i, a in enumerate(base_bufs):
# TODO: if this is contiguous it's fine
if a.realized == self.output_buffer:
if any(
not x.arg.st.contiguous
for x in op.get_lazyops()
if x.op == BufferOps.LOAD and x.arg.idx == i + 1
):
self.output_buffer = None
break
if op.op not in LoadOps:
# add the store
info = get_lazyop_info(op)
assert info.dtype == self.dtype or isinstance(
self.dtype, ImageDType
), f"dtype mismatch {info.dtype=} != {self.dtype=}"
if isinstance(self.dtype, ImageDType) and (
prod(self.shape) != prod(self.dtype.shape)
or not any(self.shape[x] % 4 == 0 for x in self.st.unit_stride_axes())
):
if DEBUG >= 3:
print(f"forcing image {self.dtype} to float32")
self.dtype = (
dtypes.float32
) # NOTE; this is what makes the dtype above not match
op = LazyOp(UnaryOps.CAST, (op,), (dtypes.float32, False))
# TODO: why doesn't this match?
# assert info.shape == self.shape, f"shape mismatch {info.shape=} != {self.shape=}"
op = LazyOp(
BufferOps.STORE,
(op,),
MemBuffer(0, self.dtype, ShapeTracker.from_shape(info.shape)),
)
else:
# check loadop validity of bufferops
for i, s in enumerate(op.src):
assert (
isinstance(s, LazyOp)
and s.op == BufferOps.LOAD
and s.arg.idx == i + 1
and s.arg.st.contiguous
), f"bad LoadOps src {i}: {s}"
return ret + [
ScheduleItem(
op, self, tuple(base_bufs), {k: var_vals[k] for k in vars_from_ast(op)}
)
]
# *** creation/special ops ***
[docs]
@staticmethod
def loadop(
op,
shape: Tuple[sint, ...],
dtype: DType,
device: str,
arg=None,
src: Optional[LazyBuffer] = None,
) -> LazyBuffer:
"""
Load operation factory method.
Creates and returns a new `LazyBuffer` object based on the given parameters. This is a static method and does not require an instance of the class.
Attributes:
op: The operation to be performed.
shape (Tuple[sint, ...]): The shape of the data to be loaded.
dtype: The data type of the data to be loaded.
device (str): The device where the data will be loaded onto.
arg: An optional argument for the operation. Default is None.
src (Optional[LazyBuffer]): An optional source `LazyBuffer` object. Default is None.
Returns:
LazyBuffer: A new `LazyBuffer` object created with the given parameters.
"""
return create_lazybuffer(
device,
ShapeTracker.from_shape(shape),
LoadOps,
LazyOp(op, tuple() if src is None else (src,), arg),
dtype,
)
# create a constant with the shape and dtype of self
[docs]
def const(self, val: Union[float, int]) -> LazyBuffer:
"""
Creates a new constant `LazyBuffer` object based on the shape and data type of this instance.
Returns:
LazyBuffer: A new constant `LazyBuffer` object with the same shape and data type as this instance.
"""
# NOTE: dtypes.from_np(self.dtype.np) to deal with image types
return (
LazyBuffer.loadop(
LoadOps.CONST,
tuple(),
dtypes.from_np(self.dtype.np),
self.device,
arg=val,
)
.reshape((1,) * len(self.shape))
.expand(self.shape)
)
[docs]
def copy_to_device(self, device: str) -> LazyBuffer:
"""
This method is used to copy a lazy buffer to a specified device. It will check if the buffer is already on the target device and return it directly if true. Otherwise, it will create a new lazy buffer on the target device.
Attributes:
self (LazyBuffer): The source lazy buffer.
device (str): The target device to copy the buffer to.
Returns:
LazyBuffer: The copied lazy buffer on the target device.
"""
# back off a FROM if it's a double FROM
if (
not self.realized
and self.op.op == LoadOps.FROM
and cast(LazyBuffer, self.op.src[0]).device == device
):
return cast(LazyBuffer, self.op.src[0])
return LazyBuffer.loadop(
LoadOps.FROM, self.shape, self.dtype, device, src=self.contiguous()
)
[docs]
def contiguous(self: LazyBuffer) -> LazyBuffer:
"""
This method is used to ensure a lazy buffer is stored in a contiguous manner. If the source buffer is already contiguous, it will return itself directly. Otherwise, it will create a new contiguous lazy buffer.
Attributes:
self (LazyBuffer): The source lazy buffer.
Returns:
LazyBuffer: A contiguous version of the source lazy buffer.
"""
if not self.realized and self.op.op in LoadOps and self.op.op != LoadOps.CONST:
return self # all LoadOps are already contiguous (except CONST)
if (
self.st.contiguous
and self.st.size() == self.base.st.size()
and not self.is_unrealized_const()
):
# this will turn into nothing, it's based and a copy
# TODO: based lazybuffers shouldn't take dtype or var_vals, same issue in movementops
return create_lazybuffer(
self.device,
ShapeTracker.from_shape(tuple(self.shape)),
LoadOps,
LazyOp(LoadOps.CONTIGUOUS, (self,), None),
self.dtype,
base=self.base,
)
return LazyBuffer.loadop(
LoadOps.CONTIGUOUS, self.shape, self.dtype, self.device, src=self
)
[docs]
@staticmethod
def fromCPU(x: np.ndarray) -> LazyBuffer:
"""
Create a new `LazyBuffer` object from a numpy array.
Attributes:
x (np.ndarray): The numpy array to be used for creating the `LazyBuffer`.
Returns:
LazyBuffer: A new `LazyBuffer` object created from the input numpy array.
"""
return LazyBuffer(
"CPU",
ShapeTracker.from_shape(x.shape),
LoadOps,
None,
dtypes.from_np(x.dtype),
Buffer("CPU", prod(x.shape), dtypes.from_np(x.dtype), x.flatten()),
)
[docs]
def cast(self, dtype: DType, bitcast: bool = False):
"""
Cast the elements of this buffer to a new data type.
Attributes:
dtype (DType): The desired data type for the elements of this buffer.
bitcast (bool): Whether to allow a bit-level cast, which is faster but may result in unspecified behavior
if used incorrectly. Defaults to `False`.
Returns:
self: This method modifies the `LazyBuffer` object in-place and returns it for chaining purposes.
"""
return self.e(UnaryOps.CAST, arg=(dtype, bitcast))
# *** elementwise ops ***
[docs]
def e(
self: LazyBuffer,
op: Union[UnaryOps, BinaryOps, TernaryOps],
*srcs: LazyBuffer,
arg: Optional[Any] = None,
) -> LazyBuffer:
"""
This method performs an operation on the input buffers and returns a new LazyBuffer.
:param self: The instance of the LazyBuffer class.
:type self: LazyBuffer
:param op: The operation to be performed (UnaryOps, BinaryOps, TernaryOps).
:type op: Union[UnaryOps, BinaryOps, TernaryOps]
:param srcs: The input buffers for the operation.
:type srcs: LazyBuffer
:param arg: An optional argument for certain operations, defaults to None.
:type arg: Optional[Any], optional
:return: A new LazyBuffer with the result of the operation.
:rtype: LazyBuffer
Attributes:
srcs (LazyBuffer): The input buffers for the operation. Includes self.
out_device (str): The output device.
out_shape (Tuple[int, ...]): The output shape.
out_dtype (DType): The output data type.
"""
# srcs includes self
srcs = (self,) + srcs
# if we are separated from other binary ops by movement ops, we push those movement ops above those binaryops
if SHUFFLE_MOVEMENT_OPS:
srcs = _push_movement_ops(srcs)
# get outputs now
out_device, out_shape, out_dtype = (
srcs[0].device,
srcs[0].shape,
max([x.dtype for x in srcs])
if op != UnaryOps.CAST
else cast(Tuple[DType, bool], arg)[0],
)
# push all contiguous to the end of BinaryOps
if PUSH_CONTIGUOUS and any(
not x.realized
and x.op.op == LoadOps.CONTIGUOUS
and len(x.op.src[0].children) <= 1
for x in srcs
):
new_srcs: List[LazyBuffer] = []
for x in srcs:
if (
not x.realized
and x.op.op == LoadOps.CONTIGUOUS
and len(x.op.src[0].children) <= 1
):
x.op.src[0].children.discard(x)
x = cast(LazyBuffer, x.op.src[0])
new_srcs.append(x)
return new_srcs[0].e(op, *new_srcs[1:], arg=arg).contiguous()
if MERGE_ELEMENTWISE_OPS:
# remove the buffers from any (childless) BinaryOps that feed into this
_srcs = tuple(
[
x.op
if x.optype == BinaryOps and not x.children and not x.realized
else x
for x in srcs
]
)
# TODO: needs general merge limiting
if (
out_device != "WEBGPU"
or len(
dedup(
[
x.base
for _src in _srcs
for x in _src.buffers
if not x.is_unrealized_const()
]
)
)
< 7
):
srcs = _srcs # type: ignore
return create_lazybuffer(
out_device,
ShapeTracker.from_shape(out_shape),
BinaryOps,
LazyOp(op, srcs, arg),
out_dtype,
)
# *** reduce ops ***
def _reduce_op(
self: LazyBuffer, op: ReduceOps, new_shape: Tuple[sint, ...]
) -> LazyBuffer:
"""
Create a new LazyBuffer with reduced dimensions.
This method is used to reduce the dimensions of the current `LazyBuffer` object by applying a reduction operation specified by `op`. The reduced shape is given by `new_shape`. If the current shape and `new_shape` are equal, this method returns the original LazyBuffer.
Attributes:
self (LazyBuffer): The `LazyBuffer` object on which the method is called.
op (ReduceOps): The reduction operation to be applied.
new_shape (Tuple[sint, ...]): The desired shape of the reduced LazyBuffer.
Returns:
LazyBuffer: A new `LazyBuffer` object with reduced dimensions.
"""
if self.shape == tuple(new_shape):
return self
srcs = _push_movement_ops((self,)) if SHUFFLE_MOVEMENT_OPS else (self,)
unbound_new_shape = tuple(
s.unbind()[0] if not isinstance(s, int) else s for s in new_shape
)
return create_lazybuffer(
self.device,
ShapeTracker.from_shape(new_shape),
ReduceOps,
LazyOp(op, srcs, unbound_new_shape),
self.dtype,
)
[docs]
def r(self: LazyBuffer, op: ReduceOps, new_shape: Tuple[sint, ...]) -> LazyBuffer:
"""
Alias for `_reduce_op`.
This method is an alias for `_reduce_op` and provides another way to call it. It takes the same arguments as `_reduce_op` and returns a new `LazyBuffer` with reduced dimensions.
Attributes:
self (LazyBuffer): The `LazyBuffer` object on which the method is called.
op (ReduceOps): The reduction operation to be applied.
new_shape (Tuple[sint, ...]): The desired shape of the reduced LazyBuffer.
Returns:
LazyBuffer: A new `LazyBuffer` object with reduced dimensions.
"""
# TODO: can we split symbolic shape if the reduce axis is not symbolic?
if (
not all_int(self.shape)
or (0 in self.shape)
or prod(self.shape) // prod(new_shape)
< getenv("REDUCEOP_SPLIT_THRESHOLD", 32768)
):
return self._reduce_op(op, new_shape)
heuristic, divisor, dim_to_split = max(((divisor := math.gcd(256, old)) / (stride or math.inf), divisor, i) for i, (old, new, stride) in enumerate(zip(self.shape, new_shape, self.st.real_strides())) if old != new) # type: ignore
if divisor < 16 or heuristic < 0.1:
return self._reduce_op(op, new_shape)
# choose largest divisor (>=16) to split on, penalize large strides
def splitted_shape(dim_aft_div):
return (
self.shape[:dim_to_split]
+ (self.shape[dim_to_split] // divisor,)
+ dim_aft_div
+ self.shape[dim_to_split + 1 :]
)
return (
self.reshape(splitted_shape((divisor,)))
._reduce_op(op, splitted_shape((1,)))
.reshape(splitted_shape(()))
._reduce_op(op, new_shape)
)
# *** movement ops ***
[docs]
def reshape(self: LazyBuffer, arg: Tuple[sint, ...]) -> LazyBuffer:
"""
Reshapes the buffer.
Attributes:
self (LazyBuffer): The current lazy buffer.
arg (Tuple[sint, ...]): The new shape for the buffer.
Returns:
LazyBuffer: The reshaped lazy buffer.
"""
if self.shape == arg:
return self
if not self.realized and self.op.op == MovementOps.RESHAPE:
assert isinstance(self.op.src[0], LazyBuffer)
self.op.src[0].children.discard(
self
) # NOTE: this is only required in reshape and when pushing permutes, why??
return self.op.src[0].reshape(arg)
return self._movement_op(self.st.reshape(arg), MovementOps.RESHAPE, arg)
[docs]
def pad(self: LazyBuffer, arg: Tuple[Tuple[int, int], ...]) -> LazyBuffer:
"""
Pad the buffer object.
This method pads the current buffer with specified start and end indices. If all padding values are 0, it simply returns the original buffer. If the buffer is not realized and the last operation was a pad operation, the pad operation is combined with the new one. Otherwise, a new movement operation is created with the pad argument.
:param arg: A tuple of tuples containing start and end padding values.
:type arg: Tuple[Tuple[int, int], ...]
:return: The padded buffer object.
:rtype: LazyBuffer
"""
if all(b == 0 and e == 0 for b, e in arg):
return self
if not self.realized and self.op.op == MovementOps.PAD:
return self.op.src[0].pad(
tuple(
[(b1 + b2, e1 + e2) for (b1, e1), (b2, e2) in zip(self.op.arg, arg)]
)
)
return self._movement_op(self.st.pad(arg), MovementOps.PAD, arg)
[docs]
def expand(self: LazyBuffer, arg: Tuple[sint, ...]) -> LazyBuffer:
"""
Expand the current LazyBuffer based on the given argument.
This function checks if the shape of the current LazyBuffer is equal to the provided argument.
If so, it returns the LazyBuffer itself. If not, and if the LazyBuffer hasn't been realized yet
and its operation is an expansion operation, it returns the source of the operation with index 0,
also expanded using the provided argument. Otherwise, it creates a new LazyBuffer by calling
the movement_op function with the result of expanding the current state, the MovementOps.EXPAND
operation, and the provided argument.
:param self: The current LazyBuffer object.
:type self: LazyBuffer
:param arg: A tuple containing sint objects representing the new shape of the LazyBuffer.
:type arg: Tuple[sint, ...]
:return: A new LazyBuffer with the expanded shape.
:rtype: LazyBuffer
"""
if self.shape == arg:
return self
if not self.realized and self.op.op == MovementOps.EXPAND:
return self.op.src[0].expand(arg)
return self._movement_op(self.st.expand(arg), MovementOps.EXPAND, arg)
[docs]
def permute(self: LazyBuffer, arg: Tuple[int, ...]) -> LazyBuffer:
"""
Permute the current LazyBuffer based on the given argument.
This function is not yet implemented.
:param self: The current LazyBuffer object.
:type self: LazyBuffer
:param arg: A tuple containing int objects representing the permutation of the dimensions.
:type arg: Tuple[int, ...]
:return: A new LazyBuffer with the permuted shape.
:rtype: LazyBuffer
"""
if arg == tuple(range(len(self.shape))):
return self
if not self.realized and self.op.op == MovementOps.PERMUTE:
return self.op.src[0].permute(tuple([self.op.arg[i] for i in arg]))
if SHUFFLE_MOVEMENT_OPS and not self.realized:
if PUSH_PERMUTES and self.optype == ReduceOps:
# reduceops have one buffer input, permute it
narg = tuple([self.op.arg[a] for a in arg])
src, rop = self.op.src[0], self.op.op
src.children.discard(self)
del self # TODO: why doesn't this delete remove it from the children
return src.permute(arg).r(cast(ReduceOps, rop), narg)
# move permutes before expands (always, this is safe)
if self.op.op == MovementOps.EXPAND:
return (
self.op.src[0]
.permute(arg)
.expand(tuple([self.op.arg[a] for a in arg]))
)
# move permutes before reshapes if we can
if (
PUSH_PERMUTES
and self.op.op == MovementOps.RESHAPE
and isinstance(self.op.src[0], LazyBuffer)
):
if shape_idx_groups := get_contraction(
self.op.src[0].shape, self.shape
):
self.op.src[0].children.discard(
self
) # NOTE: this is only required in reshape and when pushing permutes, why??
return (
self.op.src[0]
.permute(tuple(flatten(shape_idx_groups[i] for i in arg)))
.reshape(self.st.permute(arg).shape)
)
return self._movement_op(self.st.permute(arg), MovementOps.PERMUTE, arg)
[docs]
def shrink(self: LazyBuffer, arg: Tuple[Tuple[sint, sint], ...]) -> LazyBuffer:
"""
Shrinks the buffer based on the provided arguments.
Attributes:
self (LazyBuffer): The lazy buffer to be shrunk.
arg (Tuple[Tuple[sint, sint], ...]): A tuple of tuples containing the start and end indices for each dimension of the buffer.
Returns:
LazyBuffer: The shrunken lazy buffer.
"""
if all(b - a == s for s, (a, b) in zip(self.shape, arg)):
return self
if not self.realized and self.op.op == MovementOps.SHRINK:
return self.op.src[0].shrink(
tuple(
[(b1 + b2, b1 + e2) for (b1, _), (b2, e2) in zip(self.op.arg, arg)]
)
)
return self._movement_op(self.st.shrink(arg), MovementOps.SHRINK, arg)
[docs]
def stride(self: LazyBuffer, arg: Tuple[int, ...]) -> LazyBuffer:
"""
Applies a stride operation to the buffer based on the provided arguments.
Attributes:
self (LazyBuffer): The lazy buffer to have the stride operation applied to.
arg (Tuple[int, ...]): A tuple of integers representing the strides for each dimension of the buffer.
Returns:
LazyBuffer: The lazy buffer after the stride operation has been applied.
"""
if all(a == 1 for a in arg):
return self
if not self.realized and self.op.op == MovementOps.STRIDE:
return self.op.src[0].stride(
tuple(a1 * a2 for a1, a2 in zip(arg, self.op.arg))
)
return self._movement_op(self.st.stride(arg), MovementOps.STRIDE, arg)
def _movement_op(
self,
st: ShapeTracker,
op: MovementOps,
arg: Union[Tuple[sint, ...], Tuple[Tuple[sint, sint], ...]],
) -> LazyBuffer:
"""
Perform movement operation on the current instance.
This function checks certain conditions and based on them either replaces the current instance with a new one
created using movement operations or creates a new `LazyBuffer` object.
Parameters:
st (ShapeTracker): Shape tracker for the operation.
op (MovementOps): Movement operation to be performed.
arg (Union[Tuple[sint, ...], Tuple[Tuple[sint, sint], ...]]): Arguments for the movement operation.
Returns:
LazyBuffer: The result of the movement operation.
Attributes:
SHUFFLE_MOVEMENT_OPS (bool): If True, shuffle movement operations are performed.
self.realized (bool): Indicates if the object is realized or not.
self.optype (BinaryOps): The type of operation to be performed.
self.children (list): List of children for the current instance.
MovementOps.SHRINK, MovementOps.STRIDE, MovementOps.PERMUTE (MovementOps): Enumeration values for shrink, stride and permute operations.
self.op.op (UnaryOps): The type of unary operation to be performed.
PUSH_RESHAPES (bool): If True, reshapes are pushed.
Attributes:
REMOVE_MOVEMENT_NOPS (bool): If True, no-operation movement are removed.
get_movementroot (function): Function to get the root of movement operations.
self.st.contiguous (bool): Indicates if the shape tracker is contiguous or not.
prod (function): Function to calculate the product of a tuple of integers.
Attributes:
self.device (str): Device on which the operation will be performed.
st (ShapeTracker): Shape tracker for the operation.
MovementOps (Enum): Enum class for movement operations.
LazyOp (class): Class representing a lazy operation.
self.dtype (data-type): Data type of the elements in the buffer.
self.base (object): Base object for the current instance.
"""
if (
SHUFFLE_MOVEMENT_OPS
and not self.realized
and self.optype == BinaryOps
and not self.children
):
if op in {MovementOps.SHRINK, MovementOps.STRIDE, MovementOps.PERMUTE} or (
op == MovementOps.RESHAPE and (self.op.op in UnaryOps or PUSH_RESHAPES)
):
return self.op.replace_with_movement_ops([(op, arg)])
if REMOVE_MOVEMENT_NOPS and not self.realized and st.contiguous:
# MovementOps aren't stacked any more, they each have one parent, find the root
if (
(root := get_movementroot(self)) != self
and root.st.contiguous
and prod(st.shape) == prod(root.shape)
):
return root.reshape(st.shape)
return create_lazybuffer(
self.device,
st,
MovementOps,
LazyOp(op, (self,), arg),
self.dtype,
base=self.base,
)
[docs]
def replace_with_movement_ops(
self: LazyBuffer, ops: List[Tuple[MovementOps, Any]]
) -> LazyBuffer:
"""
This method takes a list of tuples as an argument where each tuple contains a movement operation and its corresponding argument.
Attributes:
self (LazyBuffer): The lazy buffer instance on which the operations are to be performed.
ops (List[Tuple[MovementOps, Any]]): A list of tuples where each tuple contains a MovementOps enum member and its corresponding argument.
Returns:
LazyBuffer: The updated lazy buffer instance after performing all the movement operations in sequence.
This method iterates over the list of tuples and for each tuple, it retrieves the corresponding function from the MOVEMENT_OPS_DISPATCHER dictionary using the MovementOps enum member as the key.
The function is then called with the lazy buffer instance and its argument to perform the movement operation.
The result of this operation is then stored back in the lazy buffer instance for subsequent operations.
Once all the operations have been performed, the final updated lazy buffer instance is returned.
"""
y = self
for op, arg in ops:
y = MOVEMENT_OPS_DISPATCHER[op](y, arg)
return y
"""
Constants and configurations.
Attributes:
UNSAFE_PAD_OPS (set): A set of unsafe padding operations. These include division, comparison less than, base-2 logarithm, base-2 exponentiation, and reciprocal (1/x).
"""
UNSAFE_PAD_OPS = {
BinaryOps.DIV,
BinaryOps.CMPLT,
UnaryOps.LOG2,
UnaryOps.EXP2,
UnaryOps.RECIP,
}
def _push_movement_ops(srcs: Tuple[LazyBuffer, ...]) -> Tuple[LazyBuffer, ...]:
"""
This function pushes movement operations to the sources of a lazy buffer.
Attributes:
srcs (Tuple[LazyBuffer, ...]): A tuple of LazyBuffer objects.
Returns:
Tuple[LazyBuffer, ...]: A tuple of updated LazyBuffer objects with movement operations pushed to their sources.
"""
new_srcs = []
for x in srcs:
mops: List[Tuple[MovementOps, Any]] = []
bx = x
# backwalk all the movement ops. don't push PAD or EXPAND
while (
not bx.realized
and bx.optype is MovementOps
and bx.op.op is not MovementOps.EXPAND
and (SHUFFLE_PAD_OPS or bx.op.op is not MovementOps.PAD)
and len(bx.children) <= 1
):
assert isinstance(bx.op.op, MovementOps) and isinstance(
bx.op.src[0], LazyBuffer
)
mops.append((bx.op.op, bx.op.arg))
bx = bx.op.src[0]
# NOTE: can't push pads past anything where f(0, 0) != 0 or f(0) != 0
if (
mops
and not bx.realized
and bx.optype is BinaryOps
and len(bx.children) <= 1
and (
all(y[0] is not MovementOps.PAD for y in mops)
or all(y.op not in UNSAFE_PAD_OPS for y in bx.op.get_lazyops())
)
):
x = bx.op.replace_with_movement_ops(mops[::-1])
new_srcs.append(x)
return tuple(new_srcs)
MOVEMENT_OPS_DISPATCHER: Dict[MovementOps, Callable] = {
MovementOps.RESHAPE: LazyBuffer.reshape,
MovementOps.EXPAND: LazyBuffer.expand,
MovementOps.SHRINK: LazyBuffer.shrink,
MovementOps.PERMUTE: LazyBuffer.permute,
MovementOps.PAD: LazyBuffer.pad,
MovementOps.STRIDE: LazyBuffer.stride,
}
"""
This dictionary acts as a dispatcher for various movement operations on `LazyBuffer` objects. It maps each operation to its corresponding function in the `LazyBuffer` class.
Attributes:
MovementOps (Enum): An enumeration of different movement operations like reshape, expand, shrink, permute, pad, and stride.
LazyBuffer (Class): The `LazyBuffer` class which contains the methods corresponding to each operation in this dispatcher.
Dict[MovementOps, Callable] (Type Hinting): Type hint indicating that keys are movement operations and values are their corresponding functions in `LazyBuffer`.
The dictionary keys:
MovementOps.RESHAPE: Corresponds to the reshape operation, mapped to `LazyBuffer.reshape` method.
MovementOps.EXPAND: Corresponds to the expand operation, mapped to `LazyBuffer.expand` method.
MovementOps.SHRINK: Corresponds to the shrink operation, mapped to `LazyBuffer.shrink` method.
MovementOps.PERMUTE: Corresponds to the permute operation, mapped to `LazyBuffer.permute` method.
MovementOps.PAD: Corresponds to the pad operation, mapped to `LazyBuffer.pad` method.
MovementOps.STRIDE: Corresponds to the stride operation, mapped to `LazyBuffer.stride` method.
"""