IndexError: индекс кортежа выходит за пределы допустимого диапазона в Graphsage

Я пытаюсь создать графовую нейронную сеть для прогнозирования краев и получил эту ошибку. Был бы очень признателен, если бы кто-то мог мне помочь.

      from sklearn.metrics import roc_auc_score
model = GraphSAGE(train_g.ndata['congestion_onehot'].shape[1],16)
# You can replace DotPredictor with MLPPredictor.
#pred = MLPPredictor(16)
pred = DotPredictor()

def compute_loss(pos_score, neg_score):
    scores = torch.cat([pos_score, neg_score])
    labels = torch.cat([torch.ones(pos_score.shape[0]), torch.zeros(neg_score.shape[0])])
    return F.binary_cross_entropy_with_logits(scores, labels)

def compute_auc(pos_score, neg_score):
    scores = torch.cat([pos_score, neg_score]).numpy()
    labels = torch.cat(
        [torch.ones(pos_score.shape[0]), torch.zeros(neg_score.shape[0])]).numpy()
    return roc_auc_score(labels, scores)

Ошибка была:

      ---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
<ipython-input-56-d9c7e915d747> in <module>()
      1 from sklearn.metrics import roc_auc_score
----> 2 model = GraphSAGE(train_g.ndata['congestion_onehot'].shape[1],16)
      3 # You can replace DotPredictor with MLPPredictor.
      4 #pred = MLPPredictor(16)
      5 pred = DotPredictor()

IndexError: tuple index out of range

Если это поможет

      train_g
      Graph(num_nodes=4333, num_edges=60222,
      ndata_schemes={'congestion_onehot': Scheme(shape=(), dtype=torch.int64)}
      edata_schemes={'weight': Scheme(shape=(), dtype=torch.float64)})

0 ответов

Другие вопросы по тегам