Source code for jactorch.optim.optimizer_group

#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File   : optimizer_group.py
# Author : Jiayuan Mao
# Email  : maojiayuan@gmail.com
# Date   : 02/16/2020
#
# This file is part of Jacinle.
# Distributed under terms of the MIT license.

from .custom_optimizer_base import CustomizedOptimizer

__all__ = ['OptimizerGroup']


[docs] class OptimizerGroup(CustomizedOptimizer): """A group of optimizers. Useful when using multiple optimizers for different parts of the model."""
[docs] def __init__(self, **optimizers): """Initialize the optimizer group. Args: **optimizers: the list of optimizers. """ self.optimizers = optimizers
def __getattr__(self, item): return self.optimizers[item] def __getitem__(self, item): return self.optimizers[item] @property def state(self): return { name: opt.state for name, opt in self.optimizers.items() } @property def param_groups(self): return { name: opt.param_groups for name, opt in self.optimizers.items() }
[docs] def state_dict(self): return { name: opt.state_dict() for name, opt in self.optimizers.items() }
[docs] def load_state_dict(self, state_dict): for name, opt in state_dict.items(): if name in self.optimizers: opt.load_state_dict(opt)
[docs] def zero_grad(self): for opt in self.optimizers: opt.zero_grad()
[docs] def step(self, closure=None): loss = None if closure is not None: loss = closure() for opt in self.optimizers: opt.step() return loss