Context manager used in a forward pass. It supports accessing to global variables and monitors
in different modules. See ForwardContext
for more details.
Class ForwardContext
class ForwardContext[source]
Bases: object
A context manager that serves as a global variable for the forward pass. It supports
accessing to global variables in different modules.
with ForwardContext(training=True) as ctx:
ctx.add_accuracy(accuracy, 'acc')
ctx.add_output(output, 'output')
# In a different file, you can access the context by:
ctx = get_forward_context()
loss, monitors, output_dict = ctx.finalize()
__init__(training, *, loss=0, monitors=None, output_dict=None)[source]
Initialize the context.
- Parameters:
training (bool) – whether the forward pass is in training mode.
loss (float) – the initial loss.
monitors (Dict | None) – the initial monitors.
output_dict (Dict | None) – the initial output dictionary.
add_accuracy(accuracy, key)[source]
Add an accuracy to the context.
- Parameters:
- Returns:
- Return type:
add_loss(loss, key=None, accumulate=True)[source]
Add a (sub) loss to the context.
- Parameters:
loss (float | Tensor) – the sub-loss to add.
key (str | None) – the name of the loss. If None, the loss will not be monitored.
accumulate – whether to accumulate the loss in the final loss. This value can either be a boolean or a
float number. If it is a boolean, it indicates whether to accumulate the loss.
If it is a float number, it indicates the weight of the loss.
- Returns:
- Return type:
add_output(output, key)[source]
Add an output to the context.
- Parameters:
- Returns:
- Return type:
Set the context as the default context.
- Return type:
binary_classification_accuracy(label, name='', saturation=True)
Compute the accuracy of binary classification.
- Parameters:
pred (Tensor) – the prediction, of the same shape as label
label (Tensor) – the label, of the same shape as pred
name (str) – the name of this monitor.
saturation (bool) – whether to check the saturation of the prediction. Saturation
is defined as
- Returns:
a dict of monitor values.
- Return type:
Dict[str, float]
classification_accuracy(label, name='')
Compute the accuracy of N-way classification.
- Parameters:
pred (Tensor) – the prediction, of the same shape as label
label (Tensor) – the label, of the same shape as pred
name (str) – the name of this monitor.
- Returns:
a dict of monitor values.
- Return type:
Dict[str, float]
Finalize the context and return the loss, monitors, and output dictionary.
get_hyperparameter(key, default=None)[source]
Get a hyperparameter for the forward pass.
- Parameters:
- Returns:
the value of the hyperparameter.
- Return type:
Monitor the RMS of the gradients of the parameters of the given model.
- Parameters:
model (Module) – the model to monitor.
- Returns:
a dict of monitor values.
- Return type:
Dict[str, float]
Monitor the ratio of the RMS of the gradients of the parameters of the given model.
- Parameters:
model (Module) – the model to monitor.
- Returns:
a dict of monitor values.
- Return type:
Dict[str, float]
Monitor the RMS of the parameters of the given model.
- Parameters:
model (Module) – the model to monitor.
- Returns:
a dict of monitor values.
- Return type:
Dict[str, float]
Monitor the saturation of the parameters of the given model.
- Parameters:
model (Module) – the model to monitor.
- Returns:
a dict of monitor values.
- Return type:
Dict[str, float]
Monitor the RMS of the given values. This function takes either a dict or multiple keyword arguments.
- Parameters:
- Returns:
a dict of monitor values.
- Return type:
Dict[str, float]
regression_accuracy(label, name='')
Compute the accuracy of regression.
- Parameters:
pred (Tensor) – the prediction, of the same shape as label
label (Tensor) – the label, of the same shape as pred
name (str) – the name of this monitor.
- Returns:
a dict of monitor values.
- Return type:
Dict[str, float]
set_hyperparameter(key, value)[source]
Set a hyperparameter for the forward pass.
- Parameters:
update_mo(monitors, output_dict)[source]
Update the monitors and output dictionary in the context.
- Parameters:
Update the monitors in the context.
- Parameters:
monitors (Dict[str, float | Tensor]) – the monitors to update.
loss: float | Tensor
The current loss.
monitors: GView
The current monitors.
output_dict: GView
The current output dictionary.
Get the current forward context.
- Return type: