Операции на месте с PyTorch
Мне было интересно, как работать с операциями на месте в PyTorch. Как я помню, использование операции на месте с autograd всегда было проблематичным.
И на самом деле я удивлен, что этот код ниже работает, хотя я не проверял его, я думаю, что этот код вызвал бы ошибку в версии 0.3.1
,
В основном, я хочу сделать, это установить определенную позицию тензорного вектора на определенное значение, например:
my_tensor[i] = 42
Рабочий пример кода:
# test parameter a
a = torch.rand((2), requires_grad=True)
print('a ', a)
b = torch.rand(2)
# calculation
c = a + b
# performing in-place operation
c[0] = 0
print('c ', c)
s = torch.sum(c)
print('s ', s)
# calling backward()
s.backward()
# optimizer step
optim = torch.optim.Adam(params=[a], lr=0.5)
optim.step()
# changed parameter a
print('changed a', a)
Выход:
a tensor([0.2441, 0.2589], requires_grad=True)
c tensor([0.0000, 1.1511], grad_fn=<CopySlices>)
s tensor(1.1511, grad_fn=<SumBackward0>)
changed a tensor([ 0.2441, -0.2411], requires_grad=True)
Так очевидно в версии 0.4.1
, это работает просто отлично без предупреждений или ошибок.
Ссылаясь на эту статью в документации: автоград-механика
Поддержка операций на месте в autograd является сложной задачей, и в большинстве случаев мы не рекомендуем их использовать. Агрессивное освобождение и повторное использование буфера Autograd делает его очень эффективным, и очень мало случаев, когда операции на месте фактически уменьшают использование памяти на сколько-нибудь значительную величину. Если вы не работаете под сильным давлением памяти, вам, возможно, никогда не понадобится их использовать.
Но даже при том, что это работает, использование операций на месте в большинстве случаев не рекомендуется.
Итак, мои вопросы:
Насколько использование операций на месте влияет на производительность?
Как обойти использование операций на месте в тех случаях, когда я хочу установить для одного элемента тензора определенное значение?
Заранее спасибо!
3 ответа
Я не уверен, насколько операции на месте влияют на производительность, но я могу ответить на второй запрос. Вы можете использовать маску вместо оперативных операций.
a = torch.rand((2), requires_grad=True)
print('a ', a)
b = torch.rand(2)
# calculation
c = a + b
# performing in-place operation
mask = np.zeros(2)
mask[1] =1
mask = torch.tensor(mask)
c = c*mask
...
Возможно, это не прямой ответ на ваш вопрос, а просто для информации.
Операции на месте работают для тензоров, не являющихся листовыми, в вычислительном графе.
Тензоры листьев - это тензоры, которые являются «концом» вычислительного графа. Официально (от
Для тензоров, у которых require_grad имеет значение True, они будут листовыми тензорами, если они были созданы пользователем. Это означает, что они не являются результатом операции, и поэтому grad_fn равен None.
Пример, который работает без ошибок:
a = torch.tensor([3.,2.,7.], requires_grad=True)
print(a) # tensor([3., 2., 7.], requires_grad=True)
b = a**2
print(b) # tensor([ 9., 4., 49.], grad_fn=<PowBackward0>)
b[1] = 0
print(b) # tensor([ 9., 0., 49.], grad_fn=<CopySlices>)
c = torch.sum(2*b)
print(c) # tensor(116., grad_fn=<SumBackward0>)
c.backward()
print(a.grad) # tensor([12., 0., 28.])
С другой стороны, операции на месте не работают для листовых тензоров.
Пример, вызывающий ошибку:
a = torch.tensor([3.,2.,7.], requires_grad=True)
print(a) # tensor([3., 2., 7.], requires_grad=True)
a[1] = 0
print(a) # tensor([3., 0., 7.], grad_fn=<CopySlices>)
b = a**2
print(b) # tensor([ 9., 0., 49.], grad_fn=<PowBackward0>)
c = torch.sum(2*b)
print(c) # tensor(116., grad_fn=<SumBackward0>)
c.backward() # Error occurs at this line.
# RuntimeError: leaf variable has been moved into the graph interior
Я полагаю, что
старый b ---(CopySlices)----> новый b
С другой стороны, тензор
Это всего лишь мое личное мнение, поэтому обращайтесь к официальным документам.
Примечание:
Хотя операции на месте работают для промежуточных тензоров, будет безопасно использовать как можно больше клонирования и отсоединения, когда вы выполняете некоторые операции на месте, чтобы явно создать новый тензор, который не зависит от вычислительного графа.
Для вашего второго запроса, когда вы выполняете c[i] = i
или аналогичные операции, __setitem__
обычно называется. Чтобы выполнить эту операцию на месте, вы можете попробовать вызвать__setitem__
функция (если это то, что выполняет c[i] = i
операция.