TypeError: __init__() got an unexpected keyword argument 'trainable'

11,531

I think you missed a small detail in your layer definition. You layers' __init__ method should take keyword arguments (**kwargs) and you should pass these keyword arguments to the parent class __init__, like this:

class AttLayer(Layer):
    def __init__(self, attention_dim, **kwargs):
        self.init = initializers.get('normal')
        self.supports_masking = True
        self.attention_dim = attention_dim
        super(AttLayer, self).__init__(**kwargs)

This way any generic layer parameter will be correctly passed to the parent class, in your case, the trainable flag.

Share:
11,531
Biswadip Mandal
Author by

Biswadip Mandal

Updated on June 24, 2022

Comments

  • Biswadip Mandal
    Biswadip Mandal almost 2 years

    I am trying to load a RNN model architecture trained in Keras using keras.models.model_from_json and I am getting the mentioned error

    with open('model_architecture.json', 'r') as f:
        model = model_from_json(f.read(), custom_objects={'AttLayer':AttLayer})
    
    # Load weights into the new model
    model.load_weights('model_weights.h5')
    

    Here is the custom layer I am using

    class AttLayer(Layer):
        def __init__(self, attention_dim):
            self.init = initializers.get('normal')
            self.supports_masking = True
            self.attention_dim = attention_dim
            super(AttLayer, self).__init__()
    
        def build(self, input_shape):
            assert len(input_shape) == 3
            self.W = K.variable(self.init((input_shape[-1], self.attention_dim)))
            self.b = K.variable(self.init((self.attention_dim, )))
            self.u = K.variable(self.init((self.attention_dim, 1)))
            self.trainable_weights = [self.W, self.b, self.u]
            super(AttLayer, self).build(input_shape)
    
        def compute_mask(self, inputs, mask=None):
            return None
    
        def call(self, x, mask=None):
            # size of x :[batch_size, sel_len, attention_dim]
            # size of u :[batch_size, attention_dim]
            # uit = tanh(xW+b)
            uit = K.tanh(K.bias_add(K.dot(x, self.W), self.b))
            ait = K.dot(uit, self.u)
            ait = K.squeeze(ait, -1)
    
            ait = K.exp(ait)
    
            if mask is not None:
                # Cast the mask to floatX to avoid float64 upcasting in theano
                ait *= K.cast(mask, K.floatx())
            ait /= K.cast(K.sum(ait, axis=1, keepdims=True) + K.epsilon(), K.floatx())
            ait = K.expand_dims(ait)
            weighted_input = x * ait
            output = K.sum(weighted_input, axis=1)
    
            return output
    
        def compute_output_shape(self, input_shape):
            return (input_shape[0], input_shape[-1])
    
        def get_config(self):
            config = {'attention_dim': self.attention_dim}
            base_config = super(AttLayer, self).get_config()
            return dict(list(base_config.items()) + list(config.items()))
    

    error:

    File "scripts/Classifier.py", line 254, in test
        model = model_from_json(f.read(), custom_objects={'AttLayer':AttLayer})
      File "/home/biswadip/.local/lib/python2.7/site-packages/keras/models.py", line 345, in model_from_json
        return layer_module.deserialize(config, custom_objects=custom_objects)
      File "/home/biswadip/.local/lib/python2.7/site-packages/keras/layers/__init__.py", line 54, in deserialize
        printable_module_name='layer')
      File "/home/biswadip/.local/lib/python2.7/site-packages/keras/utils/generic_utils.py", line 139, in deserialize_keras_object
        list(custom_objects.items())))
      File "/home/biswadip/.local/lib/python2.7/site-packages/keras/engine/topology.py", line 2489, in from_config
        process_layer(layer_data)
      File "/home/biswadip/.local/lib/python2.7/site-packages/keras/engine/topology.py", line 2475, in process_layer
        custom_objects=custom_objects)
      File "/home/biswadip/.local/lib/python2.7/site-packages/keras/layers/__init__.py", line 54, in deserialize
        printable_module_name='layer')
      File "/home/biswadip/.local/lib/python2.7/site-packages/keras/utils/generic_utils.py", line 139, in deserialize_keras_object
        list(custom_objects.items())))
      File "/home/biswadip/.local/lib/python2.7/site-packages/keras/layers/wrappers.py", line 100, in from_config
        custom_objects=custom_objects)
      File "/home/biswadip/.local/lib/python2.7/site-packages/keras/layers/__init__.py", line 54, in deserialize
        printable_module_name='layer')
      File "/home/biswadip/.local/lib/python2.7/site-packages/keras/utils/generic_utils.py", line 139, in deserialize_keras_object
        list(custom_objects.items())))
      File "/home/biswadip/.local/lib/python2.7/site-packages/keras/engine/topology.py", line 2489, in from_config
        process_layer(layer_data)
      File "/home/biswadip/.local/lib/python2.7/site-packages/keras/engine/topology.py", line 2475, in process_layer
        custom_objects=custom_objects)
      File "/home/biswadip/.local/lib/python2.7/site-packages/keras/layers/__init__.py", line 54, in deserialize
        printable_module_name='layer')
      File "/home/biswadip/.local/lib/python2.7/site-packages/keras/utils/generic_utils.py", line 141, in deserialize_keras_object
        return cls.from_config(config['config'])
      File "/home/biswadip/.local/lib/python2.7/site-packages/keras/engine/topology.py", line 1254, in from_config
        return cls(**config)
    TypeError: __init__() got an unexpected keyword argument 'trainable'
    

    Versions:

    Keras==2.0.8
    tensorflow==1.4.1
    

    I tried training and loading using different versions, but with no luck. Finally I removed 'trainable' and 'name' (key value pairs)from my custom layer detail in the model architecture file(model_architecture.json) and model seems to be loading without any error. But this looks like a fix and I have to do this every time I train the model.

  • Biswadip Mandal
    Biswadip Mandal over 5 years
    That seemed to be the issue. It's working fine now. Thank you :)
  • Ethan Chen
    Ethan Chen about 5 years
    Perfect solution to my problem. With this **kwargs change, for some reason saving the model with .to_json works but with .save doesn't.