from __future__ import annotations
from typing import (
List,
Tuple,
Any,
Optional,
cast,
DefaultDict,
Dict,
Union,
Sequence,
Final,
Set,
)
import itertools, math, functools
from collections import defaultdict
from enum import Enum, auto
from dataclasses import dataclass
from tinygrad.helpers import (
colored,
ImageDType,
DEBUG,
dtypes,
DType,
prod,
PtrDType,
getenv,
all_same,
to_function_name,
flatten,
)
from tinygrad.ops import (
LazyOp,
UnaryOps,
BinaryOps,
TernaryOps,
ReduceOps,
ConstBuffer,
MemBuffer,
BufferOps,
)
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.symbolic import (
Variable,
NumNode,
VariableOrNum,
Node,
SumNode,
MulNode,
DivNode,
ModNode,
LtNode,
AndNode,
)
from tinygrad.codegen.kernel import LocalBuffer, Kernel
from tinygrad.lazy import vars_from_ast
from tinygrad.features.image import to_image_idx
# bottom ones are asm only
[docs]
class UOps(Enum):
"""
Enumeration for Unified Operations (UOps).
Attributes:
LOOP: Indicates a loop operation.
IF: Indicates an if operation.
END: Indicates the end of an operation block.
SPECIAL: Indicates a special operation, which can be global, local, or other type.
DEFINE_GLOBAL: Indicates defining a global variable.
DEFINE_LOCAL: Indicates defining a local variable.
DEFINE_ACC: Indicates defining a buffer.
LOAD: Indicates loading data.
STORE: Indicates storing data.
CONST: Indicates a constant operation.
BARRIER: Indicates a barrier operation.
PHI: Indicates a phi operation, which is used in SSA form for merging control flow.
ALU: Indicates an arithmetic logic unit operation.
WMMA: Indicates a warp matrix multiplication and accumulation operation.
CAST: Indicates a cast operation.
GEP: Indicates a get element pointer operation, which is used to compute an address of a data element in memory.
"""
LOOP = auto()
IF = auto()
END = auto()
SPECIAL = auto() # loops can be global, local, or other # noqa: E702
DEFINE_GLOBAL = auto()
DEFINE_LOCAL = auto()
DEFINE_ACC = auto() # this defines buffers # noqa: E702
LOAD = auto()
STORE = auto()
CONST = auto()
BARRIER = auto()
PHI = auto() # noqa: E702
ALU = auto()
WMMA = auto()
CAST = auto()
GEP = auto() # noqa: E702
[docs]
@dataclass(eq=False)
class UOp:
"""
Data class for UOp.
Attributes:
uop : UOps - Operation to be performed.
dtype : Optional[DType] - Datatype of the operation result.
vin : Tuple[UOp, ...] - Inputs for the operation.
arg : Any - Additional argument for the operation.
"""
uop: UOps
dtype: Optional[DType]
vin: Tuple[UOp, ...]
arg: Any
def __repr__(self):
return f"{str(self.uop):20s}: {str(self.dtype) if self.dtype is not None else '':25s} {str([x.uop for x in self.vin]):32s} {self.arg}"
[docs]
def get_grouped_dims(prefix, start_dim, local_dims, maxdim: int = 0):
"""
Generate grouped dimensions and loop local indices.
Args:
prefix : str - Prefix for variable names.
start_dim : int - Start dimension index.
local_dims : list - List of local dimensions.
maxdim : int, optional - Maximum dimension value, default is 0.
Returns:
tuple[list, list] - Grouped dimensions and loop local indices.
"""
local_idxs = loop_local_idxs = [
Variable(f"{prefix}{start_dim+i}", 0, s - 1)
for i, s in enumerate(
local_dims[0 : maxdim - 1] + (prod(local_dims[maxdim - 1 :]),)
if len(local_dims) > maxdim
else local_dims
)
]
if maxdim != 0 and len(local_dims) > maxdim:
dd = local_idxs[maxdim - 1]
nli = []
for s in local_dims[maxdim - 1 :][::-1]:
nli.append(dd % s)
dd //= s
local_idxs = local_idxs[0 : maxdim - 1] + nli[::-1]
return local_idxs, [x for x in loop_local_idxs if not isinstance(x, NumNode)]
[docs]
class Linearizer(Kernel):
"""
The Linearizer class. Subclass of Kernel.
Attributes:
uop_alu_idx (method): Takes a UOp, b, ops, ctx, op and dtype as input, returns the result of the ALU operation.
const (method): Takes b, dtype and insert_before as input, returns a UOp with CONST operation.
cast (method): Takes val and dtype as input, returns a UOp with CAST operation if val.dtype is not equal to dtype, else returns val.
get_reduce_acc (method): Takes op and dtype as input, returns the initial accumulator value for the given reduce operation.
"""
[docs]
def uop_alu_idx(self, a: UOp, b, ops, ctx: Linearizer, op, dtype=dtypes.int32):
"""
Renders the ALU operation with given parameters and returns the result.
Args:
a (UOp): The first input UOp.
b: The second input.
ops: The operations.
ctx (Linearizer): The linearizer context.
op: The operation to perform.
dtype (DType, optional): The data type of the result. Defaults to dtypes.int32.
Returns:
UOp: The result of the ALU operation.
"""
render_b: UOp = cast(
UOp, (NumNode(b) if not isinstance(b, Node) else b).render(ops, ctx)
)
return self.uop(UOps.ALU, dtype, (a, render_b), op)
# NOTE: the consts have to be cached for deduping of downstream uops to work
[docs]
def const(
self, b: Union[int, float], dtype=dtypes.int32, insert_before=None
) -> UOp:
"""
Creates a CONST operation with the given parameters and returns the resultant UOp.
Args:
b (Union[int, float]): The constant value.
dtype (DType, optional): The data type of the result. Defaults to dtypes.int32.
insert_before (optional): For insertion before a certain operation. Defaults to None.
Returns:
UOp: The resultant UOp of the CONST operation.
"""
return self.uop(UOps.CONST, dtype, tuple(), b, insert_before=insert_before)
[docs]
def cast(self, val: UOp, dtype) -> UOp:
"""
If val.dtype is not equal to dtype, creates a CAST operation with the given parameters and returns the resultant UOp, else returns val.
Args:
val (UOp): The input UOp.
dtype: The data type to cast to.
Returns:
UOp: The resultant UOp of the CAST operation, or val if val.dtype is equal to dtype.
"""
return self.uop(UOps.CAST, dtype, (val,)) if val.dtype != dtype else val
[docs]
def get_reduce_acc(self, op, dtype: DType):
"""
Returns the initial accumulator value for the given reduce operation.
Args:
op: The reduce operation.
dtype (DType): The data type of the result.
Returns:
Depends on the operation: 0.0 if dtypes.is_float(dtype), 0 if integer, -math.inf for float and -(2**31) for int if ReduceOps.MAX.
"""
if op == ReduceOps.SUM:
return 0.0 if dtypes.is_float(dtype) else 0
elif op == ReduceOps.MAX:
return (
-math.inf
if dtypes.is_float(dtype)
else -(2**31)
if dtypes.is_int(dtype)
else False
)
"""
Render operations dictionary.
Attributes:
Variable: A function that returns the loop unary operation for a given variable node.
NumNode: A function that returns a constant value for a given numeric node.
MulNode, DivNode, ModNode: Functions that perform multiplication, division and modulo operations respectively.
LtNode: A function that performs a less-than comparison between two nodes.
SumNode: A function that adds up all the nodes in a list.
AndNode: A function that performs an AND operation on all nodes in a list.
"""
render_ops: Any = {
Variable: lambda self, ops, ctx: ctx.loop_uops[self.expr],
NumNode: lambda self, ops, ctx: ctx.const(self.b),
MulNode: lambda self, ops, ctx: ctx.uop_alu_idx(
self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.MUL
),
DivNode: lambda self, ops, ctx: ctx.uop_alu_idx(
self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.DIV
),
ModNode: lambda self, ops, ctx: ctx.uop_alu_idx(
self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.MOD
),
LtNode: lambda self, ops, ctx: ctx.uop_alu_idx(
self.a.render(ops, ctx),
self.b,
ops,
ctx,
BinaryOps.CMPLT,
dtype=dtypes.bool,
),
SumNode: lambda self, ops, ctx: functools.reduce(
lambda a, b: ctx.uop_alu_idx(a, b, ops, ctx, BinaryOps.ADD),
self.nodes[1:],
self.nodes[0].render(ops, ctx),
),
AndNode: lambda self, ops, ctx: functools.reduce(
lambda a, b: ctx.uop_alu_idx(
a, b, ops, ctx, BinaryOps.MUL, dtype=dtypes.bool
),
self.nodes[1:],
self.nodes[0].render(ops, ctx),
),
}
[docs]
def global_load(
self, i: int, idxs: Sequence[Node], acc=None, barrier: Optional[UOp] = None
) -> List[UOp]:
"""
Load data from a buffer into the GPU.
:param i: The index of the buffer to load.
:type i: int
:param idxs: A sequence of nodes representing indices for each dimension.
:type idxs: Sequence[Node]
:param acc: An accumulator value, defaults to None.
:type acc: Optional[Any], optional
:param barrier: A UOp barrier, defaults to None.
:type barrier: Optional[UOp], optional
:return: A list of UOps representing the load operation.
:rtype: List[UOp]
:Attributes:
- `buf` (Buffer): The buffer from which data is loaded.
- `const` (Union[ConstBuffer, Any]): The constant value if the buffer is a ConstBuffer, otherwise the accumulator.
"""
buf = self.bufs[i]
const = buf.val if isinstance(buf, ConstBuffer) else acc
def rename_var(v: VariableOrNum, expr: str):
"""
Rename a variable or number node with a new expression.
:param v: The variable or number node to be renamed.
:type v: Union[VariableOrNum, NumNode]
:param expr: The new expression for the renamed node.
:type expr: str
:return: The renamed variable or number node.
:rtype: Union[VariableOrNum, NumNode]
Example usage:
::
new_var = rename_var(old_var, "_uidx0")
"""
return v if isinstance(v, NumNode) else Variable(expr, v.min, v.max)
amt, dim = 1, None
upcast_dim = self.get_upcast_dim(i)
if len(upcast_dim) == 1 and len(
float4_expand := idxs[upcast_dim[0]].expand()
) in [4, 2]:
dim, amt = upcast_dim[0], len(float4_expand)
expand_vars = tuple(
[rename_var(idx.expand_idx(), f"_uidx{j}") for j, idx in enumerate(idxs)]
)
fake_idxs = [
idx.substitute({idx.expand_idx(): ev}) for idx, ev in zip(idxs, expand_vars)
]
if dim is not None:
g_idx, g_valid = self.sts[i].expr_idxs(
fake_idxs[:dim] + [float4_expand[0]] + fake_idxs[dim + 1 :]
)
if (g_idx // amt * amt).render() != g_idx.render():
(g_idx, g_valid), amt, dim = self.sts[i].expr_idxs(fake_idxs), 1, None
else:
g_idx, g_valid = self.sts[i].expr_idxs(fake_idxs)
localtype = buf.dtype if amt == 1 else buf.dtype.vec(amt)
if isinstance(buf.dtype, ImageDType):
localtype = dtypes.float if amt == 1 else dtypes.float.vec(amt)
e_idxs, e_valids = g_idx.expand(expand_vars), g_valid.expand(expand_vars)
ret = []
invalid_value = 0 if dtypes.is_int(buf.dtype) else 0.0
for idx, valid, rep_idx in zip(e_idxs, e_valids, Node.iter_idxs(expand_vars)):
this_const, idx, valid = (
(invalid_value, NumNode(0), NumNode(1))
if valid.max == 0
else (const, idx, valid)
)
key = f"{acc}{localtype}{this_const if this_const is not None and acc is None else (buf.idx if isinstance(buf, MemBuffer) else cast(LocalBuffer, buf).name)}{idx.render()}{valid.render()}"
if key not in self.load_cache:
if acc is not None:
self.load_cache[key] = self.uop(
UOps.DEFINE_ACC, localtype, (), this_const, cachable=False
)
elif this_const is not None:
self.load_cache[key] = self.const(this_const, localtype)
if valid.min == 0 and valid.max == 1:
valid_rendered = valid.render(self.render_ops, self)
self.load_cache[key] = self.uop(
UOps.ALU,
localtype,
(
valid_rendered,
self.load_cache[key],
self.const(invalid_value, localtype),
),
TernaryOps.WHERE,
)
elif isinstance(buf.dtype, ImageDType):
buf_uop = self.buf_uops[i]
assert buf_uop is not None, f"buffer {i} wasn't UOped"
image_idx, valid = to_image_idx(buf.dtype.shape, idx, valid)
rendered_idx = self.uop(
UOps.CAST,
dtypes.int.vec(2),
(
image_idx[0].render(self.render_ops, self),
image_idx[1].render(self.render_ops, self),
),
)
valid_tuple = (
(
valid.render(self.render_ops, self),
self.const(invalid_value, dtypes.float32.vec(4)),
)
if valid.min == 0
else tuple()
)
self.load_cache[key] = self.uop(
UOps.LOAD,
dtypes.float32.vec(4),
(buf_uop, rendered_idx)
+ valid_tuple
+ ((barrier,) if barrier else ()),
)
idx_small = idx % 4
res = idx_small.render(self.render_ops, self)
if localtype == localtype.scalar():
out = self.uop(
UOps.GEP, localtype, (self.load_cache[key],), idx_small.max
)
for ix in range(idx_small.max, idx_small.min, -1):
rvv = self.uop(
UOps.GEP, localtype, (self.load_cache[key],), ix - 1
)
sel = self.uop(
UOps.ALU,
res.dtype,
(res, self.const(ix)),
BinaryOps.CMPLT,
)
out = self.uop(
UOps.ALU, localtype, (sel, rvv, out), TernaryOps.WHERE
)
self.load_cache[key] = out
else:
buf_uop = self.buf_uops[i]
assert buf_uop is not None, f"buffer {i} wasn't UOped"
rendered_idx = idx.render(self.render_ops, self)
valid_tuple = (
(
valid.render(self.render_ops, self),
self.const(invalid_value, localtype),
)
if valid.min == 0
else tuple()
)
self.load_cache[key] = self.uop(
UOps.LOAD,
localtype,
(buf_uop, rendered_idx)
+ valid_tuple
+ ((barrier,) if barrier else ()),
)
ret.append(
self.uop(
UOps.GEP, localtype.scalar(), (self.load_cache[key],), rep_idx[dim]
)
if dim is not None
else self.load_cache[key]
)
return ret
[docs]
def global_store(self, i: int, idxs: List[Node], store: List[UOp]) -> List[UOp]:
"""
Perform a global store operation on the buffer at index `i`.
Parameters:
i (int): The index of the buffer to perform the store operation on.
idxs (List[Node]): The list of Nodes representing indices for the store operation.
store (List[UOp]): The list of UOps containing the values to be stored.
Returns:
List[UOp]: A list of UOps representing the stored values.
"""
buf = self.bufs[i]
buf_uop = self.buf_uops[i]
assert buf_uop is not None, f"buffer {i} wasn't UOped"
expanded_nodes = [idx.expand() for idx in idxs]
_idxs = [x[::-1] for x in itertools.product(*expanded_nodes[::-1])]
store_offset = dict(zip(_idxs, store))
# float4 grouping
upcast_dim = self.get_upcast_dim(i)
if len(upcast_dim) == 1 and len(expanded_nodes[upcast_dim[0]]) in [2, 4]:
grouped_store_offset = defaultdict(list)
for k in store_offset:
_idx = (
k[: upcast_dim[0]]
+ (expanded_nodes[upcast_dim[0]][0],)
+ k[upcast_dim[0] + 1 :]
)
grouped_store_offset[_idx].append(store_offset[k])
store_offset_new = {}
for k, out_tokens in grouped_store_offset.items():
amt = len(out_tokens)
idx, valid = self.sts[i].expr_idxs(k)
assert (
idx.render() == ((idx // amt) * amt).render()
), "float4 stores are always aligned"
store_offset_new[k] = self.uop(
UOps.CAST, dtypes.float.vec(amt), tuple(out_tokens)
)
store_offset = store_offset_new
stores = []
for idx, var in store_offset.items():
idx, valid = self.sts[i].expr_idxs(idx)
if isinstance(buf.dtype, ImageDType):
idx, valid = to_image_idx(buf.dtype.shape, idx, valid)
rendered_idx = self.uop(
UOps.CAST,
dtypes.int.vec(2),
tuple(x.render(self.render_ops, self) for x in idx),
)
else:
rendered_idx = idx.render(self.render_ops, self)
if valid.min == 1:
stores.append(self.uop(UOps.STORE, None, (buf_uop, rendered_idx, var)))
else:
stores.append(
self.uop(
UOps.STORE,
None,
(
buf_uop,
rendered_idx,
var,
valid.render(self.render_ops, self),
),
)
)
return stores
kernel_cnt: Final[DefaultDict[str, int]] = defaultdict(int)
[docs]
def linearize(self):
"""
Linearize the computation graph.
This function checks if new options are applied and if it is necessary to relinearize. It also handles backups,
global uop cache, limiting dimensions, uops, buffer uops, and loop uops.
Attributes:
self.applied_opts (Dict): The newly applied options.
self.applied_opts_cache (Dict): The previously applied options.
sts_backup (List): A backup of the current state.
gfr_backup (List): A backup of the group for reduce.
upc_backup (Any): A backup of the upcasted value.
self.saved_exprs (Dict[Tuple, UOp]): A dictionary storing expressions as keys and UOps as values.
self.opts.global_max (int): The maximum global dimension.
self.opts.local_max (int): The maximum local dimension.
self.uops (List[UOp]): A list of UOp objects, representing the computation operations.
self.buf_uops (List[Optional[UOp]]): A list of optional UOp objects for buffering.
self.loop_uops (Dict[str, UOp]): A dictionary with loop names as keys and UOp objects as values.
Returns:
self if relinearization is not needed, otherwise None.
"""
# no new opts and we already ran? skip relinearizing
if self.applied_opts == self.applied_opts_cache:
return self
# save backups
sts_backup, gfr_backup, upc_backup = (
self.sts[:],
self.group_for_reduce[:],
self.upcasted,
)
# global uop cache
self.saved_exprs: Dict[Tuple, UOp] = dict()
# limit dims if we need to
if self.opts.global_max and self.opts.local_max:
self.limit_dims_to_max(self.opts.global_max, self.opts.local_max)
# uops
self.uops: List[UOp] = []
self.buf_uops: List[Optional[UOp]] = [None] * len(self.bufs)
self.loop_uops: Dict[str, UOp] = {}
# add global buffers
for i, buf in enumerate(self.bufs):
if isinstance(buf, MemBuffer):
self.buf_uops[i] = self.uop(
UOps.DEFINE_GLOBAL,
PtrDType(buf.dtype)
if not isinstance(buf.dtype, ImageDType)
else buf.dtype,
(),
(f"data{buf.idx}", buf.dtype),
)
# add var vals
for var in vars_from_ast(self.ast):
assert var.expr is not None
self.loop_uops[var.expr] = self.uop(
UOps.DEFINE_GLOBAL, dtypes.int32, (), (var.expr, dtypes._arg_int32)
)
# define local buffers
for lb in self.local_alias.values():
self.buf_uops[self.bufs.index(lb)] = self.uop(
UOps.DEFINE_LOCAL,
PtrDType(dtypes.float32),
(),
(lb.name, self.sts[self.bufs.index(lb)].size()),
)
# add a local buffer for multistage reduce. # TODO: use local alias
if self.group_for_reduce:
# TODO: the strides of this can be controlled
self.sts.append(
ShapeTracker.from_shape(
tuple(
[1] * self.global_dims
+ list(
self.full_shape[
self.global_dims : self.global_dims
+ self.local_dims
+ len(self.group_for_reduce)
]
)
+ [1]
* (
self.shape_len
- self.upcasted
- len(self.group_for_reduce)
- self.first_reduce
)
+ [x[0] for x in self.upcasted_axis(0)]
)
)
)
self.bufs.append(LocalBuffer("temp", self.sts[-1].size()))
self.buf_uops.append(
self.uop(
UOps.DEFINE_LOCAL,
PtrDType(dtypes.float32),
(),
("temp", self.sts[-1].size()),
)
)
# kernel name (before late upcast)
self.name = ("r_" if self.reduceop else "E_") + colored("_", "BLACK").join(
[colored(str(x), c) for x, c in zip(self.full_shape, self.colors())]
)
# name the function something unique
Linearizer.kernel_cnt[(function_name := to_function_name(self.name))] += 1
suffix = (
f"{'n'+str(Linearizer.kernel_cnt[function_name]-1)}"
if Linearizer.kernel_cnt[function_name] > 1
else ""
)
self.name = self.name + colored(suffix, "BLACK")
# define indexes
global_idxs, loop_global_idxs = get_grouped_dims(
"gidx",
0,
self.full_shape[: self.global_dims],
3 if self.opts.has_local else 0,
)
local_idxs, loop_local_idxs = get_grouped_dims(
"lidx",
self.global_dims,
self.full_shape[
self.global_dims : self.first_reduce + len(self.group_for_reduce)
],
3 if self.opts.has_local else 0,
)
full_upcast_idxs = [
Variable(None, 0, s - 1)
for s in self.full_shape[self.shape_len - self.upcasted :]
]
upcast_idxs = [
Variable(None, 0, s - 1)
for s in self.output_shape[self.shape_len - self.upcasted :]
]
# global and local loops
def render_loop(xx: List[Variable]) -> Tuple[UOp, ...]:
"""
This function is responsible for rendering the loop.
Attributes:
xx (List[Variable]): List of variables to be processed.
Returns:
Tuple[UOp, ...]: A tuple of UOp objects representing the rendered loops.
"""
new_loops = {
x.expr: self.uop(
UOps.LOOP,
dtypes.int32,
(
self.const(x.min)
if isinstance(x.min, int)
else cast(Node, x.min).render(self.render_ops, self),
self.const(x.max + 1)
if isinstance(x.max, int)
else cast(Node, x.max + 1).render(self.render_ops, self),
),
cachable=False,
)
for x in xx
if not isinstance(x, NumNode) and x.expr is not None
}
self.loop_uops.update(new_loops)
return tuple(new_loops.values())
# set global/local size
self.global_size: Optional[List[int]] = None
self.local_size: Optional[List[int]] = None
if self.dont_use_locals:
self.global_size = [x.max + 1 for x in loop_global_idxs][::-1]
self.loop_uops.update(
{
x.expr: self.uop(
UOps.SPECIAL,
dtypes.int32,
(),
(
len(loop_global_idxs) - 1 - i,
x.expr.replace("gidx", "idx"),
x.max + 1,
),
)
for i, x in enumerate(loop_global_idxs)
}
)
elif self.opts.has_local:
self.global_size, self.local_size = [x.max + 1 for x in loop_global_idxs][
::-1
], [x.max + 1 for x in loop_local_idxs][::-1]
self.loop_uops.update(
{
x.expr: self.uop(
UOps.SPECIAL,
dtypes.int32,
(),
(len(loop_global_idxs) - 1 - i, x.expr, x.max + 1),
)
for i, x in enumerate(loop_global_idxs)
}
)
self.loop_uops.update(
{
x.expr: self.uop(
UOps.SPECIAL,
dtypes.int32,
(),
(len(loop_local_idxs) - 1 - i, x.expr, x.max + 1),
)
for i, x in enumerate(loop_local_idxs)
}
)
else:
render_loop(loop_global_idxs + loop_local_idxs)
# parse AST
loaded_buffers = {}
acc: List[UOp] = []
self.load_cache: Dict[str, UOp] = {}
# reduce op
fake_reduce_idxs: List[Variable] = []
if self.reduceop is not None:
# define indexes
reduce_idxs = [
Variable(f"ridx{i}", 0, self.full_shape[i] - 1)
for i in range(
self.first_reduce + len(self.group_for_reduce),
self.shape_len - self.upcasted,
)
]
fake_reduce_idxs = [x * 0 for x in reduce_idxs]
# define accumulator
acc = self.global_load(
0,
global_idxs + local_idxs + fake_reduce_idxs + upcast_idxs,
self.get_reduce_acc(self.reduceop.op, self.bufs[0].dtype),
)
if self.tensor_core:
def calc_tc_idxs(local_size: int, aliases: List[List[int]]):
replace_idxs = []
for alias in aliases:
full_var, full_var_sz = NumNode(0), 1
if alias[0] != 0:
for i in alias:
next_var = (
local_idxs[-i]
if i > 0
else Variable(None, 0, local_size - 1)
)
full_var += next_var * full_var_sz
full_var_sz *= next_var.max + 1
replace_idxs.append(full_var)
return replace_idxs
replace_acc_idxs = calc_tc_idxs(
self.tensor_core.thread_local_sizes[2],
self.tensor_core.thread_local_aliases[2],
)
for n in range(len(self.tensor_core.threads)):
local_idxs[
self.local_dims - len(self.tensor_core.threads) + n
] = replace_acc_idxs[
n
] # replace locals
for n in range(len(replace_acc_idxs) - len(self.tensor_core.threads)):
upcast_idxs[n] = replace_acc_idxs[
len(self.tensor_core.threads) + n
] # replace upcasts
# reduce loop
loop_ctx = render_loop(reduce_idxs)
# barrier for fast GEMM
if self.tensor_core:
self.uop(UOps.BARRIER, None, (), cachable=False)
# compute local aliases
locals_to_store = []
for i in self.local_alias:
localbuf_idx = self.bufs.index(self.local_alias[i])
buf_idxs = [
idx * 0 if s == 0 else idx
for idx, s in zip(
global_idxs + local_idxs + reduce_idxs + full_upcast_idxs,
self.sts[i].real_strides(),
)
]
if self.tensor_core:
min_alias_idx = min(self.local_alias.keys())
replace_input_idxs = calc_tc_idxs(
self.tensor_core.thread_local_sizes[i - min_alias_idx],
self.tensor_core.thread_local_aliases[i - min_alias_idx],
)
for n in range(len(self.tensor_core.threads)):
buf_idxs[
self.first_reduce - len(self.tensor_core.threads) + n
] = replace_input_idxs[
n
] # replace locals
for n in range(
len(replace_input_idxs) - len(self.tensor_core.threads)
):
buf_idxs[
self.shape_len - self.upcasted + n
] = replace_input_idxs[
len(self.tensor_core.threads) + n
] # replace upcasts
if DEBUG >= 3:
print(f"{localbuf_idx} alias {i}: idxs=", buf_idxs)
ll = self.global_load(i, buf_idxs)
locals_to_store.append((localbuf_idx, buf_idxs, ll))
# copy in any global buffers
if self.tensor_core:
wmma_sz = self.tensor_core.thread_local_sizes
# calculate the number of local accumulator reduces and render WMMAs: this is bad... this needs to come from someplace else
nx, ny, nacc = (
(len(locals_to_store[0][2]) // wmma_sz[0]),
(len(locals_to_store[1][2]) // wmma_sz[1]),
(len(acc) // wmma_sz[2]),
)
acc_reds = math.isqrt((nx * ny) // nacc)
i, bx, by = 0, nx // acc_reds, ny // acc_reds
for y in range(by):
for x in range(bx):
for j in range(acc_reds):
op1, op2, op3 = (
locals_to_store[0][2][
(x + (j * bx))
* wmma_sz[0] : (x + (j * bx) + 1)
* wmma_sz[0]
],
locals_to_store[1][2][
(y + (j * by))
* wmma_sz[1] : (y + (j * by) + 1)
* wmma_sz[1]
],
acc[i : i + wmma_sz[2]],
)
if self.opts.device != "HIP":
ops = tuple(op1 + op2 + op3)
else:
ops = (
self.uop(
UOps.CAST, dtypes.half.vec(16), tuple(op1)
),
self.uop(
UOps.CAST, dtypes.half.vec(16), tuple(op2)
),
self.uop(
UOps.CAST, dtypes.float.vec(8), tuple(op3)
),
)
ret = self.uop(
UOps.WMMA,
dtypes.float.vec(2)
if wmma_sz[2] == 2
else dtypes.float.vec(8),
ops,
(
self.opts.device,
self.tensor_core.dtype_in,
self.tensor_core.dtype_out,
),
)
for z in range(cast(DType, ret.dtype).sz):
acc[i + z] = self.uop(
UOps.PHI,
dtypes.float,
(
op3[z],
self.uop(UOps.GEP, dtypes.float, (ret,), z),
)
+ loop_ctx,
)
i += wmma_sz[2]
else:
if locals_to_store:
self.uop(UOps.BARRIER, None, (), cachable=False)
for i, idxs, ll in locals_to_store:
self.global_store(i, idxs, ll)
self.uop(UOps.BARRIER, None, (), cachable=False)
# load earlybufs
loaded_buffers.update(
{
b: self.global_load(
self.bufs.index(self.local_alias[i])
if i in self.local_alias
else i,
global_idxs + local_idxs + reduce_idxs + full_upcast_idxs,
)
for i, b in enumerate(self.bufs[1:], start=1)
if b in self.earlybufs
}
)
# run early AST (with reduce)
self.ast_parse(
self.reduceop,
acc,
self.acc_offsets(self.full_buf_index),
loaded_buffers,
do_reduce=True,
loop_ctx=loop_ctx,
)
# end the reduce loop
self.load_cache.clear()
# end the local loop, do the local reduce
if self.group_for_reduce:
fake_global_idxs = [x * 0 for x in global_idxs]
stores = self.global_store(
-1,
fake_global_idxs + local_idxs + fake_reduce_idxs + upcast_idxs,
acc,
) # store accumulators
barrier = self.uop(UOps.BARRIER, None, tuple(stores), cachable=False)
if self.opts.has_local:
fake_idxs = [NumNode(0)] * len(self.sts[-1].shape)
fake_idxs[
self.global_dims
+ self.local_dims : self.global_dims
+ len(local_idxs)
] = local_idxs[self.local_dims :]
if_cond: UOp = (self.sts[-1].expr_idxs(fake_idxs)[0] < 1).render(
self.render_ops, self
)
barrier = self.uop(
UOps.IF, None, (if_cond, barrier), cachable=False
)
# create new late reduce local loops and replace local_idxs that have been used
end_local_idxs = [
Variable(
f"tidx{i}",
0,
self.full_shape[i] - 1
if i >= self.first_reduce
and i not in self.upcast_in_mid_reduce_axes
else 0,
)
for i in range(0, self.first_reduce + len(self.group_for_reduce))
]
local_idxs = (
local_idxs[: self.local_dims]
+ end_local_idxs[self.global_dims + self.local_dims :]
)
# if any group_for_reduce items aren't reduces, upcast them here
for j in self.upcast_in_mid_reduce_axes:
self.reshape_and_permute(
None, [i for i in range(self.shape_len) if i != j] + [j]
)
self.upcast()
self.group_for_reduce.pop()
local_idxs = local_idxs[:-1]
end_local_idxs = end_local_idxs[:-1]
# regenerate upcast_idxs
upcast_idxs = [
Variable(None, 0, s - 1)
for s in self.output_shape[self.shape_len - self.upcasted :]
]
# NOTE: this structure is the same as the reduce op above
# define late accumulator
acc = self.global_load(
-1,
fake_global_idxs + local_idxs + fake_reduce_idxs + upcast_idxs,
self.get_reduce_acc(self.reduceop.op, self.bufs[-1].dtype),
)
# late reduce loop
loop_ctx = render_loop(end_local_idxs)
# load localbufs
loaded_buffers[self.bufs[-1]] = self.global_load(
-1,
fake_global_idxs + local_idxs + fake_reduce_idxs + upcast_idxs,
barrier=barrier,
)
# there's no AST here (and there's no shape for the reduce LazyOp)
self.ast_parse(
LazyOp(
self.reduceop.op, (LazyOp(BufferOps.LOAD, (), self.bufs[-1]),)
),
acc,
self.acc_offsets(-1),
loaded_buffers,
do_reduce=True,
loop_ctx=loop_ctx,
)
# end the late reduce loop
self.load_cache.clear()
# load latebufs
loaded_buffers.update(
{
b: self.global_load(
i, global_idxs + local_idxs + fake_reduce_idxs + upcast_idxs
)
for i, b in enumerate(self.bufs)
if b not in self.earlybufs and i != 0 and b.__class__ is not LocalBuffer
}
)
# run late AST (without the store)
val = self.ast_parse(cast(LazyOp, self.ast.src[0]), acc, None, loaded_buffers)
# store
self.global_store(
0, global_idxs + local_idxs + fake_reduce_idxs + upcast_idxs, val
)
# graph helper functions
@functools.lru_cache(None)
def get_recursive_parents(x: UOp) -> Set[UOp]:
"""
Get recursive parents of a given UOp object.
This function uses a cache to store previously calculated results for
better performance. It returns the union of direct parent UOps and
all recursive parents of those direct parents.
:param x: The input UOp object.
:type x: UOp
:return: A set containing all recursive parent UOps.
:rtype: Set[UOp]
"""
return set.union(set(x.vin), *[get_recursive_parents(p) for p in x.vin])
def get_recursive_children(x: UOp) -> Set[UOp]:
"""
Get recursive children of a given UOp object.
This function starts with the initial set containing only the input
UOp and iteratively adds all its direct children to the set until no
new children can be added. It returns the final set containing all
recursive children UOps.
:param x: The input UOp object.
:type x: UOp
:return: A set containing all recursive child UOps.
:rtype: Set[UOp]
"""
deps = set([x])
ssize = 0
while ssize != len(deps):
ssize = len(deps)
for u in self.uops:
if len(deps.intersection([x for x in u.vin if x.uop != UOps.PHI])):
deps.add(u)
return deps
def replace_op(old: UOp, new: UOp):
"""
Replace a given UOp with another UOp in the list of all UOps.
This function iterates over all UOps and replaces every occurrence
of the old UOp with the new one. It also removes the old UOp from
the list.
:param old: The UOp to be replaced.
:type old: UOp
:param new: The replacement UOp.
:type new: UOp
"""
for u in self.uops:
u.vin = tuple(new if x is old else x for x in u.vin)
self.uops.remove(old)
# fix loop scope, push CONST and ALU upward out of loop if it does not depend on the loop
loop_stack: List[List[UOp]] = [[]]
for u in self.uops:
if not loop_stack[-1]:
loop_stack[-1].append(u)
elif u.uop == UOps.LOOP:
loop_stack.append([u])
elif u.uop not in [UOps.CONST, UOps.ALU, UOps.CAST]:
loop_stack[-1].append(u)
else:
parents = get_recursive_parents(u)
for i in reversed(range(len(loop_stack))):
# check backwards and put the uop in the first encounter with some dependency
if any(x in parents for x in loop_stack[i]) or i == 0:
loop_stack[i].append(u)
break
self.uops = flatten(loop_stack)
# uops optimization
changed_something = True
while changed_something:
changed_something = False
for u in self.uops:
if u.uop == UOps.PHI and len(u.vin) == 3:
# if the parents of the PHI node don't have the LOOP in their parents, it can be folded
# TODO: ADD becomes a MUL, MAX can just become nothing
if (
all(
x.uop != UOps.LOOP
for x in get_recursive_parents(
UOp(u.uop, u.dtype, u.vin[0:2], u.arg)
)
)
and u.vin[1].arg == BinaryOps.ADD
):
if DEBUG >= 4:
print(f"removing PHI node {u}")
del self.saved_exprs[(u.uop, u.dtype, u.vin, u.arg)]
# NOTE: assuming u.vin[2].vin[1] and u.vin[2].vin[0] have the same dtype
loop_len = self.uop(
UOps.ALU,
u.vin[2].vin[1].dtype,
(u.vin[2].vin[1], u.vin[2].vin[0]),
BinaryOps.SUB,
insert_before=self.uops.index(u),
)
if loop_len.dtype != u.dtype:
loop_len = self.uop(
UOps.CAST,
u.dtype,
(loop_len,),
insert_before=self.uops.index(u),
)
replace_op(
u,
self.uop(
UOps.ALU,
u.dtype,
(
u.vin[1],
loop_len,
),
BinaryOps.MUL,
insert_before=self.uops.index(u),
),
)
changed_something = True
# (recursively) remove childless uops
# NOTE: DEFINE_GLOBAL should be removable, but we'd have to propagate that
UOPS_W_SIDE_EFFECTS = {UOps.STORE, UOps.BARRIER, UOps.DEFINE_GLOBAL}
while 1:
has_child: Set[UOp] = set()
for ru in self.uops:
for vu in ru.vin:
has_child.add(vu)
nu: List[UOp] = [
x for x in self.uops if x in has_child or x.uop in UOPS_W_SIDE_EFFECTS
]
if len(nu) == len(self.uops):
break
if DEBUG >= 4:
print(f"reduced UOp count from {len(self.uops)} to {len(nu)}")
self.uops = nu
del nu
# add UOps.END
for u in self.uops:
if u.uop == UOps.LOOP:
# add END of loops after the last thing that (recursively) depends on them
self.uop(
UOps.END,
None,
(u,),
cachable=False,
insert_before=self.uops.index(
sorted(list(get_recursive_children(u)), key=self.uops.index)[-1]
)
+ 1,
)
elif u.uop == UOps.IF:
# END any if statements at the end of the uops
self.uop(UOps.END, None, (u,), cachable=False)
# maybe graph the uops
if DEBUG >= 5:
for u in self.uops:
print(
f"{self.uops.index(u):4d} {str(u.uop):20s}: {str(u.dtype) if u.dtype is not None else '':25s} {str([self.uops.index(x) for x in u.vin]):32s} {u.arg}"
)
if getenv("GRAPHUOPS"):
from tinygrad.graph import graph_uops
graph_uops(self.uops)
# restore backups
self.sts, self.group_for_reduce, self.upcasted = (
sts_backup,
gfr_backup,
upc_backup,
)
# set cache and return
self.applied_opts_cache = self.applied_opts[:]
return self
[docs]
def uop(
self,
uop: UOps,
dtype: Optional[DType] = None,
vin: Tuple[UOp, ...] = tuple(),
arg: Any = None,
cachable=True,
insert_before=None,
simplify=True,
) -> UOp:
"""
This function creates and manipulates unary operations (uops). It supports various uop types, data types, and operations.
Attributes:
self: The instance of the class that this method is called on
uop: The type of unary operation to perform (e.g., UOps.PHI for phi function)
dtype: The desired output data type (optional)
vin: A tuple of input unary operations (default is an empty tuple)
arg: An optional argument that can be used in certain uops (default is None)
cachable: A boolean flag indicating whether the operation should be cached for faster access (default is True)
insert_before: An optional parameter to specify where this operation should be inserted (default is None)
simplify: A boolean flag indicating whether to attempt to simplify the uop before creating it (default is True)
Returns:
The created and manipulated UOp object based on the given parameters.
"""
key = (uop, dtype, vin, arg)
if uop == UOps.PHI and vin[1].dtype != dtype:
vin = (vin[0], self.cast(vin[1], dtype)) + vin[1:]
if uop == UOps.ALU: # upcast vins to the same dtype
upcast_dtype = (
dtypes.float
if arg == TernaryOps.MULACC
else max(cast(DType, x.dtype) for x in vin)
) # MULACC is only supported in float
if arg == TernaryOps.WHERE:
vin = (vin[0],) + tuple(
self.cast(x, upcast_dtype) for x in vin[1:]
) # the first arg is always bool
else:
vin = tuple(self.cast(x, upcast_dtype) for x in vin)
dtype = dtype or upcast_dtype # some ops like BinaryOps.CMPLT return bool
if simplify:
if uop == UOps.PHI and len(vin) == 2:
return vin[1] # a phi without loops is a noop
if uop == UOps.GEP and vin[0].uop == UOps.CONST:
return self.const(vin[0].arg, dtype, insert_before)
if (
uop == UOps.CAST
and all(x.uop == UOps.CONST for x in vin)
and all_same([x.arg for x in vin])
):
return self.const(vin[0].arg, dtype, insert_before)
if uop == UOps.ALU:
# rewrites. NOTE: the rewritten NEG op is still around...
if (
arg == BinaryOps.ADD
and vin[1].uop == UOps.ALU
and vin[1].arg == UnaryOps.NEG
):
return self.uop(
UOps.ALU,
dtype,
(vin[0], vin[1].vin[0]),
BinaryOps.SUB,
cachable=cachable,
insert_before=insert_before,
)
# constant folding
if arg == UnaryOps.NEG and vin[0].uop == UOps.CONST:
return self.const(-vin[0].arg, dtype, insert_before)
if arg == TernaryOps.WHERE and vin[1] == vin[2]:
return vin[
1
] # a conditional with the same results either way is a noop
# zero folding
for x in [0, 1]:
if (
arg == BinaryOps.ADD
and vin[x].uop == UOps.CONST
and vin[x].arg == 0.0
):
return vin[1 - x]
if (
arg == BinaryOps.MUL
and vin[x].uop == UOps.CONST
and vin[x].arg == 1.0
):
return vin[1 - x]
if (
arg == BinaryOps.MUL
and vin[x].uop == UOps.CONST
and vin[x].arg == 0.0
):
return vin[x]
if (
arg == BinaryOps.SUB
and vin[1].uop == UOps.CONST
and vin[1].arg == 0.0
):
return vin[0]
if (
arg == BinaryOps.DIV
and vin[1].uop == UOps.CONST
and vin[1].arg == 1.0
):
return vin[0]
# When insert_before is set, need to check if the cached expr is valid with the given insert place.
if (
cachable
and (expr := self.saved_exprs.get(key, None)) is not None
and (insert_before is None or self.uops.index(expr) <= insert_before)
):
return expr
ret = UOp(uop, dtype, vin, arg)
if insert_before is not None:
self.uops.insert(insert_before, ret)
else:
self.uops.append(ret)
if cachable:
self.saved_exprs[key] = ret
return ret
[docs]
def ast_parse(
self,
x: LazyOp,
acc: List[UOp],
offs: Optional[List[int]],
loaded_buffers: Dict[Union[MemBuffer, ConstBuffer, LocalBuffer], List[UOp]],
do_reduce=False,
loop_ctx=tuple(),
) -> List[UOp]:
"""
Parses abstract syntax tree (AST) for operations.
Attributes:
x (LazyOp): The operation to parse.
acc (List[UOp]): The list of UOps accumulated so far.
offs (Optional[List[int]]): List of offsets, if any.
loaded_buffers (Dict[Union[MemBuffer, ConstBuffer, LocalBuffer], List[UOp]]): A dictionary containing loaded buffers.
do_reduce (bool): Flag indicating whether to perform a reduction operation. Default is False.
loop_ctx (tuple): The loop context, if any.
Returns:
List[UOp]: The list of UOps resulting from parsing the AST.
"""
if x.op in BufferOps:
return loaded_buffers[x.arg]
if x.op == UnaryOps.CAST:
return [
self.uop(UOps.CAST, x.arg[0], (u,), x.arg)
if not isinstance(x.arg[0], ImageDType)
else u
for u in self.ast_parse(
cast(LazyOp, x.src[0]), acc, offs, loaded_buffers
)
]
if x.op in ReduceOps and not do_reduce:
assert offs is None, "not available if we aren't doing reduce"
return acc
# MULACC fusion. TODO: this is copied from Interpreted
if (
x.op == ReduceOps.SUM
and x.src[0].__class__ is LazyOp
and x.src[0].op == BinaryOps.MUL
):
x = LazyOp(TernaryOps.MULACC, x.src[0].src, x.arg)
if (
x.op == ReduceOps.SUM
and x.src[0].__class__ is LazyOp
and x.src[0].op == UnaryOps.CAST
and x.src[0].src[0].__class__ is LazyOp
and x.src[0].src[0].op == BinaryOps.MUL
):
x = LazyOp(TernaryOps.MULACC, x.src[0].src[0].src, x.arg)
values = [
self.ast_parse(
cast(LazyOp, v), acc, offs, loaded_buffers, loop_ctx=loop_ctx
)
for v in x.src
]
ops = {
ReduceOps.SUM: BinaryOps.ADD,
ReduceOps.MAX: BinaryOps.MAX,
TernaryOps.MULACC: TernaryOps.MULACC,
}
if x.op in ops:
ret: List[UOp] = []
input_acc = acc[:]
for val, off in zip(zip(*values), cast(List[int], offs)):
acc[off] = self.uop(UOps.ALU, vin=val + (acc[off],), arg=ops[x.op])
ret.append(acc[off])
for off in range(len(acc)):
if input_acc[off] != acc[off]:
acc[off] = self.uop(
UOps.PHI,
input_acc[off].dtype,
(input_acc[off], acc[off]) + tuple(loop_ctx),
)
else:
ret = [
self.uop(
UOps.ALU,
dtype=dtypes.bool if x.op == BinaryOps.CMPLT else None,
vin=val,
arg=x.op,
)
for val in zip(*values)
]
return ret