Source code for tinygrad.renderer.cstyle

from typing import Dict, List, Optional, NamedTuple, Tuple, Union, DefaultDict, cast
import math, functools
from collections import defaultdict
from tinygrad.codegen.linearizer import UOps, UOp
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps
from tinygrad.helpers import ImageDType, dtypes, prod, DType, strip_parens


[docs] class CStyleLanguage(NamedTuple): """ Class representing a C-style programming language. Attributes: size_prefix (str): Prefix for size. Defaults to "int". generic_var_prefix (str): Prefix for generic variables. Defaults to empty string. kernel_prefix (str): Prefix for kernels. Defaults to empty string. buffer_prefix (str): Prefix for buffers. Defaults to empty string. buffer_suffix (str): Suffix for buffers. Defaults to empty string. smem_align (str): Alignment for shared memory. Defaults to empty string. smem_prefix (str): Prefix for shared memory. Defaults to empty string. smem_prefix_for_cast (bool): Indicates whether prefix should be used in casts. Defaults to True. arg_int_prefix (str): Prefix for integer arguments. Defaults to empty string. barrier (str): Barrier synchronization method. Defaults to empty string. xid (List[str]): List of x identifiers. Defaults to empty list. gid (List[str]): List of g identifiers. Defaults to empty list. lid (List[str]): List of l identifiers. Defaults to empty list. global_max (List[int]): List of maximum global values. Defaults to empty list. local_max (List[int]): List of maximum local values. Defaults to empty list. extra_args (List[str]): List of extra arguments. Defaults to empty list. float4 (Optional[str]): Float4 value with None as default. half_prekernel (Optional[str]): Half pre-kernel value with None as default. uses_vload (bool): Indicates whether vload is used. Defaults to False. external_local_bufs (bool): Indicates whether external local buffers are used. Defaults to False. uses_ptr_arithmetic (bool): Indicates whether pointer arithmetic is used. Defaults to False. launch_bounds (bool): Indicates whether launch bounds are used. Defaults to False. code_for_op (Dict): Dictionary containing operations for unary, binary, and ternary ops. """ size_prefix: str = "int" generic_var_prefix: str = "" kernel_prefix: str = "" buffer_prefix: str = "" buffer_suffix: str = "" smem_align: str = "" smem_prefix: str = "" smem_prefix_for_cast: bool = True arg_int_prefix: str = "" barrier: str = "" xid: List[str] = [] gid: List[str] = [] lid: List[str] = [] global_max: List[int] = [] local_max: List[int] = [] extra_args: List[str] = [] float4: Optional[str] = None half_prekernel: Optional[str] = None uses_vload: bool = False external_local_bufs: bool = False uses_ptr_arithmetic: bool = False launch_bounds: bool = False code_for_op: Dict = { UnaryOps.NEG: lambda x, dtype: f"(-{x})" if dtype != dtypes.bool else f"(!{x})", UnaryOps.EXP2: lambda x, dtype: f"exp2({x})", UnaryOps.LOG2: lambda x, dtype: f"log2({x})", UnaryOps.SIN: lambda x, dtype: f"sin({x})", UnaryOps.SQRT: lambda x, dtype: f"sqrt({x})", BinaryOps.ADD: lambda a, b, dtype: f"({a}+{b})", BinaryOps.SUB: lambda a, b, dtype: f"({a}-{b})", BinaryOps.MUL: lambda a, b, dtype: f"({a}*{b})", BinaryOps.DIV: lambda a, b, dtype: f"({a}/{b})", BinaryOps.MAX: lambda a, b, dtype: f"max({a},{b})", BinaryOps.MOD: lambda a, b, dtype: f"({a}%{b})", BinaryOps.CMPLT: lambda a, b, dtype: f"({a}<{b})", TernaryOps.MULACC: lambda a, b, c, dtype: f"(({a}*{b})+{c})", TernaryOps.WHERE: lambda a, b, c, dtype: f"({a}!=0?{b}:{c})", } # returns a str expression of the casted xs with the given type
[docs] def render_cast(self, x: List[str], var_dtype: DType) -> str: """ Returns a string expression of the casted xs with the given type. Attributes: self (Any): The instance of the class. x (List[str]): A list of strings to be casted. var_dtype (DType): The datatype into which the elements of x are to be casted. Returns: str: String expression of the casted xs with the given type. Raises: AssertionError: If length of x is not equal to var_dtype.sz or if float4 attribute is None. """ if len(x) == 1: return f"({var_dtype.name})({x[0]})" assert len(x) == var_dtype.sz, f"cast is wrong size {len(x)} != {var_dtype.sz}" assert ( self.float4 is not None ), "vectorized cast is not supported on this platform" return f"{self.float4.replace('float4', var_dtype.name)}({','.join(x)})"
# returns a str expression of the const with the given type
[docs] def render_const(self, x: Union[float, int, bool], var_dtype) -> str: """ Returns a string expression of the constant with the given type. Attributes: x (Union[float, int, bool]): The input value var_dtype: The data type of the variable Returns: str: A string representation of the constant """ if math.isnan(x): val = "NAN" elif math.isinf(x): val = ("-" if x < 0 else "") + "INFINITY" else: val = ( f"{float(x)}f" if dtypes.is_float(var_dtype) else f"{int(x)}" if dtypes.is_int(var_dtype) else f"{bool(x)}".lower() ) return ( self.render_cast([val] * var_dtype.sz, var_dtype) if var_dtype.sz > 1 or var_dtype not in [dtypes.float, dtypes.int, dtypes.bool] else val )
# returns a str expression of the loaded value with the output type
[docs] def render_load(self, output_dtype, buf_name, buf_dtype, idx, local=False) -> str: """ Returns a string expression of the loaded value with the output type. Attributes: output_dtype: The output data type buf_name (str): The buffer name buf_dtype: The data type of the buffer idx (int): The index in the buffer local (bool): Whether the buffer is local or not Returns: str: A string representation of the loaded value """ if isinstance(buf_dtype, ImageDType): assert output_dtype == dtypes.float.vec( 4 ), f"images must be float4, getting {output_dtype}" return f"read_imagef({buf_name}, smp, {idx})" if ( self.uses_vload and buf_dtype.scalar() == dtypes.float16 and output_dtype.scalar() != dtypes.float16 ): return f"vload_half{'' if output_dtype.sz == 1 else str(output_dtype.sz)}(0, {buf_name}+{idx})" if output_dtype.sz > 1: out_val = f"*(({self.smem_prefix if local and self.smem_prefix_for_cast else self.buffer_prefix}{buf_dtype.name}{output_dtype.sz}*)({buf_name}+{idx}))" else: out_val = ( f"*({buf_name}+{idx})" if self.uses_ptr_arithmetic else f"{buf_name}[{idx}]" ) return ( self.render_cast([out_val], output_dtype) if output_dtype != buf_dtype else out_val )
[docs] def render_local(self, name: str, size: int): """ Render the local memory variable declaration. Attributes: self (Any): The object itself name (str): The name of the local memory variable size (int): The size of the local memory variable array Returns: str: The rendered local memory variable declaration """ return self.smem_align + self.smem_prefix + f"float {name}[{size}];"
[docs] def render_for( self, expr: str, _min: Union[int, str], _max: Union[int, str] ) -> str: """ Render a for loop. Attributes: self (Any): The object itself expr (str): The loop variable expression _min (Union[int, str]): The start value of the loop variable _max (Union[int, str]): The end value of the loop variable Returns: str: The rendered for loop """ return f"for (int {expr} = {_min}; {expr} < {_max}; ++{expr}) {{"
[docs] def render_if(self, cond: str): """ Render an if statement. Attributes: self (Any): The object itself cond (str): The condition of the if statement Returns: str: The rendered if statement """ return f"if ({cond}) {{"
[docs] def render_conditional(self, cond: str, x: str, y: str) -> str: """ Render a conditional expression. Attributes: self (Any): The object itself cond (str): The condition of the conditional expression x (str): The value to return if the condition is true y (str): The value to return if the condition is false Returns: str: The rendered conditional expression """ return f"({cond})?({x}):{y}"
[docs] def render_kernel( self, function_name: str, kernel: List[str], bufs: List[Tuple[str, DType]], local_size: List[int], prekernel: List[str], ) -> str: """ Render the kernel with given parameters. This function generates a complete OpenCL kernel program based on the provided arguments. It creates the necessary boilerplate code for the kernel, including sampler creation and buffer type definitions, then concatenates the actual kernel code and post-processing code if needed. Attributes: self (OpenCLKernelBuilder): The instance of the OpenCLKernelBuilder class. function_name (str): The name of the kernel function. kernel (List[str]): The list of strings that form the body of the kernel function. bufs (List[Tuple[str, DType]]): A list of tuples where each tuple contains a buffer name and its data type. local_size (List[int]): The local size for the kernel execution. prekernel (List[str]): Pre-kernel code that will be inserted before the actual kernel code. Returns: str: The complete OpenCL kernel program as a string. """ tmp = ( "const sampler_t smp = CLK_NORMALIZED_COORDS_FALSE | CLK_ADDRESS_CLAMP | CLK_FILTER_NEAREST;\n" if any(isinstance(dtype, ImageDType) for _, dtype in bufs) else "" ) buftypes = [ ( name, f"{'read_only' if i > 0 else 'write_only'} image2d_t" if dtype.name.startswith("image") else self.arg_int_prefix if dtype == dtypes._arg_int32 else ("const " if i > 0 else "") + self.buffer_prefix + dtype.name + "*" + self.buffer_suffix, ) for i, (name, dtype) in enumerate(bufs) ] prg = "".join( [ f"{self.kernel_prefix}void {f'__launch_bounds__ ({prod(local_size)}, 1) ' if self.launch_bounds else ''}{function_name}(", ] + [", ".join([f"{t} {name}" for name, t in buftypes] + self.extra_args)] + [") {\n" + tmp] + ["\n".join(kernel), "\n}"] ) if self.half_prekernel and any(dtype == dtypes.float16 for _, dtype in bufs): prg = "".join([f"{self.half_prekernel}", "\n", prg]) return prg
# returns a str statement that does the store
[docs] def render_store( self, buf_name: str, buf_dtype: DType, var_name: str, var_dtype: DType, idx: str, local=False, ) -> str: """ Returns a string statement that performs the store operation. Attributes: self (Any): The instance of the class. buf_name (str): The buffer name. buf_dtype (DType): The buffer data type. var_name (str): The variable name. var_dtype (DType): The variable data type. idx (str): Index of the operation. local (bool, optional): Whether the operation is local or not. Defaults to False. Returns: str: String statement that performs the store operation based on the given parameters. """ if isinstance(buf_dtype, ImageDType): assert var_dtype == dtypes.float.vec(4), "images must be float4" return f"write_imagef({buf_name}, {idx}, {var_name});" if ( self.uses_vload and buf_dtype.scalar() == dtypes.float16 and var_dtype.scalar() != dtypes.float16 ): return f"vstore_half{'' if var_dtype.sz == 1 else str(var_dtype.sz)}({var_name}, 0, {buf_name}+{idx});" if var_dtype.sz > 1: return f"*(({self.smem_prefix if local and self.smem_prefix_for_cast else self.buffer_prefix}{buf_dtype.name}{var_dtype.sz}*)({buf_name}+{idx})) = ({buf_dtype.name}{var_dtype.sz}){var_name};" return ( f"*({buf_name}+{idx}) = {var_name};" if self.uses_ptr_arithmetic else f"{buf_name}[{idx}] = {var_name};" )
[docs] def uops_to_cstyle( lang: CStyleLanguage, function_name: str, uops: List[UOp] ) -> Tuple[str, Dict]: """ Converts a list of micro-operations (uops) to the specified C-style language. :param lang: The target C-style programming language for conversion. :type lang: CStyleLanguage :param function_name: The name of the function being converted. :type function_name: str :param uops: A list of micro-operations to convert. :type uops: List[UOp] :return: A tuple containing the converted C-style code and a dictionary. :rtype: Tuple[str, Dict] :Attributes: - local_size (List[int]): Holds the size of local variables. - kernel, prekernel, bufs (List[]): Lists for storing generated code and buffers. - depth (int): The indentation level for the generated code. Default is 1. - kk: A helper function to append lines of code with proper indentation. - c, r (DefaultDict[str, int], Dict): Counters and mappings for temporary variables. - ssa: A helper function to generate single static assignment (SSA) form representations. - child_count (DefaultDict[UOp, int]): A counter for the number of children each uop has. """ local_size: List[int] = [] kernel, prekernel, bufs = [], [], [] # pend_close = None depth = 1 def kk(s): """ A helper function to append a line of code with proper indentation. :param s: The string to append as a line of code. :type s: str """ kernel.append(" " * depth + s) c: DefaultDict[str, int] = defaultdict(int) r: Dict[UOp, str] = {} def ssa(u, prefix="t"): """ A helper function to append a line of code with proper indentation. :param s: The string to append as a line of code. :type s: str """ nonlocal c, r c[prefix] += 1 r[u] = f"{prefix}{c[prefix]-1}" return r[u] child_count: DefaultDict[UOp, int] = defaultdict(int) for ru in uops: for v in ru.vin: child_count[v] += 1 for u in uops: uop, dtype, vin, args = u.uop, u.dtype, u.vin, u.arg if uop == UOps.LOOP: kk(lang.render_for(ssa(u, "ridx"), r[vin[0]], r[vin[1]])) depth += 1 elif uop == UOps.IF: kk(lang.render_if(r[vin[0]])) depth += 1 elif uop == UOps.BARRIER: kk(lang.barrier) elif uop == UOps.END: depth -= 1 kk("}") elif uop == UOps.WMMA: if args[0] == "METAL": assert dtype == dtypes.float.vec( 2 ), "output dtype of METAL TC is _float2" # ((lidx2*32)+(lidx3*4)+(lidx4*16)+(lidx5*8)+(lidx6*2)) output = ssa(u, "wmma") kk( f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {output};" ) kk("{ simdgroup_float8x8 a,b,c;") kk( f"a.thread_elements()[0] = {r[vin[0]]}; a.thread_elements()[1] = {r[vin[1]]};" ) kk( f"b.thread_elements()[0] = {r[vin[2]]}; b.thread_elements()[1] = {r[vin[3]]};" ) kk( f"c.thread_elements()[0] = {r[vin[4]]}; c.thread_elements()[1] = {r[vin[5]]};" ) kk("simdgroup_multiply_accumulate(c, a, b, c);") kk( f"{output}.x = c.thread_elements()[0]; {output}.y = c.thread_elements()[1]; }}" ) elif args[0] == "HIP": assert dtype == dtypes.float.vec(8), "output dtype of HIP TC is _float8" kk( f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {ssa(u, 'wmma')} = __builtin_amdgcn_wmma_f32_16x16x16_f16_w32({r[vin[0]]}, {r[vin[1]]}, {r[vin[2]]});" ) else: raise NotImplementedError(f"WMMA not implemented for {args}") elif uop == UOps.ALU: assert dtype is not None # remove parens if ALU types are the same. TODO: can do more here if ( vin[0].uop == UOps.ALU and vin[0].arg == args and args in {BinaryOps.ADD, BinaryOps.SUB, BinaryOps.MUL} ): val = lang.code_for_op[args]( strip_parens(r[vin[0]]), *[r[x] for x in vin[1:]], dtype ) else: val = lang.code_for_op[args](*[r[x] for x in vin] + [dtype]) assert child_count[u] != 0, f"childless ALU op found {u}" if ( child_count[u] <= 1 or dtypes.is_int(dtype) ) and args != BinaryOps.MAX: # fix index rendering issue. fix clang nested max macro issue r[u] = val else: kk( f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {ssa(u,'alu')} = {val};" ) elif uop == UOps.DEFINE_ACC: assert dtype is not None kk( f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {ssa(u,'acc')} = {lang.render_const(args, dtype)};" ) elif uop == UOps.SPECIAL: xid = ( lang.gid if args[1].startswith("g") else (lang.xid if args[1].startswith("i") else lang.lid) ) kk(f"{lang.size_prefix} {args[1]} = {xid[args[0]]}; /* {args[2]} */") if args[1].startswith("l"): local_size.append(args[2]) r[u] = args[1] elif uop == UOps.CONST: r[u] = ( lang.render_const(args, dtype) if args >= 0 else f"({lang.render_const(args, dtype)})" ) elif uop == UOps.LOAD: assert dtype is not None val = lang.render_load( dtype, r[vin[0]], vin[0].dtype, strip_parens(r[vin[1]]), vin[0].uop == UOps.DEFINE_LOCAL, ) if len(vin) > 3: val = lang.render_conditional(r[vin[2]], val, r[vin[3]]) kk( f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {ssa(u,'val')} = {val};" ) elif uop == UOps.PHI: kk(f"{r[vin[0]]} = {r[vin[1]]};") r[u] = r[vin[0]] elif uop == UOps.STORE: assert vin[0].dtype is not None and vin[2].dtype is not None if len(vin) > 3: kk(lang.render_if(r[vin[3]])) kk( lang.render_store( r[vin[0]], vin[0].dtype, r[vin[2]], vin[2].dtype, strip_parens(r[vin[1]]), vin[0].uop == UOps.DEFINE_LOCAL, ) ) if len(vin) > 3: kk("}") elif uop == UOps.CAST and dtype is not None: val = lang.render_cast([r[x] for x in vin], dtype) if child_count[u] <= 1: r[u] = val else: kk( f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {ssa(u,'cast')} = {val};" ) elif uop == UOps.DEFINE_LOCAL: if lang.external_local_bufs: prekernel.append(lang.render_local(args[0], args[1])) else: kk(lang.render_local(args[0], args[1])) r[u] = args[0] elif uop == UOps.DEFINE_GLOBAL: bufs.append(args) r[u] = args[0] elif uop == UOps.GEP: if cast(DType, vin[0].dtype).sz > 4: r[u] = f"({r[vin[0]]})[{args}]" # this is correct for HIP else: r[u] = f"({r[vin[0]]}).{'xyzw'[args]}" else: raise RuntimeError(f"failed to render {uop}") return lang.render_kernel(function_name, kernel, bufs, local_size, prekernel), {}
[docs] class OpenCLLanguage(CStyleLanguage): """ OpenCLLanguage class. Inherits from CStyleLanguage. Attributes: kernel_prefix (str): Prefix for kernel functions. Default is "__kernel ". buffer_prefix (str): Prefix for buffer variables. Default is "__global ". smem_align (str): Alignment attribute for shared memory. Default is "__attribute__ ((aligned (16))) ". smem_prefix (str): Prefix for shared memory variables. Default is "__local ". arg_int_prefix (str): Prefix for integer function arguments. Default is "const int". half_prekernel (str): OpenCL extension pragma for enabling half precision. Default is "#pragma OPENCL EXTENSION cl_khr_fp16 : enable". barrier (str): Barrier code for synchronizing threads in a work group. Default is "barrier(CLK_LOCAL_MEM_FENCE);". float4 (str): String conversion for type float4. Default is "(float4)". gid (list of str): List of get_group_id function calls for dimensions 0, 1, and 2. lid (list of str): List of get_local_id function calls for dimensions 0, 1, and 2. xid (list of str): List of get_global_id function calls for dimensions 0, 1, and 2. uses_vload (bool): Flag indicating if vload is used. Default is True. code_for_op (dict): Dictionary mapping operation names to lambda functions that generate code for the operations. Inherits from CStyleLanguage and adds a new entry for TernaryOps.MULACC. """ kernel_prefix = "__kernel " buffer_prefix = "__global " smem_align = "__attribute__ ((aligned (16))) " smem_prefix = "__local " arg_int_prefix = "const int" half_prekernel = "#pragma OPENCL EXTENSION cl_khr_fp16 : enable" barrier = "barrier(CLK_LOCAL_MEM_FENCE);" float4 = "(float4)" gid = [f"get_group_id({i})" for i in range(3)] lid = [f"get_local_id({i})" for i in range(3)] xid = [f"get_global_id({i})" for i in range(3)] uses_vload = True # NOTE: mad is used so the loads aren't reordered into the math on 845 code_for_op = { **CStyleLanguage().code_for_op, TernaryOps.MULACC: lambda a, b, c, dtype: f"mad({a},{b},{c})", }
OpenCLRenderer = functools.partial(uops_to_cstyle, OpenCLLanguage())
[docs] class MetalLanguage(CStyleLanguage): """ MetalLanguage Class: Inherits from CStyleLanguage. Defines Metal language specific attributes and methods. Attributes: kernel_prefix (str): Prefix for kernel functions. Default is "#include <metal_stdlib>\nusing namespace metal;\nkernel ". buffer_prefix (str): Prefix for device buffers. Default is "device ". smem_prefix (str): Prefix for threadgroup shared memory. Default is "threadgroup ". arg_int_prefix (str): Prefix for constant integer arguments. Default is "constant int&". barrier (str): Thread group barrier for synchronization. Default is "threadgroup_barrier(mem_flags::mem_threadgroup);". float4 (str): Data type for four component floating point values. Default is "float4". uses_ptr_arithmetic (bool): Indicates whether the language requires pointer arithmetic. Default is True. gid (list[str]): List of strings representing global thread IDs. Generated from range 3. lid (list[str]): List of strings representing local thread IDs. Generated from range 3. extra_args (list[str]): Additional arguments required by Metal language. Default is ["uint3 gid [[threadgroup_position_in_grid]]", "uint3 lid [[thread_position_in_threadgroup]]"]. """ kernel_prefix = "#include <metal_stdlib>\nusing namespace metal;\nkernel " buffer_prefix = "device " smem_prefix = "threadgroup " arg_int_prefix = "constant int&" barrier = "threadgroup_barrier(mem_flags::mem_threadgroup);" float4 = "float4" uses_ptr_arithmetic = True gid = [f"gid.{chr(120+i)}" for i in range(3)] lid = [f"lid.{chr(120+i)}" for i in range(3)] extra_args = [ "uint3 gid [[threadgroup_position_in_grid]]", "uint3 lid [[thread_position_in_threadgroup]]", ]
MetalRenderer = functools.partial(uops_to_cstyle, MetalLanguage())
[docs] class CUDALanguage(CStyleLanguage): """ The `CUDALanguage` class is a subclass of the `CStyleLanguage` class, specifically tailored for CUDA GPU programming. Attributes: kernel_prefix (str): Prefix for defining kernels in CUDA. smem_prefix (str): Prefix for shared memory variables in CUDA. smem_prefix_for_cast (bool): Flag to indicate whether prefix should be used with casting for shared memory. arg_int_prefix (str): Prefix for integer arguments in CUDA kernel functions. barrier (str): Code snippet for thread synchronization in CUDA. float4 (str): Function name for creating a 4-component floating point number in CUDA. gid (list): List of strings representing the global index in each dimension (x, y, z). lid (list): List of strings representing the local index in each dimension (x, y, z). xid (list): List of strings representing the combined global and local indices for each dimension (x, y, z). code_for_op (dict): Dictionary mapping binary operations to their corresponding CUDA code. half_prekernel (str): Pre-kernel code needed when working with half-precision floating point numbers in CUDA. """ kernel_prefix = '#define INFINITY (__int_as_float(0x7f800000))\n#define NAN (__int_as_float(0x7fffffff))\nextern "C" __global__ ' smem_prefix = "__shared__ " smem_prefix_for_cast = False arg_int_prefix = "const int" barrier = "__syncthreads();" float4 = "make_float4" gid = [f"blockIdx.{chr(120+i)}" for i in range(3)] lid = [f"threadIdx.{chr(120+i)}" for i in range(3)] xid = [ f"(blockIdx.{chr(120+i)}*blockDim.{chr(120+i)}+threadIdx.{chr(120+i)})" for i in range(3) ] code_for_op = { **CStyleLanguage().code_for_op, BinaryOps.MAX: lambda a, b, dtype: f"max({a},{b})" if dtype != dtypes.half else f"__hmax({a},{b})", } half_prekernel = """ #include <cuda_fp16.h> struct half4 { half x, y, z, w; }; __device__ half4 make_half4(half x, half y, half z, half w) { half4 ret; ret.x = x; ret.y = y; ret.z = z; ret.w = w; return ret; } """
CUDARenderer = functools.partial(uops_to_cstyle, CUDALanguage())
[docs] class HIPLanguage(CStyleLanguage): """ HIPLanguage class that inherits from CStyleLanguage. Attributes: kernel_prefix (str): Kernel prefix containing various function definitions and necessary includes. launch_bounds (bool): A boolean indicating if the language supports launch bounds. smem_prefix (str): The shared memory prefix for the language. smem_prefix_for_cast (bool): A boolean indicating whether a cast is required for the shared memory prefix. barrier (str): The barrier synchronization primitive for the language. float4 (str): The name of the float4 type for the language. uses_vload (bool): Whether the language uses vload. uses_ptr_arithmetic (bool): Whether the language uses pointer arithmetic. arg_int_prefix (str): The integer prefix for function arguments. gid (list): A list comprising strings representing the grid index in three dimensions. lid (list): A list comprising strings representing the local thread index in three dimensions. xid (list): A list comprising strings representing the extended thread index in three dimensions. code_for_op (dict): A dictionary containing lambda functions for different operations and their corresponding implementations based on data type. """ kernel_prefix = ( '#include <hip/hip_common.h>\n#define INFINITY (__builtin_inff())\n#define NAN (__builtin_nanf(""))' + """ __device__ float4 max(float4 x, float4 y) { return float4(max(x.x, y.x), max(x.y, y.y), max(x.z, y.z), max(x.w, y.w)); } __device__ float4 pow(float x, float4 y) { return float4(pow(x, y.x), pow(x, y.y), pow(x, y.z), pow(x, y.w)); } __device__ float4 pow(float4 x, float4 y) { return float4(pow(x.x, y.x), pow(x.y, y.y), pow(x.z, y.z), pow(x.w, y.w)); } __device__ float4 log2(float4 x) { return float4(log2(x.x), log2(x.y), log2(x.z), log2(x.w)); } __device__ float4 exp2(float4 x) { return float4(exp2(x.x), exp2(x.y), exp2(x.z), exp2(x.w)); } __device__ float4 sin(float4 x) { return float4(sin(x.x), sin(x.y), sin(x.z), sin(x.w)); } typedef float float8 __attribute__((ext_vector_type(8))); __device__ float8 make_float8(float x, float y, float z, float w, float a, float b, float c, float d) { return {x, y, z, w, a, b, c, d}; } extern "C" __global__ """ ) launch_bounds = True smem_prefix = "__shared__ " smem_prefix_for_cast = False barrier = "__syncthreads();" float4 = "make_float4" uses_vload = True uses_ptr_arithmetic = True arg_int_prefix = "const int" half_prekernel = ( "#include <hip/hip_fp16.h>\n" + """ typedef union { struct { half x, y, z, w; } __attribute__((aligned(8))); half data[4]; } half4; __device__ half4 make_half4(half x, half y, half z, half w) { return {x, y, z, w}; } typedef union { struct { half x, y, z, w, a, b, c, d; } __attribute__((aligned(16))); half data[8]; } half8; __device__ half8 make_half8(half x, half y, half z, half w, half a, half b, half c, half d) { return {x, y, z, w, a, b, c, d}; } typedef _Float16 half16 __attribute__((ext_vector_type(16))); __device__ half16 make_half16(half x, half y, half z, half w, half a, half b, half c, half d, half e, half f, half g, half h, half i, half j, half k, half l) { return {x, y, z, w, a, b, c, d, e, f, g, h, i, j, k, l}; } __device__ float vload_half(size_t offset, const half *p) { return (float)*(p + offset); } __device__ float2 vload_half2(size_t offset, const half *p) { return make_float2((float)*(p + offset*2), (float)*(p + offset*2 + 1)); } __device__ float4 vload_half4(size_t offset, const half *p) { return make_float4((float)*(p + offset*4), (float)*(p + offset*4 + 1), (float)*(p + offset*4 + 2), (float)*(p + offset*4 + 3)); } __device__ void vstore_half(float data, size_t offset, half *p) { *(p + offset) = (half)data; } __device__ void vstore_half2(float2 data, size_t offset, half *p) { *(p + offset*2) = (half)data.x; *(p + offset*2 + 1) = (half)data.y; } __device__ void vstore_half4(float4 data, size_t offset, half *p) { *(p + offset*4) = (half)data.x; *(p + offset*4 + 1) = (half)data.y; *(p + offset*4 + 2) = (half)data.z; *(p + offset*4 + 3) = (half)data.w; } __device__ half exp2(half x) { return hexp2(x); } __device__ half log2(half x) { return hlog2(x); } __device__ half sin(half x) { return hsin(x); } __device__ half sqrt(half x) { return hsqrt(x); } __device__ half hmax(half a, half b) { return __hgt(a, b) ? a : b; } __device__ half operator%(const half &a, const half &b) { return __hsub(a, __hmul(b, __float2half(floorf(__half2float(a) / __half2float(b))))); } __device__ bool operator!=(const half &a, const int &b) { return (float)a != b; } // HACKS for ALU ops on half and result of half2 GEP __device__ half operator+(const half &a, const unsigned short &b) { return __hadd(a, (half)(b)); } __device__ half operator-(const half &a, const unsigned short &b) { return __hsub(a, (half)(b)); } __device__ half operator*(const half &a, const unsigned short &b) { return __hmul(a, (half)(b)); } __device__ half operator/(const half &a, const unsigned short &b) { return __hdiv(a, (half)(b)); } __device__ bool operator<(const half &a, const unsigned short &b) { return __hlt(a, (half)(b)); } // now the other way __device__ half operator+(const unsigned short &a, const half &b) { return __hadd((half)(a), b); } __device__ half operator-(const unsigned short &a, const half &b) { return __hsub((half)(a), b); } __device__ half operator*(const unsigned short &a, const half &b) { return __hmul((half)(a), b); } __device__ half operator/(const unsigned short &a, const half &b) { return __hdiv((half)(a), b); } __device__ bool operator<(const unsigned short &a, const half &b) { return __hlt((half)(a), b); } """ ) gid = [f"blockIdx.{chr(120+i)}" for i in range(3)] lid = [f"threadIdx.{chr(120+i)}" for i in range(3)] xid = [ f"(blockIdx.{chr(120+i)}*blockDim.{chr(120+i)}+threadIdx.{chr(120+i)})" for i in range(3) ] code_for_op = { **CStyleLanguage().code_for_op, BinaryOps.MAX: lambda a, b, dtype: f"max({a},{b})" if dtype != dtypes.half else f"hmax({a},{b})", TernaryOps.WHERE: lambda a, b, c, dtype: f"({a}!=0?{b}:{c})" if dtype != dtypes.half else f"(half)({a}!=0?{b}:{c})", }
HIPRenderer = functools.partial(uops_to_cstyle, HIPLanguage()) # TODO: how much of this can be merged with above?
[docs] class WGSLLanguage(CStyleLanguage): """ The WGSLLanguage class, a subclass of CStyleLanguage. This class is used to represent the WebGPU Shading Language (WGSL). Attributes: gid (list): List containing gid indices for x, y, and z dimensions. lid (list): List containing lid indices for x, y, and z dimensions. size_prefix (str): Prefix used to declare a variable. Defaults to "let". barrier (str): Code snippet for workgroup barrier. Defaults to "workgroupBarrier();". generic_var_prefix (str): Prefix used for generic variables. Defaults to "var ". external_local_bufs (bool): Flag indicating if local buffers are external. Defaults to True. code_for_op (dict): Dictionary containing custom op codes for various operations like BinaryOps and TernaryOps. type_map (dict): Dictionary mapping dtypes to WGSL types. """ gid = [f"i32(gindex.{'xyz'[x]})" for x in range(3)] lid = [f"i32(lindex.{'xyz'[x]})" for x in range(3)] size_prefix = "let" barrier = "workgroupBarrier();" generic_var_prefix = "var " external_local_bufs = True code_for_op = { **CStyleLanguage().code_for_op, BinaryOps.CMPLT: lambda x, y, dtype: f"f32({x}<{y})", TernaryOps.MULACC: lambda x, y, z, dtype: f"fma({x},{y},{z})", TernaryOps.WHERE: lambda a, b, c, dtype: f"select({c},{b},{a}!=0.)", } type_map = { dtypes.float: "f32", dtypes.half: "f16", dtypes.int32: "i32", dtypes.uint32: "u32", dtypes.bool: "bool", }
[docs] def render_local(self, name: str, size: int): """ Render a local variable declaration. Args: name (str): The name of the variable to be declared. size (int): The size of the array. Returns: str: The WGSL code snippet for declaring a local variable. """ return f"var<workgroup> {name}: array<f32,{size}>;"
[docs] def render_const(self, x: Union[float, int], var_dtype) -> str: """ Render a constant value. Args: x (Union[float, int]): The constant value to be rendered. var_dtype: The data type of the variable. Returns: str: The WGSL code snippet for the constant value. """ if math.isnan(x): return "nan()" elif math.isinf(x): return ("-" if x < 0 else "") + "0x1.fffffep+127f" return f"({super().render_const(x, var_dtype)})"
[docs] def render_kernel( self, function_name: str, kernel: List[str], bufs: List[Tuple[str, DType]], local_size: List[int], prekernel: List[str], ) -> str: """ Render the kernel for execution. Attributes: function_name (str): The name of the function to be rendered. kernel (List[str]): The list of kernel code lines. bufs (List[Tuple[str, DType]]): A list of tuples containing buffer names and their respective data types. local_size (List[int]): The local size for workgroup execution. If not provided, default is [1]. prekernel (List[str]): Code lines to be executed before the kernel code. Returns: str: The rendered kernel code as a string. """ local_size = local_size[::-1] if local_size else [1] bind_it = iter(range(len(bufs))) prg = "fn nan() -> f32 { let bits = 0xffffffffu; return bitcast<f32>(bits); }\n" prg += "\n".join( prekernel + [ f"@group(0) @binding({next(bind_it)}) var<storage,read_write> {name}: array<{self.type_map[dtype]}>;" for name, dtype in bufs ] ) prg += ( f"\n@compute @workgroup_size({','.join([str(x) for x in local_size])}) fn {function_name}(@builtin(workgroup_id) gindex: vec3<u32>, @builtin(local_invocation_id) lindex: vec3<u32>) {{\n" + "\n".join(kernel) + "\n}" ) return prg
[docs] def render_for( self, expr: str, _min: Union[int, str], _max: Union[int, str] ) -> str: """ Render a for loop with the given expression and range. Attributes: expr (str): The loop variable's name. _min (Union[int, str]): The starting value of the loop variable. _max (Union[int, str]): The ending value of the loop variable. Returns: str: The rendered for loop as a string. """ return f"for(var {expr} = {_min}; {expr} < {_max}; {expr}++) {{"
[docs] def render_if(self, cond: str): """ Render an if statement with the given condition. Attributes: cond (str): The conditional expression to be checked in the if statement. Returns: str: The rendered if statement as a string. """ return f"if (bool({cond})) {{"
[docs] def render_conditional(self, cond: str, x: str, y: str) -> str: """ Render a conditional expression that selects between two values based on a condition. Attributes: cond (str): The conditional expression to be checked. x (str): The value to select if the condition is true. y (str): The value to select if the condition is false. Returns: str: The rendered conditional expression as a string. """ return f"select(f32({y}), {x}, bool({cond}))"
[docs] def render_cast(self, x: List[str], var_dtype: DType) -> str: """ Render a type cast for the given value with the target data type. Attributes: x (List[str]): The value to be casted. var_dtype (DType): The target data type. Returns: str: The rendered type cast expression as a string. Raises: NotImplementedError: If no cast is available for the target data type. """ if self.type_map[var_dtype]: return f"{self.type_map[var_dtype]}({x[0]})" raise NotImplementedError(f"no cast for {var_dtype}")
[docs] def render_store( self, buf_name: str, buf_dtype: DType, var_name: str, var_dtype: DType, idx, local=False, ) -> str: """ Render a store operation that stores a value in a buffer at the given index. Attributes: buf_name (str): The name of the buffer. buf_dtype (DType): The data type of the buffer. var_name (str): The name of the variable to store. var_dtype (DType): The data type of the variable. idx: The index at which to store the value in the buffer. local (bool, optional): Whether this is a local store operation. Defaults to False. Returns: str: The rendered store operation as a string. """ return f"{buf_name}[{idx}] = {self.render_cast([var_name], buf_dtype) if var_dtype != buf_dtype else var_name};"
WGSLRenderer = functools.partial(uops_to_cstyle, WGSLLanguage())