Source code for jactorch.nn.losses.losses

#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File   : losses.py
# Author : Jiayuan Mao
# Email  : maojiayuan@gmail.com
# Date   : 01/25/2018
#
# This file is part of Jacinle.
# Distributed under terms of the MIT license.

import torch.nn as nn

from jacinle.utils.enum import JacEnum
from jactorch.functional.indexing import one_hot

from . import functional as F
from .functional import weighted_loss

__all__ = [
    'LossAverageMethod', 'AverageLoss',
    'BinaryCrossEntropyLossWithProbs', 'PNBalancedBinaryCrossEntropyLossWithProbs',
    'CrossEntropyLossWithLogits', 'CrossEntropyLoss', 'MSELoss',
    'CrossEntropyLossWithProbs',
    'SmoothL1Loss',
    'CompatibleCrossEntropyLossWithProbs', 'CompatibleMSEProbabilityLoss',
    'CosineLoss',
]


[docs] class LossAverageMethod(JacEnum): NONE = 'none' ALL = 'all' VALID = 'valid'
[docs] class AverageLoss(nn.Module):
[docs] def __init__(self, average='valid'): super().__init__() self.average_method = LossAverageMethod.from_string(average)
def _average(self, loss, mask): if self.average_method is not LossAverageMethod.NONE: if mask is not None: loss = loss * mask if self.average_method is LossAverageMethod.ALL: loss = loss.mean() elif self.average_method is LossAverageMethod.VALID: loss = loss.sum() / mask.sum().clamp(min=0.1) else: raise ValueError('Unknown average method: {}.'.format(self.average_method)) else: loss = loss.mean() return loss
[docs] class BinaryCrossEntropyLossWithProbs(AverageLoss):
[docs] def forward(self, logits, target, mask=None): loss = F.binary_cross_entropy_with_probs(logits, target) return self._average(loss, mask)
[docs] class PNBalancedBinaryCrossEntropyLossWithProbs(nn.Module):
[docs] def forward(self, probs, target, mask=None): return F.pn_balanced_binary_cross_entropy_with_probs(probs, target, mask)
[docs] class CrossEntropyLossWithLogits(AverageLoss):
[docs] def __init__(self, dim=-1, average='valid'): super().__init__(average) self.dim = dim
[docs] def forward(self, logits, target, mask=None): loss = F.cross_entropy_with_logits(logits, target, self.dim) return self._average(loss, mask)
CrossEntropyLoss = CrossEntropyLossWithLogits # Typical PyTorch naming.
[docs] class MSELoss(AverageLoss):
[docs] def __init__(self, average='valid'): super().__init__(average)
[docs] def forward(self, output, target, mask=None): loss = F.l2_loss(output, target) return self._average(loss, mask)
[docs] class CrossEntropyLossWithProbs(AverageLoss): _eps = 1e-8
[docs] def __init__(self, dim=-1, average='valid'): super().__init__(average) self.dim = dim
[docs] def forward(self, probs, target, mask=None): loss = F.cross_entropy_with_probs(probs, target, self.dim, self._eps) return -self._average(loss, mask)
[docs] class SmoothL1Loss(AverageLoss):
[docs] def __init__(self, sigma=3.0, average='valid'): super().__init__(average) self.sigma = sigma
[docs] def forward(self, output, target, sidechain=None): loss = F.smooth_l1(output, target, self.sigma) loss = loss.sum(dim=-1) mask = None if sidechain is not None: mask = (sidechain > 0).float() return self._average(loss, mask)
[docs] class CompatibleCrossEntropyLossWithProbs(CrossEntropyLossWithProbs):
[docs] def __init__(self, dim=-1, weight=None, ignore_index=None): super().__init__(dim, average='none') self.weight = weight self.ignore_index = ignore_index
[docs] def forward(self, probs, target, mask=None): assert mask is None loss = super().forward(probs, target) return weighted_loss(loss, target, self.weight, self.ignore_index)
[docs] class CompatibleMSEProbabilityLoss(nn.Module):
[docs] def __init__(self, weight=None, ignore_index=None): super().__init__() self.weight = weight self.ignore_index = ignore_index
[docs] def forward(self, probs, target): target_onehot = one_hot(target, probs.size(1)) loss = F.l2_loss(probs, target_onehot) return weighted_loss(loss, target, self.weight, self.ignore_index)
[docs] class CosineLoss(AverageLoss):
[docs] def forward(self, output, target, mask=None): loss = F.cosine_loss(output, target) return self._average(loss, mask)