jactorch.graph.context#

Context manager used in a forward pass. It supports accessing to global variables and monitors in different modules. See ForwardContext for more details.

Classes

ForwardContext

A context manager that serves as a global variable for the forward pass.

Functions

get_forward_context()

Get the current forward context.

Class ForwardContext

class ForwardContext[source]#

Bases: object

A context manager that serves as a global variable for the forward pass. It supports accessing to global variables in different modules.

Example

with ForwardContext(training=True) as ctx:
    ctx.add_loss(loss)
    ctx.add_accuracy(accuracy, 'acc')
    ctx.add_output(output, 'output')

    # In a different file, you can access the context by:
    ctx = get_forward_context()
    ctx.add_loss(some_other_loss)

loss, monitors, output_dict = ctx.finalize()
__init__(training, *, loss=0, monitors=None, output_dict=None)[source]#

Initialize the context.

Parameters:
  • training (bool) – whether the forward pass is in training mode.

  • loss (float) – the initial loss.

  • monitors (Dict | None) – the initial monitors.

  • output_dict (Dict | None) – the initial output dictionary.

__new__(**kwargs)#
add_accuracy(accuracy, key)[source]#

Add an accuracy to the context.

Parameters:
  • accuracy (float | Tensor) – the accuracy to add.

  • key (str) – the name of the accuracy.

Returns:

self.

Return type:

ForwardContext

add_loss(loss, key=None, accumulate=True)[source]#

Add a (sub) loss to the context.

Parameters:
  • loss (float | Tensor) – the sub-loss to add.

  • key (str | None) – the name of the loss. If None, the loss will not be monitored.

  • accumulate – whether to accumulate the loss in the final loss. This value can either be a boolean or a float number. If it is a boolean, it indicates whether to accumulate the loss. If it is a float number, it indicates the weight of the loss.

Returns:

self.

Return type:

ForwardContext

add_output(output, key)[source]#

Add an output to the context.

Parameters:
  • output (Any) – the output to add.

  • key (str) – the name of the output.

Returns:

self.

Return type:

ForwardContext

as_default()[source]#

Set the context as the default context.

Return type:

ForwardContext

binary_classification_accuracy(label, name='', saturation=True)#

Compute the accuracy of binary classification.

Parameters:
  • pred (Tensor) – the prediction, of the same shape as label.

  • label (Tensor) – the label, of the same shape as pred.

  • name (str) – the name of this monitor.

  • saturation (bool) – whether to check the saturation of the prediction. Saturation is defined as \(1 - \min(pred, 1 - pred)\)

Returns:

a dict of monitor values.

Return type:

Dict[str, float]

classification_accuracy(label, name='')#

Compute the accuracy of N-way classification.

Parameters:
  • pred (Tensor) – the prediction, of the same shape as label.

  • label (Tensor) – the label, of the same shape as pred.

  • name (str) – the name of this monitor.

Returns:

a dict of monitor values.

Return type:

Dict[str, float]

finalize()[source]#

Finalize the context and return the loss, monitors, and output dictionary.

get_hyperparameter(key, default=None)[source]#

Get a hyperparameter for the forward pass.

Parameters:
  • key (str) – the key of the hyperparameter.

  • default (Any) – the default value of the hyperparameter.

Returns:

the value of the hyperparameter.

Return type:

Any

monitor_param_gradrms()#

Monitor the RMS of the gradients of the parameters of the given model.

Parameters:

model (Module) – the model to monitor.

Returns:

a dict of monitor values.

Return type:

Dict[str, float]

monitor_param_gradrms_ratio()#

Monitor the ratio of the RMS of the gradients of the parameters of the given model.

Parameters:

model (Module) – the model to monitor.

Returns:

a dict of monitor values.

Return type:

Dict[str, float]

monitor_param_rms()#

Monitor the RMS of the parameters of the given model.

Parameters:

model (Module) – the model to monitor.

Returns:

a dict of monitor values.

Return type:

Dict[str, float]

monitor_param_saturation()#

Monitor the saturation of the parameters of the given model.

Parameters:

model (Module) – the model to monitor.

Returns:

a dict of monitor values.

Return type:

Dict[str, float]

monitor_rms(**values)#

Monitor the RMS of the given values. This function takes either a dict or multiple keyword arguments.

Parameters:
  • _dict (Dict[str, Tensor] | None) – a dict of values.

  • **values (Tensor) – multiple keyword arguments.

Returns:

a dict of monitor values.

Return type:

Dict[str, float]

regression_accuracy(label, name='')#

Compute the accuracy of regression.

Parameters:
  • pred (Tensor) – the prediction, of the same shape as label.

  • label (Tensor) – the label, of the same shape as pred.

  • name (str) – the name of this monitor.

Returns:

a dict of monitor values.

Return type:

Dict[str, float]

set_hyperparameter(key, value)[source]#

Set a hyperparameter for the forward pass.

Parameters:
  • key (str) – the key of the hyperparameter.

  • value (Any) – the value of the hyperparameter.

update_mo(monitors, output_dict)[source]#

Update the monitors and output dictionary in the context.

Parameters:
update_monitors(monitors)[source]#

Update the monitors in the context.

Parameters:

monitors (Dict[str, float | Tensor]) – the monitors to update.

loss: float | Tensor#

The current loss.

monitors: GView#

The current monitors.

output_dict: GView#

The current output dictionary.

Functions

get_forward_context()[source]#

Get the current forward context.

Return type:

ForwardContext