Source code for jactorch.optim.accum_grad

#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File   : accum_grad.py
# Author : Jiayuan Mao
# Email  : maojiayuan@gmail.com
# Date   : 01/24/2018
#
# This file is part of Jacinle.
# Distributed under terms of the MIT license.

__all__ = ['AccumGrad']

from .custom_optimizer_base import CustomizedOptimizer


[docs] class AccumGrad(CustomizedOptimizer): """A wrapper for optimizer that accumulates gradients for several steps. Basically, this wrapper will accumulate gradients for several steps, and then call the base optimizer's step method. """
[docs] def __init__(self, base_optimizer, nr_acc): """Initialize the wrapper. Args: base_optimizer: the base optimizer. nr_acc: the number of steps to accumulate gradients. """ self._base_optimizer = base_optimizer self._nr_acc = nr_acc self._current = 0
@property def state(self): return self._base_optimizer.state @property def param_groups(self): return self._base_optimizer.param_groups
[docs] def state_dict(self): # TODO(Jiayuan Mao @ 05/08): use a separate method to store all grad_buffer. return { 'base_optimizer': self._base_optimizer.state_dict(), 'current': self._current }
[docs] def load_state_dict(self, state_dict): self._current = state_dict['current'] return self._base_optimizer.load_state_dict(state_dict['base_optimizer'])
[docs] def zero_grad(self): return self._base_optimizer.zero_grad()
[docs] def step(self, closure=None): loss = None if closure is not None: loss = closure() self._current += 1 for group in self._base_optimizer.param_groups: for p in group['params']: if p.grad is None: continue d_p = p.grad.data param_state = self._base_optimizer.state[p] # NB(Jiayuan Mao @ 02/16): we guarantee that grad_buffer does not require grad. if 'grad_buffer' not in param_state: buf = param_state['grad_buffer'] = d_p.clone() else: buf = param_state['grad_buffer'] buf.add_(d_p) if self._current >= self._nr_acc: buf.mul_(1. / self._current) p.grad.data.copy_(buf) buf.zero_() if self._current >= self._nr_acc: self._base_optimizer.step() self._current = 0 return loss