Stellargraph не работает с перемешиванием данных

когда я запустил демонстрацию StellarGraph по классификации графов с использованием DGCNN, я получил тот же результат, что и в демонстрации.

Однако, когда я тестировал, что происходит, когда я впервые перетасовываю данные, используя следующий код:

      shuffler = list(zip(graphs, graph_labels))
random.shuffle(shuffler)
graphs, graph_labels = zip(*shuffler)

Модель вообще не обучалась (точность около 50% - как и распределение данных).

кто-нибудь знает, почему это произошло? Может я не так тасовал? Или данные должны быть перетасованы в первую очередь (также почему? Это не имеет никакого смысла)? Или это ошибка в реализации StellarGraph?

1 ответ

Решение

Я нашел проблему. Это не имело ничего общего ни с алгоритмом перетасовки, ни с реализацией StellarGraph. Проблема была в демонстрации, в следующих строках:

      train_gen = gen.flow(
    list(train_graphs.index - 1),
    targets=train_graphs.values,
    batch_size=50,
    symmetric_normalization=False,
)

test_gen = gen.flow(
    list(test_graphs.index - 1),
    targets=test_graphs.values,
    batch_size=1,
    symmetric_normalization=False,
)

Проблема была вызвана, в частности, train_graphs.index - 1 а также test_graphs.index - 1. Индексы уже находятся в диапазоне от 0 к n, поэтому вычитание одного из них приведет к тому, что данные графика «сместятся» на один назад, в результате чего каждая точка данных получит метку другой точки данных.

Чтобы исправить это, просто измените их на train_graphs.index а также test_graphs.index без -1 в конце.

Другие вопросы по тегам