import ctypes
from typing import Any, Optional, Tuple, Dict, List, cast
import gpuctypes.cuda as cuda
from tinygrad.helpers import init_c_var, encode_args_cuda_style
from tinygrad.device import CompiledASTRunner, update_stats, Buffer
from tinygrad.runtime.ops_cuda import check, cu_time_execution
from tinygrad.shape.symbolic import Variable
from tinygrad.jit import (
JitItem,
get_input_replace,
get_jit_stats,
get_jc_idxs_with_updatable_launch_dims,
get_jc_idxs_with_updatable_var_vals,
GraphException,
)
[docs]
class CUDAGraph:
"""
This class represents a CUDA Graph. It is initialized with a list of JitItems, input raw buffers, and a dictionary mapping variables to integers.
Attributes:
jit_cache (List[JitItem]): A list of JitItems.
input_rawbuffers (List[Buffer]): A list of input raw buffers.
var_vals (Dict[Variable, int]): A dictionary mapping variables to integers.
"""
def __init__(
self,
jit_cache: List[JitItem],
input_rawbuffers: List[Buffer],
var_vals: Dict[Variable, int],
):
"""Initialize the object.
Args:
jit_cache (List[JitItem]): A list of JitItems to be used for initialization.
input_rawbuffers (List[Buffer]): A list of raw buffers to be used as inputs.
var_vals (Dict[Variable, int]): A dictionary mapping variables to integer values.
Attributes:
jit_cache (List[JitItem]): The JIT cache.
input_replace (Dict[Tuple[int, int], str]): A dictionary mapping tuples of integers to strings, representing the inputs to replace.
op_estimate (float): An estimation of the operations involved in the process.
mem_estimate (float): An estimation of the memory used during the process.
jc_idxs_with_updatable_launch_dims (List[int]): A list of indices with updatable launch dimensions.
jc_idxs_with_updatable_var_vals (List[int]): A list of indices with updatable variable values.
jc_idxs_with_updatable_rawbufs (List[int]): A list of indices with updatable raw buffers.
updatable_nodes (Dict[int, Tuple[Any, Any, Any]]): A dictionary mapping integers to tuples of any type, representing the updatable nodes.
graph: The created graph.
instance: The instantiated graph.
"""
if not all(isinstance(ji.prg, CompiledASTRunner) for ji in jit_cache):
raise GraphException
self.jit_cache = jit_cache
self.input_replace = get_input_replace(jit_cache, input_rawbuffers)
self.op_estimate, self.mem_estimate = get_jit_stats(jit_cache)
self.jc_idxs_with_updatable_launch_dims = (
get_jc_idxs_with_updatable_launch_dims(jit_cache)
)
self.jc_idxs_with_updatable_var_vals = get_jc_idxs_with_updatable_var_vals(
jit_cache
)
self.jc_idxs_with_updatable_rawbufs = list(
set([x[0] for x in self.input_replace.keys()])
)
self.updatable_nodes: Dict[
int, Tuple[Any, Any, Any]
] = {} # Dict[jc index] = tuple(graph node, node params, input kernel params)
self.graph = self.graph_create()
graph_node: Optional[ctypes._CData] = None
for (j, i), input_name in self.input_replace.items():
self.jit_cache[j].rawbufs[i] = input_rawbuffers[input_name]
for j, ji in enumerate(self.jit_cache):
prg: CompiledASTRunner = cast(CompiledASTRunner, ji.prg)
c_deps = (
(type(graph_node) * 1)(*(graph_node,))
if graph_node is not None
else None
)
c_kernel_input_config, c_input_params = encode_args_cuda_style(
[cast(Buffer, x)._buf for x in ji.rawbufs],
[var_vals[x] for x in prg.vars],
*self.encode_args_info(),
)
c_node_params = self.build_kernel_node_params(
prg,
*cast(Tuple[List[int], List[int]], prg.launch_dims(var_vals)),
c_kernel_input_config,
)
graph_node = self.graph_add_kernel_node(self.graph, c_deps, c_node_params)
if (
j in self.jc_idxs_with_updatable_launch_dims
or j in self.jc_idxs_with_updatable_var_vals
or j in self.jc_idxs_with_updatable_rawbufs
):
self.updatable_nodes[j] = (graph_node, c_node_params, c_input_params)
self.instance = self.graph_instantiate(self.graph)
def __call__(
self,
input_rawbuffers: List[Buffer],
var_vals: Dict[Variable, int],
wait=False,
jit=False,
) -> Optional[float]:
"""
Call method for the class. Updates rawbuffers and var_vals in the c_input_params struct, updates launch dims in
the c_node_params struct, and finally updates graph nodes with the updated structs. Executes the graph and
returns the execution time.
Attributes:
self (Any): The instance of the class.
input_rawbuffers (List[Buffer]): List of raw buffers to be used as input.
var_vals (Dict[Variable, int]): Dictionary containing variable values.
wait (bool): Whether or not to wait for the execution to finish before returning. Default is False.
jit (bool): Whether or not to use Just-In-Time compilation. Default is False.
Returns:
Optional[float]: The execution time of the graph, if wait is True. Otherwise, None.
"""
# Update rawbuffers in the c_input_params struct.
for (j, i), input_idx in self.input_replace.items():
setattr(
self.updatable_nodes[j][2], f"f{i}", input_rawbuffers[input_idx]._buf
)
# Update var_vals in the c_input_params struct.
for j in self.jc_idxs_with_updatable_var_vals:
for i, v in enumerate(cast(CompiledASTRunner, self.jit_cache[j].prg).vars):
setattr(
self.updatable_nodes[j][2],
f"f{len(self.jit_cache[j].rawbufs) + i}",
var_vals[v],
)
# Update launch dims in the c_node_params struct.
for j in self.jc_idxs_with_updatable_launch_dims:
self.set_kernel_node_launch_dims(
self.updatable_nodes[j][1],
*cast(CompiledASTRunner, self.jit_cache[j].prg).launch_dims(var_vals),
)
# Update graph nodes with the updated structs.
for node, c_node_params, _ in self.updatable_nodes.values():
self.graph_exec_kernel_node_set_params(
self.instance, node, ctypes.byref(c_node_params)
)
et = self.graph_launch(self.instance, None, wait=wait)
update_stats(
f"<batched {len(self.jit_cache)}>",
self.op_estimate,
self.mem_estimate,
var_vals,
et,
buf_count=len(input_rawbuffers),
jit=jit,
num_kernels=len(self.jit_cache),
)
return et
def __del__(self):
"""
Destructor for the class.
This method is called when an instance of the class is about to be destroyed. It destroys the cuGraph and cuGraphExec objects.
Attributes:
self.graph (cuda.CUgraph): The graph object to destroy.
self.instance (cuda.CUgraphExec): The graph execution instance to destroy.
"""
check(cuda.cuGraphDestroy(self.graph))
check(cuda.cuGraphExecDestroy(self.instance))
[docs]
def encode_args_info(self):
"""
Encodes the arguments information for the class.
This method encodes the arguments information required by other methods in the class.
Returns:
tuple: A tuple containing the encoded arguments information.
"""
return (cuda.CUdeviceptr_v2, (1, 2, 0))
[docs]
def graph_create(self):
"""
Creates a new cuGraph object.
This method initializes and creates a new cuGraph object.
Returns:
cuda.CUgraph: The newly created cuGraph object.
"""
return init_c_var(
cuda.CUgraph(), lambda x: check(cuda.cuGraphCreate(ctypes.byref(x), 0))
)
[docs]
def graph_instantiate(self, graph):
"""
Instantiates a cuGraphExec object from a cuGraph object.
This method creates an instance of the cuGraph execution from a given cuGraph object.
Args:
graph (cuda.CUgraph): The cuGraph object to instantiate.
Returns:
cuda.CUgraphExec: The instantiated cuGraph execution object.
"""
return init_c_var(
cuda.CUgraphExec(),
lambda x: check(
cuda.cuGraphInstantiate_v2(ctypes.byref(x), graph, None, None, 0)
),
)
[docs]
def graph_add_kernel_node(self, graph, c_deps, c_node_params):
"""
Adds a kernel node to the cuGraph object.
This method adds a new kernel node to the given cuGraph object.
Args:
graph (cuda.CUgraph): The cuGraph object to add the kernel node to.
c_deps (ctypes.c_void_p): The dependencies of the kernel node.
c_node_params (ctypes.c_void_p): The parameters for the new kernel node.
Returns:
cuda.CUgraphNode: The newly created kernel node.
"""
return init_c_var(
cuda.CUgraphNode(),
lambda x: check(
cuda.cuGraphAddKernelNode(
ctypes.byref(x),
graph,
c_deps,
ctypes.sizeof(c_deps) // 8 if c_deps else 0,
ctypes.byref(c_node_params),
)
),
)
[docs]
def graph_launch(self, *args, wait=False):
"""
Launches the cuGraph execution.
This method launches the execution of the cuGraph object.
Args:
*args (tuple): The arguments for the graph execution.
wait (bool, optional): Whether to wait for the execution to finish or not. Defaults to False.
"""
return cu_time_execution(lambda: check(cuda.cuGraphLaunch(*args)), enable=wait)
[docs]
def graph_exec_kernel_node_set_params(self, *args):
"""
Set parameters for the kernel node.
Parameters:
*args: Variable length argument list.
Returns:
Result of check function with CUDA call.
"""
return check(cuda.cuGraphExecKernelNodeSetParams(*args))
[docs]
def build_kernel_node_params(self, prg, global_size, local_size, c_kernel_config):
"""
Build parameters for the kernel node.
Parameters:
prg (Program): CUDA program object.
global_size (Tuple[int, int, int]): Global size of thread blocks.
local_size (Tuple[int, int, int]): Local size of thread blocks.
c_kernel_config: Kernel configuration parameters.
Returns:
CUDA_KERNEL_NODE_PARAMS object with specified parameters.
"""
return cuda.CUDA_KERNEL_NODE_PARAMS(
prg.clprg.prg, *global_size, *local_size, 0, None, c_kernel_config
)
[docs]
def set_kernel_node_launch_dims(
self, node, global_size: Tuple[int, int, int], local_size: Tuple[int, int, int]
):
"""
Set launch dimensions for the kernel node.
Parameters:
node: Kernel node object.
global_size (Tuple[int, int, int]): Global size of thread blocks.
local_size (Tuple[int, int, int]): Local size of thread blocks.
Returns:
None
Attributes:
node.blockDimX, node.blockDimY, node.blockDimZ: Local dimensions.
node.gridDimX, node.gridDimY, node.gridDimZ: Global dimensions.
"""
(
node.blockDimX,
node.blockDimY,
node.blockDimZ,
node.gridDimX,
node.gridDimY,
node.gridDimZ,
) = (*local_size, *global_size)