Source code for jaclearn.embedding.word_embedding
#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File : word_embedding.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 01/24/2018
#
# This file is part of Jacinle.
# Distributed under terms of the MIT license.
# This file is adapted from from UKP Lab's project.
# https://github.com/UKPLab/emnlp2017-relation-extraction
# Original copyrights:
# Copyright (C) 2016 UKP lab
#
# Author: Daniil Sorokin (ukp.tu-darmstadt.de/ukp-home/)
#
# Embeddings and vocabulary utility methods
import re
import numpy as np
from jacinle.logging import get_logger
from .constant import EBD_ALL_ZEROS, EBD_UNKNOWN
logger = get_logger(__file__)
special_tokens = {"–": "–", "—": "—", "@card@": "0"}
[docs]
def load_word_index(path, filter=None, format='glove'):
"""
Loads only the word index from the embeddings file
@return word to index dictionary
"""
assert format == 'glove'
word2idx = {} # Maps a word to the index in the embeddings matrix
with open(path, 'r') as fIn:
idx = 1
for line in fIn:
split = line.strip().split(' ')
if filter is not None and split[0] not in filter:
continue
word2idx[split[0]] = idx
idx += 1
word2idx[EBD_ALL_ZEROS] = 0
word2idx[EBD_UNKNOWN] = idx
return word2idx
[docs]
def load(path, word_index_only=False, filter=None, format='glove'):
"""
Loads pre-trained embeddings from the specified path.
"""
assert format in ('glove', 'fasttext')
if word_index_only:
return load_word_index(path, filter=filter, format=format)
word2idx = {} # Maps a word to the index in the embeddings matrix
embeddings = []
embedding_size = None
with open(path, 'r', encoding='utf-8') as fIn:
idx = 1
for lineno, line in enumerate(fIn):
if format == 'fasttext':
if lineno == 0:
continue
try:
split = line.strip().split(' ')
if filter is not None and split[0] not in filter:
continue
val = np.array([float(num) for num in split[1:]], dtype='float32')
if embedding_size is None:
embedding_size = len(val)
else:
if embedding_size != len(val):
logger.warning('Skip invalid entry (vector length): Line#{}.'.format(lineno))
continue
embeddings.append(val)
word2idx[split[0]] = idx
idx += 1
except ValueError:
logger.warning('Skip invalid entry (encoding): Line#{}.'.format(lineno))
# 840D GloVe file has some encoding errors...
continue
word2idx[EBD_ALL_ZEROS] = 0
embeddings.insert(0, np.zeros(embedding_size, dtype='float32'))
# rare words
unknown_emb = np.average(np.array(embeddings[-101:]), axis=0)
embeddings.append(unknown_emb)
word2idx[EBD_UNKNOWN] = idx
idx += 1
return np.array(embeddings, dtype='float32'), word2idx
[docs]
def map(word, word2idx):
"""
Get the word index for the given word. Maps all numbers to 0, lowercases if necessary.
:param word: the word in question
:param word2idx: dictionary constructed from an embeddings file
:return: integer index of the word
"""
unknown_idx = word2idx[EBD_UNKNOWN]
word = word.strip()
if word in word2idx:
return word2idx[word]
elif word.lower() in word2idx:
return word2idx[word.lower()]
elif word in special_tokens:
return word2idx[special_tokens[word]]
trimmed = re.sub("(^\W|\W$)", "", word)
if trimmed in word2idx:
return word2idx[trimmed]
elif trimmed.lower() in word2idx:
return word2idx[trimmed.lower()]
no_digits = re.sub("([0-9][0-9.,]*)", '0', word)
if no_digits in word2idx:
return word2idx[no_digits]
return unknown_idx
[docs]
def map_sequence(word_sequence, word2idx):
"""
Get embedding indices for the given word sequence.
:param word_sequence: sequence of words to process
:param word2idx: dictionary of word mapped to their embedding indices
:return: a sequence of embedding indices
"""
return [map(word, word2idx) for word in word_sequence]