import math
from typing import Tuple, Optional, cast
from tinygrad.helpers import argsort, DType
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, ReduceOps
from tinygrad.tensor import Function
from tinygrad.lazy import LazyBuffer
from tinygrad.shape.symbolic import sint
[docs]
class Contiguous(Function):
"""
This class defines a function for ensuring tensor is contiguous in memory.
Attributes:
forward (method): This method accepts a LazyBuffer and returns a LazyBuffer that is contiguous.
backward (method): This method accepts a LazyBuffer and returns it without any modification.
"""
[docs]
def forward(self, x: LazyBuffer) -> LazyBuffer:
"""
Returns a contiguous version of the input tensor.
Args:
x (LazyBuffer): The input tensor.
Returns:
LazyBuffer: A contiguous version of the input tensor.
"""
return x.contiguous()
[docs]
def backward(self, grad_output: LazyBuffer) -> LazyBuffer:
"""
Returns the gradient output without any modification.
Args:
grad_output (LazyBuffer): The gradient tensor.
Returns:
LazyBuffer: The same as the input grad_output tensor.
"""
return grad_output
[docs]
class ContiguousBackward(Function):
"""
This class defines a function for ensuring tensor is contiguous in memory during backpropagation.
Attributes:
forward (method): This method accepts a LazyBuffer and returns it without any modification.
backward (method): This method accepts a LazyBuffer and returns its contiguous version.
"""
[docs]
def forward(self, x: LazyBuffer) -> LazyBuffer:
"""
Returns the input tensor without any modifications.
Args:
x (LazyBuffer): The input tensor.
Returns:
LazyBuffer: The same as the input tensor.
"""
return x
[docs]
def backward(self, grad_output: LazyBuffer) -> LazyBuffer:
"""
Returns a contiguous version of the gradient tensor.
Args:
grad_output (LazyBuffer): The gradient tensor.
Returns:
LazyBuffer: A contiguous version of the input grad_output tensor.
"""
return grad_output.contiguous()
[docs]
class Cast(Function):
"""
This class defines a function for casting tensors to a specified data type.
Attributes:
forward (method): This method accepts a LazyBuffer, a DType and a boolean flag bitcast. It returns the casted tensor.
backward (method): This method accepts a LazyBuffer and returns its casted version with respect to the input dtype and bitcast.
"""
[docs]
def forward(self, x: LazyBuffer, dtype: DType, bitcast: bool = False) -> LazyBuffer:
"""
Casts the input tensor to a specified data type.
Args:
x (LazyBuffer): The input tensor.
dtype (DType): The target data type.
bitcast (bool): Whether to perform a bitcast operation. Default is False.
Returns:
LazyBuffer: The casted version of the input tensor.
"""
self.input_dtype, self.bitcast = x.dtype, bitcast
return x.cast(dtype, bitcast)
[docs]
def backward(self, grad_output: LazyBuffer) -> LazyBuffer:
"""
Casts the gradient tensor to the input data type and performs a bitcast operation if necessary.
Args:
grad_output (LazyBuffer): The gradient tensor.
Returns:
LazyBuffer: The casted version of the gradient tensor with respect to the original input tensor data type and bitcast.
"""
return grad_output.cast(self.input_dtype, self.bitcast)
# ************* unary ops *************
[docs]
class Zero(Function):
"""
Zero class for performing basic operations related to zero.
This class inherits from the Function class and overrides its methods.
Attributes:
forward (method): Applies the forward pass of the operation.
backward (method): Applies the backward pass of the operation.
"""
[docs]
def forward(self, x: LazyBuffer) -> LazyBuffer:
"""
Forward method for applying the forward pass of the operation.
This method takes a lazy buffer as input and returns another lazy buffer with all elements set to zero.
Parameters:
x (LazyBuffer): The input lazy buffer.
Returns:
LazyBuffer: The output lazy buffer with all elements set to zero.
"""
return x.const(0)
[docs]
def backward(self, grad: LazyBuffer) -> LazyBuffer:
"""
Backward method for applying the backward pass of the operation.
This method takes a lazy buffer as input and returns another lazy buffer with all elements set to zero.
Parameters:
grad (LazyBuffer): The gradient lazy buffer from previous operations.
Returns:
LazyBuffer: The output lazy buffer with all elements set to zero.
"""
return grad.const(0)
[docs]
class Neg(Function):
"""
Neg class for negating a given value.
Attributes:
forward (method): Method to perform the forward pass of the negation operation on the input x.
backward (method): Method to perform the backward pass of the negation operation, used during backpropagation.
"""
[docs]
def forward(self, x: LazyBuffer) -> LazyBuffer:
"""
Forward method for negating a given value during the forward pass.
Args:
x (LazyBuffer): Input buffer to be negated.
Returns:
LazyBuffer: The negation of the input buffer x.
"""
return x.e(UnaryOps.NEG)
[docs]
def backward(self, grad: LazyBuffer) -> LazyBuffer:
"""
Backward method for negating a given gradient during the backward pass.
Args:
grad (LazyBuffer): Gradient to be negated.
Returns:
LazyBuffer: The negation of the input gradient grad.
"""
return grad.e(UnaryOps.NEG)
[docs]
class Sin(Function):
"""
Implementation of the sine function in radians.
Attributes:
x (LazyBuffer): Input buffer for which the sine operation is applied.
"""
[docs]
def forward(self, x: LazyBuffer) -> LazyBuffer:
"""
Apply the sine function to input LazyBuffer and return result.
Args:
x (LazyBuffer): Input buffer for which the sine operation is applied.
Returns:
LazyBuffer: Output buffer after applying the sine operation.
"""
self.x = x
return x.e(UnaryOps.SIN)
[docs]
def backward(self, grad: LazyBuffer) -> LazyBuffer:
"""
Apply the gradient of the sine function to input LazyBuffer and return result.
The gradient of sin(x) is cos(x). Hence, we apply a cos operation here.
Args:
grad (LazyBuffer): Gradient buffer for which the operation is applied.
Returns:
LazyBuffer: Output buffer after applying the gradient operation.
"""
return (
self.x.const(math.pi / 2)
.e(BinaryOps.SUB, self.x)
.e(UnaryOps.SIN)
.e(BinaryOps.MUL, grad)
)
# NOTE: maximum(x, 0) behaves differently where x=0
[docs]
class Relu(Function):
"""
The Relu class implements the Rectified Linear Unit (ReLU) activation function.
Attributes:
x (LazyBuffer): Input tensor to which ReLU is applied.
ret (LazyBuffer): Output tensor after applying ReLU.
"""
[docs]
def forward(self, x: LazyBuffer) -> LazyBuffer:
"""
Returns a contiguous version of the input tensor.
Args:
x (LazyBuffer): The input tensor.
Returns:
LazyBuffer: A contiguous version of the input tensor.
"""
self.ret = x.e(BinaryOps.MAX, x.const(0))
return self.ret
[docs]
def backward(self, grad_output: LazyBuffer) -> LazyBuffer:
"""
Returns the gradient output without any modification.
Args:
grad_output (LazyBuffer): The gradient tensor.
Returns:
LazyBuffer: The same as the input grad_output tensor.
"""
return (
self.ret.const(0).e(BinaryOps.CMPLT, self.ret).e(BinaryOps.MUL, grad_output)
)
[docs]
class Log(Function):
"""
Function for calculating the logarithm of a given input.
Attributes:
x (LazyBuffer): Input buffer on which operations are performed.
"""
[docs]
def forward(self, x: LazyBuffer) -> LazyBuffer:
"""
Forward method to compute the natural logarithm of the input.
Args:
x (LazyBuffer): Input buffer.
Returns:
LazyBuffer: Output buffer after performing operations.
"""
self.x = x
return x.e(UnaryOps.LOG2).e(BinaryOps.MUL, x.const(math.log(2)))
[docs]
def backward(self, grad_output: LazyBuffer) -> LazyBuffer:
"""
Backward method to compute the gradient of the logarithm function.
Args:
grad_output (LazyBuffer): Gradient output buffer.
Returns:
LazyBuffer: Gradient buffer after performing operations.
"""
return grad_output.e(BinaryOps.DIV, self.x)
[docs]
class Exp(Function):
"""
This class represents an exponential function. It has methods for forward and backward operations.
"""
[docs]
def forward(self, x: LazyBuffer) -> LazyBuffer:
"""
Forward method:
Performs the exponential operation on input x.
Args:
x (LazyBuffer): Input buffer
Returns:
LazyBuffer: Output buffer after applying the exponential function
"""
self.ret = x.e(BinaryOps.MUL, x.const(1 / math.log(2))).e(UnaryOps.EXP2)
return self.ret
[docs]
def backward(self, grad_output: LazyBuffer) -> LazyBuffer:
"""
Backward method:
Performs the backward operation to calculate gradient of the exponential function.
Args:
grad_output (LazyBuffer): Gradient buffer from previous operation
Returns:
LazyBuffer: Output buffer after applying the backward exponential operation
"""
return self.ret.e(BinaryOps.MUL, grad_output)
[docs]
class Sqrt(Function):
"""
Function for calculating the square root of a given LazyBuffer.
Attributes:
ret (LazyBuffer): The result of the forward operation, which is the input value's square root.
"""
[docs]
def forward(self, x: LazyBuffer) -> LazyBuffer:
"""
Calculate the square root of the input LazyBuffer and store it in `ret`.
Args:
x (LazyBuffer): The input value to calculate the square root for.
Returns:
LazyBuffer: The result of the forward operation, which is the input value's square root.
"""
self.ret = x.e(UnaryOps.SQRT)
return self.ret
[docs]
def backward(self, grad_output: LazyBuffer) -> LazyBuffer:
"""
Calculate the gradient of the square root function with respect to its input.
Args:
grad_output (LazyBuffer): The gradient output from previous operations.
Returns:
LazyBuffer: The gradient with respect to the input of the square root function, calculated as `grad_output / (2 * sqrt(x))`.
"""
return grad_output.e(
BinaryOps.DIV, self.ret.e(BinaryOps.MUL, self.ret.const(2))
)
# NOTE: the implicit derivative of sigmoid is not stable
# https://towardsdatascience.com/derivative-of-the-sigmoid-function-536880cf918e
# TODO: have the backend automatically find this
[docs]
class Sigmoid(Function):
"""
Sigmoid Activation Function.
This class implements the sigmoid activation function and its derivative. The forward method computes
the sigmoid of the input tensor `x`. The backward method, on the other hand, computes the gradient of
the loss function with respect to the input of the sigmoid function.
Attributes:
ret (LazyBuffer): The result of the forward pass computation.
"""
[docs]
def forward(self, x: LazyBuffer) -> LazyBuffer:
"""
Forward method for computing sigmoid activation.
This method computes the sigmoid function of input tensor `x`. It uses the following mathematical
formula to compute the result:
sigmoid(x) = 1 / (1 + exp(-x * (-1/log(2))))
Args:
x (LazyBuffer): The input tensor.
Returns:
LazyBuffer: The output tensor after applying the sigmoid function.
"""
self.ret = x.const(1).e(
BinaryOps.DIV,
x.const(1).e(
BinaryOps.ADD,
x.e(BinaryOps.MUL, x.const(-1 / math.log(2))).e(UnaryOps.EXP2),
),
)
return self.ret
[docs]
def backward(self, grad_output: LazyBuffer) -> LazyBuffer:
"""
Backward method for computing the gradient of the loss function.
This method computes the gradient of the loss function with respect to the input of the sigmoid
function. It uses the following mathematical formula to compute the result:
grad_input = sigmoid(x) * (1 - sigmoid(x)) * grad_output
Args:
grad_output (LazyBuffer): The gradient of the loss function with respect to the output.
Returns:
LazyBuffer: The gradient of the loss function with respect to the input.
"""
return self.ret.e(
BinaryOps.MUL, self.ret.const(1).e(BinaryOps.SUB, self.ret)
).e(BinaryOps.MUL, grad_output)
# ************* binary ops *************
[docs]
class Less(Function):
"""
This class represents a less than operation. It takes two arguments, x and y,
which are LazyBuffer objects. The forward method compares these two inputs
using BinaryOps.CMPLT operator and returns the result as a LazyBuffer object.
Attributes:
x (LazyBuffer): The first input buffer.
y (LazyBuffer): The second input buffer.
"""
[docs]
def forward(self, x: LazyBuffer, y: LazyBuffer) -> LazyBuffer:
return x.e(BinaryOps.CMPLT, y)
[docs]
class Add(Function):
"""
This class represents an addition operation. It takes two arguments, x and y,
which are LazyBuffer objects. The forward method adds these two inputs
using BinaryOps.ADD operator and returns the result as a LazyBuffer object.
The backward method calculates gradients with respect to the input tensors.
Attributes:
x (LazyBuffer): The first input buffer.
y (LazyBuffer): The second input buffer.
"""
[docs]
def forward(self, x: LazyBuffer, y: LazyBuffer) -> LazyBuffer:
return x.e(BinaryOps.ADD, y)
[docs]
def backward(
self, grad_output: LazyBuffer
) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
"""
This method calculates the gradient of the output with respect to the input tensors.
Attributes:
grad_output (LazyBuffer): The gradient output buffer.
Returns:
Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]: A tuple containing
the gradients with respect to the first and second input tensors, if they are required.
"""
return (
grad_output if self.needs_input_grad[0] else None,
grad_output if self.needs_input_grad[1] else None,
)
[docs]
class Sub(Function):
"""
This class represents a subtraction operation. It takes two arguments, x and y,
which are LazyBuffer objects. The forward method subtracts the second input from
the first using BinaryOps.SUB operator and returns the result as a LazyBuffer object.
The backward method calculates gradients with respect to the input tensors.
Attributes:
x (LazyBuffer): The first input buffer.
y (LazyBuffer): The second input buffer.
"""
[docs]
def forward(self, x: LazyBuffer, y: LazyBuffer) -> LazyBuffer:
return x.e(BinaryOps.SUB, y)
[docs]
def backward(
self, grad_output: LazyBuffer
) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
"""
This method calculates the gradient of the output with respect to the input tensors.
Attributes:
grad_output (LazyBuffer): The gradient output buffer.
Returns:
Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]: A tuple containing
the gradients with respect to the first and second input tensors, if they are required.
"""
return (
grad_output if self.needs_input_grad[0] else None,
grad_output.e(UnaryOps.NEG) if self.needs_input_grad[1] else None,
)
[docs]
class Mul(Function):
"""
Function class for element-wise multiplication operation.
Attributes:
x (LazyBuffer): First input tensor.
y (LazyBuffer): Second input tensor.
"""
[docs]
def forward(self, x: LazyBuffer, y: LazyBuffer) -> LazyBuffer:
"""
Forward pass of the element-wise multiplication operation.
Args:
x (LazyBuffer): First input tensor.
y (LazyBuffer): Second input tensor.
Returns:
LazyBuffer: Output tensor after applying the multiplication operation.
"""
self.x, self.y = x, y
return x.e(BinaryOps.MUL, y)
[docs]
def backward(
self, grad_output: LazyBuffer
) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
"""
Backward pass of the element-wise multiplication operation.
Args:
grad_output (LazyBuffer): Gradient tensor from previous operation.
Returns:
Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]: Gradients with respect to input tensors x and y.
If the gradient is not needed for a certain input tensor, None is returned instead.
"""
return (
self.y.e(BinaryOps.MUL, grad_output) if self.needs_input_grad[0] else None,
self.x.e(BinaryOps.MUL, grad_output) if self.needs_input_grad[1] else None,
)
[docs]
class Div(Function):
"""
The Div class represents a division function. It has methods for forward and backward propagation.
Attributes:
x (LazyBuffer): The input buffer representing the numerator.
y (LazyBuffer): The input buffer representing the denominator.
"""
[docs]
def forward(self, x: LazyBuffer, y: LazyBuffer) -> LazyBuffer:
"""
Perform forward propagation for the division operation.
Args:
x (LazyBuffer): The input buffer representing the numerator.
y (LazyBuffer): The input buffer representing the denominator.
Returns:
LazyBuffer: The result of the division operation.
"""
self.x, self.y = x, y
return x.e(BinaryOps.DIV, y)
[docs]
def backward(
self, grad_output: LazyBuffer
) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
"""
Perform backward propagation for the division operation.
Args:
grad_output (LazyBuffer): The gradient of the output buffer.
Returns:
Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]: A tuple containing the gradients with respect to the numerator and denominator. If an input does not require a gradient, None is returned for that position.
"""
return (
grad_output.e(BinaryOps.DIV, self.y) if self.needs_input_grad[0] else None,
grad_output.e(UnaryOps.NEG)
.e(BinaryOps.MUL, self.x)
.e(BinaryOps.DIV, self.y.e(BinaryOps.MUL, self.y))
if self.needs_input_grad[1]
else None,
)
# ************* ternary ops *************
[docs]
class Where(Function):
"""
The Where class is used to perform element-wise conditional operations.
Attributes:
x (LazyBuffer): The input buffer from which the operation will be performed.
"""
[docs]
def forward(self, x: LazyBuffer, y: LazyBuffer, z: LazyBuffer) -> LazyBuffer:
"""
Forward method for the Where class.
This method performs the element-wise conditional operation during forward propagation.
Args:
x (LazyBuffer): The input buffer from which the operation will be performed.
y (LazyBuffer): The first operand used in the operation.
z (LazyBuffer): The second operand used in the operation.
Returns:
LazyBuffer: The result of the element-wise conditional operation.
"""
self.x = x
return x.e(TernaryOps.WHERE, y, z)
[docs]
def backward(
self, grad_output: LazyBuffer
) -> Tuple[None, Optional[LazyBuffer], Optional[LazyBuffer]]:
"""
Backward method for the Where class.
This method performs the backward pass of the element-wise conditional operation.
Args:
grad_output (LazyBuffer): The gradient output used in the backward pass calculation.
Returns:
Tuple[None, Optional[LazyBuffer], Optional[LazyBuffer]]: A tuple containing None as the first element and optional LazyBuffer
objects for the second and third operands, depending on whether their gradients are needed or not.
"""
return (
None,
self.x.e(TernaryOps.WHERE, grad_output, grad_output.const(0))
if self.needs_input_grad[1]
else None,
self.x.e(TernaryOps.WHERE, grad_output.const(0), grad_output)
if self.needs_input_grad[2]
else None,
)
# ************* reduce ops *************
[docs]
class Sum(Function):
"""
Class for summing elements of a LazyBuffer and backward operation to expand gradients.
Attributes:
input_shape (Tuple[int, ...]): The shape of the original LazyBuffer used in the forward method.
"""
[docs]
def forward(self, x: LazyBuffer, new_shape: Tuple[int, ...]) -> LazyBuffer:
"""
Forward method to sum elements of a LazyBuffer based on the new shape provided.
Args:
x (LazyBuffer): The input LazyBuffer.
new_shape (Tuple[int, ...]): The desired output shape after the summation operation.
Returns:
LazyBuffer: The result of the summation operation with the specified new shape.
"""
self.input_shape = x.shape
return x.r(ReduceOps.SUM, new_shape)
[docs]
def backward(self, grad_output: LazyBuffer) -> LazyBuffer:
"""
Backward method to expand the gradient (grad_output) according to the input shape of the forward method.
Args:
grad_output (LazyBuffer): The gradient from previous operation.
Returns:
LazyBuffer: The expanded gradient according to the input shape in the forward method.
"""
return grad_output.expand(self.input_shape)
[docs]
class Max(Function):
"""
Class for the Max Function.
Attributes:
x (LazyBuffer): The input tensor.
ret (LazyBuffer): The result of applying the max function to x.
"""
[docs]
def forward(self, x: LazyBuffer, new_shape: Tuple[int, ...]) -> LazyBuffer:
"""
Forward pass for the Max Function.
Args:
x (LazyBuffer): Input tensor.
new_shape (Tuple[int, ...]): The desired shape of the output.
Returns:
LazyBuffer: Result of applying the max function to x.
"""
self.x, self.ret = x, x.r(ReduceOps.MAX, new_shape)
return self.ret
[docs]
def backward(self, grad_output: LazyBuffer) -> LazyBuffer:
"""
Backward pass for the Max Function.
Args:
grad_output (LazyBuffer): The gradient output tensor.
Returns:
LazyBuffer: Gradient input tensor.
"""
# 1s in locations where the max was chosen (can be two locations)
max_is_1s = self.x.const(1.0).e(
BinaryOps.SUB, self.x.e(BinaryOps.CMPLT, self.ret.expand(self.x.shape))
)
div = max_is_1s.r(ReduceOps.SUM, grad_output.shape).expand(self.x.shape)
return max_is_1s.e(BinaryOps.DIV, div).e(
BinaryOps.MUL, grad_output.expand(self.x.shape)
)
# ************* movement ops *************
# NOTE: this is sum in reverse
[docs]
class Expand(Function):
"""
Class for expanding a tensor to a given shape.
Attributes:
input_shape (Tuple[int, ...]): The original shape of the input tensor before expansion.
Methods:
forward: Expands a LazyBuffer to a specified shape.
backward: Reduces the gradient back to the original input shape using a SUM operation.
"""
[docs]
def forward(self, x: LazyBuffer, shape: Tuple[int, ...]) -> LazyBuffer:
"""
Expands a LazyBuffer to a specified shape.
Args:
x (LazyBuffer): The input tensor to be expanded.
shape (Tuple[int, ...]): The target shape for the expansion.
Returns:
LazyBuffer: The expanded tensor.
"""
self.input_shape = x.shape
return x.expand(shape)
[docs]
def backward(self, grad_output: LazyBuffer) -> LazyBuffer:
"""
Reduces the gradient back to the original input shape using a SUM operation.
Args:
grad_output (LazyBuffer): The gradient tensor to be reduced.
Returns:
LazyBuffer: The reduced gradient tensor.
"""
return grad_output.r(ReduceOps.SUM, self.input_shape)
[docs]
class Reshape(Function):
"""
This class represents a function for reshaping tensors.
Attributes:
input_shape (Tuple[int, ...]): The shape of the input tensor before reshaping.
"""
[docs]
def forward(self, x: LazyBuffer, shape: Tuple[int, ...]) -> LazyBuffer:
"""
Reshape the input tensor during the forward pass.
Args:
x (LazyBuffer): The input tensor to be reshaped.
shape (Tuple[int, ...]): The desired shape of the output tensor.
Returns:
LazyBuffer: The reshaped output tensor.
"""
self.input_shape = x.shape
return x.reshape(shape)
[docs]
def backward(self, grad_output: LazyBuffer) -> LazyBuffer:
"""
Reshape the gradient tensor during the backward pass.
Args:
grad_output (LazyBuffer): The gradient tensor to be reshaped.
Returns:
LazyBuffer: The reshaped gradient tensor with the same shape as the input tensor from the forward pass.
"""
return grad_output.reshape(self.input_shape)
[docs]
class Permute(Function):
"""
The Permute class is a subclass of the Function class, used for permutation operations on tensors.
Attributes:
input_order (Tuple[int, ...]): The order of the elements in the original tensor. This attribute is set during the forward pass.
"""
[docs]
def forward(self, x: LazyBuffer, order: Tuple[int, ...]) -> LazyBuffer:
"""
Permute the input tensor according to a specified order.
Args:
x (LazyBuffer): The input tensor to be permuted.
order (Tuple[int, ...]): The order of the elements in the desired output tensor.
Returns:
LazyBuffer: The output tensor after permutation.
"""
self.input_order = order
return x.permute(order)
[docs]
def backward(self, grad_output: LazyBuffer) -> LazyBuffer:
"""
Permute the gradient tensor according to the inverse of the input order.
Args:
grad_output (LazyBuffer): The gradient tensor to be permuted.
Returns:
LazyBuffer: The permuted gradient tensor.
"""
return grad_output.permute(argsort(self.input_order))
[docs]
class Pad(Function):
"""
The Pad class is used to add padding to a given tensor. It supports forward and backward propagation.
Attributes:
narg (tuple): A tuple storing the new arguments for padding.
"""
[docs]
def forward(self, x: LazyBuffer, arg: Tuple[Tuple[int, int], ...]) -> LazyBuffer:
"""
The forward method applies padding to the input tensor.
Args:
x (LazyBuffer): The input tensor.
arg (tuple of tuples): A tuple containing two-element tuples that specify the padding for each dimension.
Returns:
LazyBuffer: The output tensor after padding.
"""
self.narg = tuple([(p[0], s + p[0]) for s, p in zip(x.shape, arg)])
return x.pad(arg)
[docs]
def backward(self, grad_output: LazyBuffer) -> LazyBuffer:
"""
The backward method removes the padding from the gradient tensor.
Args:
grad_output (LazyBuffer): The gradient tensor with padding.
Returns:
LazyBuffer: The gradient tensor after removing padding.
"""
return grad_output.shrink(self.narg)
[docs]
class Shrink(Function):
"""
The Shrink class implements the forward and backward methods for a shrinking function.
Attributes:
narg (tuple): A tuple storing the result of a zip operation between x.shape and arg, where each element is a 2-tuple containing integers.
"""
[docs]
def forward(self, x: LazyBuffer, arg: Tuple[Tuple[sint, sint], ...]) -> LazyBuffer:
"""
The forward method takes a lazy buffer `x` and a tuple of tuples `arg`, where each inner tuple contains two integers.
It computes a new tuple by applying the shrinking operation to the elements of `x.shape`.
This result is stored in the class attribute `narg`. Finally, it returns the result of calling the `shrink` method on `x` with `arg` as an argument.
Args:
x (LazyBuffer): The input lazy buffer.
arg (Tuple[Tuple[sint, sint], ...]): A tuple containing 2-tuples of integers.
Returns:
LazyBuffer: The result of calling the `shrink` method on `x` with `arg` as an argument.
"""
self.narg = tuple([(p[0], s - p[1]) for s, p in zip(x.shape, arg)])
return x.shrink(arg)
[docs]
def backward(self, grad_output: LazyBuffer) -> LazyBuffer:
"""
The backward method takes a lazy buffer `grad_output` and applies the pad operation to it using the class attribute `narg`.
It first checks if all elements in `narg` are integers. If not, it raises an assertion error stating that symbolic shrink does not support backward.
Args:
grad_output (LazyBuffer): The input lazy buffer for the backward operation.
Returns:
LazyBuffer: The result of calling the `pad` method on `grad_output` with `self.narg` as an argument.
Raises:
AssertionError: If not all elements in `narg` are integers.
"""
assert all(
isinstance(x[0], int) and isinstance(x[1], int) for x in self.narg
), "symbolic shrink does not support backward"
# need this cast because mypy cannot narrow the type even with assert
return grad_output.pad(cast(Tuple[Tuple[int, int], ...], self.narg))
[docs]
class Flip(Function):
"""
The Flip class, a subclass of Function. This class defines two methods: forward and backward.
Attributes:
arg (tuple): A tuple containing the stride arguments for flipping the input tensor. It is calculated by iterating over the length of the input tensor's shape and negating 1 if the current index is in the axis parameter, otherwise positively keeping 1.
"""
[docs]
def forward(self, x: LazyBuffer, axis: Tuple[int, ...]) -> LazyBuffer:
"""
Flip the input tensor along specified axes.
Parameters:
x (LazyBuffer): The input tensor to be flipped.
axis (Tuple[int, ...]): A tuple of integers representing the axes along which to flip the tensor.
Returns:
LazyBuffer: The flipped output tensor.
"""
self.arg = tuple([-1 if i in set(axis) else 1 for i in range(len(x.shape))])
return x.stride(self.arg)
[docs]
def backward(self, grad_output: LazyBuffer) -> LazyBuffer:
"""
Flip the gradient tensor along specified axes, mirroring the forward operation.
Parameters:
grad_output (LazyBuffer): The gradient tensor to be flipped.
Returns:
LazyBuffer: The flipped gradient tensor.
"""
return grad_output.stride(self.arg)