Source code for jactorch.nn.rnn_utils
#! /usr/bin/env python3
# -*- coding: utf-8 -*-
# File : rnn_utils.py
# Author : Jiayuan Mao
# Email : maojiayuan@gmail.com
# Date : 01/25/2018
#
# This file is part of Jacinle.
# Distributed under terms of the MIT license.
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
from jactorch.functional.indexing import inverse_permutation
__all__ = ['rnn_with_length']
[docs]
def rnn_with_length(rnn, seq_tensor, seq_lengths, initial_states, batch_first=True, sorted=False):
perm_idx = None
if not sorted:
seq_lengths, perm_idx = seq_lengths.sort(0, descending=True)
seq_tensor = seq_tensor[perm_idx]
if type(initial_states) is tuple:
if initial_states[0].dim() == 3: # including layers
initial_states = tuple(map(lambda x: x[:, perm_idx], initial_states))
else:
assert initial_states[0].dim() == 2
initial_states = tuple(map(lambda x: x[perm_idx], initial_states))
else:
if initial_states.dim() == 3:
initial_states = initial_states[:, perm_idx]
else:
assert initial_states.dim() == 2
initial_states = initial_states[perm_idx]
packed_input = pack_padded_sequence(seq_tensor, seq_lengths.cpu().numpy(), batch_first=batch_first)
packed_output, last_output = rnn(packed_input, initial_states)
output, _ = pad_packed_sequence(packed_output, batch_first=batch_first)
if not sorted:
perm_inv = inverse_permutation(perm_idx)
output = output[perm_inv]
sequence_dim = 1 if batch_first else 0
if type(last_output) is tuple:
last_output = tuple(map(lambda x: x.index_select(sequence_dim, perm_inv), last_output))
else:
last_output = last_output.index_select(sequence_dim, perm_inv)
return output, last_output