jactorch.graph.context
Context manager used in a forward pass. It supports accessing to global variables and monitors
in different modules. See ForwardContext
for more details.
Classes
Functions
Class ForwardContext
-
class ForwardContext[source]
Bases: object
A context manager that serves as a global variable for the forward pass. It supports
accessing to global variables in different modules.
Example
with ForwardContext(training=True) as ctx:
ctx.add_loss(loss)
ctx.add_accuracy(accuracy, 'acc')
ctx.add_output(output, 'output')
# In a different file, you can access the context by:
ctx = get_forward_context()
ctx.add_loss(some_other_loss)
loss, monitors, output_dict = ctx.finalize()
-
__init__(training, *, loss=0, monitors=None, output_dict=None)[source]
Initialize the context.
- Parameters:
training (bool) – whether the forward pass is in training mode.
loss (float) – the initial loss.
monitors (Dict | None) – the initial monitors.
output_dict (Dict | None) – the initial output dictionary.
-
__new__(**kwargs)
-
add_accuracy(accuracy, key)[source]
Add an accuracy to the context.
- Parameters:
-
- Returns:
self.
- Return type:
ForwardContext
-
add_loss(loss, key=None, accumulate=True)[source]
Add a (sub) loss to the context.
- Parameters:
loss (float | Tensor) – the sub-loss to add.
key (str | None) – the name of the loss. If None, the loss will not be monitored.
accumulate – whether to accumulate the loss in the final loss. This value can either be a boolean or a
float number. If it is a boolean, it indicates whether to accumulate the loss.
If it is a float number, it indicates the weight of the loss.
- Returns:
self.
- Return type:
ForwardContext
-
add_output(output, key)[source]
Add an output to the context.
- Parameters:
-
- Returns:
self.
- Return type:
ForwardContext
-
as_default()[source]
Set the context as the default context.
- Return type:
ForwardContext
-
binary_classification_accuracy(label, name='', saturation=True)
Compute the accuracy of binary classification.
- Parameters:
pred (Tensor) – the prediction, of the same shape as label
.
label (Tensor) – the label, of the same shape as pred
.
name (str) – the name of this monitor.
saturation (bool) – whether to check the saturation of the prediction. Saturation
is defined as \(1 - \min(pred, 1 - pred)\)
- Returns:
a dict of monitor values.
- Return type:
Dict[str, float]
-
classification_accuracy(label, name='')
Compute the accuracy of N-way classification.
- Parameters:
pred (Tensor) – the prediction, of the same shape as label
.
label (Tensor) – the label, of the same shape as pred
.
name (str) – the name of this monitor.
- Returns:
a dict of monitor values.
- Return type:
Dict[str, float]
-
finalize()[source]
Finalize the context and return the loss, monitors, and output dictionary.
-
get_hyperparameter(key, default=None)[source]
Get a hyperparameter for the forward pass.
- Parameters:
-
- Returns:
the value of the hyperparameter.
- Return type:
Any
-
monitor_param_gradrms()
Monitor the RMS of the gradients of the parameters of the given model.
- Parameters:
model (Module) – the model to monitor.
- Returns:
a dict of monitor values.
- Return type:
Dict[str, float]
-
monitor_param_gradrms_ratio()
Monitor the ratio of the RMS of the gradients of the parameters of the given model.
- Parameters:
model (Module) – the model to monitor.
- Returns:
a dict of monitor values.
- Return type:
Dict[str, float]
-
monitor_param_rms()
Monitor the RMS of the parameters of the given model.
- Parameters:
model (Module) – the model to monitor.
- Returns:
a dict of monitor values.
- Return type:
Dict[str, float]
-
monitor_param_saturation()
Monitor the saturation of the parameters of the given model.
- Parameters:
model (Module) – the model to monitor.
- Returns:
a dict of monitor values.
- Return type:
Dict[str, float]
-
monitor_rms(**values)
Monitor the RMS of the given values. This function takes either a dict or multiple keyword arguments.
- Parameters:
-
- Returns:
a dict of monitor values.
- Return type:
Dict[str, float]
-
regression_accuracy(label, name='')
Compute the accuracy of regression.
- Parameters:
pred (Tensor) – the prediction, of the same shape as label
.
label (Tensor) – the label, of the same shape as pred
.
name (str) – the name of this monitor.
- Returns:
a dict of monitor values.
- Return type:
Dict[str, float]
-
set_hyperparameter(key, value)[source]
Set a hyperparameter for the forward pass.
- Parameters:
-
-
update_mo(monitors, output_dict)[source]
Update the monitors and output dictionary in the context.
- Parameters:
-
-
update_monitors(monitors)[source]
Update the monitors in the context.
- Parameters:
monitors (Dict[str, float | Tensor]) – the monitors to update.
-
loss: float | Tensor
The current loss.
-
monitors: GView
The current monitors.
-
output_dict: GView
The current output dictionary.
Functions
-
get_forward_context()[source]
Get the current forward context.
- Return type:
ForwardContext