from typing import List, Dict, Optional
from tinygrad.ops import LoadOps, ScheduleItem
from tinygrad.device import Device, Buffer, BufferCopy, JITRunner
from tinygrad.graph import log_schedule_item, print_tree
from tinygrad.helpers import prod
from tinygrad.shape.symbolic import Variable
[docs]
class CustomOp(JITRunner):
"""
This class represents a custom operation that can be executed by the JITRunner.
Attributes:
fxn (function): The function to be called when this operation is executed.
"""
def __init__(self, fxn):
"""
Initializes a new instance of the CustomOp class with the specified function.
Parameters:
fxn (function): The function to call when this operation is executed.
"""
self.fxn = fxn
super().__init__()
def __call__(
self,
rawbufs: List[Buffer],
var_vals: Dict[Variable, int],
wait=False,
jit=False,
):
"""
Executes this operation with the specified parameters.
The function associated with this CustomOp instance is called with the provided rawbufs as its arguments.
Parameters:
rawbufs (List[Buffer]): A list of Buffer objects to be passed as arguments to the function.
var_vals (Dict[Variable, int]): A dictionary mapping Variables to integers. This parameter is not used in this method.
wait (bool): A flag indicating whether to wait for completion before returning. Default is False.
jit (bool): A flag indicating whether to use Just-In-Time compilation. Default is False.
"""
self.fxn(*rawbufs)
[docs]
def lower_schedule_item(si: ScheduleItem) -> Optional[JITRunner]:
"""
Lower a schedule item to a JIT runner.
This function takes a `ScheduleItem` and returns an optional `JITRunner`. If the schedule item is empty, it returns None.
If the operation is of type LoadOps.FROM, it returns BufferCopy. If the operation is of type LoadOps.CUSTOM, it returns CustomOp with argument si.ast.arg.
Otherwise, it gets a runner from Device based on the output's device.
:param ScheduleItem si: The schedule item to lower.
:return: Optional[JITRunner]
"""
assert (
all(si.out.device == x.device for x in si.inputs) or si.ast.op is LoadOps.FROM
), f"all devices must be the same, {si.out.device} != {[x.device for x in si.inputs]} {print_tree(si.ast) or ''}"
if si.ast.op is LoadOps.EMPTY:
return None
if si.ast.op is LoadOps.FROM:
return BufferCopy
if si.ast.op is LoadOps.CUSTOM:
return CustomOp(si.ast.arg)
return Device[si.out.device].get_runner(si.ast)
[docs]
def run_schedule(schedule: List[ScheduleItem], disable_logging=False):
"""
Run the schedule of operations.
This function pops the first ScheduleItem from the schedule and checks if logging is disabled. If not, it logs the
schedule item. It then asserts that all inputs have been realized, raising an error otherwise. Then, it lowers the
schedule item to a program (not shown), creates an output buffer if one doesn't exist already, and deletes the
output buffer's operation and its views' operations. Finally, it executes the program with the realized outputs
and inputs, along with variable values.
:param schedule: A list of ScheduleItems to be executed.
:type schedule: List[ScheduleItem]
:param disable_logging: Whether or not to disable logging, defaults to False
:type disable_logging: bool, optional
"""
while len(schedule):
si = schedule.pop(0)
if not disable_logging:
log_schedule_item(si)
assert all(
x.realized for x in si.inputs
), "can't run schedule, some inputs aren't realized"
# get the program
prg = lower_schedule_item(si)
# we don't have an output buffer, we have to create it, and create to max size if it has symbolic shape
si.out.realized = (
si.out.output_buffer
if si.out.output_buffer is not None
else Buffer(
si.out.device,
prod((s if isinstance(s, int) else s.max for s in si.out.shape)),
si.out.dtype,
)
)
del si.out.op
for v in si.out.views:
del v.op
# run the function (put it in JIT)
if prg:
prg.exec([si.out.realized] + [x.realized for x in si.inputs], si.var_vals)