tinygrad nn.state

Note

You likely want the upstream tinygrad, not tinygrab. Tinygrab contains AI generated docstrings for a tinygrad snapshot. Upstream: https://tinygrad.org

tinygrad.nn.state.get_parameters(obj) List[Tensor][source]

Retrieve the parameters of an object as a list of tensors.

This function calls get_state_dict on the given object and returns its values as a list.

tinygrad.nn.state.obj

The object to retrieve the parameters from.

Returns:

A list of tensors representing the parameters of the object.

Return type:

List[Tensor]

tinygrad.nn.state.get_state_dict(obj, prefix: str = '', tensor_type=<class 'tinygrad.tensor.Tensor'>) Dict[str, Tensor][source]

Recursively retrieve the state dictionary of an object.

This function is used to extract the state dictionary of an object in a nested manner. It starts by checking if the given object is an instance of the tensor_type. If it is, it simply returns a dictionary with the prefixed key and the object as value. If not, it proceeds to check for other conditions such as whether the object has an _asdict attribute or if it’s an instance of OrderedDict or has a __dict__ attribute. For list or tuple objects, it iterates over their elements and updates the state dictionary with the result of get_state_dict called on each element. Similarly, for dict objects, it iterates over their key-value pairs and updates the state dictionary with the result of get_state_dict called on each value. If no conditions are met, an empty dictionary is returned.

tinygrad.nn.state.obj

The object to retrieve the state dictionary from.

tinygrad.nn.state.prefix

Optional; A prefix to prepend to the keys in the state dictionary. Defaults to “”.

Type:

str

tinygrad.nn.state.tensor_type

Optional; The type of tensors to consider when building the state dictionary. Defaults to Tensor.

Returns:

The state dictionary of the object.

Return type:

Dict[str, Tensor]

tinygrad.nn.state.load_state_dict(model, state_dict, strict=True, verbose=True)[source]

Load a state dictionary into a model.

This function loads the weights from a state dictionary into a given model. It first retrieves the state dictionary of the model and prints a warning if there are unused weights in the state dictionary. Then, it iterates over each item in the model’s state dictionary, setting the description of the progress bar to indicate the current key and memory usage. If strict mode is disabled and a key from the model’s state dictionary is not found in the given state_dict, a warning is printed and the iteration continues. Otherwise, it assigns the value from the state_dict to the tensor in the model and realizes the assignment.

tinygrad.nn.state.model

The model to load the weights into.

tinygrad.nn.state.state_dict

The state dictionary containing the weights to be loaded.

Type:

Dict[str, Tensor]

tinygrad.nn.state.strict

Optional; Whether to strictly enforce that all keys in the model’s state dictionary are present in the given state_dict. Defaults to True.

Type:

bool

tinygrad.nn.state.verbose

Optional; Whether to display a progress bar during weight loading. Defaults to True.

Type:

bool

tinygrad.nn.state.safe_load(fn: Tensor | str) Dict[str, Tensor][source]

Safely load tensors from a file or tensor.

Parameters:

fn (Union[Tensor, str]) – File name or tensor to load data from.

Returns:

A dictionary containing the loaded tensors.

Return type:

Dict[str, Tensor]

tinygrad.nn.state.safe_load_metadata(fn: Tensor | str) Tuple[Tensor, int, Any][source]

Safely load metadata from a file or tensor.

Parameters:

fn (Union[Tensor, str]) – File name or tensor to load metadata from.

Returns:

A tuple containing the tensor, json length, and metadata.

Return type:

Tuple[Tensor, int, Any]

tinygrad.nn.state.safe_save(tensors: Dict[str, Tensor], fn: str, metadata: Dict[str, Any] | None = None)[source]

Safely save tensors to a file.

Parameters:
  • tensors (Dict[str, Tensor]) – Dictionary of tensors to save.

  • fn (str) – File name to save data to.

  • metadata (Optional[Dict[str, Any]]) – Optional metadata dictionary. Default is None.

Returns:

None

tinygrad.nn.state.torch_load(fn: str)[source]

Load a PyTorch tensor from disk.

Parameters:

fn (str) – Path to the file containing the tensor.

Returns:

Loaded tensor from the specified file.

Return type:

Tensor

tinygrad.nn.state.t

Empty tensor with size obtained from the input file’s size.

Type:

Tensor

tinygrad.nn.state.offsets

Dictionary to store offsets of tensors in the file.

Type:

Dict[Union[str, int], int]

tinygrad.nn.state.lens

Dictionary to store lengths of tensors in the file.

Type:

Dict[Union[str, int], int]