Как обернуть функции PyTorch и реализовать автоград?
Я работаю над учебником PyTorch по определению новых функций автограда. Функция автограда, которую я хочу реализовать - это оболочка torch.nn.functional.max_pool1d
, Вот что у меня так далеко:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.autograd as tag
class SquareAndMaxPool1d(tag.Function):
@staticmethod
def forward(ctx, input, kernel_size, stride=None, padding=0, dilation=1, \
return_indices=False, ceil_mode=False):
ctx.save_for_backward( input )
inputC = input.clone() #copy input
inputC *= inputC
output = F.max_pool1d(inputC, kernel_size, stride=stride, \
padding=padding, dilation=dilation, \
return_indices=return_indices, \
ceil_mode=ceil_mode)
return output
@staticmethod
def backward(ctx, grad_output):
input, = ctx.saved_tensors
grad_input = get_max_pool1d_grad_somehow(grad_output)
return 2.0*input*grad_input
Мой вопрос: как получить градиент обернутой функции? Я знаю, что, возможно, есть другие способы сделать это, учитывая, насколько простой пример, который я представляю, но то, что я хочу сделать, соответствует этой структуре и требует от меня реализации autograd
функция.
Изменить: После изучения этого сообщения в блоге я решил попробовать следующее для backward
:
def backward(ctx, grad_output):
input, output = ctx.saved_tensors
grad_input = output.backward(grad_output)
return 2.0*input*grad_input
с output
добавлены к сохраненным переменным. Затем я запускаю следующий код:
x = np.random.randn(1,1,5)
xT = torch.from_numpy(x)
xT.requires_grad=True
f = SquareAndMaxPool1d.apply
s = torch.sum(f(xT,2))
s.backward()
и я получаю Bus error: 10
,
Сказать, xT
является tensor([[[ 1.69533562, -0.21779421, 2.28693953, -0.86688095, -1.01033497]]], dtype=torch.float64)
тогда я бы ожидал, что xT.grad
является tensor([[[ 3.39067124, -0. , 9.14775812, -0. , -2.02066994]]], dtype=torch.float64)
после звонка s.backward()
(то есть 2*x*grad_of_max_pool
, с grad_of_max_pool
содержащий tensor([[[1., 0., 2., 0., 1.]]], dtype=torch.float64)
).
Я понял, почему я получаю Bus error: 10
, Похоже, что приведенный выше код приводит к рекурсивному вызову моего backward
в grad_input = output.backward(grad_output)
, Поэтому мне нужно найти какой-то другой способ получить градиент max_pool1d
, Я знаю, как реализовать это в чистом Python, но результат будет намного медленнее, чем если бы я мог обернуть код библиотеки.
1 ответ
Вы выбрали довольно неудачный пример. torch.nn.functional.max_pool1d
не является примером torch.autograd.Function
потому что это встроенный PyTorch, определенный в коде C++ и с автоматически сгенерированным связыванием Python. Я не уверен, возможно ли получить backward
свойство через его интерфейс.
Во-первых, если вы не заметили, вам не нужно писать какой-либо пользовательский код для обратного распространения этой формулы, потому что и питание, иmax_pool1d
уже определили, так что их состав также покрыт автоград. Предполагая, что ваша цель - это упражнение, я бы посоветовал вам сделать это более вручную (не прибегая кbackward
из max_pool1d
). Пример ниже
import torch
import torch.nn.functional as F
import torch.autograd as tag
class SquareAndMaxPool1d(tag.Function):
@staticmethod
def forward(ctx, input, kernel_size, **kwargs):
# we're gonna need indices for backward. Currently SquareAnd...
# never actually returns indices, I left it out for simplicity
kwargs['return_indices'] = True
input_sqr = input ** 2
output, indices = F.max_pool1d(input_sqr, kernel_size, **kwargs)
ctx.save_for_backward(input, indices)
return output
@staticmethod
def backward(ctx, grad_output):
input, indices = ctx.saved_tensors
# first we need to reconstruct the gradient of `max_pool1d`
# by putting all the output gradient elements (corresponding to
# input elements which made it through the max_pool1d) in their
# respective places, the rest has gradient of 0. We do it by
# scattering it against a tensor of 0s
grad_output_unpooled = torch.zeros_like(input)
grad_output_unpooled.scatter_(2, indices, grad_output)
# then incorporate the gradient of the "square" part of your
# operator
grad_input = 2. * input * grad_output_unpooled
# the docs for backward
# https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function.backward
# say that "it should return as many tensors, as there were inputs
# to forward()". It fails to mention that if an argument was not a
# tensor, it should return None (I remember reading this somewhere,
# but can't find it anymore). Anyway, we need to
# return a (grad_input, None) tuple to avoid a complaint that two
# outputs were expected
return grad_input, None
Затем мы можем использовать средство проверки числового градиента, чтобы убедиться, что операция работает должным образом.
f = SquareAndMaxPool1d.apply
xT = torch.randn(1, 1, 6, requires_grad=True, dtype=torch.float64)
tag.gradcheck(lambda t: f(t, 2), xT)
Извините, если это не решит ваш вопрос о том, как получить backward
из max_pool1d
, но, надеюсь, вы найдете мой ответ достаточно полезным.
Проблемы, которые у вас были с рекурсивными вызовами, на самом деле связаны с "выходом" и тем фактом, что по умолчанию "with no_grad" является поведением по умолчанию, которое, похоже, в объявлении класса унаследовано от torch.autograd.Function. Если вы проверите output.grad_fn в прямом направлении, он, вероятно, будет None, а в обратном направлении он, вероятно, будет ссылаться на объект функции
import torch
import torch.nn.functional as F
class custom_Linear(nn.Linear):
def forward(self, _input):
return Custom_Linear_AGfn_getAround.apply(_input, self.weight, self.bias)
class Custom_Linear_AGfn_getAround(torch.autograd.Function):
@staticmethod
def forward(ctx, _input, _weight, _bias):
print('Custom forward')
with torch.enable_grad():
detached_input = _input.detach()
detached_input.requires_grad_(True)
detached_weight = _weight.detach()
detached_weight.requires_grad_(True)
detached_bias = _bias.detach()
detached_bias.requires_grad_(True)
_tmp = F.linear(detached_input, detached_weight, detached_bias)
ctx.saved_input = detached_input
ctx.saved_param = detached_weight, detached_bias
ctx.save_for_backward(_tmp)
_output = _tmp.detach()
return _output
@staticmethod
def backward(ctx, grad_out):
print('Custom backward')
_tmp, = ctx.saved_tensors
_weight, _bias = ctx.saved_param
detached_input = ctx.saved_input
with torch.enable_grad():
_tmp.backward(grad_out)
return detached_input.grad, _weight.grad, _bias.grad
По сути, речь идет о построении небольшого изолированного графа для интересующей его части без нарушения работы с основным графом и использовании grad_fn и requires_grad для отслеживания графов при поиске того, что нужно отсоединить и что необходимо для изолированного графа.
О сложных частях:
- отсоединение веса и смещения: вы можете обойтись без него, но ЛИБО вы затем передадите _weight и _bias через save_for_backward и получите _weight.grad, _bias.grad как None внутри назад, НО как только за пределами _weight.grad, _bias.grad будут иметь свои правильные значения, ИЛИ вы передаете их через атрибут, например ctx.saved_param, и в этом случае вам придется вручную указать None для двух последних возвращенных значений backward (return detached_input.grad, None, None), иначе вы получите вдвое больше правильное значение, когда вы впоследствии проверяете градиент веса и смещения за пределами заднего хода.
- как сказано в начале, назад и вперед для унаследованного класса torch.autograd.Function, похоже, по умолчанию имеет поведение 'with no_grad'. Таким образом, удаление 'with torch.enable_grad():' в приведенном выше коде приведет к тому, что '_tmp.grad_fn' будет равно None (не удалось понять, почему по умолчанию для параметра gradient_fn было значение None, а для параметра requires_grad значение False в прямом направлении, несмотря на то, что для него требовался градиент для detached_input, пока я не наткнулся на: https://github.com/pytorch/pytorch/issues/7698)
- Я полагаю, но я не проверял, что вы можете получить двойной grad_fn для _output, если вы не отключите его, как когда у меня нет 'with torch.enable_grad()', и не отключаю выход, в результате чего получается '_tmp.grad_fn'равно None в прямом направлении, он получает
grad_fn в обратном направлении (и приводит к бесконечным рекурсивным вызовам).