#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File : env.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 05/09/2018
#
# This file is part of Jacinle.
# Distributed under terms of the MIT license.
import os.path as osp
import time
import torch
import torch.nn as nn
from jacinle.logging import get_logger
from jacinle.utils.registry import SimpleEventRegistry
from jactorch.graph.nn_env import NNEnv
from jactorch.io import state_dict, load_state_dict
from jactorch.utils.meta import as_tensor, as_float, as_cpu
from .utils import set_learning_rate, decay_learning_rate
logger = get_logger(__file__)
__all__ = ['TrainerEnv']
[docs]
def cuda_time(sync: bool = True) -> float:
"""Return the current time in seconds, with CUDA synchronization.
Args:
sync: if True, synchronize the CUDA stream before taking the time.
Returns:
the current time in seconds.
"""
if sync:
torch.cuda.synchronize()
return time.time()
[docs]
def default_reduce_func(k, v):
"""Default reduce function for the TrainerEnv."""
if torch.is_tensor(v):
return v.mean()
return v
[docs]
class TrainerEnv(NNEnv):
[docs]
def __init__(self, model: nn.Module, optimizer):
super().__init__(model)
self._optimizer = optimizer
self._train_loader = None
self._validation_loader = None
self._event_manager = SimpleEventRegistry({
'epoch:before', 'epoch:after',
'step:before', 'step:after',
'forward:before', 'forward:after',
'backward:before', 'backward:after',
})
self._init_event_triggers()
self.__prepared = False
@property
def optimizer(self):
return self._optimizer
def _init_event_triggers(self):
for key in self._event_manager.allowed_events:
name, be = key.split(':')
callback_name = f'on_{name}_{be}'
if hasattr(self.model_unwrapped, callback_name):
callback = getattr(self.model_unwrapped, callback_name)
self.register_event(key, callback)
[docs]
def register_event(self, name, callback):
logger.info('Register trainer event: name={}, callback={}.'.format(name, callback.__module__ + '.' + callback.__name__))
self._event_manager.register(name, callback)
[docs]
def trigger_event(self, name, *args, **kwargs):
self._event_manager.trigger(name, *args, **kwargs)
[docs]
def save_checkpoint(self, filename, extra=None):
# Hack the data parallel.
model = self._model
state = {
'model': state_dict(model, cpu=True),
'optimizer': as_cpu(self._optimizer.state_dict()),
'extra': extra
}
try:
torch.save(state, filename)
logger.info('Checkpoint saved: "{}".'.format(filename))
except Exception:
logger.exception('Error occurred when dump checkpoint "{}".'.format(filename))
[docs]
def load_checkpoint(self, filename, **kwargs):
if osp.isfile(filename):
model = self._model
if isinstance(model, nn.DataParallel):
model = model.module
try:
checkpoint = torch.load(filename)
load_state_dict(model, checkpoint['model'], **kwargs)
self._optimizer.load_state_dict(checkpoint['optimizer'])
logger.critical('Checkpoint loaded: {}.'.format(filename))
return checkpoint['extra']
except Exception:
logger.exception('Error occurred when load checkpoint "{}".'.format(filename))
else:
logger.warning('No checkpoint found at specified position: "{}".'.format(filename))
return None
[docs]
def set_learning_rate(self, lr):
set_learning_rate(self._optimizer, lr)
[docs]
def decay_learning_rate(self, decay):
decay_learning_rate(self._optimizer, decay)
[docs]
def prepare(self):
self.__prepared = True
assert self._model.training, 'Step a evaluation-mode model.'
self.trigger_event('step:before', self)
self._optimizer.zero_grad()
[docs]
def update(self, feed_dict, loss, monitors, output_dict, grad_clip=0., reduce_func=default_reduce_func, measure_time=False, extra=None):
assert self.__prepared, 'Two consecutive call of TrainerEnv.update()'
self.__prepared = False
if extra is None:
extra = dict()
loss = reduce_func('loss', loss)
monitors = {k: reduce_func(k, v) for k, v in monitors.items()}
loss_f = as_float(loss)
monitors_f = as_float(monitors)
if measure_time:
extra['time/loss'] = cuda_time() - end_time
end_time = cuda_time(False)
self.trigger_event('backward:before', self, feed_dict, loss, monitors, output_dict)
if loss.requires_grad:
loss.backward()
if grad_clip > 0:
from torch.nn.utils.clip_grad import clip_grad_norm_
clip_grad_norm_(self.model.parameters(), grad_clip)
if measure_time:
extra['time/backward'] = cuda_time() - end_time
end_time = cuda_time(False)
self.trigger_event('backward:after', self, feed_dict, loss, monitors, output_dict)
if loss.requires_grad:
self._optimizer.step()
if measure_time:
extra['time/optimize'] = cuda_time() - end_time
end_time = cuda_time(False)
self.trigger_event('step:after', self)
return loss_f, monitors_f, output_dict, extra
[docs]
def step(self, feed_dict, grad_clip=0., reduce_func=default_reduce_func, cast_tensor=False, measure_time=False):
if hasattr(self.model, 'train_step'):
try:
return self.model.train_step(
self.optimizer, feed_dict,
grad_clip=grad_clip, reduce_func=reduce_func, cast_tensor=False
)
except NotImplementedError:
pass
extra = dict()
self.prepare()
if measure_time:
end_time = cuda_time()
if cast_tensor:
feed_dict = as_tensor(feed_dict)
self.trigger_event('forward:before', self, feed_dict)
loss, monitors, output_dict = self._model(feed_dict)
self.trigger_event('forward:after', self, feed_dict, loss, monitors, output_dict)
if measure_time:
extra['time/forward'] = cuda_time() - end_time
end_time = cuda_time(False)
return self.update(feed_dict, loss, monitors, output_dict, grad_clip=grad_clip, reduce_func=reduce_func, measure_time=measure_time, extra=extra)
[docs]
def evaluate(self, feed_dict, cast_tensor=False):
assert not self._model.training, 'Evaluating a training-mode model.'
begin = time.time()
if cast_tensor:
feed_dict = as_tensor(feed_dict)
with torch.no_grad():
output_dict = self._model(feed_dict)
end = time.time()
return output_dict, dict(gpu_time=end - begin)