问题描述:

I am using tf.contrib.learn.ReadBatchFeatures (https://www.tensorflow.org/versions/master/api_docs/python/contrib.learn/input_processing#read_batch_features) to read in Example protos as part of my input function, which returns a dict of Tensor objects. After training my model, calling predict on my Estimator returns one batch of predictions as an array, which I would like to compare to the known values.

I try to obtain the known values by calling tf.Session().run(labels), where labels is a Tensor of known values, returned from the input function. However, at this point, my program hangs. I suspect it is stuck in an infinite loop reading labels from the disk, rather than just reading one batch as I would like.

Is this the correct way to obtain one batch of values in the labels Tensor?

Edit: I have tried to start the queue runners, is the following correct?

_, labels = eval_input_fn()

with tf.Session().as_default():

tf.local_variables_initializer()

tf.train.start_queue_runners()

label_values = labels.eval()

print(label_values)

网友答案:

The whole setup you need is:

_, labels = eval_input_fn()

with tf.Session() as sess:
        sess.run([
            tf.local_variables_initializer(),
            tf.global_variables_initializer()
        ])

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)

        try:
            while not coord.should_stop():
                print(sess.run(label))

        except tf.errors.OutOfRangeError as error:
            coord.request_stop(error)
        finally:
            coord.request_stop()
            coord.join(threads)
相关阅读:
Top