Source code for tinygrad.runtime.ops_webgpu

from wgpu.utils.device import get_default_device
from tinygrad.device import Compiled, Allocator
from tinygrad.codegen.kernel import LinearizerOptions
from tinygrad.renderer.cstyle import WGSLRenderer
import wgpu

wgpu_device = get_default_device()


[docs] class WebGPUProgram: """ This class represents a WebGPU program. It stores the name, library, and compiled shader module of a GPU program. Attributes: name (str): The name of the GPU program. lib (bytes): The library containing the code for the GPU program. prg (wgpu_device.create_shader_module): The compiled shader module of the GPU program. """ def __init__(self, name: str, lib: bytes): """ Constructs a WebGPUProgram object. Args: name (str): The name of the GPU program. lib (bytes): The library containing the code for the GPU program. Notes: This is the compiler for the GPU program. """ self.name, self.lib, self.prg = ( name, lib, wgpu_device.create_shader_module(code=lib), ) # NOTE: this is the compiler def __call__(self, *bufs, global_size, local_size, vals=(), wait=False): """ Executes the GPU program. Args: *bufs (tuple of buffers): The input buffers for the GPU program. global_size (tuple): The size of the global workgroup. local_size (tuple): The size of the local workgroup. vals (tuple, optional): Additional values to pass to the GPU program. Defaults to empty tuple. wait (bool, optional): Whether or not to wait for the execution to finish before returning. Defaults to False. """ assert len(bufs) <= 8, "WEBGPU only supports 8 buffers" binding_layouts = [ { "binding": i, "visibility": wgpu.ShaderStage.COMPUTE, "buffer": {"type": wgpu.BufferBindingType.storage}, } for i in range(len(bufs)) ] bindings = [ {"binding": i, "resource": {"buffer": x, "offset": 0, "size": x.size}} for i, x in enumerate(bufs) ] bind_group_layout = wgpu_device.create_bind_group_layout( entries=binding_layouts ) pipeline_layout = wgpu_device.create_pipeline_layout( bind_group_layouts=[bind_group_layout] ) bind_group = wgpu_device.create_bind_group( layout=bind_group_layout, entries=bindings ) compute_pipeline = wgpu_device.create_compute_pipeline( layout=pipeline_layout, compute={"module": self.prg, "entry_point": self.name}, ) command_encoder = wgpu_device.create_command_encoder() compute_pass = command_encoder.begin_compute_pass() compute_pass.set_pipeline(compute_pipeline) compute_pass.set_bind_group(0, bind_group, [], 0, 999999) # last 2 not used compute_pass.dispatch_workgroups(*global_size) # x y z compute_pass.end() wgpu_device.queue.submit([command_encoder.finish()])
[docs] class WebGpuAllocator(Allocator): """ WebGpuAllocator class. Attributes: Allocator (parent class): Parent class for this class. """ def _alloc(self, size: int): """ Allocate memory on the device. Args: size (int): Size of memory to be allocated. Returns: Memory buffer created by wgpu_device.create_buffer(). """ return wgpu_device.create_buffer( size=size, usage=wgpu.BufferUsage.STORAGE | wgpu.BufferUsage.COPY_DST | wgpu.BufferUsage.COPY_SRC, )
[docs] def copyin(self, dest, src: memoryview): """ Copy data from source to destination. Args: dest: Destination of the data. src (memoryview): Source of the data. """ wgpu_device.queue.write_buffer(dest, 0, src)
[docs] def copyout(self, dest, src: memoryview): """ Copy data from source to destination. Args: dest: Destination of the data. src (memoryview): Source of the data. Note: This is a temporary solution and should be removed in the future. """ dest[:] = wgpu_device.queue.read_buffer(src, 0) # TODO: remove this copy
[docs] class WebGpuDevice(Compiled): """ WebGpuDevice class. Attributes: Compiled (parent class): Parent class for this class. """ def __init__(self, device: str): """ Initialize an instance of the WebGpuDevice class. Args: device (str): Device identifier. Note: The WebGpuAllocator and LinearizerOptions classes are also initialized here. """ super().__init__( WebGpuAllocator(), LinearizerOptions( device="WEBGPU", supports_float4=False, local_max=[256, 256, 64], global_max=[65535, 65535, 65535], ), WGSLRenderer, lambda x: x, WebGPUProgram, )