tf.argmax() для более чем одного индекса Tensorflow
В Tensorflow tf.argmax() возвращает индекс наибольшего элемента в массиве.
Однако для задач классификации с несколькими метками функция, которая возвращает N самых больших элементов в массиве, была бы очень удобной.
predicted_array: [0.4, 0.6, 0.7, 0.2, 0.9]
tf.something(predicted_array, N = 2): [2,4]
Чтобы затем сравнить его с основной истиной один горячий кодированный массив
one_hot_array: [0, 0, 1, 0, 1]
tf.something(one_hot_array, N = 2): [2,4]
Есть ли такая функция? Или что-то похожее на это?
Спасибо за любую помощь
1 ответ
Решение
Да, есть. это tf.nn.top_k
( отсюда).
Вы можете использовать его как tf.nn.top_k(predicted_array, k=2)