How to correctly implement a batch-input LSTM network in PyTorch?

10,977

Solution 1

Question 1 - Last Timestep

This is the code that i use to get the output of the last timestep. I don't know if there is a simpler solution. If it is, i'd like to know it. I followed this discussion and grabbed the relative code snippet for my last_timestep method. This is my forward.

class BaselineRNN(nn.Module):
    def __init__(self, **kwargs):
        ...

    def last_timestep(self, unpacked, lengths):
        # Index of the last output for each sequence.
        idx = (lengths - 1).view(-1, 1).expand(unpacked.size(0),
                                               unpacked.size(2)).unsqueeze(1)
        return unpacked.gather(1, idx).squeeze()

    def forward(self, x, lengths):
        embs = self.embedding(x)

        # pack the batch
        packed = pack_padded_sequence(embs, list(lengths.data),
                                      batch_first=True)

        out_packed, (h, c) = self.rnn(packed)

        out_unpacked, _ = pad_packed_sequence(out_packed, batch_first=True)

        # get the outputs from the last *non-masked* timestep for each sentence
        last_outputs = self.last_timestep(out_unpacked, lengths)

        # project to the classes using a linear layer
        logits = self.linear(last_outputs)

        return logits

Question 2 - Masked Cross Entropy Loss

Yes, by default the zero padded timesteps (targets) matter. However, it is very easy to mask them. You have two options, depending on the version of PyTorch that you use.

  1. PyTorch 0.2.0: Now pytorch supports masking directly in the CrossEntropyLoss, with the ignore_index argument. For example, in language modeling or seq2seq, where i add zero padding, i mask the zero padded words (target) simply like this:

    loss_function = nn.CrossEntropyLoss(ignore_index=0)

  2. PyTorch 0.1.12 and older: In the older versions of PyTorch, masking was not supported, so you had to implement your own workaround. I solution that i used, was masked_cross_entropy.py, by jihunchoi. You may be also interested in this discussion.

Solution 2

A few days ago, I found this method which uses indexing to accomplish the same task with a one-liner.

I have my dataset batch first ([batch size, sequence length, features]), so for me:

unpacked_out = unpacked_out[np.arange(unpacked_out.shape[0]), lengths - 1, :]

where unpacked_out is the output of torch.nn.utils.rnn.pad_packed_sequence.

I have compared it with the method described here, which looks similar to the last_timestep() method Christos Baziotis is using above (also recommended here), and the results are the same in my case.

Share:
10,977

Related videos on Youtube

Saddle Point
Author by

Saddle Point

Updated on October 09, 2022

Comments

  • Saddle Point
    Saddle Point over 1 year

    This release of PyTorch seems provide the PackedSequence for variable lengths of input for recurrent neural network. However, I found it's a bit hard to use it correctly.

    Using pad_packed_sequence to recover an output of a RNN layer which were fed by pack_padded_sequence, we got a T x B x N tensor outputs where T is the max time steps, B is the batch size and N is the hidden size. I found that for short sequences in the batch, the subsequent output will be all zeros.

    Here are my questions.

    1. For a single output task where the one would need the last output of all the sequences, simple outputs[-1] will give a wrong result since this tensor contains lots of zeros for short sequences. One will need to construct indices by sequence lengths to fetch the individual last output for all the sequences. Is there more simple way to do that?
    2. For a multiple output task (e.g. seq2seq), usually one will add a linear layer N x O and reshape the batch outputs T x B x O into TB x O and compute the cross entropy loss with the true targets TB (usually integers in language model). In this situation, do these zeros in batch output matters?
  • chenfei
    chenfei over 6 years
    I am trying your solution, and I got the error: File "/root/PycharmProjects/skip-thoughts.torch/pytorch/tmpRNN.py‌​", line 13, in last_timestep return unpacked.gather(1, idx).squeeze() File "/usr/local/lib/python3.5/dist-packages/torch/autograd/varia‌​ble.py", line 684, in gather return Gather.apply(self, dim, index) RuntimeError: save_for_backward can only save input or output tensors, but argument 0 doesn't satisfy this condition