tf.data.Dataset: how to get the dataset size (number of elements in an epoch)?

72,202

Solution 1

tf.data.Dataset.list_files creates a tensor called MatchingFiles:0 (with the appropriate prefix if applicable).

You could evaluate

tf.shape(tf.get_default_graph().get_tensor_by_name('MatchingFiles:0'))[0]

to get the number of files.

Of course, this would work in simple cases only, and in particular if you have only one sample (or a known number of samples) per image.

In more complex situations, e.g. when you do not know the number of samples in each file, you can only observe the number of samples as an epoch ends.

To do this, you can watch the number of epochs that is counted by your Dataset. repeat() creates a member called _count, that counts the number of epochs. By observing it during your iterations, you can spot when it changes and compute your dataset size from there.

This counter may be buried in the hierarchy of Datasets that is created when calling member functions successively, so we have to dig it out like this.

d = my_dataset
# RepeatDataset seems not to be exposed -- this is a possible workaround 
RepeatDataset = type(tf.data.Dataset().repeat())
try:
  while not isinstance(d, RepeatDataset):
    d = d._input_dataset
except AttributeError:
  warnings.warn('no epoch counter found')
  epoch_counter = None
else:
  epoch_counter = d._count

Note that with this technique, the computation of your dataset size is not exact, because the batch during which epoch_counter is incremented typically mixes samples from two successive epochs. So this computation is precise up to your batch length.

Solution 2

len(list(dataset)) works in eager mode, although that's obviously not a good general solution.

Solution 3

Take a look here: https://github.com/tensorflow/tensorflow/issues/26966

It doesn't work for TFRecord datasets, but it works fine for other types.

TL;DR:

num_elements = tf.data.experimental.cardinality(dataset).numpy()

Solution 4

As of TensorFlow (>=2.3) one can use:

dataset.cardinality().numpy()

Note that the .cardinality() method was integrated into the main package (before it was in the experimental package).

Note that when applying the filter() operation this operation can return -2.

Solution 5

UPDATE:

Use tf.data.experimental.cardinality(dataset) - see here.


In case of tensorflow datasets you can use _, info = tfds.load(with_info=True). Then you may call info.splits['train'].num_examples. But even in this case it doesn't work properly if you define your own split.

So you may either count your files or iterate over the dataset (like described in other answers):

num_training_examples = 0
num_validation_examples = 0

for example in training_set:
    num_training_examples += 1

for example in validation_set:
    num_validation_examples += 1
Share:
72,202

Related videos on Youtube

nessuno
Author by

nessuno

Updated on July 21, 2022

Comments

  • nessuno
    nessuno almost 2 years

    Let's say I have defined a dataset in this way:

    filename_dataset = tf.data.Dataset.list_files("{}/*.png".format(dataset))
    

    how can I get the number of elements that are inside the dataset (hence, the number of single elements that compose an epoch)?

    I know that tf.data.Dataset already knows the dimension of the dataset, because the repeat() method allows repeating the input pipeline for a specified number of epochs. So it must be a way to get this information.

    • P-Gn
      P-Gn almost 6 years
      Do you need to have this information before the first epoch completed, or is it okay to compute it after?
    • nessuno
      nessuno almost 6 years
      Before the first epoch completed
    • benjaminplanche
      benjaminplanche almost 6 years
      Working as an iterator, I don't think a Dataset knows the total number of elements before reaching the last one - then it starts repeating over if requested (c.f. source repeat_dataset_op.cc)
    • GPhilo
      GPhilo almost 6 years
      Can't you just list the files in "{}/*.png".format(dataset) before (say via glob or os.listdir), get the length of that and then pass the list to a Dataset? Datasets don't have (natively) access to the number of items they contain (knowing that number would require a full pass on the dataset, and you still have the case of unlimited datasets coming from streaming data or generators)
    • nessuno
      nessuno almost 6 years
      @GPhilo I could only in this particular case, but I'd like to have a more general solution.
    • nessuno
      nessuno almost 6 years
      @user1735003 thank you for you answer, I'm gonna test it soon. Can you please also add the option to get the size after the end of the first epoch?
    • GPhilo
      GPhilo almost 6 years
      @nessuno the thing is, there is no general solution, because Datasets don't know their size. If you have TFRecord datasets, for example, there is no way for you to know at creation time how may samples your dataset contains. The only way is to count them as you go, or do a full pass of the dataset before you start training (which, depending on your dataset's size, can be quite slow)
    • nessuno
      nessuno almost 6 years
      @GPhilo understood, thank you for the explanation! However the answer of user1735003 perfectly fits my needs
    • irudyak
      irudyak almost 4 years
      From what I can see in official tf tutorials - they count files before creating a dataset, not number of elements in a dataset.
  • Happy Gene
    Happy Gene over 4 years
    Alternatively, a more concise way to add things up in TF 2.0: count = dataset.reduce(0, lambda x, _: x + 1)
  • CSharp
    CSharp over 4 years
    I found you have to call numpy() on count to get the actual value otherwise count is a tensor. i.e: count = dataset.reduce(0, lambda x, _: x + 1).numpy()
  • yrekkehs
    yrekkehs over 4 years
    It defeats the purpose of it being an iterator. Calling list() runs the entire thing in a single shot. It works for smaller amounts of data, but can likely take too many resources for larger datasets.
  • markemus
    markemus over 4 years
    @yrekkehs absolutely, that's why it's not a good general solution. But it works.
  • yrekkehs
    yrekkehs over 4 years
    @markemus Didn't mean to sound contentious, I was just trying to answer PhonoDots. :)
  • markemus
    markemus over 4 years
    @yrekkehs gotcha, and I agree :)
  • bachr
    bachr over 3 years
    train_ds.cardinality().numpy() is given me -2!!!
  • Timbus Calin
    Timbus Calin over 3 years
    It's giving you -2 because you have used .filter() somewhere in your code
  • Timbus Calin
    Timbus Calin over 3 years
  • Timbus Calin
    Timbus Calin over 3 years
    You can try to see, this works prior to applying filter :D
  • Li-Pin Juan
    Li-Pin Juan about 3 years
    Hi, I think you are wrong. len() is not applicable to tf.data.dataset object. Based on the discussion of this thread, it's unlikely to have this feature in the near future.
  • Li-Pin Juan
    Li-Pin Juan about 3 years
    I think your solution is incorrent. The return object, ds, is not the same as what split['train'] represents. You can see what I mean by this: (train, val), info = tfds.load('oxford_iiit_pet', split=['train[:70%]','train[70%:]'], shuffle_files=True, as_supervised=True). The sizes of subdatasets train and val change as we modify the percentage specified in split= argument. However, info.splits['train'].num_examples is fixed at 3680.
  • alzoubi36
    alzoubi36 about 3 years
    Hey, I would not describe it as not applicable. I had a dataset of 391 images and it returned exactly that.
  • Li-Pin Juan
    Li-Pin Juan about 3 years
    I knew it works in some cases but generally it doesn't work. len() is unable to be applied on a Dataset object like this one, for example, tfds.load('tf_flowers')['train'].repeat() because the size of it is infinite.
  • Sabito 錆兎 stands with Ukraine
    Sabito 錆兎 stands with Ukraine about 3 years
    Are you trying to answer the question or are you asking a question?
  • krenerd
    krenerd about 3 years
    @Yatin I found a very fast solution(the second code snippet), but I also want to understand how this works behind the scenes, and how to clean it up.
  • Li-Pin Juan
    Li-Pin Juan about 3 years
    The method doesn't work for MapDataset object.
  • learner
    learner almost 3 years
    Using len(tfdataset) raises TypeError: dataset length is unknown.