PyTorch - Перезаписанные переменные остаются в графе?

Мне интересно, хранятся ли в вычислительном графе PyTorch тензорные средства PyTorch, в которых переменные Python перезаписываются?


Итак, вот небольшой пример, где у меня есть модель RNN, где скрытые состояния (и некоторые другие переменные) сбрасываются после каждой итерации, backward() называется позже.

Пример:

for i in range(5):
   output = rnn_model(inputs[i])
   loss += criterion(output, target[i])
   ## hidden states are overwritten with a zero vector
   rnn_model.reset_hidden_states() 
loss.backward()

Итак, мой вопрос:

  • Есть ли проблема в перезаписи скрытых состояний перед вызовом? backward()?

  • Или вычислительный граф хранит в памяти необходимую информацию о скрытых состояниях предыдущих итераций для вычисления градиентов?

  • Редактировать: было бы здорово иметь официальное заявление для этого. например, указав, что все переменные, относящиеся к CG, сохранены - независимо от того, есть ли еще другие ссылки на python для этих переменных. Я предполагаю, что в самом графе есть ссылка, которая не позволяет сборщику мусора удалить его. Но я хотел бы знать, так ли это на самом деле.

Заранее спасибо!

1 ответ

Я думаю, что это нормально, чтобы сбросить, прежде чем назад. График сохраняет необходимую информацию.

class A (torch.nn.Module):
     def __init__(self):
         super().__init__()
         self.f1 = torch.nn.Linear(10,1)
     def forward(self, x):
         self.x = x 
         return torch.nn.functional.sigmoid (self.f1(self.x))
     def reset_x (self):
        self.x = torch.zeros(self.x.shape) 
net = A()
net.zero_grad()
X = torch.rand(10,10) 
loss = torch.nn.functional.binary_cross_entropy(net(X), torch.ones(10,1))
loss.backward()
params = list(net.parameters())
for i in params: 
    print(i.grad)
net.zero_grad() 

loss = torch.nn.functional.binary_cross_entropy(net(X), torch.ones(10,1))
net.reset_x()
print (net.x is X)
del X
loss.backward()     
params = list(net.parameters())
for i in params:
    print(i.grad)

В приведенном выше коде я печатаю грады с / без сброса ввода x. Градиент зависит от x точно, и его сброс не имеет значения. Поэтому я думаю, что граф сохраняет информацию для выполнения обратной операции.

Другие вопросы по тегам