import numpy as np
from typing import Callable, Dict, Tuple
from tinygrad.helpers import dtypes, flat_mv
from tinygrad.ops import (
BufferOps,
UnaryOps,
BinaryOps,
MovementOps,
ReduceOps,
TernaryOps,
Op,
)
from tinygrad.device import Interpreted, Allocator
[docs]
def shape_to_axis(
old_shape: Tuple[int, ...], new_shape: Tuple[int, ...]
) -> Tuple[int, ...]:
"""
Compare two shapes and return a tuple containing the indices of axes that differ between the two shapes.
Parameters:
old_shape (Tuple[int, ...]): The original shape to be compared.
new_shape (Tuple[int, ...]): The new shape to compare against the original shape.
Returns:
Tuple[int, ...]: A tuple of indices where the axes differ between the two shapes.
Raises:
AssertionError: If the dimensions of old_shape and new_shape are not equal.
"""
assert len(old_shape) == len(new_shape), "reduce shapes must have same dimensions"
return tuple(i for i, (a, b) in enumerate(zip(old_shape, new_shape)) if a != b)
# TODO: this should be global infrastructure
[docs]
def output_type(x, y):
"""
Determine the datatype with higher priority between two numpy arrays.
Parameters:
x (numpy.ndarray): The first numpy array to compare.
y (numpy.ndarray): The second numpy array to compare.
Returns:
numpy.dtype: The datatype with higher priority between x and y.
"""
return (
x.dtype
if dtypes.from_np(x.dtype).priority > dtypes.from_np(y.dtype).priority
else y.dtype
)
[docs]
def match_types(x, y):
"""
Cast two numpy arrays to a common datatype determined by output_type() function and return the casted arrays.
Parameters:
x (numpy.ndarray): The first numpy array for casting.
y (numpy.ndarray): The second numpy array for casting.
Returns:
Tuple[numpy.ndarray, numpy.ndarray]: A tuple of the casted x and y arrays.
"""
up = output_type(x, y)
return x.astype(up, copy=False), y.astype(up, copy=False)
[docs]
def einsum_mulacc(einsum, get_strides, expand):
"""
This function returns a higher-order function named `mulacc`. The returned function performs multiplication and accumulation using the numpy.einsum function. It takes two arrays as input along with their shapes and strides.
Attributes:
einsum (function): A function that performs numpy einsum operation.
get_strides (function): Function to calculate the strides of an array.
expand (function): Function to perform broadcasting/expansion of arrays.
"""
def einscripts(x):
"""
This function converts a list/array of integers into a string where each integer is mapped to a lowercase English alphabet in order. For example, [0, 1, 2] would be converted to 'abc'.
Args:
x (list or array): A list or array of integers.
Returns:
str: A string created by mapping each integer in `x` to a lowercase English alphabet.
"""
return "".join(["abcdefghijklmnopqrstuvwxyz"[i] for i in x])
def axes_slice(strides):
"""
This function returns two lists: one containing the indices of non-zero strides and another containing slices for those indices. The input is a list/array of strides.
Args:
strides (list or array): A list or array of integers representing strides.
Returns:
list, tuple: Two lists: first one contains the indices of non-zero strides and second one contains slices for those indices.
"""
return [i for i, s in enumerate(strides) if s != 0], tuple(
[slice(None) if s != 0 else 0 for i, s in enumerate(strides)]
)
def mulacc(a, b, new_shape):
"""
This function performs multiplication and accumulation using the numpy.einsum function on two arrays `a` and `b`. It also takes their strides and a new shape as input.
Args:
a (array): The first array.
b (array): The second array.
new_shape (tuple): The desired output shape.
Returns:
array: The result of multiplication and accumulation operation on `a` and `b`.
"""
(a_axes, a_slices), (b_axes, b_slices) = axes_slice(get_strides(a)), axes_slice(
get_strides(b)
)
out = [
i
for i in range(len(new_shape))
if a.shape[i] == new_shape[i] and (i in a_axes or i in b_axes)
]
ret = einsum(
f"{einscripts(a_axes)}, {einscripts(b_axes)} -> {einscripts(out)}",
a[a_slices],
b[b_slices],
)
return expand(
ret.reshape(
[
(1 if i not in a_axes and i not in b_axes else s)
for i, s in enumerate(new_shape)
]
),
new_shape,
)
return mulacc
numpy_fxn_for_op: Dict[Op, Callable] = {
BufferOps.CONST: lambda val, dtype: np.array(val, dtype=dtype.np),
UnaryOps.EXP2: np.exp2,
UnaryOps.LOG2: np.log2,
UnaryOps.SIN: np.sin,
UnaryOps.CAST: lambda x, y: x.view(y[0].np)
if y[1]
else x.astype(y[0].np, copy=False),
UnaryOps.NEG: lambda x: np.logical_not(x)
if x.dtype == np.bool_
else np.negative(x),
BinaryOps.MAX: np.maximum,
BinaryOps.CMPLT: lambda x, y: (x < y).astype(output_type(x, y)),
BinaryOps.ADD: lambda x, y: np.add(*match_types(x, y)),
BinaryOps.SUB: lambda x, y: np.subtract(*match_types(x, y)),
BinaryOps.MUL: lambda x, y: np.multiply(*match_types(x, y)),
BinaryOps.DIV: lambda x, y: np.divide(*match_types(x, y)).astype(
output_type(x, y), copy=False
),
UnaryOps.SQRT: np.sqrt,
ReduceOps.SUM: lambda x, new_shape: x.sum(
shape_to_axis(x.shape, new_shape), dtype=x.dtype, keepdims=True
)
if x.shape != new_shape
else x,
ReduceOps.MAX: lambda x, new_shape: x.max(
shape_to_axis(x.shape, new_shape), keepdims=True
)
if x.shape != new_shape
else x,
MovementOps.AS_STRIDED: lambda x, arg: np.ndarray(
arg[0],
buffer=np.require(x, requirements="C"),
dtype=x.dtype,
offset=arg[2] * x.dtype.itemsize,
strides=tuple(y * x.dtype.itemsize for y in arg[1]),
),
MovementOps.PAD: np.pad,
MovementOps.EXPAND: np.broadcast_to,
TernaryOps.MULACC: einsum_mulacc(
lambda s, a, b: np.einsum(s, *match_types(a.copy(), b.copy()), optimize=True),
lambda x: x.strides,
np.broadcast_to,
),
TernaryOps.WHERE: np.where,
}
[docs]
class NumpyAllocator(Allocator):
"""
Allocator class for numpy arrays.
Attributes:
_alloc (method): Allocates memory for a given size and dtype.
as_buffer (method): Converts an np.ndarray to a memoryview object.
copyin (method): Copies data from a memoryview object to an np.ndarray.
copyout (method): Copies data from an np.ndarray to a memoryview object.
"""
[docs]
def _alloc(self, size: int):
"""
Allocates memory for a given size and dtype.
Parameters:
size (int): Size of the array.
Returns:
np.ndarray: Empty numpy array with specified size and dtype=np.uint8.
"""
return np.empty(size, dtype=np.uint8)
[docs]
def as_buffer(self, src: np.ndarray) -> memoryview:
"""
Converts an np.ndarray to a memoryview object.
Parameters:
src (np.ndarray): The numpy array to be converted.
Returns:
memoryview: A view into the original numpy array's data.
"""
return flat_mv(np.require(src, requirements="C").data)
[docs]
def copyin(self, dest: np.ndarray, src: memoryview):
"""
Copies data from a memoryview object to an np.ndarray.
Parameters:
dest (np.ndarray): The destination numpy array.
src (memoryview): The source memoryview object.
"""
np.copyto(dest, np.frombuffer(src, dest.dtype).reshape(dest.shape))
[docs]
def copyout(self, dest: memoryview, src: np.ndarray):
"""
Copies data from an np.ndarray to a memoryview object.
Parameters:
dest (memoryview): The destination memoryview object.
src (np.ndarray): The source numpy array.
"""
np.copyto(np.frombuffer(dest, src.dtype).reshape(src.shape), src)
CPUDevice = Interpreted(NumpyAllocator(), numpy_fxn_for_op)