Source code for jactorch.train.utils

#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File   : utils.py
# Author : Jiayuan Mao
# Email  : maojiayuan@gmail.com
# Date   : 03/03/2018
#
# This file is part of Jacinle.
# Distributed under terms of the MIT license.

from jacinle.utils.deprecated import deprecated
from jactorch.graph.parameter import mark_freezed as gmark_freezed, mark_unfreezed as gmark_unfreezed

__all__ = ['mark_freezed', 'mark_unfreezed', 'set_learning_rate', 'decay_learning_rate']


[docs] @deprecated def mark_freezed(model): """DEPRECATED(Jiayuan Mao): jactorch.train.utils.mark_freezed has been deprecated and will be removed by 07/16/2022; please use jactorch.graph.parameter.mark_freezed instead.""" return gmark_freezed(model)
[docs] @deprecated def mark_unfreezed(model): """DEPRECATED(Jiayuan Mao): jactorch.train.utils.mark_unfreezed has been deprecated and will be removed by 07/16/2022; please use jactorch.graph.parameter.mark_unfreezed instead.""" return gmark_unfreezed(model)
[docs] def set_learning_rate(optimizer, lr: float): """Set the learning rate of the optimizer.""" for param_group in optimizer.param_groups: param_group['lr'] = lr
[docs] def decay_learning_rate(optimizer, decay: float): """Decay the learning rate of the optimizer by a factor of decay.""" for param_group in optimizer.param_groups: param_group['lr'] *= decay