facenet triplet loss with keras


Solution 1

What could have happened, other than the learning rate was simply too high, was that an unstable triplet selection strategy had been used, effectively. If, for example, you only use 'hard triplets' (triplets where the a-n distance is smaller than the a-p distance), your network weights might collapse all embeddings to a single point (making the loss always equal to margin (your _alpha), because all embedding distances are zero).

This can be fixed by using other kinds of triplets as well (like 'semi-hard triplets' where a-p is smaller than a-n, but the distance between a-p and a-n is still smaller than margin). So maybe if you always checked for this... It is explained in more detail in this blog post: https://omoindrot.github.io/triplet-loss

Solution 2

Are you constraining your embeddings to "be on a d-dimensional hypersphere"? Try running tf.nn.l2_normalize on your embeddings right after they come out of the CNN.

The problem could be that the embeddings are sort of being smart-alecs. One easy way to reduce the loss is to just set everything to zero. l2_normalize forces them to be unit length.

It looks you'll want to add the normalizing right after the last average pool.

Solution 3

I have met the same problem, and I did some research work. I think it is because triplet loss needs multiple inputs, which may cause the network to generate outputs like that. I haven't fixed the problem yet, but you can check the issue page of keras for more details https://github.com/keras-team/keras/issues/9498.

In the issue above, I implemented a fake dataset and a fake triplet loss to reproduce the problem, after I changed the input structure of the network, the loss became normal.


Related videos on Youtube

Author by


