Source code for jactorch.optim.weight_decay
#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File :
# Author : Jiayuan Mao
# Email :
# Date : 01/24/2018
# This file is part of Jacinle.
# Distributed under terms of the MIT license.
import math
import torch
from torch.optim import Optimizer
__all__ = ['AdamW']
class AdamW(Optimizer):
"""Implements AdamW algorithm."""
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
"""Initialize AdamW optimizer.
params: iterable of parameters to optimize or dicts defining parameter groups
lr: learning rate
betas: coefficients used for computing running averages of gradient and its square
eps: term added to the denominator to improve numerical stability
weight_decay: weight decay (L2 penalty)
defaults = dict(lr=lr, betas=betas, eps=eps,
super().__init__(params, defaults)
def step(self, closure=None):
"""Performs a single optimization step.
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
grad =
if grad.is_sparse:
raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
state = self.state[p]
# State initialization
if 'step' not in state:
state['step'] = 0
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
beta1, beta2 = group['betas']
state['step'] += 1
# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(other=grad, alpha=1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(value=1 - beta2, tensor1=grad, tensor2=grad)
denom = exp_avg_sq.sqrt().add_(group['eps'])
bias_correction1 = 1 - beta1 ** state['step']
bias_correction2 = 1 - beta2 ** state['step']
step_size = group['lr'] * math.sqrt(bias_correction2) / bias_correction1, tensor1=exp_avg, tensor2=denom)
if group['weight_decay'] != 0: - group['lr'] * group['weight_decay'])
return loss