#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File : train.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 02/27/2018
#
# This file is part of Jacinle.
# Distributed under terms of the MIT license.
import time
import torch
from jacinle.utils.meter import GroupMeters
from jactorch.data.dataloader.dataloader import JacDataLoader
from jactorch.optim.quickaccess import get_optimizer
from jactorch.utils.meta import as_numpy, as_float, as_tensor
from jacinle.logging import get_logger
logger = get_logger(__file__)
__all__ = ['simple_fit', 'ModelTrainer']
[docs]
def simple_fit(model, loss_function, dataset, optimizer, epochs, lr=0.01, weight_decay=0, print_interval=1, batch_size=None, **opt_kwargs):
optimizer = get_optimizer(optimizer, model, lr=lr, weight_decay=weight_decay, **opt_kwargs)
if batch_size is None:
dataloader = dataset
else:
dataloader = JacDataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)
iterations = 1
model.train()
for epoch_index in range(1, 1 + epochs):
for data_index, data in enumerate(dataloader):
optimizer.zero_grad()
loss, monitors = loss_function(model, data)
loss.backward()
optimizer.step()
if iterations % print_interval == 0:
logger.info(f'Epoch {epoch_index} Index {data_index} (Iteration {iterations}): loss = {loss.item():.4f}, monitors={monitors}.')
iterations += 1
[docs]
class ModelTrainer(object):
[docs]
def __init__(self, model, optimizer, lr=0.01, weight_decay=0, **opt_kwargs):
optimizer = get_optimizer(optimizer, model, lr=lr, weight_decay=weight_decay, **opt_kwargs)
self._model = model
self._optimizer = optimizer
[docs]
def train_step(self, feed_dict, meters=None):
assert self._model.training
feed_dict = as_tensor(feed_dict)
self._optimizer.zero_grad()
loss, monitors, output_dict = self._model(feed_dict)
loss.backward()
self._optimizer.step()
loss, monitors = map(as_float, [loss, monitors])
if meters is not None:
meters.update(loss=loss)
meters.update(monitors)
return as_float(loss)
[docs]
def train_epoch(self, data_loader, meters=None):
if meters is None:
meters = GroupMeters()
self._model.train()
end = time.time()
for fd in data_loader:
data_time = time.time() - end; end = time.time()
self.train_step(fd, meters=meters)
step_time = time.time() - end; end = time.time()
meters.update({'time/data': data_time, 'time/step': step_time})
return meters
[docs]
def train(self, data_loader, nr_epochs, verbose=True, meters=None, early_stop=None, print_interval=1):
if meters is None:
meters = GroupMeters()
for epoch in range(1, 1 + nr_epochs):
meters.reset()
self.train_epoch(data_loader, meters=meters)
if verbose and epoch % print_interval == 0:
caption = 'Epoch: {}:'.format(epoch)
logger.info(meters.format_simple(caption))
if early_stop is not None:
flag = early_stop(self._model)
if flag:
break
[docs]
def validate_step(self, feed_dict, metric, meters=None):
feed_dict_np = as_numpy(feed_dict)
feed_dict = as_tensor(feed_dict)
with torch.no_grad():
output_dict = self._model(feed_dict)
output_dict_np = as_numpy(output_dict)
result = as_float(metric(feed_dict_np, output_dict_np))
if meters is not None:
meters.update(result)
return result
[docs]
def validate(self, data_loader, metric, meters=None):
if meters is None:
meters = GroupMeters()
self._model.eval()
end = time.time()
for fd in data_loader:
data_time = time.time() - end; end = time.time()
self.validate_step(fd, metric, meters=meters)
step_time = time.time() - end; end = time.time()
meters.update({'time/data': data_time, 'time/step': step_time})
return meters.avg