jactorch.graph.nn_env#

Classes

NNEnv

A basic environment that wraps around a nn.Module.

Class NNEnv

class NNEnv[source]#

Bases: object

A basic environment that wraps around a nn.Module. This Env supports basic utility functions such as loading a checkpoint.

__init__(model)[source]#

Initialize the environment.

Parameters:

model (Module) – the model to be wrapped.

__new__(**kwargs)#
forward(*args, cast_tensor=False, **kwargs)[source]#

Forward the model. Roughly equivalent to self.model(*args, **kwargs).

Parameters:

cast_tensor – whether to cast inputs to tensors.

load_weights(filename, **kwargs)[source]#

Load weights from a checkpoint file.

Parameters:

filename (str) – the checkpoint file.

property model#

Get the model.

property model_unwrapped#

Get the model, but unwrap the DataParallel if necessary.