jactorch.io#

Functions

load_state_dict(model, state_dict[, ...])

Load a state dict into the model.

load_weights(model, filename[, include, ...])

Load weights from a file.

state_dict(model[, include, exclude, cpu])

Get a state dict representation of the model.

Exceptions

Functions

load_state_dict(model, state_dict, include=None, exclude=None)[source]#

Load a state dict into the model. This function is similar to model.load_state_dict(), but it also supports additional features including:

  • include and exclude: only include/exclude some parameters.

  • extra_state_dict: each module can implement a custom load_extra_state_dict method which load additional state

    of the module (such as hyperparameters, random seeds, etc.).

Furthermore, this function will not raise exceptions when there are missing or unexpected parameters. This is similar to the latest PyTorch behavior specifed by strict=False.

Parameters:
  • model (Module) – the PyTorch model.

  • state_dict (Dict[str, Any]) – the state dict to be loaded.

  • include (Sequence[str] | None) – a list of parameter names to be included in the state dict (supports globbing).

  • exclude (Sequence[str] | None) – a list of parameter names to be excluded from the state dict (supports globbing).

load_weights(model, filename, include=None, exclude=None, return_raw=True)[source]#

Load weights from a file. Internally this function calls load_state_dict(). It handles the case where the loaded file is a checkpoint (i.e., containing both weights and optimizer states). It will automatically detect such case and extract the weights.

Parameters:
  • model – the PyTorch model.

  • filename – the file name.

  • include – a list of parameter names to be included in the state dict (supports globbing).

  • exclude – a list of parameter names to be excluded from the state dict (supports globbing).

  • return_raw – whether to return the raw state dict. If False, this function will return True when the weights are successfully loaded. If True, this function will return the loaded file (either a state dict or a checkpoint). when the weights are successfully loaded, and None otherwise.

Returns:

  • If return_raw=False, return True when the weights are successfully loaded, and False otherwise.

  • If return_raw=True, return the loaded file (either a state dict or a checkpoint) when the weights are successfully loaded,

    and None otherwise.

Return type:

bool | Dict[str, Any] | None

state_dict(model, include=None, exclude=None, cpu=True)[source]#

Get a state dict representation of the model. This function is similar to model.state_dict(), but it also supports additional features including:

  • include and exclude: only include/exclude some parameters.

  • cpu: move the parameters to CPU.

  • extra_state_dict: each module can implement a custom extra_state_dict method which return additional state

    of the module (such as hyperparameters, random seeds, etc.).

Parameters:
  • model (Module) – the PyTorch model.

  • include (Sequence[str] | None) – a list of parameter names to be included in the state dict (supports globbing).

  • exclude (Sequence[str] | None) – a list of parameter names to be excluded from the state dict (supports globbing).

  • cpu (bool) – whether to move the parameters to CPU.

Returns:

A state dict.

Return type:

Dict[str, Any]