Source code for tinygrad.shape.shapetracker

# ShapeTracker allows movement operations to a buffer that don't require a copy to be made.
from __future__ import annotations
import functools, itertools, operator
from dataclasses import dataclass
from typing import Tuple, List, Optional, Dict, Set, cast, Union, Iterable
from tinygrad.ops import MovementOps

from tinygrad.helpers import prod, DEBUG, merge_dicts
from tinygrad.shape.symbolic import Variable, MulNode, Node, SumNode, NumNode, sint
from tinygrad.shape.view import View, _merge_dims


[docs] def expr_node_mask(view: View, idx: Node, valid: Optional[Node] = None) -> Node: """ Generate an expression for a view with mask. Attributes: view: The view to generate the expression for. idx: The index node. valid: An optional node representing validity. Returns: A node representing the generated expression. """ expr = [valid] if valid is not None else [] if view.mask is not None: acc = 1 for ns, (x, y) in reversed(list(zip(view.shape, view.mask))): if x != 0 or y != ns: base = (idx // acc) % ns expr += [base >= x, base < y] acc *= ns return Variable.ands(expr)
# generate an expression if you have a single idx variable
[docs] def expr_node(view: View, idx: Optional[Node] = None) -> Node: """ Generate an expression for a given view and index node. Attributes: view: The view to generate the expression for. idx: An optional index node. If not provided, a new one is created with the range of 0 to prod(view.shape) - 1. Returns: A node representing the generated expression. """ if idx is None: idx = Variable("idx", 0, prod(view.shape) - 1) ret: List[Node] = ( [NumNode(view.offset) if isinstance(view.offset, int) else view.offset] if view.offset else [] ) acc = 1 for d, s, _ in reversed(_merge_dims(view.shape, view.strides)): ret.append(((idx // acc) % d) * s) acc *= d return Variable.sum(ret)
# generate an expression if you have a variable or expression for each index
[docs] def expr_idxs(view: View, idxs: Tuple[Node, ...]) -> Node: """ Generate an expression for a given view and list of indices. Parameters: view (View): The input view. idxs (Tuple[Node, ...]): Indices corresponding to each dimension in the view's shape. Returns: Node: The generated expression. Raises: AssertionError: If the number of indices does not match the number of dimensions in the view. """ assert len(idxs) == len( view.shape ), f"need an idx for all dimensions {idxs} vs {view.shape}" return Variable.sum( [NumNode(view.offset) if isinstance(view.offset, int) else view.offset] + [ idx * st for idx, sh, st in zip(idxs, view.shape, view.strides) if sh != 1 and st != 0 ] )
[docs] @functools.lru_cache(maxsize=None) def merge_views(vm2: View, vm1: View) -> Optional[View]: """ Merge two views into a single view if possible. Parameters: vm2 (View): The first view to merge. vm1 (View): The second view to merge. Returns: Optional[View]: The merged view, or None if the views could not be merged. """ if vm2.mask or vm1.offset != 0: return None # this isn't supported yet if None in (strides := ShapeTracker((vm2, vm1)).real_strides()): return None return View.create(vm1.shape, cast(Tuple[sint, ...], strides), vm2.offset, vm1.mask)
[docs] @functools.lru_cache(maxsize=None) def idxs_to_idx(shape: Tuple[int, ...], idxs: Tuple[Node, ...]) -> Node: """ Convert a list of indices for each dimension into a single index. Parameters: shape (Tuple[int, ...]): The shape of the array. idxs (Tuple[Node, ...]): Indices corresponding to each dimension in the shape. Returns: Node: The generated expression for the single index. Raises: AssertionError: If the number of indices does not match the number of dimensions in the shape. """ assert len(idxs) == len(shape), "need an idx for all dimensions" acc = 1 ret = [] for tidx, d in reversed(list(zip(idxs, shape))): ret.append(tidx * acc) acc *= d return Variable.sum(ret)
[docs] @dataclass(frozen=True) class ShapeTracker: """ This class is responsible for tracking the shape of a View object. It contains methods to create instances from shapes, check if instances are contiguous and obtain various properties like shape, size and variables. Attributes: views (Tuple[View, ...]): A tuple of View objects representing the shape. """ views: Tuple[View, ...] def __post_init__(self): """ Ensure that ShapeTracker is created with a tuple of View objects. Raises: AssertionError: If the views attribute is not a tuple of View objects. """ assert isinstance(self.views, tuple) and all( isinstance(v, View) for v in self.views ), "ShapeTracker must be created with a tuple of Views"
[docs] @staticmethod def from_shape(shape: Tuple[sint, ...]): """ Create a ShapeTracker instance from a given shape. Args: shape (Tuple[sint, ...]): The shape to create the ShapeTracker for. Returns: ShapeTracker: A new ShapeTracker instance. """ return ShapeTracker((View.create(shape),))
@property def contiguous(self) -> bool: """ Check if the current ShapeTracker is contiguous, meaning that it contains exactly one View and that View is contiguous. Returns: bool: True if the ShapeTracker is contiguous, False otherwise. """ return len(self.views) == 1 and self.views[0].contiguous @property def shape(self) -> Tuple[sint, ...]: """ Get the shape of the last View in the ShapeTracker. Returns: Tuple[sint, ...]: The shape of the last View. """ return self.views[-1].shape
[docs] def size(self) -> int: """ Calculate the size of the ShapeTracker by finding the maximum index in the expression and adding one. Returns: int: The size of the ShapeTracker. """ if 0 in self.shape: return 0 ret = self.expr_idxs()[0].max while not isinstance(ret, int): ret = ( ret.max ) # TODO: this is a while loop?!? it should be more clear what max does assert isinstance(ret, int), f"ret must be integer, {ret=} isn't" return ret + 1
[docs] def vars(self) -> Set[Variable]: """ Get the union of all variables in each View within the ShapeTracker. Returns: Set[Variable]: The set of unique variables in the ShapeTracker. """ return set.union(*[v.vars() for v in self.views], set())
@property def var_vals(self) -> Dict[Variable, int]: """ Return a dictionary mapping variables to their integer values. Returns: Dict[Variable, int]: A dictionary where each key-value pair represents a variable and its corresponding integer value. """ return merge_dicts([dict([v.unbind()]) for v in self.vars()])
[docs] def unbind(self) -> ShapeTracker: """ Unbind the shape tracker object into its constituent parts. Returns: ShapeTracker: A new shape tracker object with the same properties as the original, but unbound. """ return ShapeTracker(tuple(v.unbind() for v in self.views))
[docs] def to_movement_ops(self) -> List[Tuple[MovementOps, Tuple]]: """ Convert the shape tracker object into a list of movement operations and their corresponding arguments. Returns: List[Tuple[MovementOps, Tuple]]: A list where each element is a tuple containing a movement operation and its corresponding arguments. """ to_apply: List[Tuple[MovementOps, Tuple]] = [] for v in self.views: real_shape = tuple(y - x for x, y in v.mask) if v.mask else v.shape real_offset = ( 0 if 0 in real_shape else ( v.offset + ( sum(x * st for (x, _), st in zip(v.mask, v.strides)) if v.mask else 0 ) ) ) # first, we apply the offset # then, we make it the correct shape # then, we apply permutations to_apply.append( ( MovementOps.AS_STRIDED, ( tuple( [ s if st != 0 else 1 for s, st in zip(real_shape, v.strides) ] ), v.strides, real_offset, ), ) ) # then, we apply pre expand pads if v.mask is not None: pre_expand_pads = tuple( (x, s - y) if st != 0 else (0, 0) for (x, y), s, st in zip(v.mask, v.shape, v.strides) ) post_expand_pads = tuple( (x, s - y) if st == 0 else (0, 0) for (x, y), s, st in zip(v.mask, v.shape, v.strides) ) if any(x != (0, 0) for x in pre_expand_pads): to_apply.append((MovementOps.PAD, pre_expand_pads)) real_shape = tuple( x + s[0] + s[1] for x, s in zip(real_shape, pre_expand_pads) ) # then, we do any expands # NOTE: this is a good idea even without masks, since torch doesn't support negative strides and has to make a copy if any(s != 1 and st == 0 for s, st in zip(real_shape, v.strides)): to_apply.append((MovementOps.EXPAND, real_shape)) # lastly, we apply post expand pads if v.mask is not None and any(x != (0, 0) for x in post_expand_pads): to_apply.append((MovementOps.PAD, post_expand_pads)) return to_apply
# NOTE: if a stride is not always valid, it will be None
[docs] def real_strides(self, ignore_valid=False) -> Tuple[Optional[sint], ...]: """ Calculate the strides of the shape tracker. If a stride is not always valid, it will be None. Arributes: self: The object itself ignore_valid (bool): Whether to ignore validity or not Returns: Tuple[Optional[sint], ...]: A tuple of strides for each dimension """ if len(self.views) == 1 and self.views[-1].mask is None: return self.views[-1].strides idxs: List[Node] = [ Variable(f"idx{i}", 0, s - 1) for i, s in enumerate(self.shape) ] idx, valid = self.expr_idxs(idxs) ret: List[Optional[sint]] = [None] * len(self.views[-1].shape) for this_dim in idx.nodes if isinstance(idx, SumNode) else [idx]: idx_maybe, stride_maybe = ( (this_dim.a, this_dim.b) if isinstance(this_dim, MulNode) else (this_dim, 1) ) try: ret[idxs.index(idx_maybe)] = stride_maybe except ValueError: pass idx_vars, valid_vars = idx.vars(), valid.vars() for i, tidx in enumerate(idxs): if tidx in valid_vars and not ignore_valid: ret[i] = None elif tidx not in idx_vars: ret[i] = 0 return tuple(ret)
[docs] def unit_stride_axes(self, ignore_valid=False) -> List[int]: """ Return a list of axes with unit strides. Attributes: self: The object itself ignore_valid (bool): Whether to ignore validity or not Returns: List[int]: A list of axes with unit strides """ return [i for i, st in enumerate(self.real_strides(ignore_valid)) if st == 1]
def _expr_idx(self, idx: Node, valid: Node) -> Tuple[Node, Node]: """ A helper function to calculate the expression index and validity. Attributes: self: The object itself idx (Node): The index node valid (Node): The validity node Returns: Tuple[Node, Node]: A tuple containing the expression index and validity nodes """ for v in reversed(self.views[0:-1]): if valid.max == 0: return NumNode(-1), valid valid = expr_node_mask(v, idx, valid) idx = expr_node(v, idx) return idx, valid
[docs] def simplify(self) -> ShapeTracker: """ Simplify the shape tracker. Attributes: self: The object itself Returns: ShapeTracker: The simplified shape tracker """ if len(self.views) >= 2: if (new_view := merge_views(self.views[-2], self.views[-1])) is not None: if DEBUG >= 4: print( f"st simplify : {self.views[-2]} + {self.views[-1]} = {new_view}" ) return ShapeTracker(self.views[:-2] + (new_view,)).simplify() return self
[docs] def expr_idxs(self, idxs: Optional[Iterable[Node]] = None): """ Calculate the expression indices and validity. Attributes: self: The object itself idxs (Optional[Iterable[Node]]): An optional iterable of nodes """ if idxs is None: idxs = [Variable(f"idx{i}", 0, s - 1) for i, s in enumerate(self.shape)] idx = expr_idxs(self.views[-1], tuple(idxs)) valid = expr_node_mask( self.views[-1], idxs_to_idx(self.views[-1].shape, tuple(idxs)) ) return self._expr_idx(idx, valid)
[docs] def expr_node(self, idx: Union[Node, str] = "idx"): """ This method creates an expression node. Parameters: idx (Union[Node, str]): The index to use for the expression node. Default is "idx". Returns: Node: The created expression node. """ if isinstance(idx, str): idx = Variable(idx, 0, prod(self.shape) - 1) return self._expr_idx( expr_node(self.views[-1], idx), expr_node_mask(self.views[-1], idx) )
[docs] def axis_is_masked(self, axis: int) -> bool: """ This method checks if an axis is masked. Parameters: axis (int): The axis to check for masking. Returns: bool: True if the axis is masked, False otherwise. """ _, valid = self.expr_idxs() return f"idx{axis}" in [v.expr for v in valid.vars()]
# *** under this line are the movement ops ***
[docs] def pad(self, arg: Tuple[Tuple[int, int], ...]) -> ShapeTracker: """ This method pads the shape tracker with the given arguments. Parameters: arg (Tuple[Tuple[int, int], ...]): The padding argument. Returns: ShapeTracker: The padded shape tracker. """ return ShapeTracker(self.views[0:-1] + (self.views[-1].pad(arg),))
[docs] def shrink(self, arg: Tuple[Tuple[sint, sint], ...]) -> ShapeTracker: """ This method shrinks the shape tracker with the given arguments. Parameters: arg (Tuple[Tuple[sint, sint], ...]): The shrinking argument. Returns: ShapeTracker: The shrunk shape tracker. """ return ShapeTracker(self.views[0:-1] + (self.views[-1].shrink(arg),))
[docs] def expand(self, new_shape: Tuple[sint, ...]) -> ShapeTracker: """ This method expands the shape tracker to a new shape. Parameters: new_shape (Tuple[sint, ...]): The new shape for the shape tracker. Returns: ShapeTracker: The expanded shape tracker. """ return ShapeTracker(self.views[0:-1] + (self.views[-1].expand(new_shape),))
[docs] def permute(self, axis: Tuple[int, ...]) -> ShapeTracker: """ This method permutes the shape tracker according to a given axis order. Parameters: axis (Tuple[int, ...]): The new axis order for the shape tracker. Returns: ShapeTracker: The permuted shape tracker. """ return ShapeTracker(self.views[0:-1] + (self.views[-1].permute(axis),))
[docs] def stride(self, mul: Tuple[int, ...]) -> ShapeTracker: """ This method applies striding to the shape tracker with a given multiplier. Parameters: mul (Tuple[int, ...]): The stride multiplier. Returns: ShapeTracker: The strided shape tracker. """ return ShapeTracker(self.views[0:-1] + (self.views[-1].stride(mul),))
[docs] def reshape(self, new_shape: Tuple[sint, ...]) -> ShapeTracker: """ This method reshapes the shape tracker to a new shape. Parameters: new_shape (Tuple[sint, ...]): The new shape for the shape tracker. Returns: ShapeTracker: The reshaped shape tracker. """ if (new_view := self.views[-1].reshape(new_shape)) is not None: return ShapeTracker(self.views[0:-1] + (new_view,)) return ShapeTracker(self.views + (View.create(new_shape),))
# returns the axes to create new_shape if new_shape can be created by combining axis from old_shape # TODO: if we remove movementops from lazy.py we can delete this
[docs] def get_contraction( old_shape: Tuple[sint, ...], new_shape: Tuple[sint, ...] ) -> Optional[List[List[int]]]: """ Returns the axes to create new_shape if new_shape can be created by combining axis from old_shape. :param old_shape: A tuple of integers representing the original shape. :type old_shape: Tuple[sint, ...] :param new_shape: A tuple of integers representing the desired new shape. :type new_shape: Tuple[sint, ...] :return: A list of lists containing the axes to create new_shape from old_shape, or None if it's not possible. :rtype: Optional[List[List[int]]] :Attributes: - acc_old (list): List of accumulated multiplication results of old_shape elements. - acc_new (list): List of accumulated multiplication results of new_shape elements. - split (list): List of indices where the axes to combine from old_shape are located in new_shape. :raises ValueError: If it's not possible to create new_shape from old_shape. :Example: >>> old_shape = (2, 3, 4) >>> new_shape = (6, 4) >>> get_contraction(old_shape, new_shape) [[0], [1]] """ acc_old, acc_new = list(itertools.accumulate(old_shape, operator.mul)), list( itertools.accumulate(new_shape, operator.mul) ) try: split = [acc_old.index(acc) + 1 if acc != 1 else 0 for acc in acc_new] except ValueError: return None return [ list(range(st, ed)) for st, ed in zip([0] + split[:-1], split[:-1] + [len(old_shape)]) ]