jactorch.graph.parameter#
Utilities to access, filter, and mark parameters in a torch.nn.Module
.
Functions
|
Compose the param_groups argument for torch optimizers. |
|
A context manager that temporarily detach all parameters in the input list of modules. |
|
Exclude parameters from a list of parameters. |
|
Filter parameters with a pattern. |
|
Find parameters in a module with a pattern. |
|
Freeze all parameters in a model. |
|
Unfreeze all parameters in a model. |
|
A helper function used for human-friendly declaration of param groups. |
Functions
- compose_param_groups(model, *groups, filter_grad=True, verbose=True)[source]#
Compose the param_groups argument for torch optimizers.
Examples
>>> optim.Adam(compose_param_groups( ... param_group('*.weight', lr=0.01) ... param_group('*.bias', lr=0.02) ... ), lr=0.1)
- Parameters:
- Returns:
param_groups argument that can be passed to torch optimizers.
- detach_modules(*modules)[source]#
A context manager that temporarily detach all parameters in the input list of modules.
Example
>>> output1 = m2(m1(input1)) >>> with jactorch.detach_modules(m1, m2): # or jactorch.detach_modules([m1, m2]) ... output2 = m2(m1(input2)) >>> loss(output1, output2).backward()
The loss from branch output2 will not back-propagate to m1 and m2.
- Parameters:
*modules – the modules to detach. It can also be a single list of modules.
- filter_parameters(params, pattern, return_names=False)[source]#
Filter parameters with a pattern.
- Parameters:
- Returns:
a list of parameters, or a list of (name, parameter) pairs if return_names is True.
- Return type: