Source code for jactorch.nn.embedding
#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File : embedding.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 10/03/2018
#
# This file is part of Jacinle.
# Distributed under terms of the MIT license.
import torch
import torch.nn as nn
__all__ = ['LearnedPositionalEmbedding']
[docs]
class LearnedPositionalEmbedding(nn.Embedding):
"""This module learns positional embeddings up to a fixed maximum size.
Padding symbols are ignored, but it is necessary to specify whether padding
is added on the left side (left_pad=True) or right side (left_pad=False).
Adapted from: https://github.com/pytorch/fairseq/blob/master/fairseq/modules/learned_positional_embedding.py.
"""
[docs]
def __init__(self, num_embeddings, embedding_dim, padding_idx=0, left_pad=False):
super().__init__(num_embeddings, embedding_dim, padding_idx)
self.left_pad = left_pad
[docs]
def forward(self, input, incremental_state=None):
"""Input is expected to be of size [bsz x seqlen]."""
if incremental_state is not None:
# positions is the same for every token when decoding a single step
positions = input.data.new(1, 1).fill_(self.padding_idx + input.size(1))
else:
positions = make_positions(input.data, self.padding_idx, self.left_pad)
return super().forward(positions)
[docs]
def max_positions(self):
"""Maximum number of supported positions."""
return self.num_embeddings - self.padding_idx - 1
[docs]
def make_positions(tensor, padding_idx, left_pad):
"""Replace non-padding symbols with their position numbers.
Position numbers begin at padding_idx+1.
Padding symbols are ignored, but it is necessary to specify whether padding
is added on the left side (left_pad=True) or right side (left_pad=False).
"""
max_pos = padding_idx + 1 + tensor.size(1)
if not hasattr(make_positions, 'range_buf'):
make_positions.range_buf = tensor.new()
make_positions.range_buf = make_positions.range_buf.type_as(tensor)
if make_positions.range_buf.numel() < max_pos:
torch.arange(padding_idx + 1, max_pos, out=make_positions.range_buf)
mask = tensor.ne(padding_idx)
positions = make_positions.range_buf[:tensor.size(1)].expand_as(tensor)
if left_pad:
positions = positions - mask.size(1) + mask.long().sum(dim=1).unsqueeze(1)
return tensor.clone().masked_scatter_(mask, positions[mask])