How to use argmax tensorflow function in 3d array?
Я хочу знать, как использовать tf.argmax
in 3D array.
My input data is like that:
[[[0, -1, 5, 2, 1], [2, 2, 3, 2, 5], [6, 1, 2, 4, -1]],
[[-1, -2, 3, 2, 1], [0, 3, 2, 7, -1], [-1, 5, 2, 1, 3]]]
And I want to get the output of argmax by this input data like this:
[[2, 4, 0], [2, 3, 1]]
И я хочу использовать softmax_cross_entropy_with_logits
function in this format.
Как я должен использовать tf.nn.softmax_cross_entropy_with_logits
функция и tf.equal(tf.argmax)
а также tf.reduce_mean(tf.cast)
?
1 ответ
Решение
Ты можешь использовать tf.argmax
вместе axis=3
a = tf.constant([[[0, -1, 5, 2, 1], [2, 2, 3, 2, 5], [6, 1, 2, 4, -1]],
[[-1, -2, 3, 2, 1], [0, 3, 2, 7, -1], [-1, 5, 2, 1, 3]]])
b = tf.argmax(a, axis=2)