TensorFlow: Remember LSTM state for next batch (stateful LSTM)

16,207

Solution 1

I found out it was easiest to save the whole state for all layers in a placeholder.

init_state = np.zeros((num_layers, 2, batch_size, state_size))

...

state_placeholder = tf.placeholder(tf.float32, [num_layers, 2, batch_size, state_size])

Then unpack it and create a tuple of LSTMStateTuples before using the native tensorflow RNN Api.

l = tf.unpack(state_placeholder, axis=0)
rnn_tuple_state = tuple(
[tf.nn.rnn_cell.LSTMStateTuple(l[idx][0], l[idx][1])
 for idx in range(num_layers)]
)

RNN passes in the API:

cell = tf.nn.rnn_cell.LSTMCell(state_size, state_is_tuple=True)
cell = tf.nn.rnn_cell.MultiRNNCell([cell]*num_layers, state_is_tuple=True)
outputs, state = tf.nn.dynamic_rnn(cell, x_input_batch, initial_state=rnn_tuple_state)

The state - variable will then be feeded to the next batch as a placeholder.

Solution 2

Tensorflow, best way to save state in RNNs? was actually my original question. The code bellow is how I use the state tuples.

with tf.variable_scope('decoder') as scope:
    rnn_cell = tf.nn.rnn_cell.MultiRNNCell \
    ([
        tf.nn.rnn_cell.LSTMCell(512, num_proj = 256, state_is_tuple = True),
        tf.nn.rnn_cell.LSTMCell(512, num_proj = WORD_VEC_SIZE, state_is_tuple = True)
    ], state_is_tuple = True)

    state = [[tf.zeros((BATCH_SIZE, sz)) for sz in sz_outer] for sz_outer in rnn_cell.state_size]

    for t in range(TIME_STEPS):
        if t:
            last = y_[t - 1] if TRAINING else y[t - 1]
        else:
            last = tf.zeros((BATCH_SIZE, WORD_VEC_SIZE))

        y[t] = tf.concat(1, (y[t], last))
        y[t], state = rnn_cell(y[t], state)

        scope.reuse_variables()

Rather than using tf.nn.rnn_cell.LSTMStateTuple I just create a lists of lists which works fine. In this example I am not saving the state. However you could easily have made state out of variables and just used assign to save the values.

Share:
16,207
verified.human
Author by

verified.human

Ph.D. Researcher in Computer Vision and Machine Learning.

Updated on June 06, 2022

Comments

  • verified.human
    verified.human almost 2 years

    Given a trained LSTM model I want to perform inference for single timesteps, i.e. seq_length = 1 in the example below. After each timestep the internal LSTM (memory and hidden) states need to be remembered for the next 'batch'. For the very beginning of the inference the internal LSTM states init_c, init_h are computed given the input. These are then stored in a LSTMStateTuple object which is passed to the LSTM. During training this state is updated every timestep. However for inference I want the state to be saved in between batches, i.e. the initial states only need to be computed at the very beginning and after that the LSTM states should be saved after each 'batch' (n=1).

    I found this related StackOverflow question: Tensorflow, best way to save state in RNNs?. However this only works if state_is_tuple=False, but this behavior is soon to be deprecated by TensorFlow (see rnn_cell.py). Keras seems to have a nice wrapper to make stateful LSTMs possible but I don't know the best way to achieve this in TensorFlow. This issue on the TensorFlow GitHub is also related to my question: https://github.com/tensorflow/tensorflow/issues/2838

    Anyone good suggestions for building a stateful LSTM model?

    inputs  = tf.placeholder(tf.float32, shape=[None, seq_length, 84, 84], name="inputs")
    targets = tf.placeholder(tf.float32, shape=[None, seq_length], name="targets")
    
    num_lstm_layers = 2
    
    with tf.variable_scope("LSTM") as scope:
    
        lstm_cell  = tf.nn.rnn_cell.LSTMCell(512, initializer=initializer, state_is_tuple=True)
        self.lstm  = tf.nn.rnn_cell.MultiRNNCell([lstm_cell] * num_lstm_layers, state_is_tuple=True)
    
        init_c = # compute initial LSTM memory state using contents in placeholder 'inputs'
        init_h = # compute initial LSTM hidden state using contents in placeholder 'inputs'
        self.state = [tf.nn.rnn_cell.LSTMStateTuple(init_c, init_h)] * num_lstm_layers
    
        outputs = []
    
        for step in range(seq_length):
    
            if step != 0:
                scope.reuse_variables()
    
            # CNN features, as input for LSTM
            x_t = # ... 
    
            # LSTM step through time
            output, self.state = self.lstm(x_t, self.state)
            outputs.append(output)