Implementing contrastive loss and triplet loss in Tensorflow

36,479

Solution 1

Update (2018/03/19): I wrote a blog post detailing how to implement triplet loss in TensorFlow.


You need to implement yourself the contrastive loss or the triplet loss, but once you know the pairs or triplets this is quite easy.


Contrastive Loss

Suppose you have as input the pairs of data and their label (positive or negative, i.e. same class or different class). For instance you have images as input of size 28x28x1:

left = tf.placeholder(tf.float32, [None, 28, 28, 1])
right = tf.placeholder(tf.float32, [None, 28, 28, 1])
label = tf.placeholder(tf.int32, [None, 1]). # 0 if same, 1 if different
margin = 0.2

left_output = model(left)  # shape [None, 128]
right_output = model(right)  # shape [None, 128]

d = tf.reduce_sum(tf.square(left_output - right_output), 1)
d_sqrt = tf.sqrt(d)

loss = label * tf.square(tf.maximum(0., margin - d_sqrt)) + (1 - label) * d

loss = 0.5 * tf.reduce_mean(loss)

Triplet Loss

Same as with contrastive loss, but with triplets (anchor, positive, negative). You don't need labels here.

anchor_output = ...  # shape [None, 128]
positive_output = ...  # shape [None, 128]
negative_output = ...  # shape [None, 128]

d_pos = tf.reduce_sum(tf.square(anchor_output - positive_output), 1)
d_neg = tf.reduce_sum(tf.square(anchor_output - negative_output), 1)

loss = tf.maximum(0., margin + d_pos - d_neg)
loss = tf.reduce_mean(loss)

The real trouble when implementing triplet loss or contrastive loss in TensorFlow is how to sample the triplets or pairs. I will focus on generating triplets because it is harder than generating pairs.

The easiest way is to generate them outside of the Tensorflow graph, i.e. in python and feed them to the network through the placeholders. Basically you select images 3 at a time, with the first two from the same class and the third from another class. We then perform a feedforward on these triplets, and compute the triplet loss.

