jactorch.graph.nn_env#
Classes
A basic environment that wraps around a nn.Module. |
Class NNEnv
- class NNEnv[source]#
Bases:
object
A basic environment that wraps around a nn.Module. This Env supports basic utility functions such as loading a checkpoint.
- __init__(model)[source]#
Initialize the environment.
- Parameters:
model (Module) – the model to be wrapped.
- __new__(**kwargs)#
- forward(*args, cast_tensor=False, **kwargs)[source]#
Forward the model. Roughly equivalent to
self.model(*args, **kwargs)
.- Parameters:
cast_tensor – whether to cast inputs to tensors.
- load_weights(filename, **kwargs)[source]#
Load weights from a checkpoint file.
- Parameters:
filename (str) – the checkpoint file.
- property model#
Get the model.
- property model_unwrapped#
Get the model, but unwrap the DataParallel if necessary.