In tensorflow, how to iterate over a sequence of inputs stored in a tensor?

11,894

Solution 1

You can convert a tensor into a list using the unpack function which converts the first dimension into a list. There is also a split function which does something similar. I use unstack in an RNN model I am working on.

y = tf.unstack(tf.transpose(y, (1, 0, 2)))

In this case y starts out with shape (BATCH_SIZE, TIME_STEPS, 128) I transpose it to make the time steps the outer dimension and then unpack it into a list of tensors, one per time step. Now every element in the y list if of shape (BATCH_SIZE, 128) and I can feed it into my RNN.

Solution 2

In TF>=1.0, tf.pack and tf.unpack are renamed to tf.stack and tf.unstack respectively

Share:
11,894
exAres
Author by

exAres

Updated on June 09, 2022

Comments

  • exAres
    exAres almost 2 years

    I am trying RNN on a variable length multivariate sequence classification problem.

    I have defined following function to get the output of the sequence (i.e. the output of RNN cell after the final input from sequence is fed)

    def get_sequence_output(x_sequence, initial_hidden_state):
        previous_hidden_state = initial_hidden_state
        for x_single in x_sequence:
            hidden_state = gru_unit(previous_hidden_state, x_single)
            previous_hidden_state = hidden_state
        final_hidden_state = hidden_state
        return final_hidden_state
    

    Here x_sequence is tensor of shape (?, ?, 10) where first ? is for batch size and second ? is for sequence length and each input element is of length 10. gru function takes a previous hidden state and current input and spits out next hidden state (a standard gated recurrent unit).

    I am getting an error: 'Tensor' object is not iterable. How do I iterate over a Tensor in sequence manner (reading single element at a time)?

    My objective is to apply gru function for every input from the sequence and get the final hidden state.

  • Cospel
    Cospel over 7 years
    This will not work if time_steps is not present(variable length of sequence).