TypeError: __init__() got an unexpected keyword argument 'name' when loading a model with Custom Layer
Solution 1
Based on the error message only, I would suggest putting **kwargs
in __init__
. This object will then accept any other keyword argument that you haven't included.
def __init__(self, batch_size, num_patches, **kwargs):
super(TemporalReshape, self).__init__()
self.batch_size = batch_size
self.num_patches = num_patches
Solution 2
Insert **kwargs
to __init__()
function.
Error message: "TypeError: __init__() missing 3 required positional arguments: 'batch_size', 'num_patches'"
Siladittya
Like writing codes for various problems related to my field of study and also like coding for fun. Have interest in AI, Computer Vision, Astrophysics.
Updated on June 17, 2022Comments
-
Siladittya almost 2 years
I made a custom layer in keras for reshaping the outputs of a CNN before feeding to ConvLSTM2D layer
class TemporalReshape(Layer): def __init__(self,batch_size,num_patches): super(TemporalReshape,self).__init__() self.batch_size = batch_size self.num_patches = num_patches def call(self,inputs): nshape = (self.batch_size,self.num_patches)+inputs.shape[1:] return tf.reshape(inputs, nshape) def get_config(self): config = super().get_config().copy() config.update({'batch_size':self.batch_size,'num_patches':self.num_patches}) return config
When I try to load the best model using
model = tf.keras.models.load_model('/content/saved_models/model_best.h5',custom_objects={'TemporalReshape':TemporalReshape})
I get the error
TypeError Traceback (most recent call last) <ipython-input-83-40b46da33e91> in <module>() ----> 1 model = tf.keras.models.load_model('/content/saved_models/model_best.h5',custom_objects={'TemporalReshape':TemporalReshape}) /usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/saving/save.py in load_model(filepath, custom_objects, compile, options) 180 if (h5py is not None and ( 181 isinstance(filepath, h5py.File) or h5py.is_hdf5(filepath))): --> 182 return hdf5_format.load_model_from_hdf5(filepath, custom_objects, compile) 183 184 filepath = path_to_string(filepath) /usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/saving/hdf5_format.py in load_model_from_hdf5(filepath, custom_objects, compile) 176 model_config = json.loads(model_config.decode('utf-8')) 177 model = model_config_lib.model_from_config(model_config, --> 178 custom_objects=custom_objects) 179 180 # set weights /usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/saving/model_config.py in model_from_config(config, custom_objects) 53 '`Sequential.from_config(config)`?') 54 from tensorflow.python.keras.layers import deserialize # pylint: disable=g-import-not-at-top ---> 55 return deserialize(config, custom_objects=custom_objects) 56 57 /usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/layers/serialization.py in deserialize(config, custom_objects) 173 module_objects=LOCAL.ALL_OBJECTS, 174 custom_objects=custom_objects, --> 175 printable_module_name='layer') /usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/utils/generic_utils.py in deserialize_keras_object(identifier, module_objects, custom_objects, printable_module_name) 356 custom_objects=dict( 357 list(_GLOBAL_CUSTOM_OBJECTS.items()) + --> 358 list(custom_objects.items()))) 359 with CustomObjectScope(custom_objects): 360 return cls.from_config(cls_config) /usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/functional.py in from_config(cls, config, custom_objects) 615 """ 616 input_tensors, output_tensors, created_layers = reconstruct_from_config( --> 617 config, custom_objects) 618 model = cls(inputs=input_tensors, outputs=output_tensors, 619 name=config.get('name')) /usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/functional.py in reconstruct_from_config(config, custom_objects, created_layers) 1202 # First, we create all layers and enqueue nodes to be processed 1203 for layer_data in config['layers']: -> 1204 process_layer(layer_data) 1205 # Then we process nodes in order of layer depth. 1206 # Nodes that cannot yet be processed (if the inbound node /usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/functional.py in process_layer(layer_data) 1184 from tensorflow.python.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top 1185 -> 1186 layer = deserialize_layer(layer_data, custom_objects=custom_objects) 1187 created_layers[layer_name] = layer 1188 /usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/layers/serialization.py in deserialize(config, custom_objects) 173 module_objects=LOCAL.ALL_OBJECTS, 174 custom_objects=custom_objects, --> 175 printable_module_name='layer') /usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/utils/generic_utils.py in deserialize_keras_object(identifier, module_objects, custom_objects, printable_module_name) 358 list(custom_objects.items()))) 359 with CustomObjectScope(custom_objects): --> 360 return cls.from_config(cls_config) 361 else: 362 # Then `cls` may be a function returning a class. /usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/base_layer.py in from_config(cls, config) 695 A layer instance. 696 """ --> 697 return cls(**config) 698 699 def compute_output_shape(self, input_shape): TypeError: __init__() got an unexpected keyword argument 'name'
When building the model, I used the custom layer like the following :
x = TemporalReshape(batch_size = 8, num_patches = 16)(x)
What is causing the error and how to load the model without this error?
-
Dr. Snoopy over 3 yearsThis is correct but you are missing one key thing, the kwargs need to be passed to the parent init
-
Nicolas Gervais over 3 yearsLike this?
super(TemporalReshape,self).__init__(**kwargs)
-
Dr. Snoopy over 3 yearsYes that is what I mean
-
Siladittya over 3 yearsBut, even without that I got no error. But thanks for the suggestion
-
wessel almost 3 yearsThat's because the missing kwargs have default values. As a consequence you run the risk of not reconstructing the exact same thing you serialized.
-
Nicolas Gervais almost 3 yearsThis is not an answer, and the recent edit doesn't correspond to what the original answerer had said. Furthermore, this edit is a perfect copy of another answer.