How to use model.reset_states() in Keras?
Solution 1
reset_states
clears only the hidden states of your network. It's worth to mention that depending on if the option stateful=True
was set in your network - the behaviour of this function might be different. If it's not set - all states are automatically reset after every batch computations in your network (so e.g. after calling fit
, predict
and evaluate
also). If not - you should call reset_states
every time, when you want to make consecutive model calls independent.
Solution 2
If you use explicitly either of:
model.reset_states()
to reset the states of all layers in the model, or
layer.reset_states()
to reset the states of a specific stateful RNN layer (also LSTM layer), implemented here:
def reset_states(self, states=None):
if not self.stateful:
raise AttributeError('Layer must be stateful.')
this means your layer(s) must be stateful.
In LSTM you need to:
explicitly specify the batch size you are using, by passing a
batch_size
argument to the first layer in your model orbatch_input_shape
argumentset
stateful=True
.specify
shuffle=False
when callingfit()
.
The benefits of using stateful models are probable best explained here.
jef
Updated on June 04, 2022Comments
-
jef almost 2 years
I have sequential data and I declared a LSTM model which predicts
y
withx
in Keras. So if I callmodel.predict(x1)
andmodel.predict(x2)
, Is it correct to callmodel.reset_states
between the twopredict()
explicitly? Doesmodel.reset_states
clear history of inputs, not weights, right?# data1 x1 = [2,4,2,1,4] y1 = [1,2,3,2,1] # dat2 x2 = [5,3,2,4,5] y2 = [5,3,2,3,2]
And in my actual code, I use
model.evaluate()
. Inevaluate()
, isreset_states
called implicitly for each data sample?model.evaluate(dataX, dataY)
-
jef about 7 yearsI got it. If I do not set stateful option (so default=false), I do not need to call reset_states, right? And could you tell me what kind of cases I should use stateful=True?
-
Marcin Możejko about 7 yearsYes, you are right.
stateful=True
is usually used when you want to treat consecutive batches as consequtive inputs. In this case model is treating consequtive batches the same as it were in the same batch. -
ajaysinghnegi over 5 years@MarcinMożejko How does the model learn when the hidden states are cleared after every batch training? What does
clear
mean here?