tinygrad nn.state
Note
You likely want the upstream tinygrad, not tinygrab. Tinygrab contains AI generated docstrings for a tinygrad snapshot. Upstream: https://tinygrad.org
- tinygrad.nn.state.get_parameters(obj) List[Tensor] [source]
Retrieve the parameters of an object as a list of tensors.
This function calls get_state_dict on the given object and returns its values as a list.
- tinygrad.nn.state.obj
The object to retrieve the parameters from.
- Returns:
A list of tensors representing the parameters of the object.
- Return type:
List[Tensor]
- tinygrad.nn.state.get_state_dict(obj, prefix: str = '', tensor_type=<class 'tinygrad.tensor.Tensor'>) Dict[str, Tensor] [source]
Recursively retrieve the state dictionary of an object.
This function is used to extract the state dictionary of an object in a nested manner. It starts by checking if the given object is an instance of the tensor_type. If it is, it simply returns a dictionary with the prefixed key and the object as value. If not, it proceeds to check for other conditions such as whether the object has an _asdict attribute or if it’s an instance of OrderedDict or has a __dict__ attribute. For list or tuple objects, it iterates over their elements and updates the state dictionary with the result of get_state_dict called on each element. Similarly, for dict objects, it iterates over their key-value pairs and updates the state dictionary with the result of get_state_dict called on each value. If no conditions are met, an empty dictionary is returned.
- tinygrad.nn.state.obj
The object to retrieve the state dictionary from.
- tinygrad.nn.state.prefix
Optional; A prefix to prepend to the keys in the state dictionary. Defaults to “”.
- Type:
str
- tinygrad.nn.state.tensor_type
Optional; The type of tensors to consider when building the state dictionary. Defaults to Tensor.
- Returns:
The state dictionary of the object.
- Return type:
Dict[str, Tensor]
- tinygrad.nn.state.load_state_dict(model, state_dict, strict=True, verbose=True)[source]
Load a state dictionary into a model.
This function loads the weights from a state dictionary into a given model. It first retrieves the state dictionary of the model and prints a warning if there are unused weights in the state dictionary. Then, it iterates over each item in the model’s state dictionary, setting the description of the progress bar to indicate the current key and memory usage. If strict mode is disabled and a key from the model’s state dictionary is not found in the given state_dict, a warning is printed and the iteration continues. Otherwise, it assigns the value from the state_dict to the tensor in the model and realizes the assignment.
- tinygrad.nn.state.model
The model to load the weights into.
- tinygrad.nn.state.state_dict
The state dictionary containing the weights to be loaded.
- Type:
Dict[str, Tensor]
- tinygrad.nn.state.strict
Optional; Whether to strictly enforce that all keys in the model’s state dictionary are present in the given state_dict. Defaults to True.
- Type:
bool
- tinygrad.nn.state.verbose
Optional; Whether to display a progress bar during weight loading. Defaults to True.
- Type:
bool
- tinygrad.nn.state.safe_load(fn: Tensor | str) Dict[str, Tensor] [source]
Safely load tensors from a file or tensor.
- tinygrad.nn.state.safe_load_metadata(fn: Tensor | str) Tuple[Tensor, int, Any] [source]
Safely load metadata from a file or tensor.
- tinygrad.nn.state.safe_save(tensors: Dict[str, Tensor], fn: str, metadata: Dict[str, Any] | None = None)[source]
Safely save tensors to a file.
- Parameters:
tensors (Dict[str, Tensor]) – Dictionary of tensors to save.
fn (str) – File name to save data to.
metadata (Optional[Dict[str, Any]]) – Optional metadata dictionary. Default is None.
- Returns:
None
- tinygrad.nn.state.torch_load(fn: str)[source]
Load a PyTorch tensor from disk.
- Parameters:
fn (str) – Path to the file containing the tensor.
- Returns:
Loaded tensor from the specified file.
- Return type:
- tinygrad.nn.state.offsets
Dictionary to store offsets of tensors in the file.
- Type:
Dict[Union[str, int], int]
- tinygrad.nn.state.lens
Dictionary to store lengths of tensors in the file.
- Type:
Dict[Union[str, int], int]