cast tensorflow 2.0 BatchDataset to numpy array
After batching of dataset, the shape of last batch may not be same with that of rest of the batches. For example, if there are totally 100 elements in your dataset and you batch with size of 6, the last batch will have size of only 4. (100 = 6 * 16 + 4).
So, in such cases, you will not be able to transform your dataset into numpy straight forward. For that reason, you will have to use drop_remainder
parameter to True in batch method. It will drop the last batch if it is not correctly sized.
After that, I have enclosed the code on how to convert dataset to Numpy.
import tensorflow as tf
import numpy as np
(train_images, _), (test_images, _) = tf.keras.datasets.mnist.load_data()
TRAIN_BUF=1000
BATCH_SIZE=64
train_dataset = tf.data.Dataset.from_tensor_slices(train_images).
shuffle(TRAIN_BUF).batch(BATCH_SIZE, drop_remainder=True)
test_dataset = tf.data.Dataset.from_tensor_slices(test_images).
shuffle(TRAIN_BUF).batch(BATCH_SIZE, drop_remainder=True)
# print(train_dataset, type(train_dataset), test_dataset, type(test_dataset))
train_np = np.stack(list(train_dataset))
test_np = np.stack(list(test_dataset))
print(type(train_np), train_np.shape)
print(type(test_np), test_np.shape)
Output:
<class 'numpy.ndarray'> (937, 64, 28, 28)
<class 'numpy.ndarray'> (156, 64, 28, 28)
mhery
Data Science Specialist and Mulesoft Certified Developer. Nowadays, learning Dataweave and JavaScript.
Updated on July 06, 2022Comments
-
mhery almost 2 years
I have this code:
(train_images, _), (test_images, _) = tf.keras.datasets.mnist.load_data() train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(TRAIN_BUF).batch(BATCH_SIZE) test_dataset = tf.data.Dataset.from_tensor_slices(test_images).shuffle(TRAIN_BUF).batch(BATCH_SIZE) print(train_dataset, type(train_dataset), test_dataset, type(test_dataset))
And I want to cast these two
BatchDataset
variables tonumpy arrays
, can I do it easily? I am usingTF 2.0
, but I just found code to casttf.data
withTF 1.0
-
mhery almost 5 yearsthis code seems right to me, but i tried to run it at google colab, and it stucked on line that transform the data to list
-
Prasad almost 5 yearsCurrently, by default Google Colab still uses TensorFlow 1.14 version. So you will have to manually install TF2.0 by running
!pip install tensorflow==2.0.0-rc0
. After that you will not get the mentioned freezing problem.