from __future__ import annotations
import os, math, itertools
from typing import NamedTuple, Optional, List, Tuple, cast, Dict, Union
from tinygrad.lazy import vars_from_ast
from tinygrad.ops import (
LazyOp,
FlopCounter,
get_lazyop_info,
UnaryOps,
BinaryOps,
ReduceOps,
MemBuffer,
ConstBuffer,
BufferOps,
)
from tinygrad.device import Device, Compiled
from tinygrad.helpers import (
dedup,
dtypes,
colored,
ImageDType,
DType,
ansilen,
getenv,
prod,
DEBUG,
round_up,
)
from tinygrad.shape.shapetracker import ShapeTracker, get_contraction
from tinygrad.shape.symbolic import sint
from tinygrad.shape.view import View, strides_for_shape
from dataclasses import dataclass
from enum import Enum, auto
[docs]
class OptOps(Enum):
"""
This class represents an enumeration of optimization operations.
Attributes:
UPCAST (auto()): Represents the operation to upcast a data type.
UPCASTMID (auto()): Represents the operation to upcast data types in the middle of a sequence.
UNROLL (auto()): Represents the operation to unroll a loop.
LOCAL (auto()): Represents the operation to make a variable local.
LASTLOCAL (auto()): Represents the operation to make the last variable in a sequence local.
GROUP (auto()): Represents the operation to group variables.
GROUPTOP (auto()): Represents the operation to group variables at the top level.
NOLOCALS (auto()): Represents the operation to remove all local variables.
PADTO (auto()): Represents the operation to pad a sequence to a specific length.
"""
UPCAST = auto()
UPCASTMID = auto()
UNROLL = auto()
LOCAL = auto()
LASTLOCAL = auto()
GROUP = auto()
GROUPTOP = auto()
NOLOCALS = auto()
PADTO = auto() # noqa: E702
def __lt__(self, x: OptOps):
"""
Compares this instance's value with another instance's value.
Args:
x (OptOps): The other instance of the OptOps enumeration to compare against.
Returns:
bool: True if this instance's value is less than the other instance's value, False otherwise.
"""
return self.value < x.value
[docs]
@dataclass(frozen=True, order=True)
class Opt:
"""
Data class for operation options.
Attributes:
op (OptOps): The operation to perform.
axis (Optional[int]): The axis along which the operation is performed. Defaults to None.
amt (Optional[int]): The amount or value used in the operation. Defaults to None.
"""
op: OptOps
axis: Optional[int] = None
amt: Optional[int] = None
def __repr__(self):
"""
Return a string representation of the object.
Returns:
str: A string in the format "Opt(op=<op>, axis=<axis>, amt=<amt>)".
"""
return f"Opt(op={self.op}, axis={self.axis}, amt={self.amt})"
[docs]
@dataclass(frozen=True)
class TensorCore:
"""
Data class for the Tensor Core.
Attributes:
device (str): The device on which the tensor core will be used.
dims (List[int]): List of integers representing dimensions.
dtype_in (DType): Input data type.
dtype_out (DType): Output data type.
threads (List[Tuple[int, int]]): List of tuples where each tuple contains a TC dimension and an amount that constructs the warp thread structure.
upcast_dim (int): The TC dimension to upcast.
thread_local_aliases (List[List[List[int]]]): A list of lists of lists containing integers defining alias for each TC dimension.
For example: [threads_1, ..., threads_n, upcast_1(unrolled), upcast_2(upcast)] where 1 is warp threads, -1 is upcast, and 0 is unrolled.
thread_local_sizes (List[int]): List of integers representing the number of elements stored in registers for each TC dimension in each thread.
arch (Optional[str]): Optional architecture parameter. Default is None.
"""
device: str
dims: List[int]
dtype_in: DType
dtype_out: DType
threads: List[
Tuple[int, int]
] # list of (TC dim,amt) that construct the warp thread structure
upcast_dim: int # which TC dim to upcast
thread_local_aliases: List[
List[List[int]]
] # a list of [threads_1, ..., threads_n, upcast_1(unrolled), upcast_2(upcast)] defining the alias (-1 is upcast, 1-n is warp threads) for each TC dim
thread_local_sizes: List[
int
] # in each thread, the number of elements stored in registers for each TC dim
arch: Optional[str] = None
def __str__(self):
return f"tensor_core<{self.device}, {self.dims}, {self.dtype_in}, {self.dtype_out}>"
tensor_cores: Dict[str, List[TensorCore]] = {
"METAL": [
TensorCore(
device="METAL",
dims=[8, 8, 8],
dtype_in=dtypes.float,
dtype_out=dtypes.float,
upcast_dim=0,
threads=[(0, 2), (1, 4), (0, 2), (1, 2)],
thread_local_sizes=[2, 2, 2],
thread_local_aliases=[
[[4], [0], [2], [0], [-1, 1, 3], [0]],
[[0], [3], [0], [1], [2, 4], [-1]],
[[4], [3], [2], [1], [0], [-1]],
],
arch="arm64",
),
TensorCore(
device="METAL",
dims=[8, 8, 8],
dtype_in=dtypes.half,
dtype_out=dtypes.half,
upcast_dim=0,
threads=[(0, 2), (1, 4), (0, 2), (1, 2)],
thread_local_sizes=[2, 2, 2],
thread_local_aliases=[
[[4], [0], [2], [0], [-1, 1, 3], [0]],
[[0], [3], [0], [1], [2, 4], [-1]],
[[4], [3], [2], [1], [0], [-1]],
],
arch="arm64",
),
],
"HIP": [
TensorCore(
device="HIP",
dims=[16, 16, 16],
dtype_in=dtypes.half,
dtype_out=dtypes.float,
upcast_dim=1,
threads=[(0, 16), (1, 2)],
thread_local_sizes=[16, 16, 8],
thread_local_aliases=[
[[0], [0], [-1], [1]],
[[0], [1], [-1], [0]],
[[0], [1], [0], [2, -1]],
],
),
TensorCore(
device="HIP",
dims=[16, 16, 16],
dtype_in=dtypes.half,
dtype_out=dtypes.half,
upcast_dim=1,
threads=[(0, 16), (1, 2)],
thread_local_sizes=[16, 16, 8],
thread_local_aliases=[
[[0], [0], [-1], [1]],
[[0], [1], [-1], [0]],
[[0], [1], [0], [2, -1]],
],
),
],
}
[docs]
class LocalBuffer(NamedTuple):
"""
A named tuple for representing a local buffer in memory.
Attributes:
name (str): The name of the local buffer.
size (int): The size of the local buffer.
dtype (DType, optional): The data type of the elements in the local buffer. Defaults to dtypes.float32.
realized (None, optional): A placeholder for future functionality. Defaults to None.
"""
name: str
size: int
dtype: DType = dtypes.float32
realized: None = None
def __str__(self):
return f"localbuffer<{self.name}[{self.size}]>"
[docs]
class LinearizerOptions(NamedTuple):
"""
A named tuple for representing options related to linearizing memory accesses.
Attributes:
device (str, optional): The target device for the linearization. Defaults to "".
supports_float4 (bool, optional): Whether the target device supports float4 data type. Defaults to True.
supports_float4_alu (bool, optional): Whether the target device supports float4 ALU operations. Defaults to True.
has_local (bool, optional): Whether the target device has local memory. Defaults to True.
has_shared (bool, optional): Whether the target device has shared memory. Defaults to True.
global_max (Optional[List[int]], optional): The maximum global dimensions for linearization. Defaults to None.
local_max (Optional[List[int]], optional): The maximum local dimensions for linearization. Defaults to None.
"""
device: str = ""
# TODO: make this generic with a list of supported types
supports_float4: bool = True
supports_float4_alu: bool = True
has_local: bool = True
has_shared: bool = True
# NOTE: these two should be in z,y,x(reversed) order for cstyle backends, they are flipped when kernel is rendered
global_max: Optional[List[int]] = None
local_max: Optional[List[int]] = None
[docs]
class Kernel:
"""
The Kernel class represents a single kernel in the linearizer. It contains information about
the AST (Abstract Syntax Tree), options, and various buffers and shape trackers used during
the linearization process. This class also provides methods for simplifying and optimizing
the linearized code.
Attributes:
ast (LazyOp): The abstract syntax tree representing the kernel's operations.
opts (Optional[LinearizerOptions]): The options used during the linearization process.
info (FlopCounter): Information about the floating-point operations in the kernel.
reduceop (Optional[Any]): The single allowed reduce operation in an AST, if it exists.
bufs (List[Union[MemBuffer, ConstBuffer, LocalBuffer]]): The list of unique buffers used by the kernel.
earlybufs (List[Any]): The list of buffers before the reduce operation, if any.
full_buf_index (int): The index of the buffer with all axes.
sts (List[ShapeTracker]): The shape trackers for each buffer in the kernel.
applied_opts (List[Opt]): The list of optimization options that have been applied to the kernel.
group_for_reduce (List[int]): Unknown.
upcasted (int): A flag indicating whether an upcast operation has been performed on the kernel.
local_dims (int): The number of local dimensions in the kernel.
local_alias (Dict[int, LocalBuffer]): A dictionary mapping integers to local buffers.
tensor_core (Optional[TensorCore]): Information about the tensor core being used, if any.
dont_use_locals (bool): A flag indicating whether local buffers should be used in the kernel.
applied_opts_cache (Optional[List[Opt]]): A cache of optimization options that have been applied to the kernel.
"""
def __init__(self, ast: LazyOp, opts: Optional[LinearizerOptions] = None):
self.opts = (
opts
if opts
else (
cast(Compiled, Device[Device.DEFAULT]).linearizer_opts
if isinstance(Device[Device.DEFAULT], Compiled)
else LinearizerOptions()
)
)
self.ast = ast
assert (
ast.op == BufferOps.STORE
), f"kernels must have a store as the output, got {ast.op}"
# fetch lazyop info
self.info: FlopCounter = get_lazyop_info(self.ast)
# there's only allowed to be one reduceop
reduceops = [x for x in self.ast.get_lazyops() if x.op in ReduceOps]
assert len(dedup(reduceops)) <= 1, "max one reduce op in an ast"
self.reduceop = reduceops[0] if reduceops else None
# create new shapetrackers inside this kernel, we will permute them
self.bufs: List[Union[MemBuffer, ConstBuffer, LocalBuffer]] = dedup(
[x.arg for x in self.ast.get_lazyops() if x.op in BufferOps]
)
assert (
isinstance(self.bufs[0], MemBuffer) and self.bufs[0].idx == 0
), f"buffer 0 is not the store buffer {self.bufs[0]}"
# get earlybufs, before the one reduce op
self.earlybufs = (
[x.arg for x in self.reduceop.get_lazyops() if x.op in BufferOps]
if self.reduceop
else []
)
self.full_buf_index: int = (
self.bufs.index(self.earlybufs[0]) if self.earlybufs else 0
)
# create the (permuted) shapetrackers
self.sts: List[ShapeTracker] = [
x.st for x in cast(List[Union[MemBuffer, ConstBuffer]], self.bufs)
]
# move all reduce axes to the end
reduce = list(enumerate(zip(self.full_shape, self.sts[0].shape)))
permute = tuple(
[i for i, (s, n) in reduce if s == n]
+ [i for i, (s, n) in reduce if s != n]
)
self.reshape_and_permute(None, permute)
# parameters for optimization
self.applied_opts: List[Opt] = []
self.group_for_reduce: List[int] = []
self.upcasted: int = 0
self.local_dims: int = 0
self.local_alias: Dict[int, LocalBuffer] = {}
self.tensor_core: Optional[TensorCore] = None
self.dont_use_locals: bool = False
# group simplifies
self.simplify_ones()
self.simplify_merge_adjacent()
# cache
self.applied_opts_cache: Optional[List[Opt]] = None
[docs]
def copy(self):
"""
Creates a deep copy of the current Kernel object. This can be useful for creating new kernels based on existing ones
without modifying the original kernel.
Returns:
A deep copy of the current Kernel object.
"""
ret = type(self).__new__(type(self))
# base linearizer params
ret.opts, ret.ast = self.opts, self.ast
# things downstream of the AST
# NOTE: we copy bufs for local buffers and sts for optimizations
ret.info, ret.reduceop, ret.bufs, ret.earlybufs, ret.full_buf_index, ret.sts = (
self.info,
self.reduceop,
self.bufs[:],
self.earlybufs,
self.full_buf_index,
self.sts[:],
)
# parameters for optimizations
(
ret.applied_opts,
ret.group_for_reduce,
ret.upcasted,
ret.local_dims,
ret.local_alias,
ret.tensor_core,
ret.dont_use_locals,
) = (
self.applied_opts[:],
self.group_for_reduce[:],
self.upcasted,
self.local_dims,
self.local_alias.copy(),
self.tensor_core,
self.dont_use_locals,
)
# uncached since linearize didn't run
ret.applied_opts_cache = None
return ret
@property
def membufs(self) -> List[MemBuffer]:
"""
Membuffers attribute.
Returns:
List[MemBuffer]: A list of MemBuffer objects.
"""
return [x for x in self.bufs if isinstance(x, MemBuffer)]
# TODO: these need more tests or it might silently be no-op
[docs]
def shape_offsets(self, i: int):
"""
Compute the offsets of the shape.
Args:
i (int): The index for which to compute the offsets.
Returns:
itertools.product: An iterator that computes the cartesian product of input iterables.
"""
return (
itertools.product(
*[
list(range(cast(int, s)))
for s in self.sts[i].shape[self.shape_len - self.upcasted :][::-1]
]
)
if self.upcasted > 0
else [tuple()]
)
[docs]
def float4_axis(self, i: int):
"""
Compute the float4 axis.
Args:
i (int): The index for which to compute the float4 axis.
Returns:
List[int]: A list of integers representing the float4 axis.
"""
return [
x - (self.shape_len - self.upcasted)
for x in self.sts[i].unit_stride_axes()
if x >= self.shape_len - self.upcasted and self.sts[i].shape[x] % 4 == 0
]
[docs]
def upcasted_axis(self, i: int):
"""
Compute the upcasted axis.
Args:
i (int): The index for which to compute the upcasted axis.
Returns:
List[Tuple[int, int, bool]]: A list of tuples containing integers and a boolean value.
"""
return list(
zip(
self.sts[i].shape[self.shape_len - self.upcasted :],
self.sts[i].real_strides()[self.shape_len - self.upcasted :],
[
x != y
for x, y in zip(
self.sts[0].shape[self.shape_len - self.upcasted :],
self.full_shape[self.shape_len - self.upcasted :],
)
],
)
)
# TODO: is there a better way to write this?
[docs]
def acc_offsets(self, i: int) -> List[int]:
"""
Calculate access offsets for a given index.
Attributes:
i (int): The index to calculate access offsets for.
Returns:
List[int]: A list of calculated access offsets.
"""
if self.upcasted == 0:
return [0]
upcasted_i = self.upcasted_axis(i)
acc_strides = [
x * (1 - upcasted_i[::-1][i][2])
for i, x in enumerate(
strides_for_shape(tuple(1 if r else s for s, _, r in upcasted_i[::-1]))
)
]
return [
sum(t)
for t in itertools.product(
*[
[y * acc_strides[i] for y in range(x[0])]
for i, x in enumerate(upcasted_i[::-1])
]
)
]
[docs]
def get_upcast_dim(self, i: int) -> List[int]:
"""
Get dimensions that need to be upcasted.
Attributes:
i (int): The index to check for dimensions that need to be upcasted.
Returns:
List[int]: A list of dimensions that need to be upcasted.
"""
should_upcast = self.opts.supports_float4 and (
self.bufs[i].dtype in [dtypes.float32, dtypes.float16]
or isinstance(self.bufs[i].dtype, ImageDType)
)
return [
x
for x in self.sts[i].unit_stride_axes()
if should_upcast
and x >= self.shape_len - self.upcasted
and self.sts[i].shape[x] > 1
]
@property
def first_reduce(self) -> int:
"""
Calculate the index of the first reduction axis.
Attributes:
self.sts (List[SomeObject]): A list of objects with a shape attribute.
self.shape_len (int): The length of the shape attribute of an object in `self.sts`.
self.upcasted (int): The number of upcasted dimensions.
Returns:
int: The index of the first reduction axis.
"""
return [
x != y
for x, y in zip(
self.sts[0].shape[: self.shape_len - self.upcasted] + (0,),
self.full_shape[: self.shape_len - self.upcasted] + (1,),
)
].index(True)
@property
def output_shape(self) -> Tuple[sint, ...]:
"""
Get the shape of the first object in `self.sts`.
Attributes:
self.sts (List[SomeObject]): A list of objects with a shape attribute.
Returns:
Tuple[sint, ...]: The shape of the first object in `self.sts`.
"""
return self.sts[0].shape
@property
def full_shape(self) -> Tuple[sint, ...]:
"""
Get the shape of the object at index `self.full_buf_index` in `self.sts`.
Attributes:
self.sts (List[SomeObject]): A list of objects with a shape attribute.
self.full_buf_index (int): The index of the object to get the shape from.
Returns:
Tuple[sint, ...]: The shape of the object at `self.full_buf_index` in `self.sts`.
"""
return self.sts[self.full_buf_index].shape
@property
def full_unupcasted_shape(self) -> Tuple[sint, ...]:
"""
Get the unupcasted shape of the object at index `self.full_buf_index` in `self.sts`.
Attributes:
self.full_shape (Tuple[sint, ...]): The shape of the object at `self.full_buf_index` in `self.sts`.
self.upcasted (int): The number of upcasted dimensions.
Returns:
Tuple[sint, ...]: The unupcasted shape of the object at `self.full_buf_index` in `self.sts`.
"""
return self.full_shape[: self.shape_len - self.upcasted]
@property
def shape_len(self) -> int:
"""
Get the length of the shape attribute of an object in `self.sts`.
Attributes:
self.sts (List[SomeObject]): A list of objects with a shape attribute.
Returns:
int: The length of the shape attribute of an object in `self.sts`.
"""
return len(self.sts[0].shape)
@property
def upcast_in_mid_reduce_axes(self) -> List[int]:
"""
Get a list of indices where the dimensions are equal in both `self.full_shape` and `self.sts[0].shape`.
Attributes:
self.first_reduce (int): The index of the first reduction axis.
self.group_for_reduce (List[int]): A list of integers representing groups for reduction.
self.full_shape (Tuple[sint, ...]): The shape of the object at `self.full_buf_index` in `self.sts`.
self.sts[0].shape (Tuple[sint, ...]): The shape of the first object in `self.sts`.
Returns:
List[int]: A list of indices where the dimensions are equal in both `self.full_shape` and `self.sts[0].shape`.
"""
return [
j
for j in range(
self.first_reduce, self.first_reduce + len(self.group_for_reduce)
)
if self.full_shape[j] == self.sts[0].shape[j]
]
@property
def global_dims(self) -> int:
"""
Calculate and return the difference between first_reduce and local_dims attributes.
Attributes:
self.first_reduce (int): The first reduced dimension.
self.local_dims (int): The local dimensions.
Returns:
int: The difference between self.first_reduce and self.local_dims.
Notes:
This method is a property, meaning it can be accessed like an attribute on an instance of the class.
It's important to note that there are eight chunks of the shape.
"""
return self.first_reduce - self.local_dims
# there's eight chunks of the shape
# blue -- global dims
# cyan -- local dims (warp ones first)
# *** self.first_reduce
# green -- reduce-local dims
# white -- reduce-late upcasted dim (self.upcast_in_mid_reduce_axes)
# red -- reduce loops
# *** self.upcasted
# purple -- reduce upcasted
# yellow -- normal upcasted dimensions
[docs]
def colors(self) -> List[str]:
"""
Generate a list of color codes based on the dimensions of the object.
Attributes:
global_dims (int): Number of global dimensions
local_dims (int): Number of local dimensions
first_reduce (int): Index of the first reduce dimension
group_for_reduce (list): List of grouped dimensions for reduction
upcast_in_mid_reduce_axes (set): Set of axes where upcasting occurs during mid-reduction
shape_len (int): Length of the shape vector
upcasted (int): Number of upcasted dimensions
full_shape (list): Full shape of the object
sts (list): List of objects with shapes
Returns:
list: A list of color codes representing different types of dimensions.
"""
# first non local non reduce dims are global (blue)
colors = (
["blue"] * self.global_dims
if not self.dont_use_locals
else ["BLUE"] * self.global_dims
)
# after global are local_dims; warp ones used in tensor cores must be closest to first_reduce (cyan)
colors += ["cyan"] * self.local_dims
# between first_reduce and first_reduce + group_for_reduce, they are either upcast mid reduce (white), or late upcasted (green)
colors += [
"white" if i in self.upcast_in_mid_reduce_axes else "green"
for i in range(
self.first_reduce, self.first_reduce + len(self.group_for_reduce)
)
]
# between first_reduce + group_for_reduce and upcasted, they are reduce (red)
colors += ["red"] * (
(self.shape_len - self.upcasted)
- (self.first_reduce + len(self.group_for_reduce))
)
# upcasted dimensions are reduce (magenta) or normal (yellow)
colors += [
"magenta" if self.full_shape[i] != self.sts[0].shape[i] else "yellow"
for i in range(self.shape_len - self.upcasted, self.shape_len)
]
assert len(colors) == self.shape_len, "colors size mismatch"
return colors
[docs]
def colored_shape(self, pad: Optional[int] = None, dense=False) -> str:
"""
Generate a string representation of the shape with each dimension colored according to its position.
:param pad: The number of spaces to pad the resulting string. If not provided, no padding is added.
:type pad: Optional[int]
:param dense: Whether or not to represent int dimensions in dense format (i.e., with 4 digits). Defaults to False.
:type dense: bool
:return: A string representation of the shape with each dimension colored according to its position.
:rtype: str
Attributes:
self.full_shape (List[Union[int, str]]): The full shape of the object.
self.colors (Callable[[], List[str]]): A function that returns a list of colors for each dimension.
ansilen (Callable[[str], int]): A function that calculates the length of a string in terminal characters.
"""
ret = " ".join(
colored(s, color)
for s, color in zip(
[
f"{s:4d}" if isinstance(s, int) and not dense else s
for s in self.full_shape
],
self.colors(),
)
)
if pad:
ret += " " * (pad - ansilen(ret))
return ret
# ******************** base simplifiers ********************
# apply reshape and permute to all shapetrackers
[docs]
def reshape_and_permute(self, new_shape_fxn, axis):
"""
Apply reshape and permute to all shapetrackers.
Parameters:
new_shape_fxn (function): Function used for reshaping.
axis (int): Axis for permutation.
Returns:
None
"""
new_sts = []
for st in self.sts:
if new_shape_fxn is not None:
st = st.reshape(tuple(new_shape_fxn(st.shape)))
if axis is not None:
st = st.permute(tuple(axis))
new_sts.append(st)
self.sts = new_sts
# drops the final dimension
[docs]
def upcast(self):
"""
Drop the final dimension.
Parameters:
None
Returns:
None
Raises:
AssertionError: If the final dimension size is 1, as it cannot be upcasted.
"""
assert self.full_shape[-1] != 1, "can't upcast a dimension with size 1"
self.upcasted += 1
# axis : the axis to pull from
# amount : the amount to take
# top : if you want to pull that amount from the top
# insert_before : place to insert the new stuff
[docs]
def shift_to(self, axis, amount, top=False, insert_before=None):
"""
Shift elements to a specified location.
Parameters:
axis (int): The axis to pull from.
amount (int): The amount to take.
top (bool): If you want to pull that amount from the top. Default is False.
insert_before (int): Place to insert the new stuff. Default is None, which means end of list.
Returns:
None
"""
if insert_before is None:
insert_before = self.shape_len
move_axis = axis if top else axis + 1
if move_axis < insert_before:
insert_before += 1
self.reshape_and_permute(
lambda x: list(x[0:axis])
+ (
([amount, x[axis] // amount] if top else [x[axis] // amount, amount])
if x[axis] > 1
else [1, 1]
)
+ list(x[axis + 1 :]),
[i for i in range(insert_before) if i != move_axis]
+ [move_axis]
+ [i for i in range(insert_before, self.shape_len + 1) if i != move_axis],
)
# ******************** complex simplifiers ********************
[docs]
def simplify_ones(self) -> bool:
"""
Simplify by removing places where the shape is all ones. This function checks if the shape_len is 0 and then
updates local_dims and upcasted values accordingly. It also reshapes and permutes the given shapes. The function
returns True if any value in all_ones is True, else False.
:return: bool
"""
# remove places where the shape is all ones
# TODO: this should be factored in to multi shape stride
if self.shape_len == 0:
return False
all_ones = [s == 1 for s in self.full_shape]
self.local_dims -= sum(
all_ones[self.first_reduce - self.local_dims : self.first_reduce]
)
self.upcasted -= sum(all_ones[self.shape_len - self.upcasted :])
self.reshape_and_permute(
lambda shape: [x for i, x in enumerate(shape) if not all_ones[i]], None
)
return any(all_ones)
[docs]
def simplify_merge_adjacent(self):
"""
Simplify by merging adjacent dimensions. This function checks if the shape_len is 0 and then proceeds to
merge dimensions when possible. It also handles special cases for image dtypes and updates the shapes and strides
accordingly.
:return: None
"""
if self.shape_len == 0:
return
shapes, strides = [x.shape for x in self.sts], [
x.real_strides() for x in self.sts
]
# if it's an image, insert fake strides such that this fusion doesn't happen across image axes
if isinstance(self.bufs[0].dtype, ImageDType):
base_shape = self.bufs[0].dtype.shape
if shape_idx_groups := get_contraction(self.output_shape, base_shape):
special_strides: Tuple[int, ...] = tuple()
for i, g in enumerate(shape_idx_groups):
shape_piece = tuple(self.output_shape[x] for x in g)
assert (
prod(shape_piece) == base_shape[i]
), f"get_contraction was wrong? {shape_piece} != {base_shape[i]}"
special_strides += strides_for_shape(shape_piece)
# adding the fake image shape
shapes.append(self.output_shape)
strides.append(special_strides)
# merge dimensions if we can, multi get_shape_strides
# TODO: does this always preserve the reduce dimension, NO
# TODO: move this into shapetracker, with tests!
rets = [[(shapes[j][0], strides[j][0])] for j in range(len(shapes))]
for i in range(1, len(shapes[0])):
can_merge = []
for j in range(len(shapes)):
# TODO: added the always mergeability of 1s, is this right? if so, add to shapetracker in the 1 case
can_merge.append(
strides[j][i] is not None
and (
(
strides[j][i] != 0
and rets[j][-1][1]
== shapes[j][i] * cast(int, strides[j][i])
)
or (strides[j][i] == 0 and rets[j][-1][1] == 0)
)
)
# more can merge than this
mergeable = all(can_merge) and i != self.first_reduce
for j in range(len(shapes)):
if mergeable:
rets[j][-1] = (rets[j][-1][0] * shapes[j][i], strides[j][i])
else:
rets[j].append((shapes[j][i], strides[j][i]))
# do the reshapes
for i, x in enumerate(rets[: len(self.sts)]):
self.sts[i] = self.sts[i].reshape(tuple([y[0] for y in x]))
# ******************** GPU simplifiers ********************
def _limit_size(self, x: Tuple[int], max_size: List) -> Tuple[int, ...]:
"""
Limit the size of tensor dimensions.
:param x: Tuple of integers representing the shape of a tensor.
:type x: Tuple[int]
:param max_size: List of maximum allowed sizes for each dimension.
:type max_size: List
:return: Tuple of integers representing the new shape with dimensions limited by max_size.
:rtype: Tuple[int, ...]
"""
new_shape, dims = list(x), len(x)
for i in range(dims):
next_idx = (i + 1) % dims
while new_shape[i] > max_size[i]:
new_shape[i] = new_shape[i] // 2
if new_shape[next_idx] <= max_size[next_idx]:
new_shape[next_idx] = new_shape[next_idx] * 2
else:
next_idx = (next_idx + 1) % dims
new_shape[next_idx] = new_shape[next_idx] * 2
return tuple(new_shape)
[docs]
def limit_dims_to_max(self, global_max: List[int], local_max: List[int]):
"""
Limit dimensions to maximum allowed sizes.
:param global_max: List of maximum allowed global dimension sizes.
:type global_max: List[int]
:param local_max: List of maximum allowed local dimension sizes.
:type local_max: List[int]
"""
# Check the global allocation limit, current the global_size will be flipped during codegen
# and then padded right with 1s if its length < 3 which makes this part a bit awkward to write
global_dims = self.first_reduce - self.local_dims
if global_dims > 0:
if global_max:
tmp = global_max[:global_dims] + (
local_max[: self.local_dims] if local_max else []
)
if max(global_max) < max(self.full_shape[:global_dims]):
self.reshape_and_permute(
lambda x: self._limit_size(
x, tmp + [math.inf] * (len(self.full_shape) - len(tmp))
),
None,
)
assert max(global_max) >= max(
self.full_shape[:global_dims]
), f"device max allocation {max(self.full_shape[:global_dims])} exceeds global dim maximum {max(global_max)}"
for i in range(global_dims - 1):
if i < len(global_max) and self.full_shape[i] > global_max[i]:
order = list(range(len(self.full_shape)))
order[i], order[global_dims - 1] = order[global_dims - 1], order[i]
self.reshape_and_permute(None, order)
if DEBUG >= 3:
print(
"permuted global dim",
order,
"due to allocation exceeds global limit",
)
[docs]
def alias_buffer(self, i, pattern):
"""
Alias a buffer.
:param i: Index of the buffer to be aliased.
:type i: int
:param pattern: List representing the pattern for each shape.
:type pattern: List
"""
assert len(pattern) == len(
self.sts[i].shape
), f"must include a pattern for each shape {pattern} {self.sts[i].shape}"
bst = 1
real_strides = self.sts[i].real_strides()
shp, stride = [
(s if p != 0 else 1) for s, p in zip(self.sts[i].shape, pattern)
], [0] * len(pattern)
for priority in range(
1, max(pattern) + 1
): # priority. 0 is non local and ignored
for j, p in enumerate(pattern):
if priority == p and real_strides[j] != 0:
stride[j] = bst
bst *= shp[j]
self.sts.append(ShapeTracker((View.create(tuple(shp), tuple(stride)),)))
self.bufs.append(LocalBuffer(name=f"ldata{i}", size=self.sts[-1].size()))
if DEBUG >= 4:
print("aliasing buffer", self.sts[i])
self.local_alias[i] = cast(LocalBuffer, self.bufs[-1])
# ******************** high level optimizers ********************
[docs]
def apply_tensor_cores(
self, use_tensor_cores=1, extra_opts: Optional[List[Opt]] = None
):
"""
Apply tensor cores to the computation.
Attributes:
use_tensor_cores (int): Flag indicating whether to apply tensor cores or not. Default is 1.
extra_opts (Optional[List[Opt]]): Optional list of extra options. Default is None.
This function checks if the following conditions are met for applying tensor cores:
1) use_tensor_cores flag is True.
2) The current device has local memory support.
3) Reduction operation exists and it's a summation (ReduceOps.SUM).
4) The current device supports tensor cores.
If these conditions are met, the function iterates over all available tensor cores for the current device.
It then checks if certain conditions hold true to apply tensor cores:
1) Tensor core architecture is compatible with the current system.
2) The reduction operation's source is a LazyOp and its operation is UnaryOps.CAST with the correct dtype_out.
3) The multiplication operation (LazyOp with BinaryOps.MUL) exists and its sources are two LazyOps with
BufferOps.LOAD operations and compatible dtypes with the tensor core configuration.
4) The strides of both source buffers for the multiplication operation are zero for the first reduction dimension.
5) The shape of the buffers is compatible with the tensor core dimensions.
If all these conditions are met, it selects the axes for buffer 0 and buffer 1 and applies tensor cores.
"""
if (
use_tensor_cores
and self.opts.has_local
and self.reduceop
and self.reduceop.op == ReduceOps.SUM
and self.opts.device in tensor_cores
):
for tc in tensor_cores[self.opts.device]:
if not (
(tc.arch is None or tc.arch == os.uname().machine)
and isinstance(self.reduceop.src[0], LazyOp)
):
continue
has_cast = tc.dtype_in != tc.dtype_out
if has_cast and not (
isinstance(self.reduceop.src[0], LazyOp)
and self.reduceop.src[0].op == UnaryOps.CAST
and self.reduceop.src[0].arg[0] == tc.dtype_out
):
continue
mul_op = (
self.reduceop.src[0].src[0] if has_cast else self.reduceop.src[0]
)
if not (isinstance(mul_op, LazyOp) and mul_op.op == BinaryOps.MUL):
continue
if not (
isinstance(mul_op.src[0], LazyOp)
and mul_op.src[0].op == BufferOps.LOAD
and mul_op.src[0].arg.dtype == tc.dtype_in
):
continue
if not (
isinstance(mul_op.src[1], LazyOp)
and mul_op.src[1].op == BufferOps.LOAD
and mul_op.src[1].arg.dtype == tc.dtype_in
):
continue
buf0, buf1 = self.bufs.index(
cast(MemBuffer, mul_op.src[0].arg)
), self.bufs.index(cast(MemBuffer, mul_op.src[1].arg))
buf0_strides, buf1_strides = (
self.sts[buf0].real_strides(),
self.sts[buf1].real_strides(),
)
axis_buf0 = [
(i, self.full_shape[i], buf1_strides[i])
for i, s in enumerate(buf0_strides[: self.first_reduce])
if s == 0 and self.full_shape[i] % tc.dims[0] == 0
]
axis_buf1 = [
(i, self.full_shape[i], buf0_strides[i])
for i, s in enumerate(buf1_strides[: self.first_reduce])
if s == 0 and self.full_shape[i] % tc.dims[1] == 0
]
if not (
axis_buf0
and axis_buf1
and self.full_shape[self.first_reduce] % tc.dims[2] == 0
and self.full_shape[self.first_reduce] >= tc.dims[2]
and (self.shape_len - self.first_reduce) == 1
):
continue
if DEBUG >= 3:
print("TENSOR CORES", axis_buf0, axis_buf1, tc)
s0, s1 = (
axis_buf0[-1][0],
axis_buf1[-1][0],
) # TODO: select axis in smart way
s0_exists, s1_exists = True, True
assert (
s0 != s1
and self.full_shape[s0] % tc.dims[0] == 0
and self.full_shape[s1] % tc.dims[1] == 0
)
def fix(needed, ax):
"""
Fix function for tensor core operations.
This function is responsible for unrolling the reduce dimension and upcasting input tensor data type.
It then creates a thread pattern based on the specified conditions.
Attributes:
needed (bool): A flag to check if this operation is necessary.
ax (int): The axis along which the reduction is performed.
s0, s1 (float): Two values used for performing calculations.
s0_exists, s1_exists (bool): Flags indicating whether `s0` and `s1` respectively are valid or not.
"""
nonlocal s0, s1, s0_exists, s1_exists
if not needed:
return
if s0_exists and ax == s0:
if s1_exists and s0 < s1:
s1 -= 1
s0_exists = False
elif s1_exists and ax == s1:
if s0_exists and s1 < s0:
s0 -= 1
s1_exists = False
# tensor core -- unroll the reduce dim, upcast input, then create the correct thread pattern
self.apply_opt(Opt(OptOps.UNROLL, 0, tc.dims[2]))
self.apply_opt(
Opt(
OptOps.UPCAST,
s0 if tc.upcast_dim == 0 else s1,
(tc.dims[0] * tc.dims[2]) // prod([a[1] for a in tc.threads]),
)
)
for tc_dim, tc_amt in tc.threads:
fix(
self.apply_opt(
Opt(OptOps.LASTLOCAL, s0 if tc_dim == 0 else s1, tc_amt)
),
s0 if tc_dim == 0 else s1,
)
# assert tensor core and prevent extra_opts from altering the key shape structure
if use_tensor_cores == 1:
self.tensor_core = tc # TC=2 will do the shape ops without the WMMA
if extra_opts is not None:
for opt in extra_opts:
self.apply_opt(opt)
else:
# hand-coded TC opts
if s1_exists:
s1_div = [
upc
for upc in [5, 4, 3, 2, 1]
if self.full_shape[s1] % upc == 0
][0]
if s1_div != 1:
fix(self.apply_opt(Opt(OptOps.UPCAST, s1, s1_div)), s1)
if s0_exists:
s0_div = [
upc
for upc in [5, 4, 3, 2, 1]
if self.full_shape[s0] % upc == 0
][0]
if s0_div != 1:
fix(self.apply_opt(Opt(OptOps.UPCAST, s0, s0_div)), s0)
if self.tensor_core and s0_exists:
for upc in [4, 2]:
if self.full_shape[s0] % upc == 0:
self.apply_opt(Opt(OptOps.LASTLOCAL, s0, upc))
break
# alias buffer
alias_pattern = (
[0] * (self.global_dims + (self.local_dims - len(tc.threads)))
+ [2] * (len(tc.threads))
+ [0] * (self.shape_len - self.upcasted - self.first_reduce)
+ [1, 1]
+ [3] * (self.upcasted - 2)
)
self.alias_buffer(buf0, alias_pattern)
self.alias_buffer(buf1, alias_pattern)
return True
return False
[docs]
def apply_opt(self, opt: Opt):
"""
Apply an optimization to the current object.
This method checks if the optimization operation is applicable based on the 'dont_use_locals' attribute and the type of the operation. It then appends the optimization to a list of applied optimizations. The axis for the optimization is calculated based on certain conditions and defaulted to -1 if no specific axis is given.
Args:
opt (Opt): The optimization operation to apply.
Raises:
AssertionError: If 'dont_use_locals' attribute is True and the optimization operation is one of LOCAL, LASTLOCAL, GROUP, GROUPTOP, or UPCASTMID.
Attributes:
applied_opts (List[Opt]): A list of previously applied optimization operations.
dont_use_locals (bool): If True, some optimization operations are not allowed.
first_reduce (int): The index of the first reduction operation.
group_for_reduce (list): A list of groups for reduction operations.
"""
assert not self.dont_use_locals or opt.op not in {
OptOps.LOCAL,
OptOps.LASTLOCAL,
OptOps.GROUP,
OptOps.GROUPTOP,
OptOps.UPCASTMID,
}, "not using locals"
self.applied_opts.append(opt)
if opt.axis is not None:
axis = opt.axis + (
self.first_reduce
if opt.op == OptOps.UNROLL
else (
self.first_reduce + len(self.group_for_reduce)
if opt.op == OptOps.GROUP or opt.op == OptOps.GROUPTOP
else 0
)
)
else:
axis = -1
if opt.amt is not None:
amt = opt.amt if opt.amt != 0 else self.full_shape[axis]
assert (
isinstance(amt, int) and amt != 1
), "shift/padto of amt 1 or Node is meaningless"
if opt.op != OptOps.PADTO:
assert self.full_shape[axis] % amt == 0, "no longer valid shift"
else:
amt = -1
if opt.op == OptOps.LOCAL: # cyan
assert self.opts.has_local, "target does not support local"
assert axis < self.first_reduce, "can't local a reduce"
assert not (self.tensor_core), "can't local with tensor cores"
self.shift_to(axis, amt, insert_before=self.first_reduce)
self.local_dims += 1
elif opt.op == OptOps.LASTLOCAL: # cyan
assert self.opts.has_local, "target does not support local"
assert axis < self.first_reduce, "can't local a reduce"
self.shift_to(axis, amt, insert_before=self.first_reduce - self.local_dims)
self.local_dims += 1
elif opt.op == OptOps.GROUP: # green
assert (
self.opts.has_local and self.opts.has_shared
), "target does not support local or shared mem"
assert (
axis >= self.first_reduce + len(self.group_for_reduce)
and axis < self.shape_len - self.upcasted
), "must be reduce axis to group"
assert not (self.tensor_core), "can't group with tensor cores"
self.shift_to(
axis, amt, insert_before=self.first_reduce + len(self.group_for_reduce)
)
self.group_for_reduce.append(amt)
elif opt.op == OptOps.GROUPTOP: # green
assert (
self.opts.has_local and self.opts.has_shared
), "target does not support local or shared mem"
assert (
axis >= self.first_reduce + len(self.group_for_reduce)
and axis < self.shape_len - self.upcasted
), "must be reduce axis to group"
assert not (self.tensor_core), "can't group with tensor cores"
self.shift_to(
axis,
amt,
top=True,
insert_before=self.first_reduce + len(self.group_for_reduce),
)
self.group_for_reduce.append(amt)
elif opt.op == OptOps.UNROLL: # purple
assert (
axis < self.shape_len - self.upcasted
), "can't upcasted already upcasted"
assert amt <= 32, "don't unroll more than 32"
self.shift_to(axis, amt, insert_before=None)
self.upcast()
elif opt.op == OptOps.UPCAST: # yellow
assert axis < self.first_reduce, "upcast is for non-reduce"
assert amt <= 8, "don't upcast more than 8"
self.shift_to(axis, amt, insert_before=None)
self.upcast()
elif opt.op == OptOps.UPCASTMID: # white
assert (
self.bufs[0].dtype.name.startswith("image")
and not self.float4_axis(0)
and self.group_for_reduce
and self.first_reduce <= 2
and prod(self.sts[0].shape) > 1
), "invalid upcast mid reduce"
axes = self.sts[0].unit_stride_axes()
assert len(axes) == 1, f"wrong number of stride 1 axis : {axes}"
assert axes[0] == axis, "wrong axis"
assert amt == 4, "don't upcast mid anything but 4"
self.shift_to(
axis, amt, insert_before=self.first_reduce + len(self.group_for_reduce)
)
self.group_for_reduce.append(amt)
elif opt.op == OptOps.NOLOCALS:
assert (
self.opts.has_local
), "target does not support local, so this optimization is meaningless"
assert (
self.local_dims == 0 and len(self.group_for_reduce) == 0
), "can't have no locals with locals"
assert not self.dont_use_locals, "already not using locals"
self.dont_use_locals = True
elif opt.op == OptOps.PADTO:
assert not vars_from_ast(self.ast), "does not work with symbolic shape"
assert all(
op.op is not ReduceOps.MAX for op in self.ast.get_lazyops()
), "cannot pad with MAX"
padded = False
for i, st in enumerate(self.sts):
if self.sts[i].shape[axis] != 1:
assert (
self.sts[i].shape[axis] > amt // 2
), "pad adds more than double the work"
if (
ru := round_up(self.sts[i].shape[axis], amt)
- self.sts[i].shape[axis]
):
# pad right seems to be faster
self.sts[i] = st.pad(
((0, 0),) * axis
+ ((0, ru),)
+ ((0, 0),) * (len(st.shape) - axis - 1)
)
padded = True
assert padded, "nothing was padded"
return self.simplify_ones()
[docs]
def hand_coded_optimizations(self):
"""
This method handles the application of hand-coded optimizations.
Attributes:
MV_BLOCKSIZE (int): The block size for matrix-vector multiplication.
MV_THREADS_PER_ROW (int): The number of threads per row for matrix-vector multiplication.
MV_ROWS_PER_THREAD (int): The number of rows per thread for matrix-vector multiplication.
"""
# should use matvec - TODO: adjust/tune based on the wide vs tall/large vs small mat
MV_BLOCKSIZE, MV_THREADS_PER_ROW, MV_ROWS_PER_THREAD = (
getenv("MV_BLOCKSIZE", 4),
getenv("MV_THREADS_PER_ROW", 8),
getenv("MV_ROWS_PER_THREAD", 4),
)
if (
self.opts.has_local
and getenv("MV", 1) != 0
and (MV_BLOCKSIZE > 1 or MV_THREADS_PER_ROW > 1 or MV_ROWS_PER_THREAD > 1)
and self.reduceop
and self.reduceop.op == ReduceOps.SUM
and len(self.full_shape) >= 2
and self.opts.has_shared
and isinstance(self.reduceop.src[0], LazyOp)
and self.reduceop.src[0].op == BinaryOps.MUL
and self.reduceop.src[0].src[0].op == BufferOps.LOAD
and self.reduceop.src[0].src[1].op == BufferOps.LOAD
):
buf0 = self.bufs.index(self.reduceop.src[0].src[0].arg)
buf1 = self.bufs.index(self.reduceop.src[0].src[1].arg)
buf0_strides = self.sts[buf0].real_strides()
buf1_strides = self.sts[buf1].real_strides()
def has_expanded_axis(s, st):
return any(x > 1 and y == 0 for x, y in zip(s, st))
if buf0_strides[self.first_reduce] == 1 and not (
has_expanded_axis(self.sts[buf0].shape, buf0_strides)
and has_expanded_axis(self.sts[buf1].shape, buf1_strides)
):
for global_idx in range(self.global_dims):
if (
self.full_shape[self.first_reduce] % MV_THREADS_PER_ROW == 0
and self.full_shape[global_idx]
% (MV_BLOCKSIZE * MV_ROWS_PER_THREAD)
== 0
):
if DEBUG >= 3:
print(
f"MATVEC: full_shape={self.full_shape} first_reduce={self.first_reduce} buf0_strides={buf0_strides} blocksize={MV_BLOCKSIZE} threads_per_row={MV_THREADS_PER_ROW} rows_per_thread={MV_ROWS_PER_THREAD}"
)
if MV_THREADS_PER_ROW > 1:
self.apply_opt(Opt(OptOps.GROUP, 0, MV_THREADS_PER_ROW))
if MV_BLOCKSIZE > 1:
self.apply_opt(Opt(OptOps.LOCAL, global_idx, MV_BLOCKSIZE))
if MV_ROWS_PER_THREAD > 1:
self.apply_opt(
Opt(OptOps.UPCAST, global_idx, MV_ROWS_PER_THREAD)
)
return
if (
self.opts.has_local
and self.opts.has_shared
and all(isinstance(s, int) for s in self.sts[0].shape[: self.first_reduce])
):
# are we grouping? (requires local shape support)
if (
not self.float4_axis(0)
and self.first_reduce <= 2
and self.first_reduce + 1 <= self.shape_len
and prod(self.sts[0].shape[: self.first_reduce]) <= 2048
):
# TODO: use 1024 if it's allowed in a smarter way
for sz in (
([256, 16])
if prod(self.sts[0].shape[: self.first_reduce]) <= 32
else [16]
):
if all(
st.shape[self.first_reduce] % sz == 0
or st.shape[self.first_reduce] == 1
for st in self.sts
):
self.apply_opt(Opt(OptOps.GROUPTOP, 0, sz))
break
# are we upcasting in mid reduce? (only for images)
if (
self.bufs[0].dtype.name.startswith("image")
and not self.float4_axis(0)
and self.group_for_reduce
and self.first_reduce <= 2
and prod(self.sts[0].shape) > 1
):
axes = self.sts[0].unit_stride_axes()
assert len(axes) == 1, f"wrong number of stride 1 axis : {axes}"
if self.sts[0].shape[axes[0]] % 4 == 0:
self.apply_opt(Opt(OptOps.UPCASTMID, axes[0], 4))
# upcast float4 images
for buf_index, buf in enumerate(self.bufs):
unit_stride_axes_mul_4 = [
i
for i in self.sts[buf_index].unit_stride_axes(ignore_valid=True)
if self.sts[buf_index].shape[i] % 4 == 0
]
if buf.dtype.__class__ is ImageDType:
# assert len(unit_stride_axes_mul_4) >= 1, f"needs a unit stride axis in {self.bufs[buf_index]}"
if (
len(unit_stride_axes_mul_4)
and all(
x < (self.shape_len - self.upcasted)
for x in unit_stride_axes_mul_4
)
and unit_stride_axes_mul_4[0] not in self.upcast_in_mid_reduce_axes
):
if unit_stride_axes_mul_4[0] < self.first_reduce:
self.apply_opt(Opt(OptOps.UPCAST, unit_stride_axes_mul_4[0], 4))
else:
self.apply_opt(
Opt(
OptOps.UNROLL,
unit_stride_axes_mul_4[0] - self.first_reduce,
4,
)
)
# no more opt if we are grouping
if self.group_for_reduce:
return
# **** below this line need to be optional and benchmarked ****
# TODO: doing extra upcasts with images doesn't work for some reason (maybe has to do with to_image_idx)
# to trigger the above bug, remove prod(self.full_shape[self.shape_len - self.upcasted:]) from the below
# expression and run test/test_ops.py with IMAGE=2
# if there are small dims with lots of valid masks, upcast them (they might be from Tensor.stack)
# this can be made much smarter
to_upcast: List[int] = []
# upcast leading axes first (hack-ish for winograd; we actually want to upcast masked axes with low stride first)
for axis in range(self.first_reduce):
# we might want to be able to split axes that are masked, or refuse to merge them in simplify_merge_adjacent
# for now skip upcasting here if there is a symbolic axis
if (
isinstance(self.full_shape[axis], int)
and self.full_shape[axis] <= 7
and any(st.axis_is_masked(axis) for st in self.sts)
and prod(self.full_shape[self.shape_len - self.upcasted :])
* prod(self.full_shape[j] for j in to_upcast)
* self.full_shape[axis]
<= 7 * 7
):
if DEBUG >= 4:
print(f"upcasting masked axis : {axis}")
to_upcast.append(axis)
for axis in to_upcast[::-1]:
self.apply_opt(Opt(OptOps.UPCAST, axis, 0))
# potentially do more upcasts of non reduce axes based on a heuristic
upcasted_axis = set()
while prod(self.sts[0].shape[: self.first_reduce]) >= 1024:
xb_choices = []
for axis, upcast_amount in itertools.product(
range(self.first_reduce), [3, 4]
): # consider all the non reduce axes, and a 3 or 4 reduce
# if we haven't upcasted it, it's not symbolic, it mods, and some buffer has stride 0 on axis while having no stride 0 in the upcasted axis already
if (
axis not in upcasted_axis
and isinstance(self.full_shape[axis], int)
and self.full_shape[axis] % upcast_amount == 0
and any(
st.views[-1].strides[axis] == 0
and not any(x[1] == 0 for x in self.upcasted_axis(buf_index))
for buf_index, st in enumerate(self.sts)
)
):
xb_choices.append(
(
sum(st.views[-1].strides[axis] > 0 for st in self.sts),
sum(st.views[-1].strides[axis] for st in self.sts),
axis,
upcast_amount,
)
)
if xb_choices:
xb_choices = sorted(xb_choices)
if DEBUG >= 4:
print(f"float4 merging axis : {xb_choices}")
self.apply_opt(Opt(OptOps.UPCAST, xb_choices[0][2], xb_choices[0][3]))
upcasted_axis.add(xb_choices[0][2])
else:
break
# if last dim is small(ish) and it's a reduce dim, upcast the reduce (loop unrolling). no simplify needed since it's just an upcast. NOTE: careful, this has broken VALIDHACKS
if (
self.first_reduce < (self.shape_len - self.upcasted)
and (
len(list(self.shape_offsets(self.full_buf_index))) <= 4
or not any(r for _, _, r in self.upcasted_axis(self.full_buf_index))
)
and (self.upcasted == 0 or prod(self.full_shape[-self.upcasted :]) < 64)
):
if (s := self.full_unupcasted_shape[-1]) <= 32 and isinstance(
s, int
): # NOTE: cannot loop unroll symbolic axis
self.apply_opt(
Opt(
OptOps.UNROLL,
len(self.full_unupcasted_shape) - 1 - self.first_reduce,
0,
)
)
# if it's small, upcast a second reduce dimension too
if (
self.first_reduce < (self.shape_len - self.upcasted)
and s <= 3
and (s2 := self.full_unupcasted_shape[-1]) <= 3
and isinstance(s2, int)
):
self.apply_opt(
Opt(
OptOps.UNROLL,
len(self.full_unupcasted_shape) - 1 - self.first_reduce,
0,
)
)
else:
for splits in [4]:
if self.full_unupcasted_shape[-1] % splits == 0:
self.apply_opt(
Opt(
OptOps.UNROLL,
len(self.full_unupcasted_shape) - 1 - self.first_reduce,
splits,
)
)
break
# if nothing at all is upcasted and it's easy to, do an upcast
# TODO: this is breaking the tests
for splits in [4]:
if (
self.upcasted == 0
and self.full_unupcasted_shape
and self.full_unupcasted_shape[-1] % splits == 0
):
self.apply_opt(
Opt(OptOps.UPCAST, len(self.full_unupcasted_shape) - 1, splits)
)
# **** local groups ****
if self.opts.has_local:
if (
getenv("NOLOCALS")
and self.local_dims == 0
and not self.group_for_reduce
):
self.apply_opt(Opt(OptOps.NOLOCALS))
else:
# prioritize making expand axes local
local_axis_ranking = [
(
any(
self.sts[buf_index].views[-1].strides[axis] == 0
for buf_index in range(len(self.sts))
),
axis,
)
for axis in range(len(self.full_shape[: self.first_reduce]))
]
to_local: List[Tuple[int, int]] = []
for _, axis in sorted(local_axis_ranking, key=lambda x: (-x[0], -x[1])):
local_size = prod(sz for _, sz in to_local)
local_sz: Optional[int] = next(
(
x
for x in ([32] * (axis == 0) + [16, 8, 4, 3, 2])
if self.full_shape[axis] % x == 0 and local_size * x <= 128
),
None,
)
if local_sz is not None:
to_local.append((axis, local_sz))
deleted_shape = 0
for axis, local_sz in sorted(to_local[:3]):
axis = axis - deleted_shape
will_delete_shape = local_sz == self.full_shape[axis]
self.apply_opt(Opt(OptOps.LOCAL, axis, local_sz))
if will_delete_shape:
deleted_shape += 1