from typing import Dict, List, Optional, NamedTuple, Tuple, Union, DefaultDict, cast
import math, functools
from collections import defaultdict
from tinygrad.codegen.linearizer import UOps, UOp
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps
from tinygrad.helpers import ImageDType, dtypes, prod, DType, strip_parens
[docs]
class CStyleLanguage(NamedTuple):
"""
Class representing a C-style programming language.
Attributes:
size_prefix (str): Prefix for size. Defaults to "int".
generic_var_prefix (str): Prefix for generic variables. Defaults to empty string.
kernel_prefix (str): Prefix for kernels. Defaults to empty string.
buffer_prefix (str): Prefix for buffers. Defaults to empty string.
buffer_suffix (str): Suffix for buffers. Defaults to empty string.
smem_align (str): Alignment for shared memory. Defaults to empty string.
smem_prefix (str): Prefix for shared memory. Defaults to empty string.
smem_prefix_for_cast (bool): Indicates whether prefix should be used in casts. Defaults to True.
arg_int_prefix (str): Prefix for integer arguments. Defaults to empty string.
barrier (str): Barrier synchronization method. Defaults to empty string.
xid (List[str]): List of x identifiers. Defaults to empty list.
gid (List[str]): List of g identifiers. Defaults to empty list.
lid (List[str]): List of l identifiers. Defaults to empty list.
global_max (List[int]): List of maximum global values. Defaults to empty list.
local_max (List[int]): List of maximum local values. Defaults to empty list.
extra_args (List[str]): List of extra arguments. Defaults to empty list.
float4 (Optional[str]): Float4 value with None as default.
half_prekernel (Optional[str]): Half pre-kernel value with None as default.
uses_vload (bool): Indicates whether vload is used. Defaults to False.
external_local_bufs (bool): Indicates whether external local buffers are used. Defaults to False.
uses_ptr_arithmetic (bool): Indicates whether pointer arithmetic is used. Defaults to False.
launch_bounds (bool): Indicates whether launch bounds are used. Defaults to False.
code_for_op (Dict): Dictionary containing operations for unary, binary, and ternary ops.
"""
size_prefix: str = "int"
generic_var_prefix: str = ""
kernel_prefix: str = ""
buffer_prefix: str = ""
buffer_suffix: str = ""
smem_align: str = ""
smem_prefix: str = ""
smem_prefix_for_cast: bool = True
arg_int_prefix: str = ""
barrier: str = ""
xid: List[str] = []
gid: List[str] = []
lid: List[str] = []
global_max: List[int] = []
local_max: List[int] = []
extra_args: List[str] = []
float4: Optional[str] = None
half_prekernel: Optional[str] = None
uses_vload: bool = False
external_local_bufs: bool = False
uses_ptr_arithmetic: bool = False
launch_bounds: bool = False
code_for_op: Dict = {
UnaryOps.NEG: lambda x, dtype: f"(-{x})" if dtype != dtypes.bool else f"(!{x})",
UnaryOps.EXP2: lambda x, dtype: f"exp2({x})",
UnaryOps.LOG2: lambda x, dtype: f"log2({x})",
UnaryOps.SIN: lambda x, dtype: f"sin({x})",
UnaryOps.SQRT: lambda x, dtype: f"sqrt({x})",
BinaryOps.ADD: lambda a, b, dtype: f"({a}+{b})",
BinaryOps.SUB: lambda a, b, dtype: f"({a}-{b})",
BinaryOps.MUL: lambda a, b, dtype: f"({a}*{b})",
BinaryOps.DIV: lambda a, b, dtype: f"({a}/{b})",
BinaryOps.MAX: lambda a, b, dtype: f"max({a},{b})",
BinaryOps.MOD: lambda a, b, dtype: f"({a}%{b})",
BinaryOps.CMPLT: lambda a, b, dtype: f"({a}<{b})",
TernaryOps.MULACC: lambda a, b, c, dtype: f"(({a}*{b})+{c})",
TernaryOps.WHERE: lambda a, b, c, dtype: f"({a}!=0?{b}:{c})",
}
# returns a str expression of the casted xs with the given type
[docs]
def render_cast(self, x: List[str], var_dtype: DType) -> str:
"""
Returns a string expression of the casted xs with the given type.
Attributes:
self (Any): The instance of the class.
x (List[str]): A list of strings to be casted.
var_dtype (DType): The datatype into which the elements of x are to be casted.
Returns:
str: String expression of the casted xs with the given type.
Raises:
AssertionError: If length of x is not equal to var_dtype.sz or if float4 attribute is None.
"""
if len(x) == 1:
return f"({var_dtype.name})({x[0]})"
assert len(x) == var_dtype.sz, f"cast is wrong size {len(x)} != {var_dtype.sz}"
assert (
self.float4 is not None
), "vectorized cast is not supported on this platform"
return f"{self.float4.replace('float4', var_dtype.name)}({','.join(x)})"
# returns a str expression of the const with the given type
[docs]
def render_const(self, x: Union[float, int, bool], var_dtype) -> str:
"""
Returns a string expression of the constant with the given type.
Attributes:
x (Union[float, int, bool]): The input value
var_dtype: The data type of the variable
Returns:
str: A string representation of the constant
"""
if math.isnan(x):
val = "NAN"
elif math.isinf(x):
val = ("-" if x < 0 else "") + "INFINITY"
else:
val = (
f"{float(x)}f"
if dtypes.is_float(var_dtype)
else f"{int(x)}"
if dtypes.is_int(var_dtype)
else f"{bool(x)}".lower()
)
return (
self.render_cast([val] * var_dtype.sz, var_dtype)
if var_dtype.sz > 1
or var_dtype not in [dtypes.float, dtypes.int, dtypes.bool]
else val
)
# returns a str expression of the loaded value with the output type
[docs]
def render_load(self, output_dtype, buf_name, buf_dtype, idx, local=False) -> str:
"""
Returns a string expression of the loaded value with the output type.
Attributes:
output_dtype: The output data type
buf_name (str): The buffer name
buf_dtype: The data type of the buffer
idx (int): The index in the buffer
local (bool): Whether the buffer is local or not
Returns:
str: A string representation of the loaded value
"""
if isinstance(buf_dtype, ImageDType):
assert output_dtype == dtypes.float.vec(
4
), f"images must be float4, getting {output_dtype}"
return f"read_imagef({buf_name}, smp, {idx})"
if (
self.uses_vload
and buf_dtype.scalar() == dtypes.float16
and output_dtype.scalar() != dtypes.float16
):
return f"vload_half{'' if output_dtype.sz == 1 else str(output_dtype.sz)}(0, {buf_name}+{idx})"
if output_dtype.sz > 1:
out_val = f"*(({self.smem_prefix if local and self.smem_prefix_for_cast else self.buffer_prefix}{buf_dtype.name}{output_dtype.sz}*)({buf_name}+{idx}))"
else:
out_val = (
f"*({buf_name}+{idx})"
if self.uses_ptr_arithmetic
else f"{buf_name}[{idx}]"
)
return (
self.render_cast([out_val], output_dtype)
if output_dtype != buf_dtype
else out_val
)
[docs]
def render_local(self, name: str, size: int):
"""
Render the local memory variable declaration.
Attributes:
self (Any): The object itself
name (str): The name of the local memory variable
size (int): The size of the local memory variable array
Returns:
str: The rendered local memory variable declaration
"""
return self.smem_align + self.smem_prefix + f"float {name}[{size}];"
[docs]
def render_for(
self, expr: str, _min: Union[int, str], _max: Union[int, str]
) -> str:
"""
Render a for loop.
Attributes:
self (Any): The object itself
expr (str): The loop variable expression
_min (Union[int, str]): The start value of the loop variable
_max (Union[int, str]): The end value of the loop variable
Returns:
str: The rendered for loop
"""
return f"for (int {expr} = {_min}; {expr} < {_max}; ++{expr}) {{"
[docs]
def render_if(self, cond: str):
"""
Render an if statement.
Attributes:
self (Any): The object itself
cond (str): The condition of the if statement
Returns:
str: The rendered if statement
"""
return f"if ({cond}) {{"
[docs]
def render_conditional(self, cond: str, x: str, y: str) -> str:
"""
Render a conditional expression.
Attributes:
self (Any): The object itself
cond (str): The condition of the conditional expression
x (str): The value to return if the condition is true
y (str): The value to return if the condition is false
Returns:
str: The rendered conditional expression
"""
return f"({cond})?({x}):{y}"
[docs]
def render_kernel(
self,
function_name: str,
kernel: List[str],
bufs: List[Tuple[str, DType]],
local_size: List[int],
prekernel: List[str],
) -> str:
"""
Render the kernel with given parameters.
This function generates a complete OpenCL kernel program based on the provided arguments. It creates the necessary
boilerplate code for the kernel, including sampler creation and buffer type definitions, then concatenates the
actual kernel code and post-processing code if needed.
Attributes:
self (OpenCLKernelBuilder): The instance of the OpenCLKernelBuilder class.
function_name (str): The name of the kernel function.
kernel (List[str]): The list of strings that form the body of the kernel function.
bufs (List[Tuple[str, DType]]): A list of tuples where each tuple contains a buffer name and its data type.
local_size (List[int]): The local size for the kernel execution.
prekernel (List[str]): Pre-kernel code that will be inserted before the actual kernel code.
Returns:
str: The complete OpenCL kernel program as a string.
"""
tmp = (
"const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n"
if any(isinstance(dtype, ImageDType) for _, dtype in bufs)
else ""
)
buftypes = [
(
name,
f"{'read_only' if i > 0 else 'write_only'} image2d_t"
if dtype.name.startswith("image")
else self.arg_int_prefix
if dtype == dtypes._arg_int32
else ("const " if i > 0 else "")
+ self.buffer_prefix
+ dtype.name
+ "*"
+ self.buffer_suffix,
)
for i, (name, dtype) in enumerate(bufs)
]
prg = "".join(
[
f"{self.kernel_prefix}void {f'__launch_bounds__ ({prod(local_size)}, 1) ' if self.launch_bounds else ''}{function_name}(",
]
+ [", ".join([f"{t} {name}" for name, t in buftypes] + self.extra_args)]
+ [") {\n" + tmp]
+ ["\n".join(kernel), "\n}"]
)
if self.half_prekernel and any(dtype == dtypes.float16 for _, dtype in bufs):
prg = "".join([f"{self.half_prekernel}", "\n", prg])
return prg
# returns a str statement that does the store
[docs]
def render_store(
self,
buf_name: str,
buf_dtype: DType,
var_name: str,
var_dtype: DType,
idx: str,
local=False,
) -> str:
"""
Returns a string statement that performs the store operation.
Attributes:
self (Any): The instance of the class.
buf_name (str): The buffer name.
buf_dtype (DType): The buffer data type.
var_name (str): The variable name.
var_dtype (DType): The variable data type.
idx (str): Index of the operation.
local (bool, optional): Whether the operation is local or not. Defaults to False.
Returns:
str: String statement that performs the store operation based on the given parameters.
"""
if isinstance(buf_dtype, ImageDType):
assert var_dtype == dtypes.float.vec(4), "images must be float4"
return f"write_imagef({buf_name}, {idx}, {var_name});"
if (
self.uses_vload
and buf_dtype.scalar() == dtypes.float16
and var_dtype.scalar() != dtypes.float16
):
return f"vstore_half{'' if var_dtype.sz == 1 else str(var_dtype.sz)}({var_name}, 0, {buf_name}+{idx});"
if var_dtype.sz > 1:
return f"*(({self.smem_prefix if local and self.smem_prefix_for_cast else self.buffer_prefix}{buf_dtype.name}{var_dtype.sz}*)({buf_name}+{idx})) = ({buf_dtype.name}{var_dtype.sz}){var_name};"
return (
f"*({buf_name}+{idx}) = {var_name};"
if self.uses_ptr_arithmetic
else f"{buf_name}[{idx}] = {var_name};"
)
[docs]
def uops_to_cstyle(
lang: CStyleLanguage, function_name: str, uops: List[UOp]
) -> Tuple[str, Dict]:
"""
Converts a list of micro-operations (uops) to the specified C-style language.
:param lang: The target C-style programming language for conversion.
:type lang: CStyleLanguage
:param function_name: The name of the function being converted.
:type function_name: str
:param uops: A list of micro-operations to convert.
:type uops: List[UOp]
:return: A tuple containing the converted C-style code and a dictionary.
:rtype: Tuple[str, Dict]
:Attributes:
- local_size (List[int]): Holds the size of local variables.
- kernel, prekernel, bufs (List[]): Lists for storing generated code and buffers.
- depth (int): The indentation level for the generated code. Default is 1.
- kk: A helper function to append lines of code with proper indentation.
- c, r (DefaultDict[str, int], Dict): Counters and mappings for temporary variables.
- ssa: A helper function to generate single static assignment (SSA) form representations.
- child_count (DefaultDict[UOp, int]): A counter for the number of children each uop has.
"""
local_size: List[int] = []
kernel, prekernel, bufs = [], [], []
# pend_close = None
depth = 1
def kk(s):
"""
A helper function to append a line of code with proper indentation.
:param s: The string to append as a line of code.
:type s: str
"""
kernel.append(" " * depth + s)
c: DefaultDict[str, int] = defaultdict(int)
r: Dict[UOp, str] = {}
def ssa(u, prefix="t"):
"""
A helper function to append a line of code with proper indentation.
:param s: The string to append as a line of code.
:type s: str
"""
nonlocal c, r
c[prefix] += 1
r[u] = f"{prefix}{c[prefix]-1}"
return r[u]
child_count: DefaultDict[UOp, int] = defaultdict(int)
for ru in uops:
for v in ru.vin:
child_count[v] += 1
for u in uops:
uop, dtype, vin, args = u.uop, u.dtype, u.vin, u.arg
if uop == UOps.LOOP:
kk(lang.render_for(ssa(u, "ridx"), r[vin[0]], r[vin[1]]))
depth += 1
elif uop == UOps.IF:
kk(lang.render_if(r[vin[0]]))
depth += 1
elif uop == UOps.BARRIER:
kk(lang.barrier)
elif uop == UOps.END:
depth -= 1
kk("}")
elif uop == UOps.WMMA:
if args[0] == "METAL":
assert dtype == dtypes.float.vec(
2
), "output dtype of METAL TC is _float2"
# ((lidx2*32)+(lidx3*4)+(lidx4*16)+(lidx5*8)+(lidx6*2))
output = ssa(u, "wmma")
kk(
f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {output};"
)
kk("{ simdgroup_float8x8 a,b,c;")
kk(
f"a.thread_elements()[0] = {r[vin[0]]}; a.thread_elements()[1] = {r[vin[1]]};"
)
kk(
f"b.thread_elements()[0] = {r[vin[2]]}; b.thread_elements()[1] = {r[vin[3]]};"
)
kk(
f"c.thread_elements()[0] = {r[vin[4]]}; c.thread_elements()[1] = {r[vin[5]]};"
)
kk("simdgroup_multiply_accumulate(c, a, b, c);")
kk(
f"{output}.x = c.thread_elements()[0]; {output}.y = c.thread_elements()[1]; }}"
)
elif args[0] == "HIP":
assert dtype == dtypes.float.vec(8), "output dtype of HIP TC is _float8"
kk(
f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {ssa(u, 'wmma')} = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32({r[vin[0]]}, {r[vin[1]]}, {r[vin[2]]});"
)
else:
raise NotImplementedError(f"WMMA not implemented for {args}")
elif uop == UOps.ALU:
assert dtype is not None
# remove parens if ALU types are the same. TODO: can do more here
if (
vin[0].uop == UOps.ALU
and vin[0].arg == args
and args in {BinaryOps.ADD, BinaryOps.SUB, BinaryOps.MUL}
):
val = lang.code_for_op[args](
strip_parens(r[vin[0]]), *[r[x] for x in vin[1:]], dtype
)
else:
val = lang.code_for_op[args](*[r[x] for x in vin] + [dtype])
assert child_count[u] != 0, f"childless ALU op found {u}"
if (
child_count[u] <= 1 or dtypes.is_int(dtype)
) and args != BinaryOps.MAX: # fix index rendering issue. fix clang nested max macro issue
r[u] = val
else:
kk(
f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {ssa(u,'alu')} = {val};"
)
elif uop == UOps.DEFINE_ACC:
assert dtype is not None
kk(
f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {ssa(u,'acc')} = {lang.render_const(args, dtype)};"
)
elif uop == UOps.SPECIAL:
xid = (
lang.gid
if args[1].startswith("g")
else (lang.xid if args[1].startswith("i") else lang.lid)
)
kk(f"{lang.size_prefix} {args[1]} = {xid[args[0]]}; /* {args[2]} */")
if args[1].startswith("l"):
local_size.append(args[2])
r[u] = args[1]
elif uop == UOps.CONST:
r[u] = (
lang.render_const(args, dtype)
if args >= 0
else f"({lang.render_const(args, dtype)})"
)
elif uop == UOps.LOAD:
assert dtype is not None
val = lang.render_load(
dtype,
r[vin[0]],
vin[0].dtype,
strip_parens(r[vin[1]]),
vin[0].uop == UOps.DEFINE_LOCAL,
)
if len(vin) > 3:
val = lang.render_conditional(r[vin[2]], val, r[vin[3]])
kk(
f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {ssa(u,'val')} = {val};"
)
elif uop == UOps.PHI:
kk(f"{r[vin[0]]} = {r[vin[1]]};")
r[u] = r[vin[0]]
elif uop == UOps.STORE:
assert vin[0].dtype is not None and vin[2].dtype is not None
if len(vin) > 3:
kk(lang.render_if(r[vin[3]]))
kk(
lang.render_store(
r[vin[0]],
vin[0].dtype,
r[vin[2]],
vin[2].dtype,
strip_parens(r[vin[1]]),
vin[0].uop == UOps.DEFINE_LOCAL,
)
)
if len(vin) > 3:
kk("}")
elif uop == UOps.CAST and dtype is not None:
val = lang.render_cast([r[x] for x in vin], dtype)
if child_count[u] <= 1:
r[u] = val
else:
kk(
f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {ssa(u,'cast')} = {val};"
)
elif uop == UOps.DEFINE_LOCAL:
if lang.external_local_bufs:
prekernel.append(lang.render_local(args[0], args[1]))
else:
kk(lang.render_local(args[0], args[1]))
r[u] = args[0]
elif uop == UOps.DEFINE_GLOBAL:
bufs.append(args)
r[u] = args[0]
elif uop == UOps.GEP:
if cast(DType, vin[0].dtype).sz > 4:
r[u] = f"({r[vin[0]]})[{args}]" # this is correct for HIP
else:
r[u] = f"({r[vin[0]]}).{'xyzw'[args]}"
else:
raise RuntimeError(f"failed to render {uop}")
return lang.render_kernel(function_name, kernel, bufs, local_size, prekernel), {}
[docs]
class OpenCLLanguage(CStyleLanguage):
"""
OpenCLLanguage class. Inherits from CStyleLanguage.
Attributes:
kernel_prefix (str): Prefix for kernel functions. Default is "__kernel ".
buffer_prefix (str): Prefix for buffer variables. Default is "__global ".
smem_align (str): Alignment attribute for shared memory. Default is "__attribute__ ((aligned (16))) ".
smem_prefix (str): Prefix for shared memory variables. Default is "__local ".
arg_int_prefix (str): Prefix for integer function arguments. Default is "const int".
half_prekernel (str): OpenCL extension pragma for enabling half precision. Default is "#pragma OPENCL EXTENSION cl_khr_fp16 : enable".
barrier (str): Barrier code for synchronizing threads in a work group. Default is "barrier(CLK_LOCAL_MEM_FENCE);".
float4 (str): String conversion for type float4. Default is "(float4)".
gid (list of str): List of get_group_id function calls for dimensions 0, 1, and 2.
lid (list of str): List of get_local_id function calls for dimensions 0, 1, and 2.
xid (list of str): List of get_global_id function calls for dimensions 0, 1, and 2.
uses_vload (bool): Flag indicating if vload is used. Default is True.
code_for_op (dict): Dictionary mapping operation names to lambda functions that generate code for the operations. Inherits from CStyleLanguage and adds a new entry for TernaryOps.MULACC.
"""
kernel_prefix = "__kernel "
buffer_prefix = "__global "
smem_align = "__attribute__ ((aligned (16))) "
smem_prefix = "__local "
arg_int_prefix = "const int"
half_prekernel = "#pragma OPENCL EXTENSION cl_khr_fp16 : enable"
barrier = "barrier(CLK_LOCAL_MEM_FENCE);"
float4 = "(float4)"
gid = [f"get_group_id({i})" for i in range(3)]
lid = [f"get_local_id({i})" for i in range(3)]
xid = [f"get_global_id({i})" for i in range(3)]
uses_vload = True
# NOTE: mad is used so the loads aren't reordered into the math on 845
code_for_op = {
**CStyleLanguage().code_for_op,
TernaryOps.MULACC: lambda a, b, c, dtype: f"mad({a},{b},{c})",
}
OpenCLRenderer = functools.partial(uops_to_cstyle, OpenCLLanguage())
MetalRenderer = functools.partial(uops_to_cstyle, MetalLanguage())
[docs]
class CUDALanguage(CStyleLanguage):
"""
The `CUDALanguage` class is a subclass of the `CStyleLanguage` class, specifically tailored for CUDA GPU programming.
Attributes:
kernel_prefix (str): Prefix for defining kernels in CUDA.
smem_prefix (str): Prefix for shared memory variables in CUDA.
smem_prefix_for_cast (bool): Flag to indicate whether prefix should be used with casting for shared memory.
arg_int_prefix (str): Prefix for integer arguments in CUDA kernel functions.
barrier (str): Code snippet for thread synchronization in CUDA.
float4 (str): Function name for creating a 4-component floating point number in CUDA.
gid (list): List of strings representing the global index in each dimension (x, y, z).
lid (list): List of strings representing the local index in each dimension (x, y, z).
xid (list): List of strings representing the combined global and local indices for each dimension (x, y, z).
code_for_op (dict): Dictionary mapping binary operations to their corresponding CUDA code.
half_prekernel (str): Pre-kernel code needed when working with half-precision floating point numbers in CUDA.
"""
kernel_prefix = '#define INFINITY (__int_as_float(0x7f800000))\n#define NAN (__int_as_float(0x7fffffff))\nextern "C" __global__ '
smem_prefix = "__shared__ "
smem_prefix_for_cast = False
arg_int_prefix = "const int"
barrier = "__syncthreads();"
float4 = "make_float4"
gid = [f"blockIdx.{chr(120+i)}" for i in range(3)]
lid = [f"threadIdx.{chr(120+i)}" for i in range(3)]
xid = [
f"(blockIdx.{chr(120+i)}*blockDim.{chr(120+i)}+threadIdx.{chr(120+i)})"
for i in range(3)
]
code_for_op = {
**CStyleLanguage().code_for_op,
BinaryOps.MAX: lambda a, b, dtype: f"max({a},{b})"
if dtype != dtypes.half
else f"__hmax({a},{b})",
}
half_prekernel = """
#include <cuda_fp16.h>
struct half4 { half x, y, z, w; };
__device__ half4 make_half4(half x, half y, half z, half w) { half4 ret; ret.x = x; ret.y = y; ret.z = z; ret.w = w; return ret; }
"""
CUDARenderer = functools.partial(uops_to_cstyle, CUDALanguage())
[docs]
class HIPLanguage(CStyleLanguage):
"""
HIPLanguage class that inherits from CStyleLanguage.
Attributes:
kernel_prefix (str): Kernel prefix containing various function definitions and necessary includes.
launch_bounds (bool): A boolean indicating if the language supports launch bounds.
smem_prefix (str): The shared memory prefix for the language.
smem_prefix_for_cast (bool): A boolean indicating whether a cast is required for the shared memory prefix.
barrier (str): The barrier synchronization primitive for the language.
float4 (str): The name of the float4 type for the language.
uses_vload (bool): Whether the language uses vload.
uses_ptr_arithmetic (bool): Whether the language uses pointer arithmetic.
arg_int_prefix (str): The integer prefix for function arguments.
gid (list): A list comprising strings representing the grid index in three dimensions.
lid (list): A list comprising strings representing the local thread index in three dimensions.
xid (list): A list comprising strings representing the extended thread index in three dimensions.
code_for_op (dict): A dictionary containing lambda functions for different operations and their
corresponding implementations based on data type.
"""
kernel_prefix = (
'#include <hip/hip_common.h>\n#define INFINITY (__builtin_inff())\n#define NAN (__builtin_nanf(""))'
+ """
__device__ float4 max(float4 x, float4 y) { return float4(max(x.x, y.x), max(x.y, y.y), max(x.z, y.z), max(x.w, y.w)); }
__device__ float4 pow(float x, float4 y) { return float4(pow(x, y.x), pow(x, y.y), pow(x, y.z), pow(x, y.w)); }
__device__ float4 pow(float4 x, float4 y) { return float4(pow(x.x, y.x), pow(x.y, y.y), pow(x.z, y.z), pow(x.w, y.w)); }
__device__ float4 log2(float4 x) { return float4(log2(x.x), log2(x.y), log2(x.z), log2(x.w)); }
__device__ float4 exp2(float4 x) { return float4(exp2(x.x), exp2(x.y), exp2(x.z), exp2(x.w)); }
__device__ float4 sin(float4 x) { return float4(sin(x.x), sin(x.y), sin(x.z), sin(x.w)); }
typedef float float8 __attribute__((ext_vector_type(8))); __device__ float8 make_float8(float x, float y, float z, float w, float a, float b, float c, float d) { return {x, y, z, w, a, b, c, d}; }
extern "C" __global__
"""
)
launch_bounds = True
smem_prefix = "__shared__ "
smem_prefix_for_cast = False
barrier = "__syncthreads();"
float4 = "make_float4"
uses_vload = True
uses_ptr_arithmetic = True
arg_int_prefix = "const int"
half_prekernel = (
"#include <hip/hip_fp16.h>\n"
+ """
typedef union { struct { half x, y, z, w; } __attribute__((aligned(8))); half data[4]; } half4; __device__ half4 make_half4(half x, half y, half z, half w) { return {x, y, z, w}; }
typedef union { struct { half x, y, z, w, a, b, c, d; } __attribute__((aligned(16))); half data[8]; } half8; __device__ half8 make_half8(half x, half y, half z, half w, half a, half b, half c, half d) { return {x, y, z, w, a, b, c, d}; }
typedef _Float16 half16 __attribute__((ext_vector_type(16))); __device__ half16 make_half16(half x, half y, half z, half w, half a, half b, half c, half d, half e, half f, half g, half h, half i, half j, half k, half l) { return {x, y, z, w, a, b, c, d, e, f, g, h, i, j, k, l}; }
__device__ float vload_half(size_t offset, const half *p) { return (float)*(p + offset); }
__device__ float2 vload_half2(size_t offset, const half *p) { return make_float2((float)*(p + offset*2), (float)*(p + offset*2 + 1)); }
__device__ float4 vload_half4(size_t offset, const half *p) { return make_float4((float)*(p + offset*4), (float)*(p + offset*4 + 1), (float)*(p + offset*4 + 2), (float)*(p + offset*4 + 3)); }
__device__ void vstore_half(float data, size_t offset, half *p) { *(p + offset) = (half)data; }
__device__ void vstore_half2(float2 data, size_t offset, half *p) { *(p + offset*2) = (half)data.x; *(p + offset*2 + 1) = (half)data.y; }
__device__ void vstore_half4(float4 data, size_t offset, half *p) { *(p + offset*4) = (half)data.x; *(p + offset*4 + 1) = (half)data.y; *(p + offset*4 + 2) = (half)data.z; *(p + offset*4 + 3) = (half)data.w; }
__device__ half exp2(half x) { return hexp2(x); }
__device__ half log2(half x) { return hlog2(x); }
__device__ half sin(half x) { return hsin(x); }
__device__ half sqrt(half x) { return hsqrt(x); }
__device__ half hmax(half a, half b) { return __hgt(a, b) ? a : b; }
__device__ half operator%(const half &a, const half &b) { return __hsub(a, __hmul(b, __float2half(floorf(__half2float(a) / __half2float(b))))); }
__device__ bool operator!=(const half &a, const int &b) { return (float)a != b; }
// HACKS for ALU ops on half and result of half2 GEP
__device__ half operator+(const half &a, const unsigned short &b) { return __hadd(a, (half)(b)); }
__device__ half operator-(const half &a, const unsigned short &b) { return __hsub(a, (half)(b)); }
__device__ half operator*(const half &a, const unsigned short &b) { return __hmul(a, (half)(b)); }
__device__ half operator/(const half &a, const unsigned short &b) { return __hdiv(a, (half)(b)); }
__device__ bool operator<(const half &a, const unsigned short &b) { return __hlt(a, (half)(b)); }
// now the other way
__device__ half operator+(const unsigned short &a, const half &b) { return __hadd((half)(a), b); }
__device__ half operator-(const unsigned short &a, const half &b) { return __hsub((half)(a), b); }
__device__ half operator*(const unsigned short &a, const half &b) { return __hmul((half)(a), b); }
__device__ half operator/(const unsigned short &a, const half &b) { return __hdiv((half)(a), b); }
__device__ bool operator<(const unsigned short &a, const half &b) { return __hlt((half)(a), b); }
"""
)
gid = [f"blockIdx.{chr(120+i)}" for i in range(3)]
lid = [f"threadIdx.{chr(120+i)}" for i in range(3)]
xid = [
f"(blockIdx.{chr(120+i)}*blockDim.{chr(120+i)}+threadIdx.{chr(120+i)})"
for i in range(3)
]
code_for_op = {
**CStyleLanguage().code_for_op,
BinaryOps.MAX: lambda a, b, dtype: f"max({a},{b})"
if dtype != dtypes.half
else f"hmax({a},{b})",
TernaryOps.WHERE: lambda a, b, c, dtype: f"({a}!=0?{b}:{c})"
if dtype != dtypes.half
else f"(half)({a}!=0?{b}:{c})",
}
HIPRenderer = functools.partial(uops_to_cstyle, HIPLanguage())
# TODO: how much of this can be merged with above?
[docs]
class WGSLLanguage(CStyleLanguage):
"""
The WGSLLanguage class, a subclass of CStyleLanguage. This class is used to represent the WebGPU Shading Language (WGSL).
Attributes:
gid (list): List containing gid indices for x, y, and z dimensions.
lid (list): List containing lid indices for x, y, and z dimensions.
size_prefix (str): Prefix used to declare a variable. Defaults to "let".
barrier (str): Code snippet for workgroup barrier. Defaults to "workgroupBarrier();".
generic_var_prefix (str): Prefix used for generic variables. Defaults to "var ".
external_local_bufs (bool): Flag indicating if local buffers are external. Defaults to True.
code_for_op (dict): Dictionary containing custom op codes for various operations like BinaryOps and TernaryOps.
type_map (dict): Dictionary mapping dtypes to WGSL types.
"""
gid = [f"i32(gindex.{'xyz'[x]})" for x in range(3)]
lid = [f"i32(lindex.{'xyz'[x]})" for x in range(3)]
size_prefix = "let"
barrier = "workgroupBarrier();"
generic_var_prefix = "var "
external_local_bufs = True
code_for_op = {
**CStyleLanguage().code_for_op,
BinaryOps.CMPLT: lambda x, y, dtype: f"f32({x}<{y})",
TernaryOps.MULACC: lambda x, y, z, dtype: f"fma({x},{y},{z})",
TernaryOps.WHERE: lambda a, b, c, dtype: f"select({c},{b},{a}!=0.)",
}
type_map = {
dtypes.float: "f32",
dtypes.half: "f16",
dtypes.int32: "i32",
dtypes.uint32: "u32",
dtypes.bool: "bool",
}
[docs]
def render_local(self, name: str, size: int):
"""
Render a local variable declaration.
Args:
name (str): The name of the variable to be declared.
size (int): The size of the array.
Returns:
str: The WGSL code snippet for declaring a local variable.
"""
return f"var<workgroup> {name}: array<f32,{size}>;"
[docs]
def render_const(self, x: Union[float, int], var_dtype) -> str:
"""
Render a constant value.
Args:
x (Union[float, int]): The constant value to be rendered.
var_dtype: The data type of the variable.
Returns:
str: The WGSL code snippet for the constant value.
"""
if math.isnan(x):
return "nan()"
elif math.isinf(x):
return ("-" if x < 0 else "") + "0x1.fffffep+127f"
return f"({super().render_const(x, var_dtype)})"
[docs]
def render_kernel(
self,
function_name: str,
kernel: List[str],
bufs: List[Tuple[str, DType]],
local_size: List[int],
prekernel: List[str],
) -> str:
"""
Render the kernel for execution.
Attributes:
function_name (str): The name of the function to be rendered.
kernel (List[str]): The list of kernel code lines.
bufs (List[Tuple[str, DType]]): A list of tuples containing buffer names and their respective data types.
local_size (List[int]): The local size for workgroup execution. If not provided, default is [1].
prekernel (List[str]): Code lines to be executed before the kernel code.
Returns:
str: The rendered kernel code as a string.
"""
local_size = local_size[::-1] if local_size else [1]
bind_it = iter(range(len(bufs)))
prg = "fn nan() -> f32 { let bits = 0xffffffffu; return bitcast<f32>(bits); }\n"
prg += "\n".join(
prekernel
+ [
f"@group(0) @binding({next(bind_it)}) var<storage,read_write> {name}: array<{self.type_map[dtype]}>;"
for name, dtype in bufs
]
)
prg += (
f"\n@compute @workgroup_size({','.join([str(x) for x in local_size])}) fn {function_name}(@builtin(workgroup_id) gindex: vec3<u32>, @builtin(local_invocation_id) lindex: vec3<u32>) {{\n"
+ "\n".join(kernel)
+ "\n}"
)
return prg
[docs]
def render_for(
self, expr: str, _min: Union[int, str], _max: Union[int, str]
) -> str:
"""
Render a for loop with the given expression and range.
Attributes:
expr (str): The loop variable's name.
_min (Union[int, str]): The starting value of the loop variable.
_max (Union[int, str]): The ending value of the loop variable.
Returns:
str: The rendered for loop as a string.
"""
return f"for(var {expr} = {_min}; {expr} < {_max}; {expr}++) {{"
[docs]
def render_if(self, cond: str):
"""
Render an if statement with the given condition.
Attributes:
cond (str): The conditional expression to be checked in the if statement.
Returns:
str: The rendered if statement as a string.
"""
return f"if (bool({cond})) {{"
[docs]
def render_conditional(self, cond: str, x: str, y: str) -> str:
"""
Render a conditional expression that selects between two values based on a condition.
Attributes:
cond (str): The conditional expression to be checked.
x (str): The value to select if the condition is true.
y (str): The value to select if the condition is false.
Returns:
str: The rendered conditional expression as a string.
"""
return f"select(f32({y}), {x}, bool({cond}))"
[docs]
def render_cast(self, x: List[str], var_dtype: DType) -> str:
"""
Render a type cast for the given value with the target data type.
Attributes:
x (List[str]): The value to be casted.
var_dtype (DType): The target data type.
Returns:
str: The rendered type cast expression as a string.
Raises:
NotImplementedError: If no cast is available for the target data type.
"""
if self.type_map[var_dtype]:
return f"{self.type_map[var_dtype]}({x[0]})"
raise NotImplementedError(f"no cast for {var_dtype}")
[docs]
def render_store(
self,
buf_name: str,
buf_dtype: DType,
var_name: str,
var_dtype: DType,
idx,
local=False,
) -> str:
"""
Render a store operation that stores a value in a buffer at the given index.
Attributes:
buf_name (str): The name of the buffer.
buf_dtype (DType): The data type of the buffer.
var_name (str): The name of the variable to store.
var_dtype (DType): The data type of the variable.
idx: The index at which to store the value in the buffer.
local (bool, optional): Whether this is a local store operation. Defaults to False.
Returns:
str: The rendered store operation as a string.
"""
return f"{buf_name}[{idx}] = {self.render_cast([var_name], buf_dtype) if var_dtype != buf_dtype else var_name};"
WGSLRenderer = functools.partial(uops_to_cstyle, WGSLLanguage())