TensorFlow: What Are The Input Nodes For Tf.Estimator Models

I trained a Wide & Deep model using the pre-made Estimator class (DNNLinearCombinedClassifier), by essentially following the tutorial on I wanted to do inferenc

Solution 1:

If you want inference, but without using tensorflow-serving, you can just use the tf.estimator.Estimator predict method.

But if you want to do it manually (so that is runs faster), you need a workaround. I am not sure if what I did was exactly the best approach, but it worked. Here's my solution.

1) Let's do the imports and create variables and fake data:

import os
import numpy as np
from functools import partial
import pickle
import tensorflow as tf

N = 10000
EPOCHS = 1000

X_data = np.random.random((N, 10))
y_data = (np.random.random((N, 1)) >= 0.5).astype(int)

my_dir = os.getcwd() + "/"

2) Define an input_fn, which you will use Save the tensor names in a dictionary ("input_tensor_map"), which maps the input key to the tensor name.

def my_input_fn(X, y=None, is_training=False):

    def internal_input_fn(X, y=None, is_training=False):

        if (not isinstance(X, dict)):
            X = {"x": X}

        if (y is None):
            dataset =
            dataset =, y))

        if (is_training):
            dataset = dataset.repeat().shuffle(100)
            batch_size = BATCH_SIZE
            batch_size = 1

        dataset = dataset.batch(batch_size)

        dataset_iter = dataset.make_initializable_iterator()

        if (y is None):
            features = dataset_iter.get_next()
            labels = None
            features, labels = dataset_iter.get_next()

        input_tensor_map = dict()
        for input_name, tensor in features.items():
            input_tensor_map[input_name] =

        with open(os.path.join(my_dir, 'input_tensor_map.pickle'), 'wb') as f:
            pickle.dump(input_tensor_map, f, protocol=pickle.HIGHEST_PROTOCOL)

        tf.add_to_collection(tf.GraphKeys.TABLE_INITIALIZERS, dataset_iter.initializer)

        return (features, labels) if (not labels is None) else features

    return partial(internal_input_fn, X=X, y=y, is_training=is_training)

3) Define your model, to be used in your tf.estimator.Estimator. For example:

def my_model_fn(features, labels, mode):

    output = tf.layers.dense(inputs=features["x"], units=1, activation=None)
    logits = tf.identity(output, name="logits")
    prediction = tf.nn.sigmoid(logits, name="predictions")
    classes = tf.to_int64(tf.greater(logits, 0.0), name="classes")

    predictions_dict = {
                "class": classes,
                "probabilities": prediction

    if mode == tf.estimator.ModeKeys.PREDICT:
        return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions_dict)

    one_hot_labels = tf.squeeze(tf.one_hot(tf.cast(labels, dtype=tf.int32), 2))
    loss = tf.losses.sigmoid_cross_entropy(multi_class_labels=one_hot_labels, logits=logits)

    tf.summary.scalar("loss", loss)

    accuracy = tf.reduce_mean(tf.to_float(tf.equal(labels, classes)))
    tf.summary.scalar("accuracy", accuracy)

    # Configure the Training Op (for TRAIN mode)
    if (mode == tf.estimator.ModeKeys.TRAIN):
        train_op = tf.train.AdamOptimizer().minimize(loss, global_step=tf.train.get_global_step())
        return tf.estimator.EstimatorSpec(mode=mode, loss=loss, train_op=train_op)

    return tf.estimator.EstimatorSpec(mode=mode, loss=loss)

4) Train and freeze your model. The freeze method is from TensorFlow: How to freeze a model and serve it with a python API, which I added a tiny modification.

def freeze_graph(output_node_names):
    """Extract the sub graph defined by the output nodes and convert 
    all its variables into constant 
        model_dir: the root folder containing the checkpoint state file
        output_node_names: a string, containing all the output node's names, 
                            comma separated
    if (output_node_names is None):
        output_node_names = 'loss'

    if not tf.gfile.Exists(my_dir):
        raise AssertionError(
            "Export directory doesn't exists. Please specify an export "
            "directory: %s" % my_dir)

    if not output_node_names:
        print("You need to supply the name of a node to --output_node_names.")
        return -1

    # We retrieve our checkpoint fullpath
    checkpoint = tf.train.get_checkpoint_state(my_dir)
    input_checkpoint = checkpoint.model_checkpoint_path

    # We precise the file fullname of our freezed graph
    absolute_model_dir = "/".join(input_checkpoint.split('/')[:-1])
    output_graph = absolute_model_dir + "/frozen_model.pb"

    # We clear devices to allow TensorFlow to control on which device it will load operations
    clear_devices = True

    # We start a session using a temporary fresh Graph
    with tf.Session(graph=tf.Graph()) as sess:
        # We import the meta graph in the current default Graph
        saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=clear_devices)

        # We restore the weights
        saver.restore(sess, input_checkpoint)

        # We use a built-in TF helper to export variables to constants
        output_graph_def = tf.graph_util.convert_variables_to_constants(
            sess, # The session is used to retrieve the weights
            tf.get_default_graph().as_graph_def(), # The graph_def is used to retrieve the nodes 
            output_node_names.split(",") # The output node names are used to select the usefull nodes

        # Finally we serialize and dump the output graph to the filesystem
        with tf.gfile.GFile(output_graph, "wb") as f:
        print("%d ops in the final graph." % len(output_graph_def.node))

    return output_graph_def

# *****************************************************************************


estimator = tf.estimator.Estimator(model_fn=my_model_fn, model_dir=my_dir)

if (estimator.latest_checkpoint() is None):
5) Finally, you can use the frozen graph for inference, input tensors names are in the dictionary that you saved. Again, the method to load the freezed model from TensorFlow: How to freeze a model and serve it with a python API.

def load_frozen_graph(prefix="frozen_graph"):
    frozen_graph_filename = os.path.join(my_dir, "frozen_model.pb")

    # We load the protobuf file from the disk and parse it to retrieve the 
    # unserialized graph_def
    with tf.gfile.GFile(frozen_graph_filename, "rb") as f:
        graph_def = tf.GraphDef()

    # Then, we import the graph_def into a new Graph and returns it 
    with tf.Graph().as_default() as graph:
        # The name var will prefix every op/nodes in your graph
        # Since we load everything in a new graph, this is not needed
        tf.import_graph_def(graph_def, name=prefix)

    return graph

# *****************************************************************************

X_test = {"x": np.random.random((int(N/2), 10))}

prefix = "frozen_graph"
graph = load_frozen_graph(prefix)

for op in graph.get_operations():

with open(os.path.join(my_dir, 'input_tensor_map.pickle'), 'rb') as f:
    input_tensor_map = pickle.load(f)

with tf.Session(graph=graph) as sess:
    input_feed = dict()

    for key, tensor_name in input_tensor_map.items():
        tensor = graph.get_tensor_by_name(prefix + "/" + tensor_name)
        input_feed[tensor] = X_test[key]

    logits = graph.get_operation_by_name(prefix + "/logits").outputs[0]
    probabilities = graph.get_operation_by_name(prefix + "/predictions").outputs[0]
    classes = graph.get_operation_by_name(prefix + "/classes").outputs[0]

    logits_values, probabilities_values, classes_values =[logits, probabilities, classes], feed_dict=input_feed)

