индекс 43280 выходит за пределы для измерения 0 с размером 32

def train(epoch):model.train()loss_all = 0 для данных в train_loader:data = data.to(device)optimizer.zero_grad()output = model(data.x.float(), data.edge_index, data. пакет) потеря = F.nll_потеря (выход, данные.y) потеря.назад() потеря_все += данные.количество_графов * потеря.элемент() оптимизатор.шаг() возврат потеря_все/длина(поезд_набор данных)

      def test(loader):
    model.eval()
    correct = 0
    for data in loader:
        data = data.to(device)
        #for i in data.batch:
        pred = model(data.x.float(), data.edge_index, data.batch).max(dim=1)[1]
        correct += pred.eq(data.y).sum().item()
    return correct / len(loader.dataset)

train_loader = DataLoader(train_dataset, batch_size=32)
test_loader = DataLoader(test_dataset, batch_size=32)

for epoch in range(1, 201):
    loss = train(epoch)
    train_acc = test(train_loader)
    test_acc = test(test_loader)
    print('Epoch: {:03d}, Loss: {:.5f}, Train Acc: {:.5f}, Test Acc: {:.5f}'.
          format(epoch, loss, train_acc, test_acc))

Сообщение об ошибке:

Это дает ошибку времени выполнения: «индекс 43280 выходит за пределы для измерения 0 с размером 32».

      /usr/local/lib/python3.7/dist-packages/torch_geometric/deprecation.py:12: UserWarning: 'data.DataLoader' is deprecated, use 'loader.DataLoader' instead
  warnings.warn(out)
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-36-5c532377865e> in <module>()
     36 
     37 for epoch in range(1, 201):
---> 38     loss = train(epoch)
     39     train_acc = test(train_loader)
     40     test_acc = test(test_loader)

7 frames
/usr/local/lib/python3.7/dist-packages/torch_scatter/scatter.py in scatter_sum(src, index, dim, out, dim_size)
     19             size[dim] = int(index.max()) + 1
     20         out = torch.zeros(size, dtype=src.dtype, device=src.device)
---> 21         return out.scatter_add_(dim, index, src)
     22     else:
     23         return out.scatter_add_(dim, index, src)

RuntimeError: index 43280 is out of bounds for dimension 0 with size 32

What can be done to solve this issue, thanks for your help.

0 ответов

Другие вопросы по тегам