tf.argmax() для более чем одного индекса Tensorflow

В Tensorflow tf.argmax() возвращает индекс наибольшего элемента в массиве.

Однако для задач классификации с несколькими метками функция, которая возвращает N самых больших элементов в массиве, была бы очень удобной.

predicted_array: [0.4, 0.6, 0.7, 0.2, 0.9]
tf.something(predicted_array, N = 2): [2,4]

Чтобы затем сравнить его с основной истиной один горячий кодированный массив

one_hot_array: [0, 0, 1, 0, 1]
tf.something(one_hot_array, N = 2): [2,4]

Есть ли такая функция? Или что-то похожее на это?

Спасибо за любую помощь

1 ответ

Решение

Да, есть. это tf.nn.top_k ( отсюда).

Вы можете использовать его как tf.nn.top_k(predicted_array, k=2)

Другие вопросы по тегам