Pytorch, не может работать backward() даже в самой простой сети без ошибки
Я новичок в pytorch и не могу запустить backward() даже в самой простой сети без выдачи ошибки. Например:
(Linear(6, 6)(Variable(torch.zeros([10, 6]))) - Variable(torch.zeros([10, 6]))).backward()
Выдает следующую ошибку
{RuntimeError}element 0 of variables does not require grad and does not have a grad_fn
Что я сделал неправильно в коде, чтобы создать эту проблему?
1 ответ
Попробуйте добавить grad_output соответствующей формы в качестве параметра для backward:
(Линейный (6, 6) (Переменная (torch.zeros([10, 6]))) - Переменная (torch.zeros([10, 6]))). Назад (torch.zeros ([10, 6]))
В следующем ответе есть более подробная информация: почему функция backward должна вызываться только для тензора с 1 элементом или с градиентами относительно переменной?
Эта ошибка возникает, когда PyTorch не может найти параметры модели, которые имеют requires_grad = True
т.е. все параметры модели имеют requires_grad = False
,
Существуют разные причины, но может случиться так, что вы замораживаете всю модель или неправильно меняете конечные слои модели - например, в остальной сети это должен быть model.fc, а не model.classifier-,
Вы всегда должны быть осторожны, где вы размещаете это:
for param in model.parameters():
param.requires_grad = False