How do you decode one-hot labels in Tensorflow?
Solution 1
You can find out the index of the largest element in the matrix using tf.argmax
. Since your one hot vector will be one dimensional and will have just one 1
and other 0
s, This will work assuming you are dealing with a single vector.
index = tf.argmax(one_hot_vector, axis=0)
For the more standard matrix of batch_size * num_classes
, use axis=1
to get a result of size batch_size * 1
.
Solution 2
Since a one-hot encoding is typically just a matrix with batch_size
rows and num_classes
columns, and each row is all zero with a single non-zero corresponding to the chosen class, you can use tf.argmax()
to recover a vector of integer labels:
BATCH_SIZE = 3
NUM_CLASSES = 4
one_hot_encoded = tf.constant([[0, 1, 0, 0],
[1, 0, 0, 0],
[0, 0, 0, 1]])
# Compute the argmax across the columns.
decoded = tf.argmax(one_hot_encoded, axis=1)
# ...
print sess.run(decoded) # ==> array([1, 0, 3])
Solution 3
data = np.array([1, 5, 3, 8])
print(data)
def encode(data):
print('Shape of data (BEFORE encode): %s' % str(data.shape))
encoded = to_categorical(data)
print('Shape of data (AFTER encode): %s\n' % str(encoded.shape))
return encoded
encoded_data = encode(data)
print(encoded_data)
def decode(datum):
return np.argmax(datum)
decoded_Y = []
print("****************************************")
for i in range(encoded_data.shape[0]):
datum = encoded_data[i]
print('index: %d' % i)
print('encoded datum: %s' % datum)
decoded_datum = decode(encoded_data[i])
print('decoded datum: %s' % decoded_datum)
decoded_Y.append(decoded_datum)
print("****************************************")
print(decoded_Y)
Comments
-
Matt Camp almost 4 years
Been looking, but can't seem to find any examples of how to decode or convert back to a single integer from a one-hot value in TensorFlow.
I used
tf.one_hot
and was able to train my model but am a bit confused on how to make sense of the label after my classification. My data is being fed in via aTFRecords
file that I created. I thought about storing a text label in the file but wasn't able to get it to work. It appeared as ifTFRecords
couldn't store text string or maybe I was mistaken. -
martianwars over 7 yearsThe OP seems to be using just a vector, since he mentions he wants a single integer from a one-hot value