Source code for tinygrad.shape.view

from __future__ import annotations
import functools, operator
from dataclasses import dataclass
from typing import Tuple, List, Optional, Dict, cast
from tinygrad.helpers import prod, all_int
from tinygrad.shape.symbolic import Node, NumNode, Variable, VariableOrNum, Set, sint


[docs] @functools.lru_cache(maxsize=None) def filter_strides(shape: Tuple[int, ...], strides: Tuple[int, ...]) -> Tuple[int, ...]: """ Filter strides based on shape. This function takes a tuple of integers representing the shape and another tuple of integers representing the strides. It returns a new tuple of integers where each stride is either kept as it is or replaced with zero, depending on whether the corresponding element in the shape tuple is not equal to one. :param shape: Tuple[int, ...] Shape of an array. :param strides: Tuple[int, ...] Strides of an array. :return: Tuple[int, ...] Filtered strides. """ return tuple(stride if shp != 1 else 0 for stride, shp in zip(strides, shape))
[docs] @functools.lru_cache(maxsize=None) def strides_for_shape(shape: Tuple[int, ...]) -> Tuple[int, ...]: """ Calculate strides for an array with a given shape. This function takes a tuple of integers representing the shape of an array and calculates the corresponding strides. It starts with a list containing a single integer `1` if the shape is not empty, otherwise it's an empty list. Then, for each dimension in reversed order (except the first one), it appends a new integer to the list: this new integer is the product of the current dimension and the last element in the list. Finally, it returns the filtered strides by calling `filter_strides`. :param shape: Tuple[int, ...] Shape of an array for which to calculate strides. :return: Tuple[int, ...] Strides of an array with given shape. """ strides = [1] if shape else [] for d in reversed(shape[1:]): strides.append(d * strides[-1]) return filter_strides(shape, tuple(reversed(strides)))
@functools.lru_cache(maxsize=None) def _merge_dims( shape: Tuple[int, ...], strides: Tuple[int, ...], mask: Optional[Tuple[Tuple[int, int], ...]] = None, ) -> Tuple[Tuple[int, int, int], ...]: """ Merge contiguous subparts or zero strided dims. Attributes: shape (Tuple[int, ...]): The shape of the array to be merged. strides (Tuple[int, ...]): The strides of the array corresponding to the shape. mask (Optional[Tuple[Tuple[int, int], ...]]): An optional mask indicating which dimensions to merge. Returns: Tuple[Tuple[int, int, int], ...]: A tuple containing merged dimensions, their strides and a count of non-zero stride dimensions. """ # merge contiguous subparts or zero strided dims. ret = List[(merged_dims, stride, merged dims w/o zero stride), ...] if not shape: return tuple() assert len(shape) == len( strides ) # state (0, 1, 2) -> (none, in-progress, done). wrt merging zero strided dimensions. ret = [(shape[0], strides[0], shape[0] if strides[0] else 0)] state = ( 1 if mask and strides[0] == 0 and shape[0] != 1 and mask[0][1] - mask[0][0] == 1 else 0 ) for i, (sh, st) in enumerate(zip(shape[1:], strides[1:]), start=1): if sh == 1: continue if state == 1 or ret[-1][1] == sh * st: # mergeable ret[-1] = ( ret[-1][0] * sh, st, (sh if state == 1 else ret[-1][2] * sh) if st else 0, ) else: ret.append((sh, st, sh if st else 0)) # begin new # merging ends with either non-zero strided dim or zero strided dim with mask range > 1 state = ( 1 if mask and st == 0 and mask[i][1] - mask[i][0] == 1 else (2 if state != 0 else 0) ) return tuple(ret) @functools.lru_cache(maxsize=None) def _reshape_mask( view: View, new_shape: Tuple[sint, ...] ) -> Tuple[Optional[Tuple[Tuple[sint, sint], ...]], Optional[Tuple[sint, ...]], bool]: """ Reshapes the mask of a given view to match the new shape. Attributes: view (View): The view with the current mask. new_shape (Tuple[sint, ...]): The desired new shape for the view. Returns: Tuple[Optional[Tuple[Tuple[sint, sint], ...]], Optional[Tuple[sint, ...]], bool]: A tuple containing the reshaped mask, offsets, and a boolean indicating if the reshape is valid or not. """ if view.mask is None: return view.mask, tuple(), False new_mask: List[Tuple[int, int]] = [] r_masks, r_shape, r_new_shape = ( reversed(view.mask), reversed(view.shape), reversed(new_shape), ) curr_stride, off, offsets, old_dim, new_dim, mask = ( 1, 0, [], next(r_shape, 1), next(r_new_shape, 1), next(r_masks, (0, 1)), ) # off represents offset while combining masks of range one & zero stride if mask[1] - mask[0] < 1: return ((0, 0),) * len(new_shape), tuple(), False # invalid mask while len(new_mask) < len(new_shape): (l, r), next_stride = (mask[0], mask[1]), new_dim * curr_stride if old_dim >= new_dim: # need to split mask. offsets.append(off) if ( old_dim == next_stride ): # simply copy the mask and get next batch for merging new_mask.append((l // curr_stride, (r - 1) // curr_stride + 1)) curr_stride, off, old_dim, new_dim, mask = ( 1, 0, next(r_shape, 1), next(r_new_shape, 1), next(r_masks, (0, 1)), ) if mask[1] - mask[0] < 1: return ((0, 0),) * len(new_shape), tuple(), False # invalid mask else: # mask can only be splitted if reshape doesn't cut across the mask. if (l % (ns := next_stride) != 0 or r % ns != 0) and l // ns != ( r - 1 ) // ns: return view.mask, tuple(), True new_mask.append( (l % ns // curr_stride, (r - 1) % ns // curr_stride + 1) ) curr_stride, new_dim = next_stride, next( r_new_shape, 1 ) # need to get mask for next dimension elif old_dim < new_dim * curr_stride: next_mask = next(r_masks, (0, 1)) # combine if the mask can unfold continuously if (l != 0 or r != old_dim) and next_mask[1] - next_mask[0] != 1: return view.mask, tuple(), True if ( next_mask != (0, 1) and mask != (0, 1) and (next_mask[1] - next_mask[0] == 1) ): off += next_mask[0] * old_dim mask, old_dim = ( next_mask[0] * old_dim + l, (next_mask[1] - 1) * old_dim + r, ), old_dim * next(r_shape, 1) for ( mask ) in ( r_masks ): # if the old shape has leading 1s, need to make sure their mask is (0,1) if mask != (0, 1): return ((0, 0),) * len(new_shape), tuple(), False return tuple(reversed(new_mask)), tuple(offsets), False
[docs] @dataclass(frozen=True) class View: """ A data class representing a view of a multi-dimensional array. Attributes: shape (Tuple[sint, ...]): The shape of the array. strides (Tuple[sint, ...]): The strides of the array. offset (sint): The offset of the array. mask (Optional[Tuple[Tuple[sint, sint], ...]]): An optional mask for the array. contiguous (bool): Whether or not the array is contiguous. """ shape: Tuple[sint, ...] strides: Tuple[sint, ...] offset: sint mask: Optional[Tuple[Tuple[sint, sint], ...]] contiguous: bool
[docs] @staticmethod @functools.lru_cache(maxsize=None) def create( shape: Tuple[sint, ...], strides: Optional[Tuple[sint, ...]] = None, offset: sint = 0, mask: Optional[Tuple[Tuple[sint, sint], ...]] = None, ): """ Create a new view. Args: shape (Tuple[sint, ...]): The shape of the array. strides (Optional[Tuple[sint, ...]]): The strides of the array. Default is None. offset (sint): The offset of the array. Default is 0. mask (Optional[Tuple[Tuple[sint, sint], ...]]): An optional mask for the array. Default is None. Returns: View: A new view object with the given parameters. """ strides = ( filter_strides(shape, strides) if strides else strides_for_shape(shape) ) contiguous = ( offset == 0 and mask is None and strides == strides_for_shape(shape) ) return View(shape, strides, offset, mask, contiguous)
[docs] def vars(self) -> Set[Variable]: """ Get the set of variables in the view. Returns: Set[Variable]: A set of variables in the view. """ flatten_mask = ( tuple(x for m in self.mask for x in m) if self.mask is not None else tuple() ) return functools.reduce( operator.or_, [ x.vars() for x in self.shape + self.strides + (self.offset,) + flatten_mask if isinstance(x, Node) ], set(), )
[docs] def unbind(self) -> View: """ Unbind the view. Returns: View: An unbound version of the current view. """ unbound_vars: Dict[VariableOrNum, Node] = { v: v.unbind()[0] for v in self.vars() if v.val is not None } new_shape = tuple( [ s if isinstance(s, int) else s.substitute(unbound_vars) for s in self.shape ] ) new_strides = tuple( [ s if isinstance(s, int) else s.substitute(unbound_vars) for s in self.strides ] ) new_offset = ( self.offset if isinstance(self.offset, int) else self.offset.substitute(unbound_vars) ) new_mask = ( tuple( ( a if isinstance(a, int) else a.substitute(unbound_vars), b if isinstance(b, int) else b.substitute(unbound_vars), ) for (a, b) in self.mask ) if self.mask is not None else None ) return View.create(new_shape, new_strides, new_offset, new_mask)
# MovementOps live here now def __unsafe_resize(self, arg: Tuple[Tuple[sint, sint], ...], mask=None) -> View: """ Resize the view with the given arguments. Attributes: arg (Tuple[Tuple[sint, sint], ...]): The new shape of the view. mask (Optional[Tuple]): The mask for the resized view. Default is None. Returns: View: The resized view. """ offset = sum([s * x[0] for s, x in zip(self.strides, arg)]) if self.mask: # move the old mask nmask = tuple( [ (max(0, min(mx - ax, ay - ax)), max(0, min(my - ax, ay - ax))) for (mx, my), (ax, ay) in zip(self.mask, arg) ] ) # merge the masks if we have two mask = ( tuple( [ (max(mx1, mx2), min(my1, my2)) for (mx1, my1), (mx2, my2) in zip(nmask, mask) ] ) if mask is not None else nmask ) shape = [y - x for x, y in arg] return View.create( tuple(s.b if isinstance(s, NumNode) else s for s in shape), self.strides, self.offset + offset, mask, )
[docs] @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none def pad(self, arg: Tuple[Tuple[int, int], ...]) -> View: """ Pad the view with zeros. Attributes: arg (Tuple[Tuple[int, int], ...]): The padding amount for each dimension. Returns: View: The padded view. """ assert all((b >= 0 and e >= 0) for b, e in arg) and len(arg) == len(self.shape) if any(b or e for b, e in arg): zvarg = tuple([(-b, s + e) for s, (b, e) in zip(self.shape, arg)]) mask = tuple([(b, s + b) for s, (b, _) in zip(self.shape, arg)]) return self.__unsafe_resize(zvarg, mask=mask) return self
[docs] @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none def shrink(self, arg: Tuple[Tuple[sint, sint], ...]) -> View: """ Shrink the view to the specified dimensions. Attributes: arg (Tuple[Tuple[sint, sint], ...]): The new shape of the view. Returns: View: The resized view. """ assert all((b >= 0 and e <= s) for s, (b, e) in zip(self.shape, arg)) and len( arg ) == len(self.shape) return self.__unsafe_resize(arg)
[docs] @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none def expand(self, new_shape: Tuple[sint, ...]) -> View: """ Expand the view to a new shape. :param new_shape: The desired shape of the expanded view. :type new_shape: Tuple[sint] :return: A new view with the given shape. :rtype: View """ if len(new_shape) != len(self.shape): raise ValueError( f"expand arg {new_shape=} must have same number of dimensions as shape {self.shape=}" ) if 0 in self.shape: assert all( (s == x == 0) or (s > 0 and (x % s) == 0) for s, x in zip(self.shape, new_shape) ), f"can't expand {self.shape} into {new_shape}" return View.create(new_shape) assert all( (s == x or (s == 1 and st == 0)) for s, x, st in zip(self.shape, new_shape, self.strides) ), f"can't expand {self.shape} into {new_shape}" # NOTE: can the mask ever be (0,0)? mask = ( tuple( [ (((0, 0) if m != (0, 1) else (0, ns)) if s != ns else m) for m, s, ns in zip(self.mask, self.shape, new_shape) ] ) if self.mask else None ) return View.create(new_shape, self.strides, self.offset, mask)
[docs] @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none def permute(self, axis: Tuple[int, ...]) -> View: """ Permute the axes of the view according to a given axis order. :param axis: The desired axis order for the permutation. :type axis: Tuple[int] :return: A new view with permuted axes. :rtype: View """ assert all( isinstance(x, int) and x >= 0 and x < len(self.shape) for x in axis ), f"invalid permute {axis} for {self.shape}" assert len(set(axis)) == len(axis) and len(axis) == len( self.shape ), f"can't permute {self.shape} with {axis}" return View.create( tuple([self.shape[a] for a in axis]), tuple([self.strides[a] for a in axis]), self.offset, tuple([self.mask[a] for a in axis]) if self.mask is not None else None, )
[docs] @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none def stride(self, mul: Tuple[int, ...]) -> View: """ Create a view with strides multiplied by the given factors. Parameters: self (View): The current view object. mul (Tuple[int, ...]): The tuple of factors to multiply strides by. Returns: View: A new view with updated strides. Attributes: assert all(isinstance(x, int) and x != 0 for x in mul), f"invalid stride {mul} for {self.shape}" * Checks if all elements in the tuple 'mul' are non-zero integers. Raises an exception if not. strides = tuple([z * m for z, m in zip(self.strides, mul)]) * Creates a new tuple of strides by multiplying each element of 'self.strides' with the corresponding element from 'mul'. new_shape = tuple([(s + (abs(m) - 1)) // abs(m) for s, m in zip(self.shape, mul)]) * Creates a new shape by calculating the quotient of each element in 'self.shape' and the absolute value of corresponding element from 'mul'. offset = sum([((s - 1) * z) for s, z, m in zip(self.shape, self.strides, mul) if m < 0]) * Calculates the new offset by multiplying each pair of elements in 'self.shape' and 'self.strides', where 'mul' has a negative value. mask = tuple([((mx if m > 0 else s - my) + (abs(m) - 1)) // abs(m), ((my if m > 0 else s - mx) + (abs(m) - 1)) // abs(m)] for (mx, my), s, m in zip(self.mask, self.shape, mul)) if self.mask is not None else None * If 'self.mask' is not None, creates a new mask by performing calculations on elements from 'self.mask', 'self.shape', and 'mul'. Else, sets mask to None. return View.create(new_shape, strides, self.offset + offset, mask) * Returns a new view with the calculated attributes. """ # except for the negative case, you can build this from the others. invertible in the negative case assert all( isinstance(x, int) and x != 0 for x in mul ), f"invalid stride {mul} for {self.shape}" strides = tuple([z * m for z, m in zip(self.strides, mul)]) new_shape = tuple( [(s + (abs(m) - 1)) // abs(m) for s, m in zip(self.shape, mul)] ) offset = sum( [(s - 1) * z for s, z, m in zip(self.shape, self.strides, mul) if m < 0] ) mask = ( tuple( [ ( ((mx if m > 0 else s - my) + (abs(m) - 1)) // abs(m), ((my if m > 0 else s - mx) + (abs(m) - 1)) // abs(m), ) for (mx, my), s, m in zip(self.mask, self.shape, mul) ] ) if self.mask is not None else None ) return View.create(new_shape, strides, self.offset + offset, mask)
[docs] @functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none def reshape(self, new_shape: Tuple[sint, ...]) -> Optional[View]: """ Reshapes the current view to a new shape. The elements in the view are rearranged without changing their order. Returns None if not possible due to mismatched size or non-contiguous data. :param self: Instance of the class. :param new_shape: Tuple containing the desired shape of the reshaped view. :return: A new View object with the specified shape, or None if reshaping is not possible. """ if self.shape == new_shape: return self assert all( x >= 0 for x in new_shape ), f"shape can't contain negative numbers {new_shape}" if 0 in self.shape: assert 0 in new_shape, f"cannot reshape 0 size to {new_shape}" return View.create(new_shape) # check for the same size if all_int(self.shape): assert all( isinstance(s, (int, Variable)) for s in new_shape ), f"{self.shape=} -> {new_shape=} contains non (int, Variable) dim" if prod(self.shape) != prod( [s if isinstance(s, int) else cast(Variable, s).val for s in new_shape] ): raise ValueError( f"size mismatched, can't reshape {self.shape=} -> {new_shape=}" ) if new_shape == () and self.mask and any(mx == my for (mx, my) in self.mask): return None # after the asserts, it's okay to check contiguous if self.contiguous: return View.create(new_shape) strides, r_new_shape = [], reversed(new_shape) for merged_dim, s, real_dim in reversed( _merge_dims(self.shape, self.strides, self.mask) ): acc, new_stride = 1, s while ( acc <= merged_dim and acc != merged_dim and (new_dim := next(r_new_shape, None)) ): strides.append(new_stride if new_dim != 1 else 0) if new_dim == 1: continue new_stride *= new_dim if (acc := acc * new_dim) < real_dim else 0 if acc != merged_dim: break else: strides += [ 0, ] * (len(new_shape) - len(strides)) mask, off_mask, extra = _reshape_mask(self, new_shape) total_offset = ( sum([off * s for off, s in zip(off_mask, strides)]) if off_mask else 0 ) if not extra: return View.create( new_shape, tuple(reversed(strides)), self.offset - total_offset, mask, ) return None