jactorch.graph.parameter#

Utilities to access, filter, and mark parameters in a torch.nn.Module.

Functions

compose_param_groups(model, *groups[, ...])

Compose the param_groups argument for torch optimizers.

detach_modules(*modules)

A context manager that temporarily detach all parameters in the input list of modules.

exclude_parameters(params, exclude)

Exclude parameters from a list of parameters.

filter_parameters(params, pattern[, ...])

Filter parameters with a pattern.

find_parameters(module, pattern[, return_names])

Find parameters in a module with a pattern.

mark_freezed(model)

Freeze all parameters in a model.

mark_unfreezed(model)

Unfreeze all parameters in a model.

param_group(pattern, **kwargs)

A helper function used for human-friendly declaration of param groups.

Functions

compose_param_groups(model, *groups, filter_grad=True, verbose=True)[source]#

Compose the param_groups argument for torch optimizers.

Examples

>>> optim.Adam(compose_param_groups(
...     param_group('*.weight', lr=0.01)
...     param_group('*.bias', lr=0.02)
... ), lr=0.1)
Parameters:
  • model (Module) – the model containing optimizable variables.

  • *groups (Tuple[str, Dict]) – groups defined by patterns, of form (pattern, special_params).

  • filter_grad (bool) – only choose parameters with requires_grad=True.

  • verbose (bool) – whether to print the parameters in each group.

Returns:

param_groups argument that can be passed to torch optimizers.

detach_modules(*modules)[source]#

A context manager that temporarily detach all parameters in the input list of modules.

Example

>>> output1 = m2(m1(input1))
>>> with jactorch.detach_modules(m1, m2):  # or jactorch.detach_modules([m1, m2])
...     output2 = m2(m1(input2))
>>> loss(output1, output2).backward()

The loss from branch output2 will not back-propagate to m1 and m2.

Parameters:

*modules – the modules to detach. It can also be a single list of modules.

exclude_parameters(params, exclude)[source]#

Exclude parameters from a list of parameters.

Parameters:
Return type:

List[Parameter]

filter_parameters(params, pattern, return_names=False)[source]#

Filter parameters with a pattern.

Parameters:
  • params (Iterable[Parameter]) – the parameters to filter.

  • pattern (Iterable[str] | str) – the pattern(s) to match.

  • return_names (bool) – whether to return the names of the parameters.

Returns:

a list of parameters, or a list of (name, parameter) pairs if return_names is True.

Return type:

List[Parameter] | List[Tuple[str, Parameter]]

find_parameters(module, pattern, return_names=False)[source]#

Find parameters in a module with a pattern.

Parameters:
  • module (Module) – the module to search.

  • pattern (Iterable[str] | str) – the pattern(s) to match.

  • return_names (bool) – whether to return the names of the parameters.

Returns:

a list of parameters, or a list of (name, parameter) pairs if return_names is True.

Return type:

List[Parameter] | List[Tuple[str, Parameter]]

mark_freezed(model)[source]#

Freeze all parameters in a model.

Parameters:

model (Module)

mark_unfreezed(model)[source]#

Unfreeze all parameters in a model.

Parameters:

model (Module)

param_group(pattern, **kwargs)[source]#

A helper function used for human-friendly declaration of param groups.

Parameters:

pattern (str)

Return type:

Tuple[str, Dict]