Классификация графов Stellargraph: как делать прогнозы
Я использую этот код:
Все работает с точностью 80%. Мои данные для обучения представляют собой графики с метками 0 или 1. Большой вопрос в том, как делать прогнозы сейчас.
Я пробую следующее:
gen2 = PaddedGraphGenerator(graphs=sg_test)
test2 = gen2.flow(sg_test)
pred = model.predict(test2)
Я получаю прогнозы, но не думаю, что делаю это правильно.
sg_test
представляет собой список графиков звездного графа, которые модель никогда раньше не видела. У меня есть метки, но я не буду показывать модель, поскольку пытаюсь предсказать метки.
Результаты следующие:
0
0 0.176684
1 0.001646
2 0.187621
3 0.173595
4 0.388054
5 0.293297
6 0.236243
7 0.078395
8 0.296253
9 0.2984
Пожалуйста помоги.
1 ответ
Сигмоидальная функция используется для двухклассовой логистической регрессии, это классический подход машинного обучения, не специфичный для графового машинного обучения. Последний слой сети, о котором вы говорите, - это
predictions = Dense(units=1, activation="sigmoid")(x_out)
Это означает, что на выходе выводится значение в интервале [0,1], соответствующее вероятности того, что метка равна 0 или 1. Вы должны сами определить, где находится порог. Если вы установите его на 0,5, это означает, что любое значение выше этого соответствует метке 1. Если вы хотите большей уверенности, вы можете взять только то, что выше, скажем, 0,7. Итак, все ваши прогнозы в вашем вопросе соответствуют метке 0.
Погуглите немного, прочтите несколько книг по машинному обучению или посмотрите (из многих) эту статью, в которой объясняется, как смотреть на сигмоид с точки зрения проблемы классификации.