Source code for jactorch.nn.word_embedding
#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File : word_embedding.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 06/15/2018
#
# This file is part of Jacinle.
# Distributed under terms of the MIT license.
import torch
import torch.nn as nn
import torch.nn.functional as F
__all__ = ['WordEmbedding']
[docs]
class WordEmbedding(nn.Module):
[docs]
def __init__(self, word_embeddings, nr_extra_words, fake=False):
super().__init__()
self.nr_words = word_embeddings.shape[0]
self.nr_extra_words = nr_extra_words
self.nr_tot_words = self.nr_words + self.nr_extra_words
self.embedding_dim = word_embeddings.shape[1]
self.fake = fake
self.impl = nn.Embedding(self.nr_tot_words, self.embedding_dim, padding_idx=0)
if not fake:
self.word_embeddings = nn.Parameter(torch.tensor(word_embeddings))
self.word_embeddings.requires_grad = False
self.extra_word_embeddings = nn.Parameter(torch.zeros(nr_extra_words, self.embedding_dim,
dtype=self.word_embeddings.dtype, device=self.word_embeddings.device))
self.extra_word_embeddings.requires_grad = True
self.impl.weight = nn.Parameter(torch.cat((self.word_embeddings, self.extra_word_embeddings), dim=0))
self.reset_parameters()
[docs]
def reset_parameters(self):
if not self.fake:
self.extra_word_embeddings.data.normal_(
self.word_embeddings.data.mean(),
self.word_embeddings.data.std()
)
@property
def weight(self):
return self.impl.weight
[docs]
def forward(self, words):
return self.impl(words)