Stellargraph не работает с перемешиванием данных
когда я запустил демонстрацию StellarGraph по классификации графов с использованием DGCNN, я получил тот же результат, что и в демонстрации.
Однако, когда я тестировал, что происходит, когда я впервые перетасовываю данные, используя следующий код:
shuffler = list(zip(graphs, graph_labels))
random.shuffle(shuffler)
graphs, graph_labels = zip(*shuffler)
Модель вообще не обучалась (точность около 50% - как и распределение данных).
кто-нибудь знает, почему это произошло? Может я не так тасовал? Или данные должны быть перетасованы в первую очередь (также почему? Это не имеет никакого смысла)? Или это ошибка в реализации StellarGraph?
1 ответ
Я нашел проблему. Это не имело ничего общего ни с алгоритмом перетасовки, ни с реализацией StellarGraph. Проблема была в демонстрации, в следующих строках:
train_gen = gen.flow(
list(train_graphs.index - 1),
targets=train_graphs.values,
batch_size=50,
symmetric_normalization=False,
)
test_gen = gen.flow(
list(test_graphs.index - 1),
targets=test_graphs.values,
batch_size=1,
symmetric_normalization=False,
)
Проблема была вызвана, в частности,
train_graphs.index - 1
а также
test_graphs.index - 1
. Индексы уже находятся в диапазоне от
0
к
n
, поэтому вычитание одного из них приведет к тому, что данные графика «сместятся» на один назад, в результате чего каждая точка данных получит метку другой точки данных.
Чтобы исправить это, просто измените их на
train_graphs.index
а также
test_graphs.index
без
-1
в конце.