Source code for jactorch.optim.weight_decay

#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File   : weight_decay.py
# Author : Jiayuan Mao
# Email  : maojiayuan@gmail.com
# Date   : 01/24/2018
#
# This file is part of Jacinle.
# Distributed under terms of the MIT license.

import math
import torch
from torch.optim import Optimizer

__all__ = ['AdamW']


[docs] class AdamW(Optimizer): """Implements AdamW algorithm."""
[docs] def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0): """Initialize AdamW optimizer. Args: params: iterable of parameters to optimize or dicts defining parameter groups lr: learning rate betas: coefficients used for computing running averages of gradient and its square eps: term added to the denominator to improve numerical stability weight_decay: weight decay (L2 penalty) """ defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay) super().__init__(params, defaults)
[docs] def step(self, closure=None): """Performs a single optimization step. Arguments: closure (callable, optional): A closure that reevaluates the model and returns the loss. """ loss = None if closure is not None: loss = closure() for group in self.param_groups: for p in group['params']: if p.grad is None: continue grad = p.grad.data if grad.is_sparse: raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') state = self.state[p] # State initialization if 'step' not in state: state['step'] = 0 # Exponential moving average of gradient values state['exp_avg'] = torch.zeros_like(p.data) # Exponential moving average of squared gradient values state['exp_avg_sq'] = torch.zeros_like(p.data) exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] beta1, beta2 = group['betas'] state['step'] += 1 # Decay the first and second moment running average coefficient exp_avg.mul_(beta1).add_(other=grad, alpha=1 - beta1) exp_avg_sq.mul_(beta2).addcmul_(value=1 - beta2, tensor1=grad, tensor2=grad) denom = exp_avg_sq.sqrt().add_(group['eps']) bias_correction1 = 1 - beta1 ** state['step'] bias_correction2 = 1 - beta2 ** state['step'] step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1 p.data.addcdiv_(value=-step_size, tensor1=exp_avg, tensor2=denom) if group['weight_decay'] != 0: p.data.mul_(1 - group['lr'] * group['weight_decay']) return loss