Source code for tinygrad.features.graph.cuda

import ctypes
from typing import Any, Optional, Tuple, Dict, List, cast
import gpuctypes.cuda as cuda
from tinygrad.helpers import init_c_var, encode_args_cuda_style
from tinygrad.device import CompiledASTRunner, update_stats, Buffer
from tinygrad.runtime.ops_cuda import check, cu_time_execution
from tinygrad.shape.symbolic import Variable
from tinygrad.jit import (
    JitItem,
    get_input_replace,
    get_jit_stats,
    get_jc_idxs_with_updatable_launch_dims,
    get_jc_idxs_with_updatable_var_vals,
    GraphException,
)


[docs] class CUDAGraph: """ This class represents a CUDA Graph. It is initialized with a list of JitItems, input raw buffers, and a dictionary mapping variables to integers. Attributes: jit_cache (List[JitItem]): A list of JitItems. input_rawbuffers (List[Buffer]): A list of input raw buffers. var_vals (Dict[Variable, int]): A dictionary mapping variables to integers. """ def __init__( self, jit_cache: List[JitItem], input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], ): """Initialize the object. Args: jit_cache (List[JitItem]): A list of JitItems to be used for initialization. input_rawbuffers (List[Buffer]): A list of raw buffers to be used as inputs. var_vals (Dict[Variable, int]): A dictionary mapping variables to integer values. Attributes: jit_cache (List[JitItem]): The JIT cache. input_replace (Dict[Tuple[int, int], str]): A dictionary mapping tuples of integers to strings, representing the inputs to replace. op_estimate (float): An estimation of the operations involved in the process. mem_estimate (float): An estimation of the memory used during the process. jc_idxs_with_updatable_launch_dims (List[int]): A list of indices with updatable launch dimensions. jc_idxs_with_updatable_var_vals (List[int]): A list of indices with updatable variable values. jc_idxs_with_updatable_rawbufs (List[int]): A list of indices with updatable raw buffers. updatable_nodes (Dict[int, Tuple[Any, Any, Any]]): A dictionary mapping integers to tuples of any type, representing the updatable nodes. graph: The created graph. instance: The instantiated graph. """ if not all(isinstance(ji.prg, CompiledASTRunner) for ji in jit_cache): raise GraphException self.jit_cache = jit_cache self.input_replace = get_input_replace(jit_cache, input_rawbuffers) self.op_estimate, self.mem_estimate = get_jit_stats(jit_cache) self.jc_idxs_with_updatable_launch_dims = ( get_jc_idxs_with_updatable_launch_dims(jit_cache) ) self.jc_idxs_with_updatable_var_vals = get_jc_idxs_with_updatable_var_vals( jit_cache ) self.jc_idxs_with_updatable_rawbufs = list( set([x[0] for x in self.input_replace.keys()]) ) self.updatable_nodes: Dict[ int, Tuple[Any, Any, Any] ] = {} # Dict[jc index] = tuple(graph node, node params, input kernel params) self.graph = self.graph_create() graph_node: Optional[ctypes._CData] = None for (j, i), input_name in self.input_replace.items(): self.jit_cache[j].rawbufs[i] = input_rawbuffers[input_name] for j, ji in enumerate(self.jit_cache): prg: CompiledASTRunner = cast(CompiledASTRunner, ji.prg) c_deps = ( (type(graph_node) * 1)(*(graph_node,)) if graph_node is not None else None ) c_kernel_input_config, c_input_params = encode_args_cuda_style( [cast(Buffer, x)._buf for x in ji.rawbufs], [var_vals[x] for x in prg.vars], *self.encode_args_info(), ) c_node_params = self.build_kernel_node_params( prg, *cast(Tuple[List[int], List[int]], prg.launch_dims(var_vals)), c_kernel_input_config, ) graph_node = self.graph_add_kernel_node(self.graph, c_deps, c_node_params) if ( j in self.jc_idxs_with_updatable_launch_dims or j in self.jc_idxs_with_updatable_var_vals or j in self.jc_idxs_with_updatable_rawbufs ): self.updatable_nodes[j] = (graph_node, c_node_params, c_input_params) self.instance = self.graph_instantiate(self.graph) def __call__( self, input_rawbuffers: List[Buffer], var_vals: Dict[Variable, int], wait=False, jit=False, ) -> Optional[float]: """ Call method for the class. Updates rawbuffers and var_vals in the c_input_params struct, updates launch dims in the c_node_params struct, and finally updates graph nodes with the updated structs. Executes the graph and returns the execution time. Attributes: self (Any): The instance of the class. input_rawbuffers (List[Buffer]): List of raw buffers to be used as input. var_vals (Dict[Variable, int]): Dictionary containing variable values. wait (bool): Whether or not to wait for the execution to finish before returning. Default is False. jit (bool): Whether or not to use Just-In-Time compilation. Default is False. Returns: Optional[float]: The execution time of the graph, if wait is True. Otherwise, None. """ # Update rawbuffers in the c_input_params struct. for (j, i), input_idx in self.input_replace.items(): setattr( self.updatable_nodes[j][2], f"f{i}", input_rawbuffers[input_idx]._buf ) # Update var_vals in the c_input_params struct. for j in self.jc_idxs_with_updatable_var_vals: for i, v in enumerate(cast(CompiledASTRunner, self.jit_cache[j].prg).vars): setattr( self.updatable_nodes[j][2], f"f{len(self.jit_cache[j].rawbufs) + i}", var_vals[v], ) # Update launch dims in the c_node_params struct. for j in self.jc_idxs_with_updatable_launch_dims: self.set_kernel_node_launch_dims( self.updatable_nodes[j][1], *cast(CompiledASTRunner, self.jit_cache[j].prg).launch_dims(var_vals), ) # Update graph nodes with the updated structs. for node, c_node_params, _ in self.updatable_nodes.values(): self.graph_exec_kernel_node_set_params( self.instance, node, ctypes.byref(c_node_params) ) et = self.graph_launch(self.instance, None, wait=wait) update_stats( f"<batched {len(self.jit_cache)}>", self.op_estimate, self.mem_estimate, var_vals, et, buf_count=len(input_rawbuffers), jit=jit, num_kernels=len(self.jit_cache), ) return et def __del__(self): """ Destructor for the class. This method is called when an instance of the class is about to be destroyed. It destroys the cuGraph and cuGraphExec objects. Attributes: self.graph (cuda.CUgraph): The graph object to destroy. self.instance (cuda.CUgraphExec): The graph execution instance to destroy. """ check(cuda.cuGraphDestroy(self.graph)) check(cuda.cuGraphExecDestroy(self.instance))
[docs] def encode_args_info(self): """ Encodes the arguments information for the class. This method encodes the arguments information required by other methods in the class. Returns: tuple: A tuple containing the encoded arguments information. """ return (cuda.CUdeviceptr_v2, (1, 2, 0))
[docs] def graph_create(self): """ Creates a new cuGraph object. This method initializes and creates a new cuGraph object. Returns: cuda.CUgraph: The newly created cuGraph object. """ return init_c_var( cuda.CUgraph(), lambda x: check(cuda.cuGraphCreate(ctypes.byref(x), 0)) )
[docs] def graph_instantiate(self, graph): """ Instantiates a cuGraphExec object from a cuGraph object. This method creates an instance of the cuGraph execution from a given cuGraph object. Args: graph (cuda.CUgraph): The cuGraph object to instantiate. Returns: cuda.CUgraphExec: The instantiated cuGraph execution object. """ return init_c_var( cuda.CUgraphExec(), lambda x: check( cuda.cuGraphInstantiate_v2(ctypes.byref(x), graph, None, None, 0) ), )
[docs] def graph_add_kernel_node(self, graph, c_deps, c_node_params): """ Adds a kernel node to the cuGraph object. This method adds a new kernel node to the given cuGraph object. Args: graph (cuda.CUgraph): The cuGraph object to add the kernel node to. c_deps (ctypes.c_void_p): The dependencies of the kernel node. c_node_params (ctypes.c_void_p): The parameters for the new kernel node. Returns: cuda.CUgraphNode: The newly created kernel node. """ return init_c_var( cuda.CUgraphNode(), lambda x: check( cuda.cuGraphAddKernelNode( ctypes.byref(x), graph, c_deps, ctypes.sizeof(c_deps) // 8 if c_deps else 0, ctypes.byref(c_node_params), ) ), )
[docs] def graph_launch(self, *args, wait=False): """ Launches the cuGraph execution. This method launches the execution of the cuGraph object. Args: *args (tuple): The arguments for the graph execution. wait (bool, optional): Whether to wait for the execution to finish or not. Defaults to False. """ return cu_time_execution(lambda: check(cuda.cuGraphLaunch(*args)), enable=wait)
[docs] def graph_exec_kernel_node_set_params(self, *args): """ Set parameters for the kernel node. Parameters: *args: Variable length argument list. Returns: Result of check function with CUDA call. """ return check(cuda.cuGraphExecKernelNodeSetParams(*args))
[docs] def build_kernel_node_params(self, prg, global_size, local_size, c_kernel_config): """ Build parameters for the kernel node. Parameters: prg (Program): CUDA program object. global_size (Tuple[int, int, int]): Global size of thread blocks. local_size (Tuple[int, int, int]): Local size of thread blocks. c_kernel_config: Kernel configuration parameters. Returns: CUDA_KERNEL_NODE_PARAMS object with specified parameters. """ return cuda.CUDA_KERNEL_NODE_PARAMS( prg.clprg.prg, *global_size, *local_size, 0, None, c_kernel_config )
[docs] def set_kernel_node_launch_dims( self, node, global_size: Tuple[int, int, int], local_size: Tuple[int, int, int] ): """ Set launch dimensions for the kernel node. Parameters: node: Kernel node object. global_size (Tuple[int, int, int]): Global size of thread blocks. local_size (Tuple[int, int, int]): Local size of thread blocks. Returns: None Attributes: node.blockDimX, node.blockDimY, node.blockDimZ: Local dimensions. node.gridDimX, node.gridDimY, node.gridDimZ: Global dimensions. """ ( node.blockDimX, node.blockDimY, node.blockDimZ, node.gridDimX, node.gridDimY, node.gridDimZ, ) = (*local_size, *global_size)