Source code for jactorch.optim.quickaccess

#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File   : quickaccess.py
# Author : Jiayuan Mao
# Email  : maojiayuan@gmail.com
# Date   : 03/28/2018
#
# This file is part of Jacinle.
# Distributed under terms of the MIT license.

import torch.optim as optim
import jactorch.optim as jacoptim


[docs] def get_optimizer(optimizer, model, *args, **kwargs): if isinstance(optimizer, (optim.Optimizer, jacoptim.CustomizedOptimizer)): return optimizer if type(optimizer) is str: try: optimizer = getattr(optim, optimizer) except AttributeError: try: optimizer = getattr(jacoptim, optimizer) except AttributeError: raise ValueError('Unknown optimizer type: {}.'.format(optimizer)) return optimizer(filter(lambda p: p.requires_grad, model.parameters()), *args, **kwargs)