PyTorch - Перезаписанные переменные остаются в графе?
Мне интересно, хранятся ли в вычислительном графе PyTorch тензорные средства PyTorch, в которых переменные Python перезаписываются?
Итак, вот небольшой пример, где у меня есть модель RNN, где скрытые состояния (и некоторые другие переменные) сбрасываются после каждой итерации,
backward()
называется позже.Пример:
for i in range(5):
output = rnn_model(inputs[i])
loss += criterion(output, target[i])
## hidden states are overwritten with a zero vector
rnn_model.reset_hidden_states()
loss.backward()
Итак, мой вопрос:
Есть ли проблема в перезаписи скрытых состояний перед вызовом?
backward()
?Или вычислительный граф хранит в памяти необходимую информацию о скрытых состояниях предыдущих итераций для вычисления градиентов?
Редактировать: было бы здорово иметь официальное заявление для этого. например, указав, что все переменные, относящиеся к CG, сохранены - независимо от того, есть ли еще другие ссылки на python для этих переменных. Я предполагаю, что в самом графе есть ссылка, которая не позволяет сборщику мусора удалить его. Но я хотел бы знать, так ли это на самом деле.
Заранее спасибо!
1 ответ
Я думаю, что это нормально, чтобы сбросить, прежде чем назад. График сохраняет необходимую информацию.
class A (torch.nn.Module):
def __init__(self):
super().__init__()
self.f1 = torch.nn.Linear(10,1)
def forward(self, x):
self.x = x
return torch.nn.functional.sigmoid (self.f1(self.x))
def reset_x (self):
self.x = torch.zeros(self.x.shape)
net = A()
net.zero_grad()
X = torch.rand(10,10)
loss = torch.nn.functional.binary_cross_entropy(net(X), torch.ones(10,1))
loss.backward()
params = list(net.parameters())
for i in params:
print(i.grad)
net.zero_grad()
loss = torch.nn.functional.binary_cross_entropy(net(X), torch.ones(10,1))
net.reset_x()
print (net.x is X)
del X
loss.backward()
params = list(net.parameters())
for i in params:
print(i.grad)
В приведенном выше коде я печатаю грады с / без сброса ввода x. Градиент зависит от x точно, и его сброс не имеет значения. Поэтому я думаю, что граф сохраняет информацию для выполнения обратной операции.