tf.data.Dataset: how to get the dataset size (number of elements in an epoch)?
Solution 1
tf.data.Dataset.list_files
creates a tensor called MatchingFiles:0
(with the appropriate prefix if applicable).
You could evaluate
tf.shape(tf.get_default_graph().get_tensor_by_name('MatchingFiles:0'))[0]
to get the number of files.
Of course, this would work in simple cases only, and in particular if you have only one sample (or a known number of samples) per image.
In more complex situations, e.g. when you do not know the number of samples in each file, you can only observe the number of samples as an epoch ends.
To do this, you can watch the number of epochs that is counted by your Dataset
. repeat()
creates a member called _count
, that counts the number of epochs. By observing it during your iterations, you can spot when it changes and compute your dataset size from there.
This counter may be buried in the hierarchy of Dataset
s that is created when calling member functions successively, so we have to dig it out like this.
d = my_dataset
# RepeatDataset seems not to be exposed -- this is a possible workaround
RepeatDataset = type(tf.data.Dataset().repeat())
try:
while not isinstance(d, RepeatDataset):
d = d._input_dataset
except AttributeError:
warnings.warn('no epoch counter found')
epoch_counter = None
else:
epoch_counter = d._count
Note that with this technique, the computation of your dataset size is not exact, because the batch during which epoch_counter
is incremented typically mixes samples from two successive epochs. So this computation is precise up to your batch length.
Solution 2
len(list(dataset))
works in eager mode, although that's obviously not a good general solution.
Solution 3
Take a look here: https://github.com/tensorflow/tensorflow/issues/26966
It doesn't work for TFRecord datasets, but it works fine for other types.
TL;DR:
num_elements = tf.data.experimental.cardinality(dataset).numpy()
Solution 4
As of TensorFlow (>=2.3
) one can use:
dataset.cardinality().numpy()
Note that the .cardinality()
method was integrated into the main package (before it was in the experimental
package).
Note that when applying the filter()
operation this operation can return -2
.
Solution 5
UPDATE:
Use tf.data.experimental.cardinality(dataset)
- see here.
In case of tensorflow datasets you can use _, info = tfds.load(with_info=True)
. Then you may call info.splits['train'].num_examples
. But even in this case it doesn't work properly if you define your own split.
So you may either count your files or iterate over the dataset (like described in other answers):
num_training_examples = 0
num_validation_examples = 0
for example in training_set:
num_training_examples += 1
for example in validation_set:
num_validation_examples += 1
Related videos on Youtube
nessuno
Updated on July 21, 2022Comments
-
nessuno almost 2 years
Let's say I have defined a dataset in this way:
filename_dataset = tf.data.Dataset.list_files("{}/*.png".format(dataset))
how can I get the number of elements that are inside the dataset (hence, the number of single elements that compose an epoch)?
I know that
tf.data.Dataset
already knows the dimension of the dataset, because therepeat()
method allows repeating the input pipeline for a specified number of epochs. So it must be a way to get this information.-
P-Gn almost 6 yearsDo you need to have this information before the first epoch completed, or is it okay to compute it after?
-
nessuno almost 6 yearsBefore the first epoch completed
-
benjaminplanche almost 6 yearsWorking as an
iterator
, I don't think aDataset
knows the total number of elements before reaching the last one - then it starts repeating over if requested (c.f. source repeat_dataset_op.cc) -
GPhilo almost 6 yearsCan't you just list the files in
"{}/*.png".format(dataset)
before (say viaglob
oros.listdir
), get the length of that and then pass the list to a Dataset? Datasets don't have (natively) access to the number of items they contain (knowing that number would require a full pass on the dataset, and you still have the case of unlimited datasets coming from streaming data or generators) -
nessuno almost 6 years@GPhilo I could only in this particular case, but I'd like to have a more general solution.
-
nessuno almost 6 years@user1735003 thank you for you answer, I'm gonna test it soon. Can you please also add the option to get the size after the end of the first epoch?
-
GPhilo almost 6 years@nessuno the thing is, there is no general solution, because Datasets don't know their size. If you have TFRecord datasets, for example, there is no way for you to know at creation time how may samples your dataset contains. The only way is to count them as you go, or do a full pass of the dataset before you start training (which, depending on your dataset's size, can be quite slow)
-
nessuno almost 6 years@GPhilo understood, thank you for the explanation! However the answer of user1735003 perfectly fits my needs
-
irudyak almost 4 yearsFrom what I can see in official
tf
tutorials - they count files before creating a dataset, not number of elements in a dataset.
-
-
Happy Gene over 4 yearsAlternatively, a more concise way to add things up in TF 2.0:
count = dataset.reduce(0, lambda x, _: x + 1)
-
CSharp over 4 yearsI found you have to call numpy() on count to get the actual value otherwise count is a tensor. i.e: count = dataset.reduce(0, lambda x, _: x + 1).numpy()
-
yrekkehs over 4 yearsIt defeats the purpose of it being an iterator. Calling list() runs the entire thing in a single shot. It works for smaller amounts of data, but can likely take too many resources for larger datasets.
-
markemus over 4 years@yrekkehs absolutely, that's why it's not a good general solution. But it works.
-
yrekkehs over 4 years@markemus Didn't mean to sound contentious, I was just trying to answer PhonoDots. :)
-
markemus over 4 years@yrekkehs gotcha, and I agree :)
-
bachr over 3 years
train_ds.cardinality().numpy()
is given me-2
!!! -
Timbus Calin over 3 yearsIt's giving you -2 because you have used .filter() somewhere in your code
-
Timbus Calin over 3 years
-
Timbus Calin over 3 yearsYou can try to see, this works prior to applying filter :D
-
Li-Pin Juan about 3 yearsHi, I think you are wrong.
len()
is not applicable totf.data.dataset
object. Based on the discussion of this thread, it's unlikely to have this feature in the near future. -
Li-Pin Juan about 3 yearsI think your solution is incorrent. The return object, ds, is not the same as what split['train'] represents. You can see what I mean by this:
(train, val), info = tfds.load('oxford_iiit_pet', split=['train[:70%]','train[70%:]'], shuffle_files=True, as_supervised=True)
. The sizes of subdatasets train and val change as we modify the percentage specified in split= argument. However,info.splits['train'].num_examples
is fixed at 3680. -
alzoubi36 about 3 yearsHey, I would not describe it as not applicable. I had a dataset of 391 images and it returned exactly that.
-
Li-Pin Juan about 3 yearsI knew it works in some cases but generally it doesn't work.
len()
is unable to be applied on a Dataset object like this one, for example,tfds.load('tf_flowers')['train'].repeat()
because the size of it is infinite. -
Sabito 錆兎 stands with Ukraine about 3 yearsAre you trying to answer the question or are you asking a question?
-
krenerd about 3 years@Yatin I found a very fast solution(the second code snippet), but I also want to understand how this works behind the scenes, and how to clean it up.
-
Li-Pin Juan about 3 yearsThe method doesn't work for MapDataset object.
-
learner almost 3 yearsUsing len(tfdataset) raises TypeError: dataset length is unknown.