model.summary() can't print output shape while using subclass model


Solution 1

I have used this method to solve this problem, I don't know if there is an easier way.

class subclass(Model):
    def __init__(self):
    def call(self, x):

    def model(self):
        x = Input(shape=(24, 24, 3))
        return Model(inputs=[x],

if __name__ == '__main__':
    sub = subclass()

Solution 2

The way I solve the problem is very similar to what Elazar mensioned. Override the function summary() in the class subclass. Then you can directly call summary() while using model subclassing:

class subclass(Model):
    def __init__(self):
    def call(self, x):

    def summary(self):
        x = Input(shape=(24, 24, 3))
        model = Model(inputs=[x],
        return model.summary()

if __name__ == '__main__':
    sub = subclass()

Solution 3

I guess that key point is the _init_graph_network method in the class Network, which is the parent class of Model. _init_graph_network will be called if you specify the inputs and outputs arguments when calling __init__ method.

So there will be two possible methods:

  1. Manually calling the _init_graph_network method to build the graph of the model.
  2. Reinitialize with the input layer and output.

and both methods need the input layer and output (required from

Now calling summary will give the exact output shape. However it would show the Input layer, which isn't a part of subclassing Model.

from tensorflow import keras
from tensorflow.keras import layers as klayers

class MLP(keras.Model):
    def __init__(self, input_shape=(32), **kwargs):
        super(MLP, self).__init__(**kwargs)
        # Add input layer
        self.input_layer = klayers.Input(input_shape)

        self.dense_1 = klayers.Dense(64, activation='relu')
        self.dense_2 = klayers.Dense(10)

        # Get output layer with `call` method
        self.out =

        # Reinitial
        super(MLP, self).__init__(

    def build(self):
        # Initialize the graph
        self._is_graph_network = True

    def call(self, inputs):
        x = self.dense_1(inputs)
        return self.dense_2(x)

if __name__ == '__main__':
    mlp = MLP(16)

The output will be:

Model: "mlp_1"
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 16)]              0         
dense (Dense)                (None, 64)                1088      
dense_1 (Dense)              (None, 10)                650       
Total params: 1,738
Trainable params: 1,738
Non-trainable params: 0

Solution 4

I analyzed the answer of Adi Shumely:

  • Adding an Input_shape should not be needed since you set it in the build() as a parameter
  • Adding an Input layer does nothing to the model and it is brought as a parameter to the call() method
  • Adding the so-called output is not the way I see it. The only, and most important, thing that it does is calling the call() method.

So I bring it up and come up with this solution that does not need any modification in the model and just needs to improve the model as it is built before the call to the summary() method by adding a call to the call() method of the model with an Input tensor. I tried on my own model and on the three models presented in this feed and it works so far.

From the first post of this feed:

import tensorflow as tf
from tensorflow.keras import Input, layers, Model

class subclass(Model):
    def __init__(self):
        super(subclass, self).__init__()
        self.conv = layers.Conv2D(28, 3, strides=1)

    def call(self, x):
        return self.conv(x)

if __name__ == '__main__':
    sub = subclass(), 24, 24, 3))

    # Adding this call to the call() method solves it all, 24, 3)))

    # And the summary() outputs all the information

From the second post of the feed

from tensorflow import keras
from tensorflow.keras import layers as klayers

class MLP(keras.Model):
    def __init__(self, **kwargs):
        super(MLP, self).__init__(**kwargs)
        self.dense_1 = klayers.Dense(64, activation='relu')
        self.dense_2 = klayers.Dense(10)

    def call(self, inputs):
        x = self.dense_1(inputs)
        return self.dense_2(x)

if __name__ == '__main__':
    mlp = MLP(), 16))

As from the last post of the feed

import tensorflow as tf
class MyModel(tf.keras.Model):
    def __init__(self, **kwargs):
        super(MyModel, self).__init__(**kwargs) 
        self.dense10 = tf.keras.layers.Dense(10, activation=tf.keras.activations.softmax)    
        self.dense20 = tf.keras.layers.Dense(20, activation=tf.keras.activations.softmax)
    def call(self, inputs):
        x =  self.dense10(inputs)
        y_pred =  self.dense20(x)
        return y_pred

model = MyModel() = (None, 32, 32, 1)) = (32, 32, 1)))

Solution 5

had the same problem - fix it by 3 steps:

  1. add input_shape in the _ init _
  2. add a input_layer
  3. add out layer
class MyModel(tf.keras.Model):
    def __init__(self,input_shape=(32,32,1), **kwargs):
        super(MyModel, self).__init__(**kwargs) 
        self.input_layer = tf.keras.layers.Input(input_shape)
        self.dense10 = tf.keras.layers.Dense(10, activation=tf.keras.activations.softmax)    
        self.dense20 = tf.keras.layers.Dense(20, activation=tf.keras.activations.softmax)
        self.out =    
    def call(self, inputs):
        x =  self.dense10(inputs)
        y_pred =  self.dense20(x)
        return y_pred

model = MyModel()


x_test[:99].shape: (99, 32, 32, 1)
Model: "my_model_32"
Layer (type)                 Output Shape              Param #   
dense_79 (Dense)             (None, 32, 32, 10)        20        
dense_80 (Dense)             (None, 32, 32, 20)        220       
Total params: 240
Trainable params: 240
Non-trainable params: 0


Related videos on Youtube

Author by


Updated on January 28, 2022


  • Gary
    Gary about 2 years

    This is the two methods for creating a keras model, but the output shapes of the summary results of the two methods are different. Obviously, the former prints more information and makes it easier to check the correctness of the network.

    import tensorflow as tf
    from tensorflow.keras import Input, layers, Model
    class subclass(Model):
        def __init__(self):
            super(subclass, self).__init__()
            self.conv = layers.Conv2D(28, 3, strides=1)
        def call(self, x):
            return self.conv(x)
    def func_api():
        x = Input(shape=(24, 24, 3))
        y = layers.Conv2D(28, 3, strides=1)(x)
        return Model(inputs=[x], outputs=[y])
    if __name__ == '__main__':
        func = func_api()
        sub = subclass(), 24, 24, 3))


    Layer (type)                 Output Shape              Param #   
    input_1 (InputLayer)         (None, 24, 24, 3)         0         
    conv2d (Conv2D)              (None, 22, 22, 28)        784       
    Total params: 784
    Trainable params: 784
    Non-trainable params: 0
    Layer (type)                 Output Shape              Param #   
    conv2d_1 (Conv2D)            multiple                  784       
    Total params: 784
    Trainable params: 784
    Non-trainable params: 0

    So, how should I use the subclass method to get the output shape at the summary()?

  • Gilfoyle
    Gilfoyle over 3 years
    Can you explain why this works? Especially the part.
  • Rob Hall
    Rob Hall over 3 years
    @Samuel By evaluating, the, x) method is invoked. This triggers shape computation in the encapsulating instance. Furthermore, the returned instance of Model also computes its own shape which is reported in .summary(). The primary problem with this approach is that the input shape is constant shape=(24, 24, 3), so if you need a dynamic solution, this won't work.
  • GuySoft
    GuySoft about 3 years
    Can you explain what goes in the ... . Is this a general solution or do you need model-specific stuff in those calls?
  • DeWil
    DeWil almost 3 years
    @GuySoft ... in init instantiates your layers while ... in call connects the different layers building a network. It's generic for all subclassed keras models.
  • MOON
    MOON about 2 years
    Are there any advantage over the Elazar's solution? I like your approach since it is more succinct.