TypeError: __init__() got an unexpected keyword argument 'name' when loading a model with Custom Layer

14,237

Solution 1

Based on the error message only, I would suggest putting **kwargs in __init__. This object will then accept any other keyword argument that you haven't included.

def __init__(self, batch_size, num_patches, **kwargs):
        super(TemporalReshape, self).__init__()
        self.batch_size = batch_size
        self.num_patches = num_patches

Solution 2

Insert **kwargs to __init__() function.

Error message: "TypeError: __init__() missing 3 required positional arguments: 'batch_size', 'num_patches'"

Share:
14,237
Siladittya
Author by

Siladittya

Like writing codes for various problems related to my field of study and also like coding for fun. Have interest in AI, Computer Vision, Astrophysics.

Updated on June 17, 2022

Comments

  • Siladittya
    Siladittya almost 2 years

    I made a custom layer in keras for reshaping the outputs of a CNN before feeding to ConvLSTM2D layer

    class TemporalReshape(Layer):
        def __init__(self,batch_size,num_patches):
            super(TemporalReshape,self).__init__()
            self.batch_size = batch_size
            self.num_patches = num_patches
    
        def call(self,inputs):
            nshape = (self.batch_size,self.num_patches)+inputs.shape[1:]
            return tf.reshape(inputs, nshape)
    
        def get_config(self):
            config = super().get_config().copy()
            config.update({'batch_size':self.batch_size,'num_patches':self.num_patches})
            return config
    

    When I try to load the best model using

    model = tf.keras.models.load_model('/content/saved_models/model_best.h5',custom_objects={'TemporalReshape':TemporalReshape})
    

    I get the error

    TypeError                                 Traceback (most recent call last)
    <ipython-input-83-40b46da33e91> in <module>()
    ----> 1 model = tf.keras.models.load_model('/content/saved_models/model_best.h5',custom_objects={'TemporalReshape':TemporalReshape})
    
    
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/saving/save.py in load_model(filepath, custom_objects, compile, options)
        180     if (h5py is not None and (
        181         isinstance(filepath, h5py.File) or h5py.is_hdf5(filepath))):
    --> 182       return hdf5_format.load_model_from_hdf5(filepath, custom_objects, compile)
        183 
        184     filepath = path_to_string(filepath)
    
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/saving/hdf5_format.py in load_model_from_hdf5(filepath, custom_objects, compile)
        176     model_config = json.loads(model_config.decode('utf-8'))
        177     model = model_config_lib.model_from_config(model_config,
    --> 178                                                custom_objects=custom_objects)
        179 
        180     # set weights
    
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/saving/model_config.py in model_from_config(config, custom_objects)
         53                     '`Sequential.from_config(config)`?')
         54   from tensorflow.python.keras.layers import deserialize  # pylint: disable=g-import-not-at-top
    ---> 55   return deserialize(config, custom_objects=custom_objects)
         56 
         57 
    
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/layers/serialization.py in deserialize(config, custom_objects)
        173       module_objects=LOCAL.ALL_OBJECTS,
        174       custom_objects=custom_objects,
    --> 175       printable_module_name='layer')
    
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/utils/generic_utils.py in deserialize_keras_object(identifier, module_objects, custom_objects, printable_module_name)
        356             custom_objects=dict(
        357                 list(_GLOBAL_CUSTOM_OBJECTS.items()) +
    --> 358                 list(custom_objects.items())))
        359       with CustomObjectScope(custom_objects):
        360         return cls.from_config(cls_config)
    
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/functional.py in from_config(cls, config, custom_objects)
        615     """
        616     input_tensors, output_tensors, created_layers = reconstruct_from_config(
    --> 617         config, custom_objects)
        618     model = cls(inputs=input_tensors, outputs=output_tensors,
        619                 name=config.get('name'))
    
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/functional.py in reconstruct_from_config(config, custom_objects, created_layers)
       1202   # First, we create all layers and enqueue nodes to be processed
       1203   for layer_data in config['layers']:
    -> 1204     process_layer(layer_data)
       1205   # Then we process nodes in order of layer depth.
       1206   # Nodes that cannot yet be processed (if the inbound node
    
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/functional.py in process_layer(layer_data)
       1184       from tensorflow.python.keras.layers import deserialize as deserialize_layer  # pylint: disable=g-import-not-at-top
       1185 
    -> 1186       layer = deserialize_layer(layer_data, custom_objects=custom_objects)
       1187       created_layers[layer_name] = layer
       1188 
    
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/layers/serialization.py in deserialize(config, custom_objects)
        173       module_objects=LOCAL.ALL_OBJECTS,
        174       custom_objects=custom_objects,
    --> 175       printable_module_name='layer')
    
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/utils/generic_utils.py in deserialize_keras_object(identifier, module_objects, custom_objects, printable_module_name)
        358                 list(custom_objects.items())))
        359       with CustomObjectScope(custom_objects):
    --> 360         return cls.from_config(cls_config)
        361     else:
        362       # Then `cls` may be a function returning a class.
    
    /usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/base_layer.py in from_config(cls, config)
        695         A layer instance.
        696     """
    --> 697     return cls(**config)
        698 
        699   def compute_output_shape(self, input_shape):
    
    TypeError: __init__() got an unexpected keyword argument 'name'
    

    When building the model, I used the custom layer like the following :

    x = TemporalReshape(batch_size = 8, num_patches = 16)(x)

    What is causing the error and how to load the model without this error?

  • Dr. Snoopy
    Dr. Snoopy over 3 years
    This is correct but you are missing one key thing, the kwargs need to be passed to the parent init
  • Nicolas Gervais
    Nicolas Gervais over 3 years
    Like this? super(TemporalReshape,self).__init__(**kwargs)
  • Dr. Snoopy
    Dr. Snoopy over 3 years
    Yes that is what I mean
  • Siladittya
    Siladittya over 3 years
    But, even without that I got no error. But thanks for the suggestion
  • wessel
    wessel almost 3 years
    That's because the missing kwargs have default values. As a consequence you run the risk of not reconstructing the exact same thing you serialized.
  • Nicolas Gervais
    Nicolas Gervais almost 3 years
    This is not an answer, and the recent edit doesn't correspond to what the original answerer had said. Furthermore, this edit is a perfect copy of another answer.