jactorch.train.env#
Classes
Functions
|
Return the current time in seconds, with CUDA synchronization. |
|
Default reduce function for the TrainerEnv. |
Class TrainerEnv
- class TrainerEnv[source]#
Bases:
NNEnv
- __init__(model, optimizer)[source]#
Initialize the environment.
- Parameters:
model (Module) – the model to be wrapped.
- __new__(**kwargs)#
- forward(*args, cast_tensor=False, **kwargs)#
Forward the model. Roughly equivalent to
self.model(*args, **kwargs)
.- Parameters:
cast_tensor – whether to cast inputs to tensors.
- load_weights(filename, **kwargs)#
Load weights from a checkpoint file.
- Parameters:
filename (str) – the checkpoint file.
- step(feed_dict, grad_clip=0., reduce_func=default_reduce_func, cast_tensor=False, measure_time=False)[source]#
- update(feed_dict, loss, monitors, output_dict, grad_clip=0., reduce_func=default_reduce_func, measure_time=False, extra=None)[source]#
- property model#
Get the model.
- property model_unwrapped#
Get the model, but unwrap the DataParallel if necessary.
- property optimizer#
Functions