RandomLinkSplit не работает с HeteroData
у меня серьезные проблемы с
torch-geometric
при работе с моими собственными данными. Я пытаюсь построить граф, который имеет 4 различных объекта узла (из которых только 1 имеет некоторые функции узла, остальные - простые узлы) и 5 различных типов ребер (из которых только один имеет вес). Мне удалось сделать это, построив
HeteroData()
объекта и загрузки различных матриц с метками, атрибутами и т.д.
Проблема возникает, когда я пытаюсь позвонить
RandomLinkSplit
. Вот как выглядит мой звонок:
import torch_geometric.transforms as T
transform = T.RandomLinkSplit(
num_val = 0.1,
num_test = 0.1,
edge_types = [('Patient', 'suffers_from', 'Diagnosis'),
('bla', 'bla', 'bla') #I copy all the edge types here
],
)
но я получаю пустой
AssertionError
при условии:
assert is instance(rev_edge_types, list)
Поэтому я подумал, что мне нужно преобразовать граф в неориентированный (по какой-то странной причине), как это делается в учебнике, а затем сэмплировать также обратные ребра (хотя они мне не нужны):
import torch_geometric.transforms as T
data = T.ToUndirected()(data)
transform = T.RandomLinkSplit(
num_val = 0.1,
num_test = 0.1,
edge_types = [('Patient', 'suffers_from', 'Diagnosis'),
('bla', 'bla', 'bla') #I copy all the edge types here
],
rev_edge_types = [('Diagnosis', 'rev_suffers_from', 'Patient'),
...
]
)
но на этот раз я получаю ошибку
unsupported operand type(s) for *: 'Tensor' and 'NoneType'
.
У кого-нибудь из экспертов есть идеи, почему это происходит? Я просто пытаюсь выполнить разделение тестов поезда, и из документов, которые я читал, гетерогенные графики должны хорошо поддерживаться, но я не понимаю, почему это не работает, и я пробовал разные вещи в течение довольно много времени.
Любая помощь будет оценена по достоинству! Спасибо
1 ответ
Вам следует попытаться разделить каждое ребро и тренироваться на одном типе ребра за раз.
transform = T.RandomLinkSplit(
num_val = 0.1,
num_test = 0.1,
edge_types = ('Patient', 'suffers_from', 'Diagnosis'),
rev_edge_types = ('Diagnosis', 'rev_suffers_from', 'Patient')
)