jactorch.train.env#

Classes

Functions

cuda_time([sync])

Return the current time in seconds, with CUDA synchronization.

default_reduce_func(k, v)

Default reduce function for the TrainerEnv.

Class TrainerEnv

class TrainerEnv[source]#

Bases: NNEnv

__init__(model, optimizer)[source]#

Initialize the environment.

Parameters:

model (Module) – the model to be wrapped.

__new__(**kwargs)#
decay_learning_rate(decay)[source]#
evaluate(feed_dict, cast_tensor=False)[source]#
forward(*args, cast_tensor=False, **kwargs)#

Forward the model. Roughly equivalent to self.model(*args, **kwargs).

Parameters:

cast_tensor – whether to cast inputs to tensors.

load_checkpoint(filename, **kwargs)[source]#
load_weights(filename, **kwargs)#

Load weights from a checkpoint file.

Parameters:

filename (str) – the checkpoint file.

prepare()[source]#
register_event(name, callback)[source]#
save_checkpoint(filename, extra=None)[source]#
set_learning_rate(lr)[source]#
step(feed_dict, grad_clip=0., reduce_func=default_reduce_func, cast_tensor=False, measure_time=False)[source]#
trigger_event(name, *args, **kwargs)[source]#
update(feed_dict, loss, monitors, output_dict, grad_clip=0., reduce_func=default_reduce_func, measure_time=False, extra=None)[source]#
property model#

Get the model.

property model_unwrapped#

Get the model, but unwrap the DataParallel if necessary.

property optimizer#

Functions

cuda_time(sync=True)[source]#

Return the current time in seconds, with CUDA synchronization.

Parameters:

sync (bool) – if True, synchronize the CUDA stream before taking the time.

Returns:

the current time in seconds.

Return type:

float

default_reduce_func(k, v)[source]#

Default reduce function for the TrainerEnv.