jactorch.io#
Functions
|
Load a state dict into the model. |
|
Load weights from a file. |
|
Get a state dict representation of the model. |
Exceptions
Functions
- load_state_dict(model, state_dict, include=None, exclude=None)[source]#
Load a state dict into the model. This function is similar to
model.load_state_dict()
, but it also supports additional features including:include
andexclude
: only include/exclude some parameters.extra_state_dict
: each module can implement a customload_extra_state_dict
method which load additional stateof the module (such as hyperparameters, random seeds, etc.).
Furthermore, this function will not raise exceptions when there are missing or unexpected parameters. This is similar to the latest PyTorch behavior specifed by
strict=False
.
- load_weights(model, filename, include=None, exclude=None, return_raw=True)[source]#
Load weights from a file. Internally this function calls
load_state_dict()
. It handles the case where the loaded file is a checkpoint (i.e., containing both weights and optimizer states). It will automatically detect such case and extract the weights.- Parameters:
model – the PyTorch model.
filename – the file name.
include – a list of parameter names to be included in the state dict (supports globbing).
exclude – a list of parameter names to be excluded from the state dict (supports globbing).
return_raw – whether to return the raw state dict. If
False
, this function will returnTrue
when the weights are successfully loaded. IfTrue
, this function will return the loaded file (either a state dict or a checkpoint). when the weights are successfully loaded, andNone
otherwise.
- Returns:
If
return_raw=False
, returnTrue
when the weights are successfully loaded, andFalse
otherwise.- If
return_raw=True
, return the loaded file (either a state dict or a checkpoint) when the weights are successfully loaded, and
None
otherwise.
- If
- Return type:
- state_dict(model, include=None, exclude=None, cpu=True)[source]#
Get a state dict representation of the model. This function is similar to
model.state_dict()
, but it also supports additional features including:include
andexclude
: only include/exclude some parameters.cpu
: move the parameters to CPU.extra_state_dict
: each module can implement a customextra_state_dict
method which return additional stateof the module (such as hyperparameters, random seeds, etc.).
- Parameters:
model (Module) – the PyTorch model.
include (Sequence[str] | None) – a list of parameter names to be included in the state dict (supports globbing).
exclude (Sequence[str] | None) – a list of parameter names to be excluded from the state dict (supports globbing).
cpu (bool) – whether to move the parameters to CPU.
- Returns:
A state dict.
- Return type: