Преобразование графа PyG в граф NetworkX

Я пытаюсь преобразовать свой график PyG в график NetworkX, используя

Согласно документам , я могу дополнительно передавать атрибуты node и edge как итерации str в дополнение к объекту Data.

Ниже приведены списки атрибутов узлов и ребер со значениями, преобразованными в строки:

      Nodes:  ['3.3375725746154785', '2.0086510181427',..., '1.5960148572921753', '3.621992349624634']

Edges:  ['0.9940207804344958', '0.48573804411542043', ..., '0.7245483440145621', '0.24117984598949904']

to_networkxработает нормально, когда я передаю ему только объект данных. Однако, когда я также передаю эти списки атрибутов, я получаю следующую ошибку:

      G[u][v][key] = values[key][i]
KeyError: '0.30194718370332896'

Я просмотрел исходный код, но не могу понять, что он делает. Может ли кто-нибудь помочь объяснить, что не так с моими списками атрибутов и что мне нужно изменить, чтобы они были приняты.

Что я могу понять, так это то, что эта ошибка конкретно относится к моим атрибутам края. Если я их удалю, я получу следующую аналогичную ошибку, связанную с атрибутами узла:

      feat_dict.update({key: values[key][i]})
KeyError: '0.0'

1 ответ

Вам нужно передать имена атрибутов в виде списка:

      to_networkx(<PyTorchGeometricDataObject>, node_attrs=[<Name of Node Attribute 1>, <Name of Node Attributes 2>, ... ], edge_attr=[<Edge Attribute 1>, ...])

Или в контексте, исходя из вашего минимального примера:

      import numpy as np
import torch
from torch_geometric.data import Data
from torch_geometric.utils import to_networkx

n1 = np.repeat(np.array([0,1,2,3,4,5,6]),5)
n2 = np.array([0,1,2,3,4,0,1,2,3,4,0,1,2,3,4,0,1,2,3,4,0,1,2,3,4,0,1,2,3,4,0,1,2,3,4])
cat = np.stack((n1,n2), axis=0)
e = torch.tensor(cat, dtype=torch.long)
edge_index = e.t().clone().detach()
edge_attr = torch.tensor(np.random.rand(35,1))

x = torch.tensor([[0], [0], [0], [0], [0], [1], [1]], dtype=torch.float)

data = Data(x=x, edge_index=edge_index.t().contiguous(), edge_attr = edge_attr)
print(data)
# Data(edge_attr=[35, 1], edge_index=[2, 35], x=[7, 1])

networkX_graph = to_networkx(data, node_attrs=["x"], edge_attrs=["edge_attr"])

print(networkX_graph.nodes(data=True))
# [(0, {'x': 0.0}), (1, {'x': 0.0}),...
print(networkX_graph.edges(data=True))
# [(0, 0, {'edge_attr': 0.3412137594357493}), ...
Другие вопросы по тегам