Source code for jactorch.functional.loglinear

#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File   : loglinear.py
# Author : Jiayuan Mao
# Email  : maojiayuan@gmail.com
# Date   : 05/31/2018
#
# This file is part of Jacinle.
# Distributed under terms of the MIT license.

"""Linear algebra functions in the log space."""

from typing import Optional

import math
import torch

from .shape import concat_shape, move_dim

__all__ = ['logaddexp', 'logsumexp', 'logmatmulexp', 'batch_logmatmulexp', 'logits_and', 'logits_or', 'log1mexp']


[docs] def logaddexp(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: """Computes ``log(exp(x) + exp(y))`` in a numerically stable way.""" return torch.max(x, y) + torch.log(1 + torch.exp(-torch.abs(y - x)))
[docs] def logsumexp(tensor: torch.Tensor, dim: Optional[int] = None, keepdim: bool = False) -> torch.Tensor: """Computes ``tensor.exp().sum(dim, keepdim).log()`` in a numerically stable way.""" if dim is None: tensor = tensor.reshape(-1) dim = -1 inputs_max = tensor.max(dim=dim, keepdim=True)[0] tensor = tensor - inputs_max if not keepdim: inputs_max = inputs_max.squeeze(dim) out = _safe_log(tensor.exp().sum(dim=dim, keepdim=keepdim)) + inputs_max return out
[docs] def logmatmulexp(mat1: torch.Tensor, mat2: torch.Tensor, use_mm: bool = False) -> torch.Tensor: """Computes ``torch.matmul(mat1.exp(), mat2.exp()).log()`` in a numerically stable way.""" mat1_shape = mat1.size() mat2_shape = mat2.size() mat1 = mat1.contiguous().view(-1, mat1_shape[-1]) mat2 = move_dim(mat2, 0, -1) mat2 = mat2.contiguous().view(-1, mat2_shape[0]) if use_mm: mat1_max = mat1.max(dim=-1, keepdim=True)[0] mat2_max = mat2.max(dim=-1, keepdim=True)[0] mat1 = mat1 - mat1_max mat2 = mat2 - mat2_max out = _safe_log(torch.matmul(mat1.exp(), mat2.exp().t())) out = out + mat1_max + mat2_max.t() else: out_sum = mat1.unsqueeze(1) + mat2.unsqueeze(0) out = logsumexp(out_sum, dim=-1) return out.view(concat_shape(mat1_shape[:-1], mat2_shape[1:]))
[docs] def batch_logmatmulexp(mat1: torch.Tensor, mat2: torch.Tensor, use_mm: bool = False) -> torch.Tensor: """Computes ``torch.bmm(mat1.exp(), mat2.exp()).log()`` in a numerically stable way. Args: mat1: the first tensor of shape [B, N, M]. mat2: the second tensor of shape [B, M, K]. use_mm: whether to use torch.bmm internally. Returns: the output of shape [B, N, K]. """ mat1_shape = mat1.size() mat2_shape = mat2.size() mat1 = mat1.contiguous().view(mat1_shape[0], -1, mat1_shape[-1]) mat2 = move_dim(mat2, 1, -1) mat2 = mat2.contiguous().view(mat2_shape[0], -1, mat2_shape[1]) if use_mm: mat1_max = mat1.max(dim=-1, keepdim=True)[0] mat2_max = mat2.max(dim=-1, keepdim=True)[0] mat1 = mat1 - mat1_max mat2 = mat2 - mat2_max out = _safe_log(torch.bmm(mat1.exp(), mat2.exp().permute(0, 2, 1))) out = out + mat1_max + mat2_max.permute(0, 2, 1) else: out_sum = mat1.unsqueeze(2) + mat2.unsqueeze(1) out = logsumexp(out_sum, dim=-1) return out.view(concat_shape(mat1_shape[:-1], mat2_shape[2:]))
[docs] def logits_and(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: """Computes ``logit(sigmoid(x) * sigmoid(y))`` in a numerically stable way.""" t = (x + y) / 2 f = logaddexp(logaddexp((x - y) / 2, (y - x) / 2), -t) return t - f
[docs] def logits_or(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: """Computes ``logit(sigmoid(x) + sigmoid(y) - sigmoid(x) * sigmoid(y))`` in a numerically stable way.""" f = -(x + y) / 2 t = logaddexp(logaddexp((x - y) / 2, (y - x) / 2), -f) return t - f
[docs] def log1mexp(x: torch.Tensor) -> torch.Tensor: """Computes ``log(1 - exp(x))`` in a numerically stable way.""" mask = (x < _log05).to(x.dtype) impl1 = torch.log1p(-torch.exp(x)) impl2 = torch.log(-torch.expm1(x)) return impl1 * mask + impl2 * (1 - mask)
def _safe_log(x): # mask = (x < 1e-8).float() # return x.clamp(min=1e-8).log() * (1 - mask) + -1e5 * mask return x.log() _log05 = math.log(0.5)