Tensorflow Dataset .map() API

11,477

When you use Dataset.map(map_func), TensorFlow defines a subgraph for all the ops created in the function map_func, and arranges to execute it efficiently in the same session as the rest of your graph. There is almost never any need to create a tf.Graph or tf.Session inside map_func: if your parsing function is made up of TensorFlow ops, these ops can be embedded directly in the graph that defines the input pipeline.

The modified version of the code using tf.data would look like this:

import tensorflow as tf 
from tensorflow.contrib.framework.python.ops import audio_ops as contrib_audio

def _some_audio_preprocessing_func(filename):
    wav_loader = tf.read_file(filename)
    return contrib_audio.decode_wav(wav_loader, desired_channels=1)

dataset = tf.data.Dataset.list_files('*.wav')
dataset = dataset.map(_some_preprocessing_func)

If your map_func contains non-TensorFlow operations that you want to apply to each element, you should wrap them in a tf.py_func() (or Dataset.from_generator(), if the data generation process is defined in Python logic). The main performance implication is that any code running in a tf.py_func() is subject to the Global Interpreter Lock, so I would generally recommend trying to find a native TensorFlow implementation for anything that is performance critical.

Share:
11,477
lollercoaster
Author by

lollercoaster

Updated on August 04, 2022

Comments

  • lollercoaster
    lollercoaster almost 2 years

    Couple of questions about this

    For occasions when I'd like to do something like the following in Tensorflow (assume I'm creating training examples by loading WAV files):

    import tensorflow as tf 
    
    def _some_audio_preprocessing_func(filename):
       # ... some logic here which mostly uses Tensorflow ops ...
       with tf.Session(graph=tf.Graph()) as sess:
            wav_filename_placeholder = tf.placeholder(tf.string, [])
            wav_loader = io_ops.read_file(wav_filename_placeholder)
            wav_decoder = contrib_audio.decode_wav(wav_loader, desired_channels=1)
            data = sess.run(
                    [wav_decoder],
                    feed_dict={wav_filename_placeholder: filename})
            return data
    
    dataset = tf.data.Dataset.list_files('*.wav')
    dataset = dataset.map(_some_preprocessing_func)
    
    1. If I have a parse_image() function that uses tensor ops - should this be part of the main Graph? Following the example set in Google's own audio TF tutorial, it looks like they create a separate graph! Doesn't this ruin the point of using Tensorflow to make things faster?
    2. Do I use tf.py_func() any time any single line isn't from the tensorflow library? Again, I wonder what the performance implications are and when I should use this...

    Thanks!

  • Anuj
    Anuj over 4 years
    Hi. I have a follow up question on this. If I am using this dataset to train a model, then after training how do I save it such that the TensorFlow ops from the _some_audio_preprocessing_func are also included in the final model? Thanks