Understanding tf.data.Dataset.interleave()

1 minute read

Published:

09/08/2024 - This is an old post I wrote in 2020 and posted on Medium.

After I read the official document for tf.data.Dataset.interleave(), I’m still confused about how it’s working. It took me quite some time to figure it out. Hope this blog saves your time.

Specifically, we look at this example from the official guide:

dataset = Dataset.range(1, 6)  # ==> [ 1, 2, 3, 4, 5 ]
# NOTE: New lines indicate "block" boundaries.
dataset = dataset.interleave(
  lambda x: Dataset.from_tensors(x).repeat(6),
    cycle_length=2, block_length=4)
list(dataset.as_numpy_iterator())

The output is

[1, 1,
2, 2,
3, 3, 3, 3,
4, 4, 4, 4,
3, 3,
4, 4,
5, 5, 5, 5,
5, 5]

But why is that?

Let’s look at what the function does first. Suppose you have an original dataset A, and you call the function using

interleave(
    map_func, cycle_length=None, block_length=None, num_parallel_calls=None,
    deterministic=None
)

what really happens behind are the follows:

  1. Take # cycle_length elements from dataset A, apply map_func to these elements and get # cycle_length new datasets
  2. Cycle through each returned dataset objects to produce # block_length new elements.
  3. If all elements in one of the new datasets have been consumed, go to step 1 and repeat the process.

Looking back at the example at the beginning, the underlying process can be summarized in the figure below:

Hope this blog is clear! Let me know if you have more questions/suggestions.

References:

  1. https://blog.csdn.net/menghuan
  2. https://www.tensorflow.org/api_docs/python/tf/data/Dataset#interleave