IndexError: индекс 87 выходит за пределы измерения 0 с размером 39

Я создаю классификатор изображений, и эта ошибка возникает, когда я использую Vgg16, DenseNet и т. Д. Я видел модель после печати модели с использованием этого model = models.densenet169(pretrained=True).

Теперь это настоящая ошибка -

output = model.forward(images)
---> 90     conf_matrix = confusion_matrix(output, labels, conf_matrix)
     91     p = torch.nn.functional.softmax(output, dim=1)
     92     prediction = torch.argmax(p, dim=1)

<ipython-input-11-04b29a81a4e9> in confusion_matrix(preds, labels, conf_matrix, title, cmap)
     50     preds = torch.argmax(preds, 1)
     51     for p, t in zip(preds, labels):
---> 52         conf_matrix[p, t] += 1
     53 
     54     #print(conf_matrix)

Вот как я реализовал свою модель,

model = models.densenet169(pretrained=True)

for param in model.parameters():
    param.requires_grad = True

model.fc = nn.Sequential(nn.Linear(1664, 512),
                                 nn.ReLU(),
                                 nn.Dropout(0.4),
                                 nn.Linear(512,128),
                                 nn.ReLU(),
                                 nn.Dropout(0.4),
                                 nn.Linear(128,39),
                                 nn.LogSoftmax(dim=1))

и это часть моей матрицы путаницы -

def confusion_matrix(preds, labels, conf_matrix, title='Confusion matrix', cmap=plt.cm.Blues):
    preds = torch.argmax(preds, 1)
    for p, t in zip(preds, labels):
        conf_matrix[p, t] += 1

    #print(conf_matrix)
    #plt.imshow(conf_matrix)
    TP = conf_matrix.diag()
    for c in range(n_classes):
        idx = torch.ones(n_classes).byte()
        idx[c] = 0
        TN = conf_matrix[idx.nonzero()[:,None], idx.nonzero()].sum()
        FP = conf_matrix[c, idx].sum()
        FN = conf_matrix[idx, c].sum()

        Recall = (TP[c] / (TP[c]+FN))
        precision = (TP[c] / (TP[c]+FP))
        f1 = (2 * ((precision * Recall)/(precision + Recall)))

        #print('Class {}\nTP {}, TN {}, FP {}, FN {}'.format(c, TP[c], TN, FP, FN))
        #print('Sensitivity = {}'.format(sensitivity))
        #print('Specificity = {}'.format(specificity))

    return conf_matrix

Кто-нибудь может сказать мне, в чем проблема? Я не понимаю, что происходит, потому что с ResNet все работало нормально. Кроме этого ошибок быть не должно. Спасибо.

0 ответов

Другие вопросы по тегам