import os, json, pathlib, zipfile, pickle, tarfile, struct
from tqdm import tqdm
from typing import Dict, Union, List, Optional, Any, Tuple
from tinygrad.tensor import Tensor
from tinygrad.helpers import (
dtypes,
prod,
argsort,
DEBUG,
Timing,
GlobalCounters,
CI,
unwrap,
)
from tinygrad.shape.view import strides_for_shape
from tinygrad import Device
safe_dtypes = {
"F16": dtypes.float16,
"F32": dtypes.float32,
"U8": dtypes.uint8,
"I8": dtypes.int8,
"I32": dtypes.int32,
"I64": dtypes.int64,
}
inverse_safe_dtypes = {v: k for k, v in safe_dtypes.items()}
[docs]
def safe_load(fn: Union[Tensor, str]) -> Dict[str, Tensor]:
"""
Safely load tensors from a file or tensor.
Parameters:
fn (Union[Tensor, str]): File name or tensor to load data from.
Returns:
Dict[str, Tensor]: A dictionary containing the loaded tensors.
"""
t, json_len, metadata = safe_load_metadata(fn)
return {
k: t[8 + json_len + v["data_offsets"][0] :]
.cast(safe_dtypes[v["dtype"]])[: prod(v["shape"])]
.reshape(v["shape"])
for k, v in metadata.items()
if k != "__metadata__"
}
[docs]
def safe_save(
tensors: Dict[str, Tensor], fn: str, metadata: Optional[Dict[str, Any]] = None
):
"""
Safely save tensors to a file.
Parameters:
tensors (Dict[str, Tensor]): Dictionary of tensors to save.
fn (str): File name to save data to.
metadata (Optional[Dict[str, Any]]): Optional metadata dictionary. Default is None.
Returns:
None
"""
headers, offset = {}, 0
if metadata:
headers["__metadata__"] = metadata
for k, v in tensors.items():
headers[k] = {
"dtype": inverse_safe_dtypes[v.dtype],
"shape": list(v.shape),
"data_offsets": [offset, offset + v.nbytes()],
}
offset += v.nbytes()
j = json.dumps(headers, separators=(",", ":"))
j += "\x20" * ((8 - len(j) % 8) % 8)
pathlib.Path(fn).unlink(missing_ok=True)
t = Tensor.empty(8 + len(j) + offset, dtype=dtypes.uint8, device=f"disk:{fn}")
t[0:1].cast(dtypes.int64).assign([len(j)])
t[8 : 8 + len(j)].assign(
Tensor(list(j.encode("utf-8")), dtype=dtypes.uint8, device="cpu")
)
for k, v in safe_load(t).items():
v.assign(tensors[k])
# state dict
from collections import OrderedDict
[docs]
def get_state_dict(obj, prefix: str = "", tensor_type=Tensor) -> Dict[str, Tensor]:
"""
Recursively retrieve the state dictionary of an object.
This function is used to extract the state dictionary of an object in a nested manner. It starts by checking if the
given object is an instance of the tensor_type. If it is, it simply returns a dictionary with the prefixed key and
the object as value. If not, it proceeds to check for other conditions such as whether the object has an _asdict
attribute or if it's an instance of OrderedDict or has a __dict__ attribute. For list or tuple objects, it
iterates over their elements and updates the state dictionary with the result of get_state_dict called on each
element. Similarly, for dict objects, it iterates over their key-value pairs and updates the state dictionary
with the result of get_state_dict called on each value. If no conditions are met, an empty dictionary is returned.
Attributes:
obj: The object to retrieve the state dictionary from.
prefix (str): Optional; A prefix to prepend to the keys in the state dictionary. Defaults to "".
tensor_type: Optional; The type of tensors to consider when building the state dictionary. Defaults to Tensor.
Returns:
Dict[str, Tensor]: The state dictionary of the object.
"""
if isinstance(obj, tensor_type):
return {prefix.strip("."): obj}
if hasattr(obj, "_asdict"):
return get_state_dict(obj._asdict(), prefix, tensor_type) # namedtuple
if isinstance(obj, OrderedDict):
return get_state_dict(dict(obj), prefix, tensor_type)
if hasattr(obj, "__dict__"):
return get_state_dict(obj.__dict__, prefix, tensor_type)
state_dict = {}
if isinstance(obj, (list, tuple)):
for i, x in enumerate(obj):
state_dict.update(get_state_dict(x, f"{prefix}{str(i)}.", tensor_type))
elif isinstance(obj, dict):
for k, v in obj.items():
state_dict.update(get_state_dict(v, f"{prefix}{str(k)}.", tensor_type))
return state_dict
[docs]
def get_parameters(obj) -> List[Tensor]:
"""
Retrieve the parameters of an object as a list of tensors.
This function calls get_state_dict on the given object and returns its values as a list.
Attributes:
obj: The object to retrieve the parameters from.
Returns:
List[Tensor]: A list of tensors representing the parameters of the object.
"""
return list(get_state_dict(obj).values())
[docs]
def load_state_dict(model, state_dict, strict=True, verbose=True):
"""
Load a state dictionary into a model.
This function loads the weights from a state dictionary into a given model. It first retrieves the state
dictionary of the model and prints a warning if there are unused weights in the state dictionary. Then, it
iterates over each item in the model's state dictionary, setting the description of the progress bar to
indicate the current key and memory usage. If strict mode is disabled and a key from the model's state
dictionary is not found in the given state_dict, a warning is printed and the iteration continues. Otherwise,
it assigns the value from the state_dict to the tensor in the model and realizes the assignment.
Attributes:
model: The model to load the weights into.
state_dict (Dict[str, Tensor]): The state dictionary containing the weights to be loaded.
strict (bool): Optional; Whether to strictly enforce that all keys in the model's state dictionary are
present in the given state_dict. Defaults to True.
verbose (bool): Optional; Whether to display a progress bar during weight loading. Defaults to True.
"""
with Timing(
"loaded weights in ",
lambda et_ns: f", {GlobalCounters.mem_used/1e9:.2f} GB loaded at {GlobalCounters.mem_used/et_ns:.2f} GB/s",
):
model_state_dict = get_state_dict(model)
if DEBUG >= 1 and len(state_dict) > len(model_state_dict):
print(
"WARNING: unused weights in state_dict",
sorted(list(state_dict.keys() - model_state_dict.keys())),
)
for k, v in (t := tqdm(model_state_dict.items(), disable=CI or not verbose)):
t.set_description(
f"ram used: {GlobalCounters.mem_used/1e9:5.2f} GB, {k:50s}"
)
if k not in state_dict and not strict:
if DEBUG >= 1:
print(f"WARNING: not loading {k}")
continue
v.assign(state_dict[k].to(v.device)).realize()
# torch support!
[docs]
def torch_load(fn: str):
"""
Load a PyTorch tensor from disk.
Parameters:
fn (str): Path to the file containing the tensor.
Returns:
Tensor: Loaded tensor from the specified file.
Attributes:
t (Tensor): Empty tensor with size obtained from the input file's size.
offsets (Dict[Union[str, int], int]): Dictionary to store offsets of tensors in the file.
lens (Dict[Union[str, int], int]): Dictionary to store lengths of tensors in the file.
"""
t = Tensor.empty(os.stat(fn).st_size, dtype=dtypes.uint8, device=f"disk:{fn}")
offsets: Dict[Union[str, int], int] = {}
lens: Dict[Union[str, int], int] = {}
def _rebuild_tensor_v2(
storage,
storage_offset,
size,
stride,
requires_grad=None,
backward_hooks=None,
metadata=None,
):
"""
Rebuild a tensor from the given storage parameters.
Parameters:
storage (Any): Storage object for the tensor.
storage_offset (int): Offset of the tensor in the storage.
size (Tuple[int]): Shape of the tensor.
stride (Tuple[int]): Strides of the tensor dimensions.
requires_grad (bool): Whether the tensor requires gradient computation. Default: None.
backward_hooks (List[Callable]): List of functions to call on backward pass. Default: None.
metadata (Any): Metadata associated with the tensor. Default: None.
Returns:
Tensor: Rebuilt tensor based on the provided storage parameters.
"""
# print(storage, storage_offset, size, stride, requires_grad, backward_hooks, metadata)
lens[storage[2]] = storage[4] * storage[1].itemsize
if storage[2] not in offsets:
return None
byte_offset = offsets[storage[2]] + storage_offset * storage[1].itemsize
ret = t[byte_offset : byte_offset + prod(size)].cast(storage[1])
# convert bfloat16 -> float16 using LLVM for Llama 2
# upstream LLaMA also does this conversion:
# https://github.com/facebookresearch/llama/blob/6c7fe276574e78057f917549435a2554000a876d/llama/generation.py#L95
# TODO: should this be done in the example instead? or maybe we don't need this anymore with better bfloat16 support
if storage[1] == dtypes.bfloat16:
ret = (
ret.bitcast(dtypes.uint16)
.to("CPU")
.cast(dtypes.uint32)
.mul(1 << 16)
.bitcast(dtypes.float32)
.to(Device.DEFAULT)
.half()
)
# ret = ret.to("LLVM").half().to(Device.DEFAULT)
# 7 lines to deal with permuted tensors. NOTE: this currently requires reading off the disk
shape_strides = [(s, st) for s, st in zip(size, stride) if s != 1]
permute_indexes = [
len(shape_strides) - 1 - y for y in argsort([x[1] for x in shape_strides])
]
if tuple(permute_indexes) != tuple(range(len(permute_indexes))):
intermediate_shape = tuple(
[shape_strides[x][0] for x in argsort(permute_indexes)]
)
assert tuple(
[shape_strides[i][1] for i in argsort(permute_indexes)]
) == strides_for_shape(intermediate_shape), "nonpermutable strides"
if DEBUG >= 3:
print(
f"WARNING: this torch load is slow. CPU to permute {intermediate_shape} with {permute_indexes}"
)
# TODO: find a nice way to support all shapetracker on disktensors
ret = ret.cpu().reshape(intermediate_shape).permute(permute_indexes)
return ret.reshape(size)
class Parameter:
"""
This class is used to define a parameter for a tensor.
Attributes:
tensor (State): The state of the tensor.
"""
def __setstate__(self, state):
"""
Set the state of the tensor.
Args:
state (List[Any]): A list containing the state information for a tensor.
"""
self.tensor = state[0]
deserialized_objects: Dict[str, Any] = {}
"""
This dictionary is used to store deserialized objects.
Attributes:
deserialized_objects (Dict[str, Any]): A dictionary with keys as strings and values of any type.
"""
intercept = {
"HalfStorage": dtypes.float16,
"FloatStorage": dtypes.float32,
"BFloat16Storage": dtypes.bfloat16,
"IntStorage": dtypes.int32,
"LongStorage": dtypes.int64,
"_rebuild_tensor_v2": _rebuild_tensor_v2,
"FloatTensor": None,
"Parameter": Parameter,
}
"""
This dictionary is used to intercept certain storage types and functions for deserialization.
Attributes:
intercept (Dict[str, Any]): A dictionary with keys as strings and values of any type.
"""
whitelist = {
"torch",
"collections",
"numpy",
"_codecs",
} # NOTE: this is not for security, only speed
"""
This set is used to define a whitelist for speed purposes, not security.
Attributes:
whitelist (Set[str]): A set containing various libraries and modules.
"""
class Dummy:
"""
This class is a dummy class that can be used as a placeholder or for testing purposes.
"""
pass
class TorchPickle(pickle.Unpickler):
"""
Custom Unpickler class for loading PyTorch objects from pickled files.
Attributes:
find_class (method): Overrides the default `find_class` method of `pickle.Unpickler`. This method is called by
the unpickler whenever it encounters a class definition in the pickled data. It will return a custom Dummy
class if the module being loaded is not in the whitelist, otherwise it will either return a class from the
intercept dictionary or use the default behavior of the parent class.
persistent_load (method): Overrides the `persistent_load` method of `pickle.Unpickler`. This method is called
by the unpickler when loading persistent objects. It will return a deserialized object if its pid is found in
the `deserialized_objects` dictionary, otherwise it will simply return the pid.
"""
def find_class(self, module, name):
"""
Overrides default method to handle loading of classes from pickled data.
Args:
module (str): The name of the module where the class is defined.
name (str): The name of the class.
Returns:
type: The Class object for the given module and name, or a Dummy class if the module is not in the whitelist.
"""
module_root = module.split(".")[0]
if module_root not in whitelist:
if DEBUG >= 2:
print(f"WARNING: returning Dummy for {module} {name}")
return Dummy
return (
intercept[name]
if module_root == "torch"
else super().find_class(module, name)
)
def persistent_load(self, pid):
"""
Overrides default method to handle loading of persistent objects.
Args:
pid (int): The persistent ID of the object to be loaded.
Returns:
object: The deserialized object with the given pid, or the pid if no such object is found.
"""
return deserialized_objects[pid] if pid in deserialized_objects else pid
if tuple(t[0:2].numpy()) == (0x50, 0x4B):
myzip = zipfile.ZipFile(fn, "r")
base_name = myzip.namelist()[0].split("/", 1)[0]
for n in myzip.namelist():
if n.startswith(f"{base_name}/data/"):
with myzip.open(n) as myfile:
offsets[n.split("/")[-1]] = myfile._orig_compress_start # type: ignore
with myzip.open(f"{base_name}/data.pkl") as myfile:
return TorchPickle(myfile).load()
elif (
bytes(t[0:0xE].numpy()) == b"././@PaxHeader"
): # TODO: is this how you detect a tarfile?
with tarfile.open(fn, "r") as tar:
storages_offset = tar.getmember("storages").offset_data
f = unwrap(tar.extractfile("storages"))
for i in range(TorchPickle(f).load()): # num_storages
(key, _, storage_type), sz = (
TorchPickle(f).load(),
struct.unpack("<q", f.read(8))[0],
)
offsets[key] = storages_offset + f.tell()
f.seek(sz * storage_type.itemsize, 1)
f = unwrap(tar.extractfile("tensors"))
for _ in range(TorchPickle(f).load()): # num_tensors
(key, storage_id, _), ndim, _ = (
TorchPickle(f).load(),
struct.unpack("<i", f.read(4))[0],
f.read(4),
)
size, stride, storage_offset = (
struct.unpack(f"<{ndim}q", f.read(8 * ndim)),
struct.unpack(f"<{ndim}q", f.read(8 * ndim)),
struct.unpack("<q", f.read(8))[0],
)
deserialized_objects[str(key)] = _rebuild_tensor_v2(
(None, storage_type, storage_id, None, -1),
storage_offset,
size,
stride,
)
return {
k: v.tensor if isinstance(v, Parameter) else v
for k, v in TorchPickle(unwrap(tar.extractfile("pickle")))
.load()
.items()
}
else:
with open(fn, "rb") as f:
pkl = TorchPickle(f)
_, _, _, rwd, _, ids, base_offset = (
pkl.load(),
pkl.load(),
pkl.load(),
f.tell(),
pkl.load(),
pkl.load(),
f.tell(),
)
for i in ids:
offsets[i] = base_offset + 8
base_offset += 8 + lens[i]
f.seek(rwd)
return TorchPickle(f).load()