Source code for tinygrad.nn.state

import os, json, pathlib, zipfile, pickle, tarfile, struct
from tqdm import tqdm
from typing import Dict, Union, List, Optional, Any, Tuple
from tinygrad.tensor import Tensor
from tinygrad.helpers import (
    dtypes,
    prod,
    argsort,
    DEBUG,
    Timing,
    GlobalCounters,
    CI,
    unwrap,
)
from tinygrad.shape.view import strides_for_shape
from tinygrad import Device

safe_dtypes = {
    "F16": dtypes.float16,
    "F32": dtypes.float32,
    "U8": dtypes.uint8,
    "I8": dtypes.int8,
    "I32": dtypes.int32,
    "I64": dtypes.int64,
}
inverse_safe_dtypes = {v: k for k, v in safe_dtypes.items()}


[docs] def safe_load_metadata(fn: Union[Tensor, str]) -> Tuple[Tensor, int, Any]: """ Safely load metadata from a file or tensor. Parameters: fn (Union[Tensor, str]): File name or tensor to load metadata from. Returns: Tuple[Tensor, int, Any]: A tuple containing the tensor, json length, and metadata. """ t = ( fn if isinstance(fn, Tensor) else Tensor.empty(os.stat(fn).st_size, dtype=dtypes.uint8, device=f"disk:{fn}") ) json_len = t[0:1].cast(dtypes.int64).numpy()[0] return (t, json_len, json.loads(t[8 : 8 + json_len].numpy().tobytes()))
[docs] def safe_load(fn: Union[Tensor, str]) -> Dict[str, Tensor]: """ Safely load tensors from a file or tensor. Parameters: fn (Union[Tensor, str]): File name or tensor to load data from. Returns: Dict[str, Tensor]: A dictionary containing the loaded tensors. """ t, json_len, metadata = safe_load_metadata(fn) return { k: t[8 + json_len + v["data_offsets"][0] :] .cast(safe_dtypes[v["dtype"]])[: prod(v["shape"])] .reshape(v["shape"]) for k, v in metadata.items() if k != "__metadata__" }
[docs] def safe_save( tensors: Dict[str, Tensor], fn: str, metadata: Optional[Dict[str, Any]] = None ): """ Safely save tensors to a file. Parameters: tensors (Dict[str, Tensor]): Dictionary of tensors to save. fn (str): File name to save data to. metadata (Optional[Dict[str, Any]]): Optional metadata dictionary. Default is None. Returns: None """ headers, offset = {}, 0 if metadata: headers["__metadata__"] = metadata for k, v in tensors.items(): headers[k] = { "dtype": inverse_safe_dtypes[v.dtype], "shape": list(v.shape), "data_offsets": [offset, offset + v.nbytes()], } offset += v.nbytes() j = json.dumps(headers, separators=(",", ":")) j += "\x20" * ((8 - len(j) % 8) % 8) pathlib.Path(fn).unlink(missing_ok=True) t = Tensor.empty(8 + len(j) + offset, dtype=dtypes.uint8, device=f"disk:{fn}") t[0:1].cast(dtypes.int64).assign([len(j)]) t[8 : 8 + len(j)].assign( Tensor(list(j.encode("utf-8")), dtype=dtypes.uint8, device="cpu") ) for k, v in safe_load(t).items(): v.assign(tensors[k])
# state dict from collections import OrderedDict
[docs] def get_state_dict(obj, prefix: str = "", tensor_type=Tensor) -> Dict[str, Tensor]: """ Recursively retrieve the state dictionary of an object. This function is used to extract the state dictionary of an object in a nested manner. It starts by checking if the given object is an instance of the tensor_type. If it is, it simply returns a dictionary with the prefixed key and the object as value. If not, it proceeds to check for other conditions such as whether the object has an _asdict attribute or if it's an instance of OrderedDict or has a __dict__ attribute. For list or tuple objects, it iterates over their elements and updates the state dictionary with the result of get_state_dict called on each element. Similarly, for dict objects, it iterates over their key-value pairs and updates the state dictionary with the result of get_state_dict called on each value. If no conditions are met, an empty dictionary is returned. Attributes: obj: The object to retrieve the state dictionary from. prefix (str): Optional; A prefix to prepend to the keys in the state dictionary. Defaults to "". tensor_type: Optional; The type of tensors to consider when building the state dictionary. Defaults to Tensor. Returns: Dict[str, Tensor]: The state dictionary of the object. """ if isinstance(obj, tensor_type): return {prefix.strip("."): obj} if hasattr(obj, "_asdict"): return get_state_dict(obj._asdict(), prefix, tensor_type) # namedtuple if isinstance(obj, OrderedDict): return get_state_dict(dict(obj), prefix, tensor_type) if hasattr(obj, "__dict__"): return get_state_dict(obj.__dict__, prefix, tensor_type) state_dict = {} if isinstance(obj, (list, tuple)): for i, x in enumerate(obj): state_dict.update(get_state_dict(x, f"{prefix}{str(i)}.", tensor_type)) elif isinstance(obj, dict): for k, v in obj.items(): state_dict.update(get_state_dict(v, f"{prefix}{str(k)}.", tensor_type)) return state_dict
[docs] def get_parameters(obj) -> List[Tensor]: """ Retrieve the parameters of an object as a list of tensors. This function calls get_state_dict on the given object and returns its values as a list. Attributes: obj: The object to retrieve the parameters from. Returns: List[Tensor]: A list of tensors representing the parameters of the object. """ return list(get_state_dict(obj).values())
[docs] def load_state_dict(model, state_dict, strict=True, verbose=True): """ Load a state dictionary into a model. This function loads the weights from a state dictionary into a given model. It first retrieves the state dictionary of the model and prints a warning if there are unused weights in the state dictionary. Then, it iterates over each item in the model's state dictionary, setting the description of the progress bar to indicate the current key and memory usage. If strict mode is disabled and a key from the model's state dictionary is not found in the given state_dict, a warning is printed and the iteration continues. Otherwise, it assigns the value from the state_dict to the tensor in the model and realizes the assignment. Attributes: model: The model to load the weights into. state_dict (Dict[str, Tensor]): The state dictionary containing the weights to be loaded. strict (bool): Optional; Whether to strictly enforce that all keys in the model's state dictionary are present in the given state_dict. Defaults to True. verbose (bool): Optional; Whether to display a progress bar during weight loading. Defaults to True. """ with Timing( "loaded weights in ", lambda et_ns: f", {GlobalCounters.mem_used/1e9:.2f} GB loaded at {GlobalCounters.mem_used/et_ns:.2f} GB/s", ): model_state_dict = get_state_dict(model) if DEBUG >= 1 and len(state_dict) > len(model_state_dict): print( "WARNING: unused weights in state_dict", sorted(list(state_dict.keys() - model_state_dict.keys())), ) for k, v in (t := tqdm(model_state_dict.items(), disable=CI or not verbose)): t.set_description( f"ram used: {GlobalCounters.mem_used/1e9:5.2f} GB, {k:50s}" ) if k not in state_dict and not strict: if DEBUG >= 1: print(f"WARNING: not loading {k}") continue v.assign(state_dict[k].to(v.device)).realize()
# torch support!
[docs] def torch_load(fn: str): """ Load a PyTorch tensor from disk. Parameters: fn (str): Path to the file containing the tensor. Returns: Tensor: Loaded tensor from the specified file. Attributes: t (Tensor): Empty tensor with size obtained from the input file's size. offsets (Dict[Union[str, int], int]): Dictionary to store offsets of tensors in the file. lens (Dict[Union[str, int], int]): Dictionary to store lengths of tensors in the file. """ t = Tensor.empty(os.stat(fn).st_size, dtype=dtypes.uint8, device=f"disk:{fn}") offsets: Dict[Union[str, int], int] = {} lens: Dict[Union[str, int], int] = {} def _rebuild_tensor_v2( storage, storage_offset, size, stride, requires_grad=None, backward_hooks=None, metadata=None, ): """ Rebuild a tensor from the given storage parameters. Parameters: storage (Any): Storage object for the tensor. storage_offset (int): Offset of the tensor in the storage. size (Tuple[int]): Shape of the tensor. stride (Tuple[int]): Strides of the tensor dimensions. requires_grad (bool): Whether the tensor requires gradient computation. Default: None. backward_hooks (List[Callable]): List of functions to call on backward pass. Default: None. metadata (Any): Metadata associated with the tensor. Default: None. Returns: Tensor: Rebuilt tensor based on the provided storage parameters. """ # print(storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata) lens[storage[2]] = storage[4] * storage[1].itemsize if storage[2] not in offsets: return None byte_offset = offsets[storage[2]] + storage_offset * storage[1].itemsize ret = t[byte_offset : byte_offset + prod(size)].cast(storage[1]) # convert bfloat16 -> float16 using LLVM for Llama 2 # upstream LLaMA also does this conversion: # https://github.com/facebookresearch/llama/blob/6c7fe276574e78057f917549435a2554000a876d/llama/generation.py#L95 # TODO: should this be done in the example instead? or maybe we don't need this anymore with better bfloat16 support if storage[1] == dtypes.bfloat16: ret = ( ret.bitcast(dtypes.uint16) .to("CPU") .cast(dtypes.uint32) .mul(1 << 16) .bitcast(dtypes.float32) .to(Device.DEFAULT) .half() ) # ret = ret.to("LLVM").half().to(Device.DEFAULT) # 7 lines to deal with permuted tensors. NOTE: this currently requires reading off the disk shape_strides = [(s, st) for s, st in zip(size, stride) if s != 1] permute_indexes = [ len(shape_strides) - 1 - y for y in argsort([x[1] for x in shape_strides]) ] if tuple(permute_indexes) != tuple(range(len(permute_indexes))): intermediate_shape = tuple( [shape_strides[x][0] for x in argsort(permute_indexes)] ) assert tuple( [shape_strides[i][1] for i in argsort(permute_indexes)] ) == strides_for_shape(intermediate_shape), "nonpermutable strides" if DEBUG >= 3: print( f"WARNING: this torch load is slow. CPU to permute {intermediate_shape} with {permute_indexes}" ) # TODO: find a nice way to support all shapetracker on disktensors ret = ret.cpu().reshape(intermediate_shape).permute(permute_indexes) return ret.reshape(size) class Parameter: """ This class is used to define a parameter for a tensor. Attributes: tensor (State): The state of the tensor. """ def __setstate__(self, state): """ Set the state of the tensor. Args: state (List[Any]): A list containing the state information for a tensor. """ self.tensor = state[0] deserialized_objects: Dict[str, Any] = {} """ This dictionary is used to store deserialized objects. Attributes: deserialized_objects (Dict[str, Any]): A dictionary with keys as strings and values of any type. """ intercept = { "HalfStorage": dtypes.float16, "FloatStorage": dtypes.float32, "BFloat16Storage": dtypes.bfloat16, "IntStorage": dtypes.int32, "LongStorage": dtypes.int64, "_rebuild_tensor_v2": _rebuild_tensor_v2, "FloatTensor": None, "Parameter": Parameter, } """ This dictionary is used to intercept certain storage types and functions for deserialization. Attributes: intercept (Dict[str, Any]): A dictionary with keys as strings and values of any type. """ whitelist = { "torch", "collections", "numpy", "_codecs", } # NOTE: this is not for security, only speed """ This set is used to define a whitelist for speed purposes, not security. Attributes: whitelist (Set[str]): A set containing various libraries and modules. """ class Dummy: """ This class is a dummy class that can be used as a placeholder or for testing purposes. """ pass class TorchPickle(pickle.Unpickler): """ Custom Unpickler class for loading PyTorch objects from pickled files. Attributes: find_class (method): Overrides the default `find_class` method of `pickle.Unpickler`. This method is called by the unpickler whenever it encounters a class definition in the pickled data. It will return a custom Dummy class if the module being loaded is not in the whitelist, otherwise it will either return a class from the intercept dictionary or use the default behavior of the parent class. persistent_load (method): Overrides the `persistent_load` method of `pickle.Unpickler`. This method is called by the unpickler when loading persistent objects. It will return a deserialized object if its pid is found in the `deserialized_objects` dictionary, otherwise it will simply return the pid. """ def find_class(self, module, name): """ Overrides default method to handle loading of classes from pickled data. Args: module (str): The name of the module where the class is defined. name (str): The name of the class. Returns: type: The Class object for the given module and name, or a Dummy class if the module is not in the whitelist. """ module_root = module.split(".")[0] if module_root not in whitelist: if DEBUG >= 2: print(f"WARNING: returning Dummy for {module} {name}") return Dummy return ( intercept[name] if module_root == "torch" else super().find_class(module, name) ) def persistent_load(self, pid): """ Overrides default method to handle loading of persistent objects. Args: pid (int): The persistent ID of the object to be loaded. Returns: object: The deserialized object with the given pid, or the pid if no such object is found. """ return deserialized_objects[pid] if pid in deserialized_objects else pid if tuple(t[0:2].numpy()) == (0x50, 0x4B): myzip = zipfile.ZipFile(fn, "r") base_name = myzip.namelist()[0].split("/", 1)[0] for n in myzip.namelist(): if n.startswith(f"{base_name}/data/"): with myzip.open(n) as myfile: offsets[n.split("/")[-1]] = myfile._orig_compress_start # type: ignore with myzip.open(f"{base_name}/data.pkl") as myfile: return TorchPickle(myfile).load() elif ( bytes(t[0:0xE].numpy()) == b"././@PaxHeader" ): # TODO: is this how you detect a tarfile? with tarfile.open(fn, "r") as tar: storages_offset = tar.getmember("storages").offset_data f = unwrap(tar.extractfile("storages")) for i in range(TorchPickle(f).load()): # num_storages (key, _, storage_type), sz = ( TorchPickle(f).load(), struct.unpack("<q", f.read(8))[0], ) offsets[key] = storages_offset + f.tell() f.seek(sz * storage_type.itemsize, 1) f = unwrap(tar.extractfile("tensors")) for _ in range(TorchPickle(f).load()): # num_tensors (key, storage_id, _), ndim, _ = ( TorchPickle(f).load(), struct.unpack("<i", f.read(4))[0], f.read(4), ) size, stride, storage_offset = ( struct.unpack(f"<{ndim}q", f.read(8 * ndim)), struct.unpack(f"<{ndim}q", f.read(8 * ndim)), struct.unpack("<q", f.read(8))[0], ) deserialized_objects[str(key)] = _rebuild_tensor_v2( (None, storage_type, storage_id, None, -1), storage_offset, size, stride, ) return { k: v.tensor if isinstance(v, Parameter) else v for k, v in TorchPickle(unwrap(tar.extractfile("pickle"))) .load() .items() } else: with open(fn, "rb") as f: pkl = TorchPickle(f) _, _, _, rwd, _, ids, base_offset = ( pkl.load(), pkl.load(), pkl.load(), f.tell(), pkl.load(), pkl.load(), f.tell(), ) for i in ids: offsets[i] = base_offset + 8 base_offset += 8 + lens[i] f.seek(rwd) return TorchPickle(f).load()