Source code for jactorch.graph.nn_env
#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File : context.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 12/31/2021
#
# This file is part of Jacinle.
# Distributed under terms of the MIT license.
import torch.nn as nn
from jactorch.io import load_weights
from jactorch.utils.meta import as_tensor
__all__ = ['NNEnv']
[docs]
class NNEnv(object):
"""A basic environment that wraps around a nn.Module. This Env supports basic utility functions such as loading a checkpoint."""
[docs]
def __init__(self, model: nn.Module):
"""Initialize the environment.
Args:
model: the model to be wrapped.
"""
self._model = model
@property
def model(self):
"""Get the model."""
return self._model
@property
def model_unwrapped(self):
"""Get the model, but unwrap the DataParallel if necessary."""
model = self.model
if isinstance(model, nn.DataParallel):
model = model.module
return model
[docs]
def load_weights(self, filename: str, **kwargs):
"""Load weights from a checkpoint file.
Args:
filename: the checkpoint file.
"""
return load_weights(self._model, filename, **kwargs)
[docs]
def forward(self, *args, cast_tensor=False, **kwargs):
"""Forward the model. Roughly equivalent to ``self.model(*args, **kwargs)``.
Args:
cast_tensor: whether to cast inputs to tensors.
"""
if cast_tensor:
args = as_tensor(args)
kwargs = as_tensor(kwargs)
outputs = self._model(*args, **kwargs)
return outputs