Source code for jactorch.optim.weight_decay
#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File : weight_decay.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 01/24/2018
#
# This file is part of Jacinle.
# Distributed under terms of the MIT license.
import math
import torch
from torch.optim import Optimizer
__all__ = ['AdamW']
[docs]
class AdamW(Optimizer):
"""Implements AdamW algorithm."""
[docs]
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
"""Initialize AdamW optimizer.
Args:
params: iterable of parameters to optimize or dicts defining parameter groups
lr: learning rate
betas: coefficients used for computing running averages of gradient and its square
eps: term added to the denominator to improve numerical stability
weight_decay: weight decay (L2 penalty)
"""
defaults = dict(lr=lr, betas=betas, eps=eps,
weight_decay=weight_decay)
super().__init__(params, defaults)
[docs]
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
state = self.state[p]
# State initialization
if 'step' not in state:
state['step'] = 0
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p.data)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p.data)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
beta1, beta2 = group['betas']
state['step'] += 1
# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(other=grad, alpha=1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(value=1 - beta2, tensor1=grad, tensor2=grad)
denom = exp_avg_sq.sqrt().add_(group['eps'])
bias_correction1 = 1 - beta1 ** state['step']
bias_correction2 = 1 - beta2 ** state['step']
step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1
p.data.addcdiv_(value=-step_size, tensor1=exp_avg, tensor2=denom)
if group['weight_decay'] != 0:
p.data.mul_(1 - group['lr'] * group['weight_decay'])
return loss