Updated on June 04, 2022


  • DalekSupreme
    DalekSupreme almost 2 years

    I am trying to implement facenet in Keras with Tensorflow backend and I have some problem with the triplet loss.enter image description here

    I call the fit function with 3*n number of images and then I define my custom loss function as follows:

    def triplet_loss(self, y_true, y_pred):
        embeddings = K.reshape(y_pred, (-1, 3, output_dim))
        positive_distance = K.mean(K.square(embeddings[:,0] - embeddings[:,1]),axis=-1)
        negative_distance = K.mean(K.square(embeddings[:,0] - embeddings[:,2]),axis=-1)
        return K.mean(K.maximum(0.0, positive_distance - negative_distance + _alpha))
    self._model.compile(loss=triplet_loss, optimizer="sgd")
    self._model.fit(x=x,y=y,nb_epoch=1, batch_size=len(x))

    where y is just a dummy array filled with 0s

    The problem is that even after the first iteration with batch size 20 the model starts predicting the same embedding for all the images. So when I first do the prediction on the batch every embedding is different. Then I do the fit and predict again and suddenly all the embeddings becomes almost the same for all the images in the batch

    Also notice that there is a Lambda layer at the end of the model. It normalizes the output of the net so all the embeddings has a unit length as it was suggested in the face net study.

    Can anybody help me out here?

    Model summary

        Layer (type)                     Output Shape          Param #     Connected to                     
    input_1 (InputLayer)             (None, 224, 224, 3)   0                                            
    convolution2d_1 (Convolution2D)  (None, 112, 112, 64)  9472        input_1[0][0]                    
    batchnormalization_1 (BatchNormal(None, 112, 112, 64)  128         convolution2d_1[0][0]            
    maxpooling2d_1 (MaxPooling2D)    (None, 56, 56, 64)    0           batchnormalization_1[0][0]       
    convolution2d_2 (Convolution2D)  (None, 56, 56, 64)    4160        maxpooling2d_1[0][0]             
    batchnormalization_2 (BatchNormal(None, 56, 56, 64)    128         convolution2d_2[0][0]            
    convolution2d_3 (Convolution2D)  (None, 56, 56, 192)   110784      batchnormalization_2[0][0]       
    batchnormalization_3 (BatchNormal(None, 56, 56, 192)   384         convolution2d_3[0][0]            
    maxpooling2d_2 (MaxPooling2D)    (None, 28, 28, 192)   0           batchnormalization_3[0][0]       
    convolution2d_5 (Convolution2D)  (None, 28, 28, 96)    18528       maxpooling2d_2[0][0]             
    convolution2d_7 (Convolution2D)  (None, 28, 28, 16)    3088        maxpooling2d_2[0][0]             
    maxpooling2d_3 (MaxPooling2D)    (None, 28, 28, 192)   0           maxpooling2d_2[0][0]             
    convolution2d_4 (Convolution2D)  (None, 28, 28, 64)    12352       maxpooling2d_2[0][0]             
    convolution2d_6 (Convolution2D)  (None, 28, 28, 128)   110720      convolution2d_5[0][0]            
    convolution2d_8 (Convolution2D)  (None, 28, 28, 32)    12832       convolution2d_7[0][0]            
    convolution2d_9 (Convolution2D)  (None, 28, 28, 32)    6176        maxpooling2d_3[0][0]             
    merge_1 (Merge)                  (None, 28, 28, 256)   0           convolution2d_4[0][0]            
    convolution2d_11 (Convolution2D) (None, 28, 28, 96)    24672       merge_1[0][0]                    
    convolution2d_13 (Convolution2D) (None, 28, 28, 32)    8224        merge_1[0][0]                    
    maxpooling2d_4 (MaxPooling2D)    (None, 28, 28, 256)   0           merge_1[0][0]                    
    convolution2d_10 (Convolution2D) (None, 28, 28, 64)    16448       merge_1[0][0]                    
    convolution2d_12 (Convolution2D) (None, 28, 28, 128)   110720      convolution2d_11[0][0]           
    convolution2d_14 (Convolution2D) (None, 28, 28, 64)    51264       convolution2d_13[0][0]           
    convolution2d_15 (Convolution2D) (None, 28, 28, 64)    16448       maxpooling2d_4[0][0]             
    merge_2 (Merge)                  (None, 28, 28, 320)   0           convolution2d_10[0][0]           
    convolution2d_16 (Convolution2D) (None, 28, 28, 128)   41088       merge_2[0][0]                    
    convolution2d_18 (Convolution2D) (None, 28, 28, 32)    10272       merge_2[0][0]                    
    convolution2d_17 (Convolution2D) (None, 14, 14, 256)   295168      convolution2d_16[0][0]           
    convolution2d_19 (Convolution2D) (None, 14, 14, 64)    51264       convolution2d_18[0][0]           
    maxpooling2d_5 (MaxPooling2D)    (None, 14, 14, 320)   0           merge_2[0][0]                    
    merge_3 (Merge)                  (None, 14, 14, 640)   0           convolution2d_17[0][0]           
    convolution2d_21 (Convolution2D) (None, 14, 14, 96)    61536       merge_3[0][0]                    
    convolution2d_23 (Convolution2D) (None, 14, 14, 32)    20512       merge_3[0][0]                    
    maxpooling2d_6 (MaxPooling2D)    (None, 14, 14, 640)   0           merge_3[0][0]                    
    convolution2d_20 (Convolution2D) (None, 14, 14, 256)   164096      merge_3[0][0]                    
    convolution2d_22 (Convolution2D) (None, 14, 14, 192)   166080      convolution2d_21[0][0]           
    convolution2d_24 (Convolution2D) (None, 14, 14, 64)    51264       convolution2d_23[0][0]           
    convolution2d_25 (Convolution2D) (None, 14, 14, 128)   82048       maxpooling2d_6[0][0]             
    merge_4 (Merge)                  (None, 14, 14, 640)   0           convolution2d_20[0][0]           
    convolution2d_27 (Convolution2D) (None, 14, 14, 112)   71792       merge_4[0][0]                    
    convolution2d_29 (Convolution2D) (None, 14, 14, 32)    20512       merge_4[0][0]                    
    maxpooling2d_7 (MaxPooling2D)    (None, 14, 14, 640)   0           merge_4[0][0]                    
    convolution2d_26 (Convolution2D) (None, 14, 14, 224)   143584      merge_4[0][0]                    
    convolution2d_28 (Convolution2D) (None, 14, 14, 224)   226016      convolution2d_27[0][0]           
    convolution2d_30 (Convolution2D) (None, 14, 14, 64)    51264       convolution2d_29[0][0]           
    convolution2d_31 (Convolution2D) (None, 14, 14, 128)   82048       maxpooling2d_7[0][0]             
    merge_5 (Merge)                  (None, 14, 14, 640)   0           convolution2d_26[0][0]           
    convolution2d_33 (Convolution2D) (None, 14, 14, 128)   82048       merge_5[0][0]                    
    convolution2d_35 (Convolution2D) (None, 14, 14, 32)    20512       merge_5[0][0]                    
    maxpooling2d_8 (MaxPooling2D)    (None, 14, 14, 640)   0           merge_5[0][0]                    
    convolution2d_32 (Convolution2D) (None, 14, 14, 192)   123072      merge_5[0][0]                    
    convolution2d_34 (Convolution2D) (None, 14, 14, 256)   295168      convolution2d_33[0][0]           
    convolution2d_36 (Convolution2D) (None, 14, 14, 64)    51264       convolution2d_35[0][0]           
    convolution2d_37 (Convolution2D) (None, 14, 14, 128)   82048       maxpooling2d_8[0][0]             
    merge_6 (Merge)                  (None, 14, 14, 640)   0           convolution2d_32[0][0]           
    convolution2d_39 (Convolution2D) (None, 14, 14, 144)   92304       merge_6[0][0]                    
    convolution2d_41 (Convolution2D) (None, 14, 14, 32)    20512       merge_6[0][0]                    
    maxpooling2d_9 (MaxPooling2D)    (None, 14, 14, 640)   0           merge_6[0][0]                    
    convolution2d_38 (Convolution2D) (None, 14, 14, 160)   102560      merge_6[0][0]                    
    convolution2d_40 (Convolution2D) (None, 14, 14, 288)   373536      convolution2d_39[0][0]           
    convolution2d_42 (Convolution2D) (None, 14, 14, 64)    51264       convolution2d_41[0][0]           
    convolution2d_43 (Convolution2D) (None, 14, 14, 128)   82048       maxpooling2d_9[0][0]             
    merge_7 (Merge)                  (None, 14, 14, 640)   0           convolution2d_38[0][0]           
    convolution2d_44 (Convolution2D) (None, 14, 14, 160)   102560      merge_7[0][0]                    
    convolution2d_46 (Convolution2D) (None, 14, 14, 64)    41024       merge_7[0][0]                    
    convolution2d_45 (Convolution2D) (None, 7, 7, 256)     368896      convolution2d_44[0][0]           
    convolution2d_47 (Convolution2D) (None, 7, 7, 128)     204928      convolution2d_46[0][0]           
    maxpooling2d_10 (MaxPooling2D)   (None, 7, 7, 640)     0           merge_7[0][0]                    
    merge_8 (Merge)                  (None, 7, 7, 1024)    0           convolution2d_45[0][0]           
    convolution2d_49 (Convolution2D) (None, 7, 7, 192)     196800      merge_8[0][0]                    
    convolution2d_51 (Convolution2D) (None, 7, 7, 48)      49200       merge_8[0][0]                    
    maxpooling2d_11 (MaxPooling2D)   (None, 7, 7, 1024)    0           merge_8[0][0]                    
    convolution2d_48 (Convolution2D) (None, 7, 7, 384)     393600      merge_8[0][0]                    
    convolution2d_50 (Convolution2D) (None, 7, 7, 384)     663936      convolution2d_49[0][0]           
    convolution2d_52 (Convolution2D) (None, 7, 7, 128)     153728      convolution2d_51[0][0]           
    convolution2d_53 (Convolution2D) (None, 7, 7, 128)     131200      maxpooling2d_11[0][0]            
    merge_9 (Merge)                  (None, 7, 7, 1024)    0           convolution2d_48[0][0]           
    convolution2d_55 (Convolution2D) (None, 7, 7, 192)     196800      merge_9[0][0]                    
    convolution2d_57 (Convolution2D) (None, 7, 7, 48)      49200       merge_9[0][0]                    
    maxpooling2d_12 (MaxPooling2D)   (None, 7, 7, 1024)    0           merge_9[0][0]                    
    convolution2d_54 (Convolution2D) (None, 7, 7, 384)     393600      merge_9[0][0]                    
    convolution2d_56 (Convolution2D) (None, 7, 7, 384)     663936      convolution2d_55[0][0]           
    convolution2d_58 (Convolution2D) (None, 7, 7, 128)     153728      convolution2d_57[0][0]           
    convolution2d_59 (Convolution2D) (None, 7, 7, 128)     131200      maxpooling2d_12[0][0]            
    merge_10 (Merge)                 (None, 7, 7, 1024)    0           convolution2d_54[0][0]           
    averagepooling2d_1 (AveragePoolin(None, 1, 1, 1024)    0           merge_10[0][0]                   
    flatten_1 (Flatten)              (None, 1024)          0           averagepooling2d_1[0][0]         
    dense_1 (Dense)                  (None, 128)           131200      flatten_1[0][0]                  
    lambda_1 (Lambda)                (None, 128)           0           dense_1[0][0]                    
    Total params: 7456944
    • Rob
      Rob almost 7 years
      @DalekSupreme where you able to successfully implement Facenet in Keras? I am working on a project and would love to know if someone pulled this off.
  • DalekSupreme
    DalekSupreme over 7 years
    Thanks for the idea. Unfortunately the last Lambda layer lambda_1 already does that. It normalizes the embedding as it was suggested in the face net study. Why should I put the normalization after the average pool and not the dense layer?
  • chris
    chris over 7 years
    Oh, hmm. I'm not sure about that. That's just where I did it (for Siamese networks, which is a similar idea).
  • user8523104
    user8523104 over 3 years
    I am looking for a simple keras metric learning example, could you please share an example with, really appreciate,