#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File : parameter.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 03/31/2018
#
# This file is part of Jacinle.
# Distributed under terms of the MIT license.
"""Utilities to access, filter, and mark parameters in a :class:`torch.nn.Module`."""
import contextlib
from typing import Union, Iterable, Sequence, Tuple, List, Dict
import torch.nn as nn
from jacinle.logging import get_logger
from jacinle.utils.matching import NameMatcher
logger = get_logger(__file__)
__all__ = [
'find_parameters', 'filter_parameters', 'exclude_parameters', 'compose_param_groups', 'param_group',
'mark_freezed', 'mark_unfreezed', 'detach_modules'
]
[docs]
def find_parameters(module: nn.Module, pattern: Union[Iterable[str], str], return_names: bool = False) -> Union[List[nn.Parameter], List[Tuple[str, nn.Parameter]]]:
"""Find parameters in a module with a pattern.
Args:
module: the module to search.
pattern: the pattern(s) to match.
return_names: whether to return the names of the parameters.
Returns:
a list of parameters, or a list of (name, parameter) pairs if `return_names` is True.
"""
return filter_parameters(module.named_parameters(), pattern, return_names=return_names)
[docs]
def filter_parameters(params: Iterable[nn.Parameter], pattern: Union[Iterable[str], str], return_names: bool = False) -> Union[List[nn.Parameter], List[Tuple[str, nn.Parameter]]]:
"""Filter parameters with a pattern.
Args:
params: the parameters to filter.
pattern: the pattern(s) to match.
return_names: whether to return the names of the parameters.
Returns:
a list of parameters, or a list of (name, parameter) pairs if `return_names` is True.
"""
if isinstance(pattern, (str, bytes)):
pattern = [pattern]
matcher = NameMatcher({p: True for p in pattern})
with matcher:
if return_names:
return [(name, p) for name, p in params if matcher.match(name)]
else:
return [p for name, p in params if matcher.match(name)]
[docs]
def exclude_parameters(params: Iterable[nn.Parameter], exclude: Sequence[nn.Parameter]) -> List[nn.Parameter]:
"""Exclude parameters from a list of parameters."""
return [p for p in params if p not in exclude]
[docs]
def compose_param_groups(model: nn.Module, *groups: Tuple[str, Dict], filter_grad: bool = True, verbose: bool = True):
"""
Compose the param_groups argument for torch optimizers.
Examples:
>>> optim.Adam(compose_param_groups(
... param_group('*.weight', lr=0.01)
... param_group('*.bias', lr=0.02)
... ), lr=0.1)
Args:
model: the model containing optimizable variables.
*groups: groups defined by patterns, of form ``(pattern, special_params)``.
filter_grad: only choose parameters with ``requires_grad=True``.
verbose: whether to print the parameters in each group.
Returns:
param_groups argument that can be passed to torch optimizers.
"""
matcher = NameMatcher([(g[0], i) for i, g in enumerate(groups)])
param_groups = [{'params': [], 'names': []} for _ in range(len(groups) + 1)]
with matcher:
for name, p in model.named_parameters():
if filter_grad and not p.requires_grad:
continue
res = matcher.match(name)
if res is None:
res = -1
param_groups[res]['names'].append(name)
param_groups[res]['params'].append(p)
for i, g in enumerate(groups):
param_groups[i].update(g[1])
if verbose:
print_info = ['Param groups:']
for group in param_groups:
extra_params = ['{}: {}'.format(key, value) for key, value in group.items() if key not in ('params', 'names')]
extra_params = '; '.join(extra_params)
if extra_params == '':
extra_params = '(default)'
for name in group['names']:
print_info.append(' {name}: {extra}.'.format(name=name, extra=extra_params))
logger.info('\n'.join(print_info))
return param_groups
[docs]
def param_group(pattern: str, **kwargs) -> Tuple[str, Dict]:
"""A helper function used for human-friendly declaration of param groups."""
return pattern, kwargs
[docs]
def mark_freezed(model: nn.Module):
"""Freeze all parameters in a model."""
for p in model.parameters():
p.requires_grad = False
[docs]
def mark_unfreezed(model: nn.Module):
"""Unfreeze all parameters in a model."""
for p in model.parameters():
p.requires_grad = True
[docs]
@contextlib.contextmanager
def detach_modules(*modules):
"""A context manager that temporarily detach all parameters in the input list of modules.
Example:
>>> output1 = m2(m1(input1))
>>> with jactorch.detach_modules(m1, m2): # or jactorch.detach_modules([m1, m2])
... output2 = m2(m1(input2))
>>> loss(output1, output2).backward()
The loss from branch `output2` will not back-propagate to m1 and m2.
Args:
*modules: the modules to detach. It can also be a single list of modules.
"""
if len(modules) == 1 and type(modules[0]) in (list, tuple):
modules = modules[0]
all_modules = nn.ModuleList(modules)
current_values = dict()
for name, p in all_modules.named_parameters():
current_values[name] = p.requires_grad
p.requires_grad = False
yield
for name, p in all_modules.named_parameters():
p.requires_grad = current_values[name]