from __future__ import annotations
import os, functools, platform, time, re, contextlib, operator, hashlib, pickle, sqlite3, cProfile, pstats, tempfile, pathlib, string, ctypes
import numpy as np
from urllib import request
from tqdm import tqdm
from typing import (
Dict,
Tuple,
Union,
List,
NamedTuple,
Final,
ClassVar,
Optional,
Iterable,
Any,
TypeVar,
TYPE_CHECKING,
Callable,
)
if (
TYPE_CHECKING
): # TODO: remove this and import TypeGuard from typing once minimum python supported version is 3.10
from typing_extensions import TypeGuard
T = TypeVar("T")
U = TypeVar("U")
[docs]
def prod(x: Iterable[T]) -> Union[T, int]:
"""
Calculate the product of all elements in an iterable.
This function takes an iterable (e.g., list, tuple, set) as input and returns
the product of all its elements. If the iterable is empty or contains only one
element, the function will return 1 by default, regardless of the type of x.
:param x: An iterable of elements to calculate the product for.
:type x: Iterable[T]
:return: The product of all elements in the iterable or 1 if the iterable is empty.
:rtype: Union[T, int]
"""
return functools.reduce(operator.__mul__, x, 1)
# NOTE: helpers is not allowed to import from anything else in tinygrad
OSX = platform.system() == "Darwin"
CI = os.getenv("CI", "") != ""
[docs]
def dedup(x: Iterable[T]):
"""
Remove duplicates from an iterable while preserving the order.
Parameters:
x (Iterable[T]): The input iterable.
Returns:
List[T]: A list with the same elements as the input, but without any duplicates. The order of the elements is preserved.
"""
return list(dict.fromkeys(x)) # retains list order
[docs]
def argfix(*x):
"""
Fixes an argument if it's a tuple or list into a non-ambiguous form.
If the first argument is a tuple or list, it will be returned as a single element tuple.
Otherwise, the original arguments are returned.
Parameters:
*x (Any): The input arguments.
Returns:
Union[Tuple, Any]: If the first argument is a tuple or list, returns a single element tuple containing that argument.
Otherwise, return the original arguments.
"""
return tuple(x[0]) if x and x[0].__class__ in (tuple, list) else x
[docs]
def argsort(x):
"""
Returns a new sorted list of indices for the input list.
Parameters:
x (List[Any]): The input list.
Returns:
List[int]: A list containing the indices of the elements in the input list, in ascending order.
"""
return type(x)(
sorted(range(len(x)), key=x.__getitem__)
) # https://stackoverflow.com/questions/3382352/equivalent-of-numpy-argsort-in-basic-python
[docs]
def all_same(items: List[T]):
"""
Checks if all elements in a list are the same.
Parameters:
items (List[T]): The input list.
Returns:
bool: True if all elements in the input list are the same, False otherwise.
"""
return all(x == items[0] for x in items)
[docs]
def all_int(t: Tuple[Any, ...]) -> TypeGuard[Tuple[int, ...]]:
"""
Checks if all elements in a given tuple are integers.
Parameters:
t (Tuple[Any, ...]): The input tuple to check.
Returns:
TypeGuard[Tuple[int, ...]]: True if all elements are integers, False otherwise.
"""
return all(isinstance(s, int) for s in t)
[docs]
def colored(st, color: Optional[str], background=False):
"""
Adds ANSI escape codes to a string to change its color.
Parameters:
st (str): The input string.
color (Optional[str]): The name of the color to use, or None for no color.
background (bool): If True, set the background color instead of the text color.
Returns:
str: The input string with ANSI escape codes added to change its color.
"""
return (
f"\u001b[{10*background+60*(color.upper() == color)+30+['black', 'red', 'green', 'yellow', 'blue', 'magenta', 'cyan', 'white'].index(color.lower())}m{st}\u001b[0m"
if color is not None
else st
) # replace the termcolor library with one line
[docs]
def ansistrip(s: str):
"""
Removes ANSI escape codes from a string.
Parameters:
s (str): The input string.
Returns:
str: The input string without ANSI escape codes.
"""
return re.sub("\x1b\\[(K|.*?m)", "", s)
[docs]
def ansilen(s: str):
"""
Calculates the visible length of a string, ignoring ANSI escape codes.
Parameters:
s (str): The input string.
Returns:
int: The visible length of the string after removing ANSI escape codes.
"""
return len(ansistrip(s))
[docs]
def make_pair(x: Union[int, Tuple[int, ...]], cnt=2) -> Tuple[int, ...]:
"""
Converts a single integer or tuple of integers into a tuple of integers.
If the input is a single integer, it will be repeated to form a tuple of the specified length.
If the input is a tuple, it will be returned unchanged.
Parameters:
x (Union[int, Tuple[int, ...]]): The input value.
cnt (int): The number of times to repeat the integer if the input is an integer.
Returns:
Tuple[int, ...]: A tuple of integers.
"""
return (x,) * cnt if isinstance(x, int) else x
[docs]
def flatten(l: Iterable[Iterable[T]]):
"""
Flattens a list of lists.
:param l: An iterable containing iterables.
:type l: Iterable[Iterable[T]]
:return: A list with the elements of the nested iterables.
:rtype: List[T]
"""
return [item for sublist in l for item in sublist]
[docs]
def fromimport(mod, frm):
"""
Imports a specific attribute or module.
This function imports the specified attribute or module dynamically.
It is particularly useful when you need to import a module or attribute that
may not be available in all environments or versions.
:param mod: The name of the module to import.
:type mod: str
:param frm: The specific attribute or submodule to import from the module.
:type frm: str
:return: The imported attribute or module.
"""
return getattr(__import__(mod, fromlist=[frm]), frm)
[docs]
def strip_parens(fst: str):
"""
Removes parentheses from a string if they match correctly.
This function checks whether the first character is an open parenthesis and
the last character is a closing parenthesis, and that there are no unmatched
parentheses in between. If these conditions are met, it removes the
parentheses; otherwise, it leaves the string as is.
:param fst: The input string.
:type fst: str
:return: The string with parentheses removed if they match correctly, or the original string otherwise.
:rtype: str
"""
return (
fst[1:-1]
if fst[0] == "("
and fst[-1] == ")"
and fst[1:-1].find("(") <= fst[1:-1].find(")")
else fst
)
[docs]
def round_up(num, amt: int):
"""
Rounds a number up to the nearest multiple of a given amount.
:param num: The number to be rounded up.
:type num: int
:param amt: The amount to round up to.
:type amt: int
:return: The smallest integer greater than or equal to the input number that is a multiple of the given amount.
:rtype: int
"""
return (num + amt - 1) // amt * amt
[docs]
def merge_dicts(ds: Iterable[Dict[T, U]]) -> Dict[T, U]:
"""
Merge dictionaries in a list.
This function takes an iterable of dictionaries as input and merges them into one dictionary. It checks for duplicate keys
across the dictionaries and raises an assertion error if any are found.
:param ds: An iterable containing dictionaries with keys of type T and values of type U.
:return: A merged dictionary with keys of type T and values of type U.
"""
assert len(kvs := set([(k, v) for d in ds for k, v in d.items()])) == len(
set(kv[0] for kv in kvs)
), f"cannot merge, {kvs} contains different values for the same key"
return {k: v for d in ds for k, v in d.items()}
[docs]
def partition(lst: List[T], fxn: Callable[[T], bool]):
"""
Partition a list into two lists based on a predicate function.
This function takes a list and a predicate function as input and partitions the list into two lists based on the
output of the predicate function applied to each element in the list. The first list contains elements for which
the predicate is True, while the second list contains elements for which it is False.
:param lst: A list of type T.
:param fxn: A function that takes an element of type T and returns a boolean value.
:return: A tuple containing two lists: one with elements that satisfy the predicate, and another with elements that
do not.
"""
a: List[T] = []
b: List[T] = []
for s in lst:
(a if fxn(s) else b).append(s)
return a, b
[docs]
def unwrap(x: Optional[T]) -> T:
"""
Unwrap an optional value.
This function takes an optional value as input and unwraps it, returning the underlying value. It raises an assertion
error if the input is None.
:param x: An optional value of type T.
:return: The underlying value of type T.
"""
assert x is not None
return x
[docs]
def unwrap2(x: Tuple[T, Any]) -> T:
"""
Unwrap a tuple with an optional error message.
This function takes a tuple containing a value and an optional error message as input and unwraps the value. It raises
an assertion error if there is an error message in the tuple.
:param x: A tuple containing a value of type T and an optional error message.
:return: The underlying value of type T.
"""
ret, err = x
assert err is None, str(err)
return ret
[docs]
def get_child(obj, key):
"""
Recursively retrieve child element from object using dot notation.
Parameters:
obj (Any): Parent object
key (str): Dot-separated string to access child elements
Returns:
Any: Child element of the object specified by the key
"""
for k in key.split("."):
if k.isnumeric():
obj = obj[int(k)]
elif isinstance(obj, dict):
obj = obj[k]
else:
obj = getattr(obj, k)
return obj
[docs]
@functools.lru_cache(maxsize=None)
def to_function_name(s: str):
"""
Convert string to a valid function name by replacing invalid characters with their ASCII code.
Parameters:
s (str): Input string
Returns:
str: String converted into a valid function name
"""
return "".join(
[
c if c in (string.ascii_letters + string.digits + "_") else f"{ord(c):02X}"
for c in ansistrip(s)
]
)
[docs]
@functools.lru_cache(maxsize=None)
def getenv(key: str, default=0):
"""
Retrieve environment variable value by key or return a default value if not found.
Parameters:
key (str): Environment variable name
default (Union[int, float, str]): Default value to return if the environment variable is not found
Returns:
Union[int, float, str]: Value of the environment variable or the default value
"""
return type(default)(os.getenv(key, default))
[docs]
def temp(x: str) -> str:
"""
Get path to a file in the temporary directory.
Parameters:
x (str): File name or relative file path
Returns:
str: Full path to the file in the temporary directory
"""
return (pathlib.Path(tempfile.gettempdir()) / x).as_posix()
[docs]
class Context(contextlib.ContextDecorator):
stack: ClassVar[List[dict[str, int]]] = [{}]
def __init__(self, **kwargs):
"""
Initializes a new instance of the Context class with the specified keyword arguments, which
represent the temporary state of context variables to be set within this context.
Args:
kwargs (dict[str, int]): A dictionary containing key-value pairs representing the names
and values of context variables to temporarily set within this context.
"""
self.kwargs = kwargs
def __enter__(self):
"""
Called when entering the context. This method stores the current state of all context
variables, updates them with the temporary state provided during initialization, and then
allows code to execute within this new context.
Returns:
Context: The current instance of the Context class, allowing it to be used as a context
manager in a with statement.
"""
Context.stack[-1] = {
k: o.value for k, o in ContextVar._cache.items()
} # Store current state.
for k, v in self.kwargs.items():
ContextVar._cache[k].value = v # Update to new temporary state.
Context.stack.append(
self.kwargs
) # Store the temporary state so we know what to undo later.
def __exit__(self, *args):
"""
Called when exiting the context. This method reverts all context variables back to their
original states.
Args:
args (tuple): A tuple containing any arguments passed to the __exit__ method. These are
typically used to indicate if and how the context was exited.
"""
for k in Context.stack.pop():
ContextVar._cache[k].value = Context.stack[-1].get(
k, ContextVar._cache[k].value
)
[docs]
class ContextVar:
"""
A class for managing context variables.
This class provides a way to manage context variables, including caching instances and retrieving values from the environment.
It also includes methods for comparing the value of the instance with another value.
"""
_cache: ClassVar[Dict[str, ContextVar]] = {}
"""
A class variable that stores instances of this class, keyed by their respective keys.
This cache is used to ensure that only one instance of a given context variable exists in the system at any time.
If an attempt is made to create a new instance with a key that already exists in the cache, the existing instance is returned instead.
"""
value: int
"""
The value of this context variable.
This is typically retrieved from the environment using the 'getenv' function when the instance is created.
"""
def __new__(cls, key, default_value):
"""
Create a new ContextVar instance or return an existing one if it exists in the cache.
If an instance with the given key already exists in the cache, that instance is returned instead of creating a new one.
Otherwise, a new instance is created, added to the cache, and its value is set from the environment (or the default_value if not available).
:param key: The key for this context variable. This is typically used as the name of the environment variable to retrieve the value from.
:param default_value: The default value to use for this context variable if no corresponding environment variable is found.
"""
if key in ContextVar._cache:
return ContextVar._cache[key]
instance = ContextVar._cache[key] = super().__new__(cls)
instance.value = getenv(key, default_value)
return instance
def __bool__(self):
"""
Determine whether this context variable has a truthy value.
:return: True if the value of this context variable is non-zero or non-empty, False otherwise.
"""
return bool(self.value)
def __ge__(self, x):
"""
Compare this context variable's value with another value for greater than or equal to.
:param x: The value to compare against.
:return: True if this context variable's value is greater than or equal to the other value, False otherwise.
"""
return self.value >= x
def __gt__(self, x):
"""
Compare this context variable's value with another value for greater than.
:param x: The value to compare against.
:return: True if this context variable's value is greater than the other value, False otherwise.
"""
return self.value > x
def __lt__(self, x):
"""
Compare this context variable's value with another value for less than.
:param x: The value to compare against.
:return: True if this context variable's value is less than the other value, False otherwise.
"""
return self.value < x
DEBUG, IMAGE, BEAM, NOOPT = (
ContextVar("DEBUG", 0),
ContextVar("IMAGE", 0),
ContextVar("BEAM", 0),
ContextVar("NOOPT", 0),
)
GRAPH, GRAPHPATH = getenv("GRAPH", 0), getenv("GRAPHPATH", "/tmp/net")
[docs]
class Timing(contextlib.ContextDecorator):
"""
This is a context manager class for timing the execution time of code blocks. It is a subclass of contextlib.ContextDecorator,
which makes it easy to use as a context manager or decorator.
Attributes:
prefix (str): An optional string that can be used to prefix the output message. Defaults to "".
on_exit (function): An optional function that will be called with the elapsed time in nanoseconds as its argument when
the context manager exits. The return value of this function will be appended to the output message. Defaults to None.
enabled (bool): If set to False, the context manager will not print anything upon exiting the context. Defaults to True.
"""
def __init__(self, prefix="", on_exit=None, enabled=True):
self.prefix, self.on_exit, self.enabled = prefix, on_exit, enabled
def __enter__(self):
"""
This method is called when the context manager's context is entered. It records the current time in nanoseconds for
later use in calculating the elapsed time.
"""
self.st = time.perf_counter_ns()
def __exit__(self, *exc):
"""
This method is called when the context manager's context is exited. It calculates the elapsed time since the
context was entered and prints a message composed of the prefix, the elapsed time in milliseconds, and the return value
of the on_exit function (if it exists). The elapsed time is printed with 2 decimal places. If enabled is False, nothing
will be printed.
"""
self.et = time.perf_counter_ns() - self.st
if self.enabled:
print(
f"{self.prefix}{self.et*1e-6:.2f} ms"
+ (self.on_exit(self.et) if self.on_exit else "")
)
[docs]
class Profiling(contextlib.ContextDecorator):
"""
This class is a context manager for profiling code blocks. It uses the cProfile module to profile the code block
and pstats to display statistics about the execution time of functions. The class can be configured with different
sorting options (sort) and can limit the output to a certain fraction (frac) of the slowest functions.
Attributes:
enabled (bool): If True, the profiler is enabled. Default is True.
sort (str): How to sort the statistic output. Default is 'cumtime'.
frac (float): The fraction of the slowest functions to print. Default is 0.2.
"""
def __init__(self, enabled=True, sort="cumtime", frac=0.2):
"""
Construct a new Profiling object.
Args:
enabled (bool): If True, the profiler is enabled. Default is True.
sort (str): How to sort the statistic output. Default is 'cumtime'.
frac (float): The fraction of the slowest functions to print. Default is 0.2.
"""
self.enabled, self.sort, self.frac = enabled, sort, frac
def __enter__(self):
"""
Enter the context manager. This starts the profiler if it's enabled.
Returns:
Profiling: The current instance of this class.
"""
self.pr = cProfile.Profile(timer=lambda: int(time.time() * 1e9), timeunit=1e-6)
if self.enabled:
self.pr.enable()
def __exit__(self, *exc):
"""
Exit the context manager. This disables the profiler and prints the statistics if it was enabled.
Args:
exc (*Any): Any exception that was raised within the context.
"""
if self.enabled:
self.pr.disable()
pstats.Stats(self.pr).strip_dirs().sort_stats(self.sort).print_stats(
self.frac
)
[docs]
class DType(NamedTuple):
priority: int # this determines when things get upcasted
itemsize: int
name: str
np: Optional[
type
] # TODO: someday this will be removed with the "remove numpy" project
sz: int = 1
def __repr__(self):
"""
Returns a string representation of the DType object. If the size is 1, it returns the name of the data type as
key in dtypes dictionary. Otherwise, it returns the name of the scalar data type with its size appended.
"""
return (
f"dtypes.{INVERSE_DTYPES_DICT[self]}"
if self.sz == 1
else f"dtypes._{INVERSE_DTYPES_DICT[self.scalar()]}{self.sz}"
)
[docs]
def vec(self, sz: int):
"""
Creates a new DType object with vectorized attributes from the current DType object. The size of the new
DType object is given by the parameter sz. It asserts that the current size is 1 and the provided size is greater
than 1, otherwise it raises an error.
Args:
sz (int): The size of the new vectorized data type.
Returns:
DType: A new DType object with vectorized attributes.
"""
assert sz > 1 and self.sz == 1, f"can't vectorize {self} with size {sz}"
return DType(self.priority, self.itemsize * sz, self.name + str(sz), None, sz)
[docs]
def scalar(self):
"""
Returns the scalar version of the current DType object if its size is greater than 1; otherwise, it returns
the current DType object. It does this by looking up the name of the data type (without the size appended) in
the dtypes dictionary and returning that DType object.
Returns:
DType: The scalar version of the current DType object.
"""
return DTYPES_DICT[self.name[: -len(str(self.sz))]] if self.sz > 1 else self
# dependent typing?
[docs]
class ImageDType(DType):
"""
This is the ImageDType class, a subclass of DType. It is used to create custom data types for images with additional attributes such as shape.
Attributes:
priority (int): The priority level of the data type.
itemsize (int): The size of each element in bytes.
name (str): The name of the data type.
np (numpy.dtype): The numpy dtype representation.
shape (tuple): A tuple representing the shape of the image.
"""
def __new__(cls, priority, itemsize, name, np, shape):
"""
This is a special method used to create new instances of this class. It calls the parent's __new__ method and returns an instance of ImageDType.
Args:
cls (ImageDType): The ImageDType class.
priority (int): The priority level of the data type.
itemsize (int): The size of each element in bytes.
name (str): The name of the data type.
np (numpy.dtype): The numpy dtype representation.
Returns:
ImageDType: A new instance of this class.
"""
return super().__new__(cls, priority, itemsize, name, np)
def __init__(self, priority, itemsize, name, np, shape):
"""
This is the initialization method that sets up a new instance of this class. It sets the shape attribute and then calls the parent's __init__ method.
Args:
self (ImageDType): The instance of this class.
priority (int): The priority level of the data type.
itemsize (int): The size of each element in bytes.
name (str): The name of the data type.
np (numpy.dtype): The numpy dtype representation.
shape (tuple): A tuple representing the shape of the image.
"""
self.shape: Tuple[
int, ...
] = shape # arbitrary arg for the dtype, used in image for the shape
super().__init__()
def __repr__(self):
"""
This method returns a string representation of this class instance.
Args:
self (ImageDType): The instance of this class.
Returns:
str: A string representation of this class instance.
"""
return f"dtypes.{self.name}({self.shape})"
# TODO: fix this to not need these
def __hash__(self):
"""
This method returns a hash value for this class instance based on its attributes.
Args:
self (ImageDType): The instance of this class.
Returns:
int: A hash value for this class instance.
"""
return hash((super().__hash__(), self.shape))
def __eq__(self, x):
"""
This method checks if two instances of this class are equal based on their attributes.
Args:
self (ImageDType): The instance of this class.
x (ImageDType): Another instance of this class.
Returns:
bool: True if the instances are equal, False otherwise.
"""
return super().__eq__(x) and self.shape == x.shape
def __ne__(self, x):
"""
This method checks if two instances of this class are not equal based on their attributes.
Args:
self (ImageDType): The instance of this class.
x (ImageDType): Another instance of this class.
Returns:
bool: True if the instances are not equal, False otherwise.
"""
return super().__ne__(x) or self.shape != x.shape
[docs]
class PtrDType(DType):
"""
This class is a subclass of DType, which represents the pointer data type.
The constructor of this class takes one argument: dt (an instance of DType). It calls the __new__ method of the parent class (DType) with the same parameters passed to it.
The __repr__ method returns a string representation of the PtrDType object, which includes "ptr." as a prefix to the string representation of its superclass.
"""
def __new__(cls, dt: DType):
"""
This function is used to create a new instance of the class.
Parameters:
cls (PtrDType): The class of which an instance will be created.
dt (DType): An instance of the DType class, which will be passed to the constructor of the parent class.
Returns:
A new instance of the PtrDType class.
"""
return super().__new__(cls, dt.priority, dt.itemsize, dt.name, dt.np, dt.sz)
def __repr__(self):
"""
This method returns a string representation of the object.
Returns:
A string that includes "ptr." as a prefix and the string representation of its superclass.
"""
return f"ptr.{super().__repr__()}"
[docs]
class dtypes:
"""Class representing data types for various operations.
This class contains static methods to check if a given data type is an integer, float, unsigned, etc.,
and also provides methods to get the data types from numpy and get all available data types in dictionary format.
The class also defines several Final variables representing different data types like bool, int8, int16, etc.
"""
[docs]
@staticmethod # static methds on top, or bool in the type info will refer to dtypes.bool
def is_int(x: DType) -> bool:
"""Check if a given data type is an integer.
This method checks whether the provided data type is one of the integer types like int8, int16, etc.
Args:
x (DType): The data type to check.
Returns:
bool: True if the given data type is an integer, False otherwise.
"""
return x in (
dtypes.int8,
dtypes.int16,
dtypes.int32,
dtypes.int64,
dtypes.uint8,
dtypes.uint16,
dtypes.uint32,
dtypes.uint64,
)
[docs]
@staticmethod
def is_float(x: DType) -> bool:
"""Check if a given data type is a floating point number.
This method checks whether the provided data type is one of the float types like float16, float32, etc.
Args:
x (DType): The data type to check.
Returns:
bool: True if the given data type is a floating point number, False otherwise.
"""
return x in (
dtypes.float16,
dtypes.float32,
dtypes.float64,
dtypes.half.vec(4),
dtypes.float.vec(2),
dtypes.float.vec(4),
)
[docs]
@staticmethod
def is_unsigned(x: DType) -> bool:
"""Check if a given data type is an unsigned integer.
This method checks whether the provided data type is one of the unsigned integer types like uint8, uint16, etc.
Args:
x (DType): The data type to check.
Returns:
bool: True if the given data type is an unsigned integer, False otherwise.
"""
return x in (dtypes.uint8, dtypes.uint16, dtypes.uint32, dtypes.uint64)
[docs]
@staticmethod
def from_np(x) -> DType:
"""Get the data type object corresponding to a numpy data type object.
Args:
x: The numpy data type object.
Returns:
DType: The data type object corresponding to the given numpy data type object.
"""
return DTYPES_DICT[np.dtype(x).name]
[docs]
@staticmethod
def fields() -> Dict[str, DType]:
"""Get a dictionary containing all available data types.
Returns:
Dict[str, DType]: A dictionary where the keys are the names of the data types and the values are their corresponding data type objects.
"""
return DTYPES_DICT
"""
Final data types for different numerical representations.
This module contains constant data types with specific properties and attributes, such as:
- bool: Represents a boolean data type.
- float16, half: Represents a 16-bit floating point number.
- float32, float: Represents a 32-bit floating point number.
- float64, double: Represents a 64-bit floating point number.
- int8: Represents an 8-bit signed integer.
- int16: Represents a 16-bit signed integer.
- int32, int: Represents a 32-bit signed integer.
- int64: Represents a 64-bit signed integer.
- uint8: Represents an 8-bit unsigned integer.
- uint16: Represents a 16-bit unsigned integer.
- uint32: Represents a 32-bit unsigned integer.
- uint64: Represents a 64-bit unsigned integer.
Note: bfloat16 is not supported in numpy and image data types are not included here.
"""
bool: Final[DType] = DType(0, 1, "bool", np.bool_)
float16: Final[DType] = DType(9, 2, "half", np.float16)
half = float16
float32: Final[DType] = DType(10, 4, "float", np.float32)
float = float32
float64: Final[DType] = DType(11, 8, "double", np.float64)
double = float64
int8: Final[DType] = DType(1, 1, "char", np.int8)
int16: Final[DType] = DType(3, 2, "short", np.int16)
int32: Final[DType] = DType(5, 4, "int", np.int32)
int = int32
int64: Final[DType] = DType(7, 8, "long", np.int64)
uint8: Final[DType] = DType(2, 1, "unsigned char", np.uint8)
uint16: Final[DType] = DType(4, 2, "unsigned short", np.uint16)
uint32: Final[DType] = DType(6, 4, "unsigned int", np.uint32)
uint64: Final[DType] = DType(8, 8, "unsigned long", np.uint64)
# NOTE: bfloat16 isn't supported in numpy
bfloat16: Final[DType] = DType(9, 2, "__bf16", None)
# NOTE: these are internal dtypes, should probably check for that
_arg_int32: Final[DType] = DType(2, 4, "_arg_int32", None)
# NOTE: these are image dtypes
[docs]
@staticmethod
def imageh(shp):
"""
Creates an ImageDType instance with float16 data type and given shape.
Args:
shp (Tuple[int, int]): Shape of the image.
Returns:
ImageDType: An instance of ImageDType with specified properties.
"""
return ImageDType(100, 2, "imageh", np.float16, shp)
[docs]
@staticmethod
def imagef(shp):
"""
Creates an ImageDType instance with float32 data type and given shape.
Args:
shp (Tuple[int, int]): Shape of the image.
Returns:
ImageDType: An instance of ImageDType with specified properties.
"""
return ImageDType(100, 4, "imagef", np.float32, shp)
# HACK: staticmethods are not callable in 3.8 so we have to compare the class
DTYPES_DICT = {
k: v
for k, v in dtypes.__dict__.items()
if not k.startswith("__") and not callable(v) and v.__class__ is not staticmethod
}
INVERSE_DTYPES_DICT = {v: k for k, v in DTYPES_DICT.items()}
[docs]
class GlobalCounters:
"""
Class to store and manage global counters for various operations, memory usage, time etc.
The class variables include:
1) global_ops: Class Variable to store the count of global operations.
2) global_mem: Class Variable to store the total memory used globally.
3) time_sum_s: Class Variable to store the cumulative time in seconds.
4) kernel_count: Class Variable to store the number of kernels executed.
5) mem_used: Class Variable to store the amount of memory used, not reset.
6) mem_cached: Class Variable to store the amount of cached memory, not reset.
"""
global_ops: ClassVar[int] = 0
"""
Class Variable to store the count of global operations.
Default value is 0.
"""
global_mem: ClassVar[int] = 0
"""
Class Variable to store the total memory used globally.
Default value is 0.
"""
time_sum_s: ClassVar[float] = 0.0
"""
Class Variable to store the cumulative time in seconds.
Default value is 0.0.
"""
kernel_count: ClassVar[int] = 0
"""
Class Variable to store the number of kernels executed.
Default value is 0.
"""
mem_used: ClassVar[int] = 0
"""
Class Variable to store the amount of memory used, not reset.
Default value is 0.
NOTE: This variable does not get reset.
"""
mem_cached: ClassVar[int] = 0
"""
Class Variable to store the amount of cached memory, not reset.
Default value is 0.
NOTE: This variable does not get reset.
"""
[docs]
@staticmethod
def reset():
"""
This is a static method that resets the global counters.
The global counters include:
- GlobalCounters.global_ops: Keeps track of total operations performed
- GlobalCounters.global_mem: Keeps track of total memory used
- GlobalCounters.time_sum_s: Keeps track of cumulative time in seconds
- GlobalCounters.kernel_count: Keeps track of the number of kernels run
This method resets these counters back to their initial states, which are all 0.
"""
(
GlobalCounters.global_ops,
GlobalCounters.global_mem,
GlobalCounters.time_sum_s,
GlobalCounters.kernel_count,
) = (0, 0, 0.0, 0)
"""
Universal Database Cache Module.
This module manages caching of database information for tinygrad. It provides a simple interface to get the cache directory, database file path and cache level.
It also handles the connection to the database.
"""
_cache_dir: str = getenv(
"XDG_CACHE_HOME", os.path.expanduser("~/Library/Caches" if OSX else "~/.cache")
)
"""
The cache directory is determined by the environment variable 'XDG_CACHE_HOME'. If it doesn't exist, then it falls back to '~/Library/Caches' on OSX or '~/.cache' on other systems.
"""
CACHEDB: str = getenv(
"CACHEDB", os.path.abspath(os.path.join(_cache_dir, "tinygrad", "cache.db"))
)
"""
The database file path is determined by the environment variable 'CACHEDB'. If it doesn't exist, then it falls back to a default location within the cache directory.
"""
CACHELEVEL = getenv("CACHELEVEL", 2)
"""
The cache level is determined by the environment variable 'CACHELEVEL'. If it doesn't exist, then it falls back to a default value of 2.
"""
VERSION = 10
_db_connection = None
[docs]
def db_connection():
"""
Establishes a database connection.
This function checks if there is an active database connection. If not, it creates one and sets up necessary configurations.
:return: Database connection object.
"""
global _db_connection
if _db_connection is None:
os.makedirs(CACHEDB.rsplit(os.sep, 1)[0], exist_ok=True)
_db_connection = sqlite3.connect(CACHEDB)
if DEBUG >= 7:
_db_connection.set_trace_callback(print)
return _db_connection
[docs]
def diskcache_get(table: str, key: Union[Dict, str, int]) -> Any:
"""
Retrieve a value from the disk cache.
Parameters:
table (str): The name of the table to retrieve data from.
key (Union[Dict, str, int]): The key used to identify the stored value. If provided as an integer or string,
it will be converted into a dictionary with 'key' as the key.
Returns:
Any: The retrieved value from cache, if found. Otherwise, None.
Note:
This function assumes that a valid connection to the database is established in the 'db_connection()' method
and that the 'CACHELEVEL' variable is defined elsewhere in the code.
"""
if CACHELEVEL == 0:
return None
if isinstance(key, (str, int)):
key = {"key": key}
conn = db_connection()
cur = conn.cursor()
try:
res = cur.execute(
f"SELECT val FROM {table}_{VERSION} WHERE {' AND '.join([f'{x}=?' for x in key.keys()])}",
tuple(key.values()),
)
except sqlite3.OperationalError:
return None # table doesn't exist
if (val := res.fetchone()) is not None:
return pickle.loads(val[0])
return None
_db_tables = set()
[docs]
def diskcache_put(table: str, key: Union[Dict, str, int], val: Any):
"""
This function is used to put a value into the disk cache. It first checks if the cache level is set to 0, in which case it returns the value without caching.
If the cache level is not 0, it proceeds with caching the value. It establishes a database connection and creates a cursor object for executing SQL commands.
Next, it checks if the given table exists in the database. If it does not exist, it creates the table using the keys of the dictionary passed as the 'key' argument.
The function then executes an SQL command to replace or insert the key-value pair into the table. Finally, it commits the changes to the database, closes the cursor and returns the value.
:param table: The name of the table where the value will be stored.
:type table: str
:param key: The key used for indexing the value in the cache. It can be a dictionary, string or integer.
:type key: Union[Dict, str, int]
:param val: The value to be stored in the cache. Can be of any type.
:type val: Any
:return: Returns the value that was passed as an argument.
"""
if CACHELEVEL == 0:
return val
if isinstance(key, (str, int)):
key = {"key": key}
conn = db_connection()
cur = conn.cursor()
if table not in _db_tables:
TYPES = {
str: "text",
bool: "integer",
int: "integer",
float: "numeric",
bytes: "blob",
}
ltypes = ", ".join(f"{k} {TYPES[type(key[k])]}" for k in key.keys())
cur.execute(
f"CREATE TABLE IF NOT EXISTS {table}_{VERSION} ({ltypes}, val blob, PRIMARY KEY ({', '.join(key.keys())}))"
)
_db_tables.add(table)
cur.execute(
f"REPLACE INTO {table}_{VERSION} ({', '.join(key.keys())}, val) VALUES ({', '.join(['?']*len(key.keys()))}, ?)",
tuple(key.values()) + (pickle.dumps(val),),
)
conn.commit()
cur.close()
return val
[docs]
def diskcache(func):
"""
Retrieve a value from the disk cache.
Parameters:
table (str): The name of the table to retrieve data from.
key (Union[Dict, str, int]): The key used to identify the stored value. If provided as an integer or string,
it will be converted into a dictionary with 'key' as the key.
Returns:
Any: The retrieved value from cache, if found. Otherwise, None.
Note:
This function assumes that a valid connection to the database is established in the 'db_connection()' method
and that the 'CACHELEVEL' variable is defined elsewhere in the code.
"""
def wrapper(*args, **kwargs) -> bytes:
"""
This is the wrapper function that does the actual work of checking if a cached result exists and returning it if so.
If not, it calls the original function and stores its result in the cache before returning it.
:param args: The positional arguments for the function call.
:type args: tuple
:param kwargs: The keyword arguments for the function call.
:type kwargs: dict
:return: The result of the function call, either from the cache or by calling the original function.
:rtype: bytes
"""
table, key = (
f"cache_{func.__name__}",
hashlib.sha256(pickle.dumps((args, kwargs))).hexdigest(),
)
if ret := diskcache_get(table, key):
return ret
return diskcache_put(table, key, func(*args, **kwargs))
setattr(wrapper, "__wrapped__", func)
return wrapper
"""
This module provides HTTP support.
The functions and classes within this module facilitate the implementation of HTTP related functionality.
"""
[docs]
def fetch(
url: str,
name: Optional[Union[pathlib.Path, str]] = None,
allow_caching=not getenv("DISABLE_HTTP_CACHE"),
) -> pathlib.Path:
"""
Fetch a file from a given URL and store it locally. Optionally provide a name for the local file.
If caching is enabled (which is the default behavior unless the "DISABLE_HTTP_CACHE" environment variable is set),
this function will first check if the file has already been downloaded before attempting to download it again.
The fetched file is stored in a directory called "tinygrad/downloads" within the cache directory, unless an alternative name or path is provided.
If the size of the fetched file does not match the expected content length from the HTTP response, a RuntimeError will be raised.
:param url: The URL of the file to download.
:type url: str
:param name: An optional name or path for the local file. If not provided, a hash of the URL is used as the filename.
:type name: Optional[Union[pathlib.Path, str]]
:param allow_caching: Whether to cache downloaded files locally and check for existing files before downloading (defaults to True unless "DISABLE_HTTP_CACHE" env var is set).
:type allow_caching: bool
:return: The path of the fetched file.
:rtype: pathlib.Path
"""
if url.startswith("/") or url.startswith("."):
return pathlib.Path(url)
fp = (
pathlib.Path(name)
if name is not None and (isinstance(name, pathlib.Path) or "/" in name)
else pathlib.Path(_cache_dir)
/ "tinygrad"
/ "downloads"
/ (name if name else hashlib.md5(url.encode("utf-8")).hexdigest())
)
if not fp.is_file() or not allow_caching:
with request.urlopen(url, timeout=10) as r:
assert r.status == 200
total_length = int(r.headers.get("content-length", 0))
progress_bar = tqdm(total=total_length, unit="B", unit_scale=True, desc=url)
(path := fp.parent).mkdir(parents=True, exist_ok=True)
with tempfile.NamedTemporaryFile(dir=path, delete=False) as f:
while chunk := r.read(16384):
progress_bar.update(f.write(chunk))
f.close()
if (file_size := os.stat(f.name).st_size) < total_length:
raise RuntimeError(
f"fetch size incomplete, {file_size} < {total_length}"
)
pathlib.Path(f.name).rename(fp)
return fp
"""
Execution helpers module.
This module provides helper functions for executing tasks.
"""
[docs]
def cpu_time_execution(cb, enable):
"""
Executes a given function and measures the CPU time taken for its execution.
:param cb: The callback function to be executed. It takes no arguments and returns nothing.
:type cb: Callable[[], None]
:param enable: If True, measure the CPU time of the function execution. Otherwise, just execute the function without measuring.
:type enable: bool
:return: The measured CPU time of the function execution if 'enable' is True, None otherwise.
:rtype: float | None
"""
if enable:
st = time.perf_counter()
cb()
if enable:
return time.perf_counter() - st
"""
ctypes helpers module.
This module provides helper functions for working with ctypes. It is designed to make it easier
to work with the ctypes library in Python, by providing useful and well-documented functions.
"""
[docs]
def from_mv(mv, to_type=ctypes.c_char):
"""
This function takes a memoryview object 'mv' and optionally a ctypes data type 'to_type'.
It returns the result of casting the address of 'to_type' which is created from buffer 'mv' to a pointer of 'to_type'.
If no 'to_type' is provided, it defaults to ctypes.c_char.
:param mv: A memoryview object that needs to be converted.
:type mv: memoryview
:param to_type: The ctypes data type to which the memoryview will be casted. Defaults to ctypes.c_char.
:type to_type: ctypes data type, optional
"""
return ctypes.cast(
ctypes.addressof(to_type.from_buffer(mv)), ctypes.POINTER(to_type)
)
[docs]
def to_char_p_p(options: List[bytes], to_type=ctypes.c_char):
"""
This function takes a list of bytes 'options' and optionally a ctypes data type 'to_type'.
It returns the result of creating an array of pointers each pointing to a string buffer created from 'options',
which are then casted to pointer of 'to_type'. If no 'to_type' is provided, it defaults to ctypes.c_char.
:param options: A list of bytes that need to be converted.
:type options: List[bytes]
:param to_type: The ctypes data type to which the byte strings will be casted. Defaults to ctypes.c_char.
:type to_type: ctypes data type, optional
"""
return (ctypes.POINTER(to_type) * len(options))(
*[
ctypes.cast(ctypes.create_string_buffer(o), ctypes.POINTER(to_type))
for o in options
]
)
return (ctypes.POINTER(to_type) * len(options))(
*[
ctypes.cast(ctypes.create_string_buffer(o), ctypes.POINTER(to_type))
for o in options
]
)
[docs]
@functools.lru_cache(maxsize=None)
def init_c_struct_t(fields: Tuple[Tuple[str, ctypes._SimpleCData], ...]):
"""
Create a new class that inherits from ctypes.Structure and set its _pack_ and _fields_ attributes.
:param fields: A tuple of tuples containing the field name as string and field type which
should inherit from ctypes._SimpleCData.
:return: The created class.
"""
class CStruct(ctypes.Structure):
_pack_, _fields_ = 1, fields
return CStruct
[docs]
def init_c_var(ctypes_var, creat_cb):
"""
Initialize a ctypes variable using the provided creation callback.
:param ctypes_var: The ctypes variable to initialize.
:param creat_cb: The creation callback function that takes the ctypes variable as argument
and returns an initialized version of it.
:return: The initialized ctypes variable.
"""
return (creat_cb(ctypes_var), ctypes_var)[1]
[docs]
def get_bytes(arg, get_sz, get_str, check) -> bytes:
"""
Get the byte representation of a given argument using the provided functions.
:param arg: The input argument.
:param get_sz: A function that takes the input argument and a ctypes.c_size_t reference to fill
with the size information.
:param get_str: A function that takes the input argument and a string buffer to fill with the
actual data.
:param check: A function that checks whether an operation is successful or not. It takes the result
of the previous two functions as argument.
:return: The byte representation of the given argument.
"""
return (
sz := init_c_var(
ctypes.c_size_t(), lambda x: check(get_sz(arg, ctypes.byref(x)))
),
ctypes.string_at(
init_c_var(
ctypes.create_string_buffer(sz.value), lambda x: check(get_str(arg, x))
),
size=sz.value,
),
)[1]
[docs]
def flat_mv(mv: memoryview):
"""
Flattens a memory view.
:param mv: The memory view to be flattened. It is expected to be of type `memoryview`.
:return: A flattened version of the inputted memory view. The return type will be `memoryview` with element type "B" (unsigned char)
and shape equal to the number of bytes in the original memory view. If the inputted memory view is empty, the function
simply returns the memory view as it is already flat.
"""
if len(mv) == 0:
return mv
return mv.cast("B", shape=(mv.nbytes,))
"""
Helpers for CUDA-like APIs.
This module provides helper functions that resemble the CUDA API, offering similar functionalities.
It is intended to serve as a foundation for high-level GPU operations and data transfers.
"""
[docs]
def pretty_ptx(s):
"""
This function takes a string as input and modifies it based on various regular expressions.
The purpose of these modifications is to highlight different parts of the string in different colors for easier readability.
Identifiers are matched and replaced with `color(<expr>)`.
Types are matched and replaced with `color(<type>)`.
Instructions are matched and replaced with `color(<instruction>)`.
Numbers are matched and replaced with `color(<number>)`.
Space, derivatives are also highlighted.
:param s: The input string to be processed.
:type s: str
:return: The modified string with certain parts highlighted with different colors.
:rtype: str
"""
# all expressions match `<valid_before><expr><valid_after>` and replace it with `<valid_before>color(<expr>)<valid_after>`
s = re.sub(
r"([!@<\[\s,\+\-;\n])((?:[_%$][\w%\$_]+(?:\.[xyz])?\:?)|(?:buf\d+))([<>\]\s,\+\-;\n\)])",
lambda m: m[1] + colored(m[2], "blue") + m[3],
s,
flags=re.M,
) # identifiers
s = re.sub(
r"(.)((?:b|s|u|f)(?:8|16|32|64)|pred)([\.\s])",
lambda m: m[1] + colored(m[2], "green") + m[3],
s,
flags=re.M,
) # types
s = re.sub(
r"^(\s*)([\w]+)(.*?;$)",
lambda m: m[1] + colored(m[2], "yellow") + m[3],
s,
flags=re.M,
) # instructions
s = re.sub(
r"([<>\[\]\s,\+\-;])((?:0[fF][0-9a-fA-F]{8})|(?:[0-9]+)|(?:0[xX][0-9a-fA-F]+))([<>\[\]\s,\+\-;])",
lambda m: m[1] + colored(m[2], "yellow") + m[3],
s,
flags=re.M,
) # numbers
s = re.sub(
r"(\.)(param|reg|global)",
lambda m: m[1] + colored(m[2], "magenta"),
s,
flags=re.M,
) # space
s = re.sub(
r"(\.)(version|target|address_size|visible|entry)",
lambda m: m[1] + colored(m[2], "magenta"),
s,
flags=re.M,
) # derivatives
return s
[docs]
def compile_cuda_style(
prg,
compile_options,
prog_t,
create_prog,
compile_prog,
get_code,
get_code_size,
get_log,
get_log_size,
check,
) -> bytes:
"""
Compiles CUDA style code.
This function compiles the given program with the specified compile options
and returns the compiled code as a byte object. If the compilation fails,
it raises an error with the log information from the compiler.
Parameters:
prg (str): The program to be compiled.
compile_options (List[str]): A list of options for the compiler.
prog_t (ctypes type): The ctypes type of the program.
create_prog (Callable): Function to create a program.
compile_prog (Callable): Function to compile a program.
get_code (Callable): Function to get the compiled code.
get_code_size (Callable): Function to get the size of the compiled code.
get_log (Callable): Function to get the log information from the compiler.
get_log_size (Callable): Function to get the size of the log information.
check (Callable): Function to check for errors.
Returns:
bytes: The compiled code as a byte object.
"""
check(
create_prog(
ctypes.byref(prog := prog_t()),
prg.encode(),
"<null>".encode(),
0,
None,
None,
)
)
status = compile_prog(
prog, len(compile_options), to_char_p_p([o.encode() for o in compile_options])
)
if status != 0:
raise RuntimeError(
f"compile failed: {get_bytes(prog, get_log_size, get_log, check).decode()}"
)
return get_bytes(prog, get_code_size, get_code, check)
[docs]
def encode_args_cuda_style(
bufs, vals, device_ptr_t, marks
) -> Tuple[ctypes.Array, ctypes.Structure]:
"""
This function encodes arguments using CUDA style. It initializes a C structure with given buffer and value inputs.
:param bufs: A list of buffers for encoding.
:type bufs: List[ctypes]
:param vals: A list of values for encoding.
:type vals: List[int]
:param device_ptr_t: Device pointer type.
:type device_ptr_t: ctypes
:param marks: A list of marks for encoding.
:type marks: List[ctypes]
:return: A tuple containing a C-style void pointer array and a C structure initialized with the given buffers and values.
:rtype: Tuple[ctypes.Array, ctypes.Structure]
"""
c_args = init_c_struct_t(
tuple(
[(f"f{i}", device_ptr_t) for i in range(len(bufs))]
+ [(f"f{i}", ctypes.c_int) for i in range(len(bufs), len(bufs) + len(vals))]
)
)(*bufs, *vals)
return (ctypes.c_void_p * 5)(
ctypes.c_void_p(marks[0]),
ctypes.cast(ctypes.pointer(c_args), ctypes.c_void_p),
ctypes.c_void_p(marks[1]),
ctypes.cast(
ctypes.pointer(ctypes.c_size_t(ctypes.sizeof(c_args))), ctypes.c_void_p
),
ctypes.c_void_p(marks[2]),
), c_args
[docs]
def time_execution_cuda_style(
cb, ev_t, evcreate, evrecord, evsync, evdestroy, evtime, enable=False
) -> Optional[float]:
"""
This function measures the execution time of a given callback function `cb` in a CUDA-style manner.
:param cb: The callback function to be measured. It should be a callable object (function).
:type cb: Callable[[], None]
:param ev_t: Event type. This is typically a ctypes data structure or class.
:type ev_t: Type[ctypes._CData]
:param evcreate: Function to create an event. It should take two arguments: the event object and flags.
:type evcreate: Callable[[ctypes.pointer(ev_t), int], None]
:param evrecord: Function to record an event. It should take two arguments: the event object and stream.
:type evrecord: Callable[[ctypes.pointer(ev_t), Optional[Stream]], None]
:param evsync: Function to synchronize a stream. It should take one argument: the stream object.
:type evsync: Callable[[Optional[Stream]], None]
:param evdestroy: Function to destroy an event. It should take one argument: the event object.
:type evdestroy: Callable[[ctypes.pointer(ev_t)], None]
:param evtime: Function to calculate the elapsed time between two events.
It should take three arguments: a pointer to the result, the start event, and the end event.
:type evtime: Callable[[ctypes.pointer(ctypes.c_float), ctypes.pointer(ev_t), ctypes.pointer(ev_t)], None]
:param enable: If `True`, measure the execution time; if `False`, skip measuring and just run the callback function. Default is `False`.
:type enable: bool
:return: The elapsed time in milliseconds if timing was enabled, otherwise `None`.
:rtype: Optional[float]
"""
if not enable:
return cb()
evs = [init_c_var(ev_t(), lambda x: evcreate(ctypes.byref(x), 0)) for _ in range(2)]
evrecord(evs[0], None)
cb()
evrecord(evs[1], None)
evsync(evs[1])
evtime(ctypes.byref(ret := ctypes.c_float()), evs[0], evs[1])
for ev in evs:
evdestroy(ev)
return ret.value * 1e-3