Making Life Better (read: Easier) with Tensorflow Dataset API

Andika Tanuwijaya
4 min readJan 13, 2018

Tensorflow is in my opinion the most powerful framework for deep learning application and for any kind of graph computation really. It has large number of high and low level APIs, the best visualization tool (Tensorboard) out there, very active development (IYKWIM) and vast community of both users and Tensorflow developers. Despite of that, there is one thing that can be overwhelming, building the training input pipeline.

There are some options ones can take to use in their training input pipeline. Most of the tutorials out there either use iterator imported from keras like DirectoryIterator, building their own iterator, or using queue runner in the process. Using iterator (combined with feed_dict mechanism) is good for toy deep learning cases since the code is relatively less complex than queue runner based ones. But the problem with the technique is that it is relatively slow since there is an overhead reading the data, resulting in the longer training time since the data cannot be provided fast enough to the GPU. This bottleneck was overcome with the introduction of queue runner which enables us to use multi threading. The idea of queue runner is basically you have a queue that acts as a data pipe and will automatically be filled whenever the capacity is less than a certain threshold. Unfortunately queue runner’s performance is limited by Python GIL so it cannot achieve optimal multi-threading performance and somehow the resulting code is more complex than the other techniques (very subjective).

# Create a queue runner that will run 4 threads in parallel to enqueue
# examples.
qr = tf.train.QueueRunner(queue, [enqueue_op] * 4)

# Launch the graph.
sess = tf.Session()
# Create a coordinator, launch the queue runner threads.
coord = tf.train.Coordinator()
enqueue_threads = qr.create_threads(sess, coord=coord, start=True)
# Run the training loop, controlling termination with the coordinator.
for step in xrange(1000000):
if coord.should_stop():
break
sess.run(train_op)
# When done, ask the threads to stop.
coord.request_stop()
# And wait for them to actually do it.
coord.join(enqueue_threads)

But worry not friends, last year Tensorflow team introduced the Dataset API which based on the performance guide utilizes multi threading on C++ level. Not only that, the API is a lot cleaner and nicer to work with than queue runner. In the next sections we will look into how to use the Dataset API in combination with tfrecords for image segmentation input pipeline. Without further ado let’s get started!

We will use the portrait segmentation dataset used in Automatic Portrait Segmentation for Image Stylization paper which can be downloaded from here (I stole the download script from this repository, remove some missing and corrupt data, and convert the .mat files to .jpg). We will use the image and mask folder generated to generate the tfrecords. I will not explain this part in this post but you can use this script and run it like this.

python create_tfrecords.py --image_dir images_data_crop/ --mask_dir images_mask/ --train_size 1 --validation_size 0

After the tfrecords are created, we can then read them using the Dataset API. We will create a function to read and parse data from each record.

def _extract_features(example):
features = {
"image": tf.FixedLenFeature((), tf.string),
"mask": tf.FixedLenFeature((), tf.string)
}
parsed_example = tf.parse_single_example(example, features)
images = tf.cast(tf.image.decode_jpeg(parsed_example["image"]), dtype=tf.float32)
images.set_shape([800, 600, 3])
masks = tf.cast(tf.image.decode_jpeg(parsed_example["mask"]), dtype=tf.float32) / 255.
masks.set_shape([800, 600, 1])
return images, masks

In this function, we first declare the keys we want to extract from the records, which are image and mask in this example. Since image and mask are both originally jpeg file we use tf.image.decode_jpeg to decode the bytestring and set their shape (this is needed so that it can be used during graph computation).

Next, we will create the dataset and the iterator that can be used during graph computation in a session. We pass list of our tfrecords filenames (or just a single filename if there is only one) to the TFRecordDataset argument. Then we map and parse each record using the function created earlier. We can also batch the records, shuffle it and repeat it to the number of epoch we want. The iterator can be created using make_one_shot_iterator which basically is one time use or make_initializable_iterator which can be initialized/reset whenever you want. You can read more about it here.

dataset = tf.data.TFRecordDataset(filenames)
dataset = dataset.map(_extract_features)
dataset = dataset.batch(8)
dataset = dataset.shuffle(buffer_size=50)
dataset = dataset.repeat(1)
iterator = dataset.make_one_shot_iterator()

The last important part is to get the batch tensor sample from the iterator which will automatically get the next one whenever we compute it via Tensorflow session run and that’s it.

next_images, next_masks = iterator.get_next()images, masks = sess.run([next_images, next_masks])

Another thing to be noted here is that one shot and initializable iterator will throw OutOfRangeError when it reaches the last batch so you would want to catch the exception that can be used to denote the end of epoch and initialize it again for new epoch (for initializable iterator) or you can also use MonitoredTrainingSession to automatically handle it.

Thanks for reading this post and please give any feedback :)

PS: Code is available in my repository.

--

--