pyTorch может вернуться назад дважды без установки retain_graph=True

Как указано в уроке по PyTorch,

если вы даже хотите выполнить обратную часть в некоторой части графика дважды, вам нужно передать retain_graph = True во время первого прохода.

Однако я обнаружил, что следующий фрагмент кода на самом деле работает без этого. Я использую pyTorch-0.4

x = torch.ones(2, 2, requires_grad=True)
y = x + 2
y.backward(torch.ones(2, 2)) # Note I do not set retain_graph=True
y.backward(torch.ones(2, 2)) # But it can still work!
print x.grad

выход:

tensor([[ 2.,  2.], 
        [ 2.,  2.]]) 

Кто-нибудь может объяснить? Заранее спасибо!

0 ответов

Причина, почему это работает без retain_graph=True в вашем случае у вас очень простой график, который, вероятно, не будет иметь внутренних промежуточных буферов, в свою очередь, никакие буферы не будут освобождены, поэтому нет необходимости использовать retain_graph=True,

Но все меняется при добавлении еще одного дополнительного вычисления в ваш график:

Код:

x = torch.ones(2, 2, requires_grad=True)
v = x.pow(3)
y = v + 2

y.backward(torch.ones(2, 2))

print('Backward 1st time w/o retain')
print('x.grad:', x.grad)

print('Backward 2nd time w/o retain')

try:
    y.backward(torch.ones(2, 2))
except RuntimeError as err:
    print(err)

print('x.grad:', x.grad)

Выход:

Backward 1st time w/o retain
x.grad: tensor([[3., 3.],
                [3., 3.]])
Backward 2nd time w/o retain
Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.
x.grad: tensor([[3., 3.],
                [3., 3.]]).

В этом случае дополнительный внутренний v.grad будет вычислено, но torch не хранит промежуточные значения (промежуточные градиенты и т. д.), а с retain_graph=Falsev.grad будет освобожден после первого backward,

Итак, если вы хотите сделать backprop второй раз, вам нужно указать retain_graph=True "держать" график.

Код:

x = torch.ones(2, 2, requires_grad=True)
v = x.pow(3)
y = v + 2

y.backward(torch.ones(2, 2), retain_graph=True)

print('Backward 1st time w/ retain')
print('x.grad:', x.grad)

print('Backward 2nd time w/ retain')

try:
    y.backward(torch.ones(2, 2))
except RuntimeError as err:
    print(err)
print('x.grad:', x.grad)

Выход:

Backward 1st time w/ retain
x.grad: tensor([[3., 3.],
                [3., 3.]])
Backward 2nd time w/ retain
x.grad: tensor([[6., 6.],
                [6., 6.]])
Другие вопросы по тегам