Source code for jactorch.nn.embedding

#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File   : embedding.py
# Author : Jiayuan Mao
# Email  : maojiayuan@gmail.com
# Date   : 10/03/2018
#
# This file is part of Jacinle.
# Distributed under terms of the MIT license.

import torch
import torch.nn as nn

__all__ = ['LearnedPositionalEmbedding']


[docs] class LearnedPositionalEmbedding(nn.Embedding): """This module learns positional embeddings up to a fixed maximum size. Padding symbols are ignored, but it is necessary to specify whether padding is added on the left side (left_pad=True) or right side (left_pad=False). Adapted from: https://github.com/pytorch/fairseq/blob/master/fairseq/modules/learned_positional_embedding.py. """
[docs] def __init__(self, num_embeddings, embedding_dim, padding_idx=0, left_pad=False): super().__init__(num_embeddings, embedding_dim, padding_idx) self.left_pad = left_pad
[docs] def forward(self, input, incremental_state=None): """Input is expected to be of size [bsz x seqlen].""" if incremental_state is not None: # positions is the same for every token when decoding a single step positions = input.data.new(1, 1).fill_(self.padding_idx + input.size(1)) else: positions = make_positions(input.data, self.padding_idx, self.left_pad) return super().forward(positions)
[docs] def max_positions(self): """Maximum number of supported positions.""" return self.num_embeddings - self.padding_idx - 1
[docs] def make_positions(tensor, padding_idx, left_pad): """Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols are ignored, but it is necessary to specify whether padding is added on the left side (left_pad=True) or right side (left_pad=False). """ max_pos = padding_idx + 1 + tensor.size(1) if not hasattr(make_positions, 'range_buf'): make_positions.range_buf = tensor.new() make_positions.range_buf = make_positions.range_buf.type_as(tensor) if make_positions.range_buf.numel() < max_pos: torch.arange(padding_idx + 1, max_pos, out=make_positions.range_buf) mask = tensor.ne(padding_idx) positions = make_positions.range_buf[:tensor.size(1)].expand_as(tensor) if left_pad: positions = positions - mask.size(1) + mask.long().sum(dim=1).unsqueeze(1) return tensor.clone().masked_scatter_(mask, positions[mask])