Set "training=False" of "tf.layers.batch_normalization" when training will get a better validation result

12,067

Solution 1

TL;DR: Use smaller than the default momentum for the normalization layers like this:

tf.layers.batch_normalization( h1, momentum = 0.9, training=flag_training )

TS;WM:

When you set training = False that means the batch normalization layer will use its internally stored average of mean and variance to normalize the batch, not the batch's own mean and variance. When training = False, those internal variables also don't get updated. Since they are initialized to mean = 0 and variance = 1 it means that batch normalization is effectively turned off - the layer subtracts zero and divides the result by 1.

So if you train with training = False and evaluate like that, that just means you're training your network without any batch normalization whatsoever. It will still yield reasonable results, because hey, there was life before batch normalization, albeit admittedly not that glamorous...

If you turn on batch normalization with training = True that will start to normalize the batches within themselves and collect a moving average of the mean and variance of each batch. Now here's the tricky part. The moving average is an exponential moving average, with a default momentum of 0.99 for tf.layers.batch_normalization(). The mean starts at 0, the variance at 1 again. But since each update is applied with a weight of ( 1 - momentum ), it will asymptotically reach the actual mean and variance in infinity. For example in 100 steps it will reach about 73.4% of the real value, because 0.99100 is 0.366. If you have numerically large values, the difference can be enormous.

So if you have a relatively small number of batches you processed, then the internally stored mean and variance can still be significantly off by the time you're running the test. Then your network is trained on properly normalized data and is tested on mis-normalized data.

In order to speed up the convergence of the internal batch normalization values, you can apply a smaller momentum, like 0.9:

tf.layers.batch_normalization( h1, momentum = 0.9, training=flag_training )

(repeat for all batch normalization layers.) Please note that there is a downside to this, however. Random fluctuations in your data will "tug" on your stored mean and variance a lot more with a small momentum like this and the resulting values (later used in inference) can be greatly influenced by where you exactly stop the training, which is clearly not optimal. It is useful to have as large a momentum as possible. Depending on the number of training steps, we generally use 0.9, 0.99, 0.999 for 100, 1,000, 10,000 training steps respectively. No point in going over 0.999.

Another important thing is proper randomization of the training data. If you're training first with let's say the smaller numeric values of your whole data set, then the normalization will converge even slower. Best to completely randomize the order of training data and making sure you use a batch size of at least 14 (rule of thumb.)


Side note: it is known that zero debiasing the values can speed up convergence significantly, and the ExponentialMovingAverage class has this feature. But the batch normalization layers don't have this feature, save for tf.slim's batch_norm, if you're willing to restructure your code for slim.

Solution 2

The reason that you set Training = False improves performance is that Batch normalization has four variables (beta, gamma, mean, variance). It is true that mean and variance don't get updated when Training = False. However, gamma and beta still get updated. So your model has two extra variables and thus has a better performance.

Also, I guess that your model has a relatively good performance without batch normalization.

Share:
12,067
Evan
Author by

Evan

Updated on June 09, 2022

Comments

  • Evan
    Evan almost 2 years

    I use TensorFlow to train DNN. I learned that Batch Normalization is very helpful for DNN , so I used it in DNN.

    I use "tf.layers.batch_normalization" and follow the instructions of the API document to build the network: when training, set its parameter "training=True", and when validate, set "training=False". And add tf.get_collection(tf.GraphKeys.UPDATE_OPS).

    Here is my code:

    # -*- coding: utf-8 -*-
    import tensorflow as tf
    import numpy as np
    
    input_node_num=257*7
    output_node_num=257
    
    tf_X = tf.placeholder(tf.float32,[None,input_node_num])
    tf_Y = tf.placeholder(tf.float32,[None,output_node_num])
    dropout_rate=tf.placeholder(tf.float32)
    flag_training=tf.placeholder(tf.bool)
    hid_node_num=2048
    
    h1=tf.contrib.layers.fully_connected(tf_X, hid_node_num, activation_fn=None)
    h1_2=tf.nn.relu(tf.layers.batch_normalization(h1,training=flag_training))
    h1_3=tf.nn.dropout(h1_2,dropout_rate)
    
    h2=tf.contrib.layers.fully_connected(h1_3, hid_node_num, activation_fn=None)
    h2_2=tf.nn.relu(tf.layers.batch_normalization(h2,training=flag_training))
    h2_3=tf.nn.dropout(h2_2,dropout_rate)
    
    h3=tf.contrib.layers.fully_connected(h2_3, hid_node_num, activation_fn=None)
    h3_2=tf.nn.relu(tf.layers.batch_normalization(h3,training=flag_training))
    h3_3=tf.nn.dropout(h3_2,dropout_rate)
    
    tf_Y_pre=tf.contrib.layers.fully_connected(h3_3, output_node_num, activation_fn=None)
    
    loss=tf.reduce_mean(tf.square(tf_Y-tf_Y_pre))
    
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        train_step = tf.train.AdamOptimizer(1e-4).minimize(loss)
    
    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
    
        for i1 in range(3000*num_batch):
            train_feature=... # Some processing
            train_label=...  # Some processing
            sess.run(train_step,feed_dict={tf_X:train_feature,tf_Y:train_label,flag_training:True,dropout_rate:1}) # when train , set "training=True" , when validate ,set "training=False" , get a bad result . However when train , set "training=False" ,when validate ,set "training=False" , get a better result .
    
            if((i1+1)%277200==0):# print validate loss every 0.1 epoch
                validate_feature=... # Some processing
                validate_label=... # Some processing
    
                validate_loss = sess.run(loss,feed_dict={tf_X:validate_feature,tf_Y:validate_label,flag_training:False,dropout_rate:1})
                print(validate_loss)
    

    Is there any error in my code ? if my code is right , I think I get a strange result:

    when training, I set "training = True", when validate, set "training = False", the result is not good . I print validate loss every 0.1 epoch , the validate loss in 1st to 3st epoch is

     0.929624
     0.992692
     0.814033
     0.858562
     1.042705
     0.665418
     0.753507
     0.700503
     0.508338
     0.761886
     0.787044
     0.817034
     0.726586
     0.901634
     0.633383
     0.783920
     0.528140
     0.847496
     0.804937
     0.828761
     0.802314
     0.855557
     0.702335
     0.764318
     0.776465
     0.719034
     0.678497
     0.596230
     0.739280
     0.970555
    

    However , when I change the code "sess.run(train_step,feed_dict={tf_X:train_feature,tf_Y:train_label,flag_training:True,dropout_rate:1})" , that : set "training=False" when training, set "training=False" when validate . The result is good . The validate loss in 1st epoch is

     0.474313
     0.391002
     0.369357
     0.366732
     0.383477
     0.346027
     0.336518
     0.368153
     0.330749
     0.322070
     0.335551
    

    Why does this result appear ? Is it necessary to set "training=True" when training, set "training=False" when validate ?