Тестирование тензорной сети: замена in_top_k() для многослойной классификации
Я создал нейронную сеть в tenorflow. Эта сеть является многолинейной. Ergo: он пытается предсказать несколько выходных меток для одного входного набора, в данном случае три. В настоящее время я использую этот код, чтобы проверить, насколько точна моя сеть при прогнозировании трех меток:
_, indices_1 = tf.nn.top_k(prediction, 3)
_, indices_2 = tf.nn.top_k(item_data, 3)
correct = tf.equal(indices_1, indices_2)
accuracy = tf.reduce_mean(tf.cast(correct, 'float'))
percentage = accuracy.eval({champion_data:input_data, item_data:output_data})
Этот код работает отлично. Теперь проблема в том, что я пытаюсь создать код, который проверяет, находятся ли первые 3 элемента, которые он находит в indices_1, среди 5 лучших изображений в indices_2. Я знаю, что в tenorflow есть метод in_top_k(), но, насколько я знаю, он не поддерживает мультиметку. В настоящее время я пытаюсь сравнить их, используя цикл for:
_, indices_1 = tf.nn.top_k(prediction, 5)
_, indices_2 = tf.nn.top_k(item_data, 3)
indices_1 = tf.unpack(tf.transpose(indices_1, (1, 0)))
indices_2 = tf.unpack(tf.transpose(indices_2, (1, 0)))
correct = []
for element in indices_1:
for element_2 in indices_2:
if element == element_2:
correct.append(True)
else:
correct.append(False)
accuracy = tf.reduce_mean(tf.cast(correct, 'float'))
percentage = accuracy.eval({champion_data:input_data, item_data:output_data})
Однако это не работает. Код работает, но моя точность всегда 0.0.
Итак, у меня есть один из двух вопросов:
1) Существует ли простая замена in_top_k(), которая принимает многослойную классификацию, которую я могу использовать вместо пользовательского кода?
2) Если нет 1: что я делаю неправильно, что приводит к получению точности 0,0?
1 ответ
Когда вы делаете
correct = tf.equal(indices_1, indices_2)
Вы проверяете, не содержат ли эти два индекса одинаковые элементы, но содержат ли они одинаковые элементы в одинаковых позициях. Это не похоже на то, что вы хотите.
Оператор setdiff1d скажет вам, какие индексы есть в indices_1, но не в indices_2, которые вы затем можете использовать для подсчета ошибок.
Я думаю, что слишком строгая проверка правильности может привести к неправильному результату.