IndexError: индекс кортежа выходит за пределы допустимого диапазона в Graphsage
Я пытаюсь создать графовую нейронную сеть для прогнозирования краев и получил эту ошибку. Был бы очень признателен, если бы кто-то мог мне помочь.
from sklearn.metrics import roc_auc_score
model = GraphSAGE(train_g.ndata['congestion_onehot'].shape[1],16)
# You can replace DotPredictor with MLPPredictor.
#pred = MLPPredictor(16)
pred = DotPredictor()
def compute_loss(pos_score, neg_score):
scores = torch.cat([pos_score, neg_score])
labels = torch.cat([torch.ones(pos_score.shape[0]), torch.zeros(neg_score.shape[0])])
return F.binary_cross_entropy_with_logits(scores, labels)
def compute_auc(pos_score, neg_score):
scores = torch.cat([pos_score, neg_score]).numpy()
labels = torch.cat(
[torch.ones(pos_score.shape[0]), torch.zeros(neg_score.shape[0])]).numpy()
return roc_auc_score(labels, scores)
Ошибка была:
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
<ipython-input-56-d9c7e915d747> in <module>()
1 from sklearn.metrics import roc_auc_score
----> 2 model = GraphSAGE(train_g.ndata['congestion_onehot'].shape[1],16)
3 # You can replace DotPredictor with MLPPredictor.
4 #pred = MLPPredictor(16)
5 pred = DotPredictor()
IndexError: tuple index out of range
Если это поможет
train_g
Graph(num_nodes=4333, num_edges=60222,
ndata_schemes={'congestion_onehot': Scheme(shape=(), dtype=torch.int64)}
edata_schemes={'weight': Scheme(shape=(), dtype=torch.float64)})