Source code for jactorch.functional.kernel

#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File   : kernel.py
# Author : Jiayuan Mao
# Email  : maojiayuan@gmail.com
# Date   : 02/04/2018
#
# This file is part of Jacinle.
# Distributed under terms of the MIT license.

"""Useful utilities for kernel-based attention mechanism."""

import torch

from .linalg import normalize

__all__ = ['inverse_distance', 'cosine_distance', 'dot']


[docs] def inverse_distance(f_lookup, f, p=2, eps=1e-8): """ Inverse distance kernel. Args: f_lookup (FloatTensor): features of the lookup keys f (FloatTensor): features of the value keys Returns: FloatTensor: the attention mask for each lookup keys. """ n, m, k = f_lookup.size(0), f.size(0), f.size(1) f_lookup = f_lookup.view(n, 1, k).expand(n, m, k) f = f.view(1, m, k).expand(n, m, k) # TODO(Jiayuan Mao @ 05/26): this function can be optimized. dist = (f_lookup - f).norm(p, dim=2) return 1. / dist.clamp(min=eps)
[docs] def cosine_distance(f_lookup, f): """ Cosine distance kernel. Args: f_lookup (FloatTensor): features of the lookup keys f (FloatTensor): features of the value keys Returns: FloatTensor: the attention mask for each lookup keys. """ f_lookup = normalize(f_lookup, 2, dim=1) f = normalize(f, 2, dim=1) return torch.mm(f_lookup, f.t())
[docs] def dot(f_lookup, f): """ Dot product kernel, essentially a cosine distance kernel without normalization. Args: f_lookup (FloatTensor): features of the lookup keys f (FloatTensor): features of the value keys Returns: FloatTensor: the attention mask for each lookup keys. """ return torch.mm(f_lookup, f.t())