Skip to content Skip to sidebar Skip to footer

Creating `input_fn` From Iterator

Most tutorials focus on the case where the entire training dataset fits into memory. However, I have an iterator which acts as an infinite stream of (features, labels)-tuples (crea

Solution 1:

The argument of input_fn are used throughout training but the function itself is called once. So creating a sophisticated input_fn that goes beyond returning a constant array as explained in the tutorial is not as straightforward.

Tensorflow proposes two examples of such non-trivial input_fn for numpy and panda arrays, but they start from an array in memory, so this does not help you with your problem.

You could also have a look at their code by following the links above, to see how they implement an efficient non-trivial input_fn, but you may find that it requires more code that you would like.

If you are willing to use the less-high level interface of Tensorflow, things are IMHO simpler and more flexible. There is a tutorial that covers most needs and the proposed solutions are easy(-er) to implement.

In particular, if you already have an iterator that returns data as you described in your question, using placeholders (section "Feeding" in the previous link) should be straightforward.

Solution 2:

I found a pull request which converts a generator to an input_fn: https://github.com/tensorflow/tensorflow/pull/7045/files

The relevant part is

def_generator_input_fn():
    """generator input function."""
    queue = feeding_functions.enqueue_data(
      x,
      queue_capacity,
      shuffle=shuffle,
      num_threads=num_threads,
      enqueue_size=batch_size,
      num_epochs=num_epochs)

    features = (queue.dequeue_many(batch_size) if num_epochs isNoneelse queue.dequeue_up_to(batch_size))
    ifnotisinstance(features, list):
      features = [features]
    features = dict(zip(input_keys, features))
    if target_key isnotNone:
      iflen(target_key) > 1:
        target = {key: features.pop(key) for key in target_key}
      else:
        target = features.pop(target_key[0])
      return features, target
    return features
  return _generator_input_fn

Solution 3:

from tensorflow.contrib.learn.python.learn.learn_io import generator_io
import numpy as np

# define generatordefgenerator():
    for index inrange(2):
        yield {'a': np.ones(1) * index,'b': np.ones(1) * index + 32,'label': np.ones(1) * index - 32}

input_fn = generator_io.generator_input_fn(generator, target_key='label', batch_size=2, shuffle=False, num_epochs=1)
features, target = input_fn()

Refer to the test case https://github.com/tensorflow/tensorflow/pull/7045/files

Post a Comment for "Creating `input_fn` From Iterator"