Не удалось получить значение тензора
При запуске набора данных MNIST я хочу знать, что на самом деле моя модель выводит во время обучения партии. Вот мой код:(я не добавил оптимизатор и функцию потерь):
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
INPUT_NODE = 784 # the total pixels of the input images
OUTPUT_NODE = 10 # the output varies from 0 to 9
LAYER_NODE = 500
BATCH_SIZE = 100
TRAINING_STEPS = 10
def inference(input_tensor, avg_class, weight1, biase1, weight2, biase2):
if avg_class == None:
layer = tf.nn.relu(tf.matmul(input_tensor, weight1) + biase1)
return tf.matmul(layer, weight2)+biase2
else:
layer = tf.nn.relu(tf.matmul(input_tensor, avg_class.average(weight1)) +
avg_class.average(biase1))
return tf.matmul(layer, avg_class.average(weight2)) + avg_class.average(biase2)
def train(mnist):
x = tf.placeholder(tf.float32, [None, INPUT_NODE], name = 'x-input')
y = tf.placeholder(tf.float32, [None, OUTPUT_NODE],name = 'y-input')
weight1 = tf.Variable(tf.truncated_normal([INPUT_NODE, LAYER_NODE], stddev = 0.1))
biase1 = tf.Variable(tf.constant(0.1, shape = [LAYER_NODE]))
weight2 = tf.Variable(tf.truncated_normal([LAYER_NODE, OUTPUT_NODE], stddev = 0.1))
biase2 = tf.Variable(tf.constant(0.1, shape = [OUTPUT_NODE]))
out = inference(x, None, weight1, biase1, weight2, biase2)
with tf.Session() as sess:
tf.global_variables_initializer().run()
validate_feed = {x:mnist.validation.images, y:mnist.validation.labels}
test_feed = {x:mnist.test.images, y:mnist.test.labels}
for i in range(TRAINING_STEPS):
xs, ys = mnist.train.next_batch(BATCH_SIZE)
sess.run(out, feed_dict= {x:xs, y:ys})
print(out)
def main(arg = None):
mnist = input_data.read_data_sets("/home/vincent/Tensorflow/MNIST/data/", one_hot = True)
train(mnist)
if __name__ == '__main__':
tf.app.run()
Я пытаюсь распечатать:
Тензор ("add_1:0", shape=(?, 10), dtype=float32)
Если я хочу узнать ценность out, что мне делать? Я пытался print(out.eval())
и это подняло ошибку
1 ответ
Решение
out
является тензорным объектом. Если вы хотите получить его значение, замените
sess.run(out, feed_dict= {x:xs, y:ys})
print(out)
с
res_out=sess.run(out, feed_dict= {x:xs, y:ys})
print(res_out)