How can I implement a weighted cross entropy loss in tensorflow using sparse_softmax_cross_entropy_with_logits

18,622

Solution 1

import  tensorflow as tf
import numpy as np

np.random.seed(123)
sess = tf.InteractiveSession()

# let's say we have the logits and labels of a batch of size 6 with 5 classes
logits = tf.constant(np.random.randint(0, 10, 30).reshape(6, 5), dtype=tf.float32)
labels = tf.constant(np.random.randint(0, 5, 6), dtype=tf.int32)

# specify some class weightings
class_weights = tf.constant([0.3, 0.1, 0.2, 0.3, 0.1])

# specify the weights for each sample in the batch (without having to compute the onehot label matrix)
weights = tf.gather(class_weights, labels)

# compute the loss
tf.losses.sparse_softmax_cross_entropy(labels, logits, weights).eval()

Solution 2

Specifically for binary classification, there is weighted_cross_entropy_with_logits, that computes weighted softmax cross entropy.

sparse_softmax_cross_entropy_with_logits is tailed for a high-efficient non-weighted operation (see SparseSoftmaxXentWithLogitsOp which uses SparseXentEigenImpl under the hood), so it's not "pluggable".

In multi-class case, your option is either switch to one-hot encoding or use tf.losses.sparse_softmax_cross_entropy loss function in a hacky way, as already suggested, where you will have to pass the weights depending on the labels in a current batch.

Solution 3

The class weights are multiplied by the logits, so that still works for sparse_softmax_cross_entropy_with_logits. Refer to this solution for "Loss function for class imbalanced binary classifier in Tensor flow."

As a side note, you can pass weights directly into sparse_softmax_cross_entropy

tf.contrib.losses.sparse_softmax_cross_entropy(logits, labels, weight=1.0, scope=None)

This method is for cross-entropy loss using

tf.nn.sparse_softmax_cross_entropy_with_logits.

Weight acts as a coefficient for the loss. If a scalar is provided, then the loss is simply scaled by the given value. If weight is a tensor of size [batch_size], then the loss weights apply to each corresponding sample.

Share:
18,622
Roger Trullo
Author by

Roger Trullo

Updated on June 06, 2022

Comments

  • Roger Trullo
    Roger Trullo almost 2 years

    I am starting to use tensorflow (coming from Caffe), and I am using the loss sparse_softmax_cross_entropy_with_logits. The function accepts labels like 0,1,...C-1 instead of onehot encodings. Now, I want to use a weighting depending on the class label; I know that this could be done maybe with a matrix multiplication if I use softmax_cross_entropy_with_logits (one hot encoding), Is there any way to do the same with sparse_softmax_cross_entropy_with_logits?

  • Roger Trullo
    Roger Trullo over 7 years
    I was wondering if there was a way to avoid the one hot labels;because in the link provided, there is still need to multiply the matrix of one hot labels with the weight vector. Another way would be using directly the weight vector of length batchsize, but then I would have to compute this vector for every batch; how could I define it (since it depends on the labels) without having to compute the onehot label matrix?
  • andong777
    andong777 about 7 years
    I don't think this answer is correct. The weights in tf.contrib.losses.sparse_softmax_cross_entropy is per-sample, not per-class.
  • Matt S
    Matt S over 6 years
    It is correct, it's just annoying. You would pass a weight for each update and that would depend on the particular class that is in the current update. So if you had a batch of size 3 and the classes were 1,1,2. And you wanted to weight class 1 at 50%, then you would use this loss function and pass the weight argument a tensor with values [0.5,0.5,1.0]. That would effectively weight your class... Elegant? No. Effective yes.
  • user3151261
    user3151261 about 6 years
    Please, Help! How could I pass weights to my_custom_model?stackoverflow.com/questions/49312839/…