NaN loss in tensorflow LSTM model

11,204

Solution 1

It may be the case of exploding gradients, where gradients may explode during backpropagation in LSTMs, resulting number overflows. A common technique to deal with exploding gradients is to perform Gradient Clipping.

Solution 2

check your columns which are fed to the model, in my case, there was a column having NaN values, after removing NaNs, it worked

Share:
11,204
Chum-Chum Scarecrows
Author by

Chum-Chum Scarecrows

Updated on June 08, 2022

Comments

  • Chum-Chum Scarecrows
    Chum-Chum Scarecrows almost 2 years

    The following network code, which should be your classic simple LSTM language model, starts outputting nan loss after a while... on my training set it takes a couple of hours and I couldn't replicate it easily on smaller datasets. But it always happens in serious training.

    Sparse_softmax_with_cross_entropy should be numerically stable, so it can't be the cause... but other than that, I don't see any other node that could cause an issue in the graph. What could be the problem?

    class MyLM():
        def __init__(self, batch_size, embedding_size, hidden_size, vocab_size):
            self.x = tf.placeholder(tf.int32, [batch_size, None])  # [batch_size, seq-len]
            self.lengths = tf.placeholder(tf.int32, [batch_size])  # [batch_size]
    
            # remove padding. [batch_size * seq_len] -> [batch_size * sum(lengths)]
            mask = tf.sequence_mask(self.lengths)  # [batch_size, seq_len]
            mask = tf.cast(mask, tf.int32)  # [batch_size, seq_len]
            mask = tf.reshape(mask, [-1])  # [batch_size * seq_len]
    
            # remove padding + last token. [batch_size * seq_len] -> [batch_size * sum(lengths-1)]
            mask_m1 = tf.cast(tf.sequence_mask(self.lengths - 1, maxlen=tf.reduce_max(self.lengths)), tf.int32)  # [batch_size, seq_len]
            mask_m1 = tf.reshape(mask_m1, [-1])  # [batch_size * seq_len]
    
            # remove padding + first token.  [batch_size * seq_len] -> [batch_size * sum(lengths-1)]
            m1_mask = tf.cast(tf.sequence_mask(self.lengths - 1), tf.int32)  # [batch_size, seq_len-1]
            m1_mask = tf.concat([tf.cast(tf.zeros([batch_size, 1]), tf.int32), m1_mask], axis=1)  # [batch_size, seq_len]
            m1_mask = tf.reshape(m1_mask, [-1])  # [batch_size * seq_len]
    
            embedding = tf.get_variable("TokenEmbedding", shape=[vocab_size, embedding_size])
            x_embed = tf.nn.embedding_lookup(embedding, self.x)  # [batch_size, seq_len, embedding_size]
    
            lstm = tf.nn.rnn_cell.LSTMCell(hidden_size, use_peepholes=True)
    
            # outputs shape: [batch_size, seq_len, hidden_size]
            outputs, final_state = tf.nn.dynamic_rnn(lstm, x_embed, dtype=tf.float32,
                                                     sequence_length=self.lengths)
            outputs = tf.reshape(outputs, [-1, hidden_size])  # [batch_size * seq_len, hidden_size]
    
            w = tf.get_variable("w_out", shape=[hidden_size, vocab_size])
            b = tf.get_variable("b_out", shape=[vocab_size])
            logits_padded = tf.matmul(outputs, w) + b  # [batch_size * seq_len, vocab_size]
            self.logits = tf.dynamic_partition(logits_padded, mask_m1, 2)[1]  # [batch_size * sum(lengths-1), vocab_size]
    
            predict = tf.argmax(logits_padded, axis=1)  # [batch_size * seq_len]
            self.predict = tf.dynamic_partition(predict, mask, 2)[1]  # [batch_size * sum(lengths)]
    
            flat_y = tf.dynamic_partition(tf.reshape(self.x, [-1]), m1_mask, 2)[1]  # [batch_size * sum(lengths-1)]
    
            self.cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.logits, labels=flat_y)
            self.cost = tf.reduce_mean(self.cross_entropy)
            self.train_step = tf.train.AdamOptimizer(learning_rate=0.01).minimize(self.cost)
    
  • KeithWM
    KeithWM over 5 years
    Thanks for this answer. I chose to remedy the issue by initializaing the LSTM kernel with a very small value (1.e-10). Will have to see if this doesn't mess things up elsehwhere...