The issue here is that generating triplets is complicated. We want them to be valid triplets, triplets with a positive loss (otherwise the loss is 0 and the network doesn't learn).
To know whether a triplet is good or not you need to compute its loss, so you already make one feedforward through the network...

Clearly, implementing triplet loss in Tensorflow is hard, and there are ways to make it more efficient than sampling in python but explaining them would require a whole blog post !

Solution 2

Triplet loss with semihard negative mining is now implemented in tf.contrib, as follows:

triplet_semihard_loss(
    labels,
    embeddings,
    margin=1.0
)

where:

Args:

  • labels: 1-D tf.int32 Tensor with shape [batch_size] of multiclass integer labels.

  • embeddings: 2-D float Tensor of embedding vectors.Embeddings should be l2 normalized.

  • margin: Float, margin term in theloss definition.

Returns:

  • triplet_loss: tf.float32 scalar.

For further information, check the link bellow:

https://www.tensorflow.org/versions/master/api_docs/python/tf/contrib/losses/metric_learning/triplet_semihard_loss

Solution 3

Tiago, I don't think you are using the same formula Olivier gave. Here is the right code (not sure it will work though, just fixing the formula) :

def compute_euclidean_distance(x, y):
    """
    Computes the euclidean distance between two tensorflow variables
    """

    d = tf.reduce_sum(tf.square(tf.sub(x, y)),1)
    return d


def compute_contrastive_loss(left_feature, right_feature, label, margin):

    """
    Compute the contrastive loss as in


    L = 0.5 * Y * D^2 + 0.5 * (Y-1) * {max(0, margin - D)}^2

    **Parameters**
     left_feature: First element of the pair
     right_feature: Second element of the pair
     label: Label of the pair (0 or 1)
     margin: Contrastive margin

    **Returns**
     Return the loss operation

    """

    label = tf.to_float(label)
    one = tf.constant(1.0)

    d = compute_euclidean_distance(left_feature, right_feature)
    d_sqrt = tf.sqrt(compute_euclidean_distance(left_feature, right_feature))
    first_part = tf.mul(one-label, d)# (Y-1)*(d)

    max_part = tf.square(tf.maximum(margin-d_sqrt, 0))
    second_part = tf.mul(label, max_part)  # (Y) * max(margin - d, 0)

    loss = 0.5 * tf.reduce_mean(first_part + second_part)

    return loss
Share:
36,479
Tiago Freitas Pereira
Author by

Tiago Freitas Pereira

Updated on October 23, 2021

Comments

  • Tiago Freitas Pereira
    Tiago Freitas Pereira over 2 years

    I started to play with TensorFlow two days ago and I'm wondering if there is the triplet and the contrastive losses implemented.

    I've been looking at the documentation, but I haven't found any example or description about these things.

  • Tiago Freitas Pereira
    Tiago Freitas Pereira almost 8 years
    Hi Wasssim, thanks for the fix, just a patch in your code. d_sqrt = tf.sqrt(compute_euclidean_distance(left_feature, right_feature)) But even with this fix, I get very low accuracy (but the loss decreases as expected).
  • Wassim Gr
    Wassim Gr almost 8 years
    @TiagoFreitasPereira I am having the same problem with my triplet loss implementation. I will notify you if I find a solution...
  • Tiago Freitas Pereira
    Tiago Freitas Pereira almost 8 years
    Hey @Wassim, thanks. If it is easier, you can try to bootstrap my project (github.com/tiagofrepereira2012/examples.tensorflow).
  • Wassim Gr
    Wassim Gr almost 8 years
    @TiagoFreitasPereira , it seems like it has to do with the way we implement the accuracy computation. Looks like when using Triplet Loss or Contrastive Loss you can't compute accuracy using label verification (because the network wasn't trained to differentiate the 10 classes), however, you have to compute accuracy by evaluating whether the network guessed that two elements are from the same class or not.
  • Wassim Gr
    Wassim Gr almost 8 years
    See section 4 and 5.6 of this paper arxiv.org/pdf/1503.03832v3.pdf
  • Tiago Freitas Pereira
    Tiago Freitas Pereira almost 8 years
    Hi @Wassim, yes, I understand that, but my goal here is to train the siamese net (or the triplet) and use one of the fully connected layers (fc1 or fc2 in my code) as features. In our example, since the network is good to diferenciate digits, the trained features must be good.
  • Olivier Moindrot
    Olivier Moindrot almost 8 years
    The features will be good but you need to add a final softmax (and retrain it) on them
  • Wassim Gr
    Wassim Gr almost 8 years
    The trained features are indeed good but after you apply a softmax you have to beware of the indices because since we don't feed labels, an activation of the first neuron on the softmax layer doesn't necessary signify that it detected a 0, it could be any of the other digits. Usually, if you train using contrastive/triplet loss you're aiming to use the network for comparison rather than classification.
  • weitang114
    weitang114 almost 8 years
    Hi @Olivier, I am very interested in the sampling part. Would you or have you posted a blog for it? I am doing what just as you said, to feed forward once, and compute the losses for all possible triplets, filter out invalid ones, and sample a batch to do another forward+backward...
  • Olivier Moindrot
    Olivier Moindrot almost 8 years
    Didn't write any blog post. One key insight is to compute all the possible triplets as explained in OpenFace, my answer above contains the old solution. To remove the middle sess.run() call, you can add a tf.py_func operation inside the graph to filter out the bad triplets.
  • Olivier Moindrot
    Olivier Moindrot almost 8 years
    @weitang114: Another way for the 2nd part is to just compute the loss for all the triplets, removing only the invalid triplets (i.e. (+, +, +)), which can be computed in advance. This converges well, surprisingly.
  • weitang114
    weitang114 over 7 years
    thank you for this advice. I didn't get the idea that moment, but found it very useful recently. This process implemented in tf helped me reduce a training time from 5 days to 1 day. :)
  • Olivier Moindrot
    Olivier Moindrot over 7 years
    @weitang114: Yeah it's very convenient. Did you implement it without tf.py_func (the second idea I gave)?
  • weitang114
    weitang114 over 7 years
    No. I implemented it almost the same as what is said in the OpenFace article. I used tf.nn.relu() to filter out useless losses, and count how many losses are left, say C, then the mean loss is sum(losses)/C.
  • Hello Lili
    Hello Lili over 6 years
    @weitang114 how did you manage to select only the valid triplets for training? When I verify if the loss is > 0 for a set of triplets (anchor image, positive image, negative image), I have to feed the triplets again to the model to calculate the gradients. And because I use dropout, the same triplets might give loss 0 at the next feed. I'm stuck.
  • Hello Lili
    Hello Lili over 6 years
    @OlivierMoindrot can you please provide an example of filtering bad triplets using py_func?
  • Sunil
    Sunil over 6 years
    Link only answers? Include some relevant portions from the link here.
  • Machavity
    Machavity over 6 years
    While this link might provide some limited, immediate help, an answer should include sufficient context around the link so your fellow users will have some idea what it is and why it’s there. Always quote the most relevant part of an important link, to make it more useful to future readers with other, similar questions. In addition, other users tend to respond negatively to answers which are barely more than a link to an external site, and they might be deleted.
  • lincr
    lincr over 6 years
    Why d is used instead of sqrt_d at the ending of first assignment to loss in contrastive loss?
  • Olivier Moindrot
    Olivier Moindrot over 6 years
    This is the formula for contrastive loss
  • Olivier Moindrot
    Olivier Moindrot about 6 years
    @HelloLili: I finally wrote that blog post. Here it is: omoindrot.github.io/triplet-loss
  • Olivier Moindrot
    Olivier Moindrot about 6 years
    @weitang114: I finally wrote that blog post. Here it is: omoindrot.github.io/triplet-loss
  • C. Wang
    C. Wang about 6 years
    The aforementioned code of Contrastive Loss should be modified a little bit to avoid NaN error. i.e. d_sqrt = tf.sqrt(d + 1e-7). We used the code and found the bug.