Why do we need to call zero_grad() in PyTorch?

167,019

Solution 1

In PyTorch, for every mini-batch during the training phase, we typically want to explicitly set the gradients to zero before starting to do backpropragation (i.e., updating the Weights and biases) because PyTorch accumulates the gradients on subsequent backward passes. This accumulating behaviour is convenient while training RNNs or when we want to compute the gradient of the loss summed over multiple mini-batches. So, the default action has been set to accumulate (i.e. sum) the gradients on every loss.backward() call.

Because of this, when you start your training loop, ideally you should zero out the gradients so that you do the parameter update correctly. Otherwise, the gradient would be a combination of the old gradient, which you have already used to update your model parameters, and the newly-computed gradient. It would therefore point in some other direction than the intended direction towards the minimum (or maximum, in case of maximization objectives).

Here is a simple example:

import torch
from torch.autograd import Variable
import torch.optim as optim

def linear_model(x, W, b):
    return torch.matmul(x, W) + b

data, targets = ...

W = Variable(torch.randn(4, 3), requires_grad=True)
b = Variable(torch.randn(3), requires_grad=True)

optimizer = optim.Adam([W, b])

for sample, target in zip(data, targets):
    # clear out the gradients of all Variables 
    # in this optimizer (i.e. W, b)
    optimizer.zero_grad()
    output = linear_model(sample, W, b)
    loss = (output - target) ** 2
    loss.backward()
    optimizer.step()

Alternatively, if you're doing a vanilla gradient descent, then:

W = Variable(torch.randn(4, 3), requires_grad=True)
b = Variable(torch.randn(3), requires_grad=True)

for sample, target in zip(data, targets):
    # clear out the gradients of Variables 
    # (i.e. W, b)
    W.grad.data.zero_()
    b.grad.data.zero_()

    output = linear_model(sample, W, b)
    loss = (output - target) ** 2
    loss.backward()

    W -= learning_rate * W.grad.data
    b -= learning_rate * b.grad.data

Note:

  • The accumulation (i.e., sum) of gradients happens when .backward() is called on the loss tensor.
  • As of v1.7.0, Pytorch offers the option to reset the gradients to None optimizer.zero_grad(set_to_none=True) instead of filling them with a tensor of zeroes. The docs claim that this setting reduces memory requirements and slightly improves performance, but might be error-prone if not handled carefully.

Solution 2

Although the idea can be derived from the chosen answer, but I feel like I want to write that explicitly.

Being able to decide when to call optimizer.zero_grad() and optimizer.step() provides more freedom on how gradient is accumulated and applied by the optimizer in the training loop. This is crucial when the model or input data is big and one actual training batch do not fit in to the gpu card.

Here in this example from google-research, there are two arguments, named train_batch_size and gradient_accumulation_steps.

  • train_batch_size is the batch size for the forward pass, following the loss.backward(). This is limited by the gpu memory.

  • gradient_accumulation_steps is the actual training batch size, where loss from multiple forward pass is accumulated. This is NOT limited by the gpu memory.

From this example, you can see how optimizer.zero_grad() may followed by optimizer.step() but NOT loss.backward(). loss.backward() is invoked in every single iteration (line 216) but optimizer.zero_grad() and optimizer.step() is only invoked when the number of accumulated train batch equals the gradient_accumulation_steps (line 227 inside the if block in line 219)

https://github.com/google-research/xtreme/blob/master/third_party/run_classify.py

Also someone is asking about equivalent method in TensorFlow. I guess tf.GradientTape serve the same purpose.

(I am still new to AI library, please correct me if anything I said is wrong)

Solution 3

zero_grad() restarts looping without losses from the last step if you use the gradient method for decreasing the error (or losses).

If you do not use zero_grad() the loss will increase not decrease as required.

For example:

If you use zero_grad() you will get the following output:

model training loss is 1.5
model training loss is 1.4
model training loss is 1.3
model training loss is 1.2

If you do not use zero_grad() you will get the following output:

model training loss is 1.4
model training loss is 1.9
model training loss is 2
model training loss is 2.8
model training loss is 3.5
Share:
167,019

Related videos on Youtube

user1424739
Author by

user1424739

Updated on July 08, 2022

Comments

  • user1424739
    user1424739 almost 2 years

    Why does zero_grad() need to be called during training?

    |  zero_grad(self)
    |      Sets gradients of all model parameters to zero.
    
  • layser
    layser over 4 years
    thank you very much, this is really helpful! Do you happen to know whether the tensorflow has the behaviour?
  • zwep
    zwep over 4 years
    Just to be sure.. if you don't do this, then you will run into an exploding gradient problem, right?
  • Tom Roth
    Tom Roth about 4 years
    @zwep If we accumulate gradients, it doesn't mean their magnitude increases: an example would be if the sign of the gradient keeps flipping. So it wouldn't guarantee you'd run into the exploding gradient problem. Besides, exploding gradients exist even if you zero correctly.
  • MUAS
    MUAS almost 4 years
    When you run the vanilla gradient descent do you not get a "leaf Variable that requires grad has been used in an in-place operation" error when you try to update the weights?
  • dedObed
    dedObed about 3 years
    This is confusing to say the least. What looping gets restarted? Loss increase/decrease is affected indirectly, it can increase when you do .zero_grad() and it can decrease when you don't. Where are the outputs you're showing coming from?
  • Youssri Abo Elseod
    Youssri Abo Elseod about 3 years
    dear dedObed (this example for if you remove zero_grad from your correctly code), we talk about .zero_grad() function , this function only is start looping without the last result ، if loss is increasing is increasing you should be review your input ( write your problem in new topic and git me the link.
  • dedObed
    dedObed about 3 years
    I (think I) do understand PyTorch well enough. I'm just pointing out what I perceive as flaws in you answer -- it's not clear, drawing quick conclusions, showing outputs who-knows-of-what.
  • StanGeo
    StanGeo about 3 years
    In other words , its done to set the variable delta_w and delta_b back to zero.
  • Loqz
    Loqz almost 3 years
    A follow-up question on this: so you're saying we shouldn't call optimizer.zero_grad() when training RNN models such as LSTM, for example?
  • mrgloom
    mrgloom almost 3 years
    Why optimizer.zero_grad() is before output = linear_model(sample, W, b) ?
  • C-3PO
    C-3PO almost 3 years
    Thanks for the explanation. I have an additional question: What happens if I have two networks net_A, and net_B which are interconnected? If I set net_B.parameters()[i].requires_grad = False, and then compute the gradient w.r.t net_A, Would the gradients of net_A.parameters() be affected by nonsense values stored in net_B.parameters()?
  • Alaa M.
    Alaa M. almost 3 years
    Can someone answer @Loqz's question? I'm wondering about that too. Do you need to call zero_grad() when training an RNN?
  • yotabyte
    yotabyte over 2 years
    Thanks for this example. It helped me.
  • Under-qualified NASA Intern
    Under-qualified NASA Intern over 2 years
    This relates to training large models with limited GPU memory. Your ideas are expanded on in this nice post: towardsdatascience.com/…
  • Mona Jalal
    Mona Jalal over 2 years
    do you do this for validation as well or only train? I also have loss.backward in validation step. I understand we shouldn't do this in test phase since it is only one epoch.