Что означает параметр retain_graph в методе переменной backward()?
Я прохожу учебник по нейронной передаче Pytorch и запутался в использовании retain_variable
(устарел, теперь называется retain_graph
). Пример кода показывает:
class ContentLoss(nn.Module):
def __init__(self, target, weight):
super(ContentLoss, self).__init__()
self.target = target.detach() * weight
self.weight = weight
self.criterion = nn.MSELoss()
def forward(self, input):
self.loss = self.criterion(input * self.weight, self.target)
self.output = input
return self.output
def backward(self, retain_variables=True):
#Why is retain_variables True??
self.loss.backward(retain_variables=retain_variables)
return self.loss
Из документации
retain_graph (bool, необязательно) - если False, график, используемый для вычисления града, будет освобожден. Обратите внимание, что почти во всех случаях установка этой опции в True не требуется и часто может быть обойдена гораздо более эффективным способом. По умолчанию используется значение create_graph.
Итак, установив retain_graph= True
Мы не освобождаем память, выделенную для графика на обратном проходе. В чем преимущество сохранения этой памяти, зачем она нам нужна?
2 ответа
@cleros хорош в том, что касается использования retain_graph=True
, По сути, он будет хранить любую необходимую информацию для вычисления определенной переменной, чтобы мы могли выполнить обратную передачу по ней.
Наглядный пример
Предположим, что у нас есть граф вычислений, показанный выше. Переменная d
а также e
это выход, и a
это вход. Например,
import torch
from torch.autograd import Variable
a = Variable(torch.rand(1, 4), requires_grad=True)
b = a**2
c = b*2
d = c.mean()
e = c.sum()
когда мы делаем d.backward()
Это нормально. После этого вычисления часть графика, которая вычисляет d
будет освобожден по умолчанию для экономии памяти. Так что, если мы делаем e.backward()
появится сообщение об ошибке. Для того, чтобы сделать e.backward()
, мы должны установить параметр retain_graph
в True
в d.backward()
т.е.
d.backward(retain_graph=True)
Пока вы используете retain_graph=True
в своем обратном методе вы можете сделать обратное в любое время:
d.backward(retain_graph=True) # fine
e.backward(retain_graph=True) # fine
d.backward() # also fine
e.backward() # error will occur!
Более полезное обсуждение можно найти здесь.
Реальный вариант использования
Прямо сейчас, реальным вариантом использования является многозадачное обучение, когда у вас есть множественные потери, которые могут быть на разных уровнях. Предположим, что у вас есть 2 потери: loss1
а также loss2
и они проживают в разных слоях. Для того, чтобы backprop градиент loss1
а также loss2
по отношению к изучаемому весу вашей сети самостоятельно. Вы должны использовать retain_graph=True
в backward()
метод в первом обратном распространении потерь.
# suppose you first back-propagate loss1, then loss2 (you can also do the reverse)
loss1.backward(retain_graph=True)
loss2.backward() # now the graph is freed, and next process of batch gradient descent is ready
optimizer.step() # update the network parameters
Это очень полезная функция, когда у вас есть несколько выходов из сети. Вот полностью вымышленный пример: представьте, что вы хотите построить какую-то случайную сверточную сеть, в которой вы можете задать два вопроса: содержит ли входное изображение кошку и содержит ли автомобиль изображение?
Один из способов сделать это состоит в том, чтобы иметь сеть, которая разделяет сверточные слои, но которая имеет два параллельных классификационных слоя, следующие (простите мой ужасный график ASCII, но предполагается, что это три конвектора, за которыми следуют три полностью связанных слоя, один для кошек). и один для авто):
-- FC - FC - FC - cat?
Conv - Conv - Conv -|
-- FC - FC - FC - car?
Учитывая картину, на которой мы хотим запустить обе ветви, при обучении сети мы можем сделать это несколькими способами. Во-первых (что, вероятно, было бы лучше всего здесь, иллюстрируя, насколько плох пример), мы просто рассчитываем потери на обеих оценках и суммируем потери, а затем получаем обратное распространение.
Однако есть еще один сценарий, в котором мы хотим сделать это последовательно. Сначала мы хотим сделать backprop через одну ветку, а затем через другую (у меня был этот вариант использования раньше, поэтому он не полностью составлен). В этом случае работает .backward()
на одном графике также будет уничтожена любая информация о градиенте в сверточных слоях, и сверточные вычисления второй ветви (так как они являются единственными, совместно используемыми с другой ветвью) больше не будут содержать граф! Это означает, что когда мы пытаемся выполнить обратный переход через вторую ветвь, Pytorch выдаст ошибку, так как не может найти график, соединяющий вход с выходом! В этих случаях мы можем решить проблему, просто сохранив график на первом обратном проходе. Тогда график не будет использован, а будет использован только при первом обратном проходе, который не требует его сохранения.
РЕДАКТИРОВАТЬ: Если вы сохраняете граф на всех обратных проходах, неявные определения графа, прикрепленные к выходным переменным, никогда не будут освобождены. Здесь также может быть случай использования, но я не могу придумать один. В общем, вы должны убедиться, что последний проход назад освобождает память, не сохраняя графическую информацию.
Что касается того, что происходит для нескольких обратных проходов: как вы уже догадались, pytorch накапливает градиенты, добавляя их на месте (к переменным / параметрам .grad
имущество). Это может быть очень полезно, поскольку это означает, что циклическое выполнение пакета и его обработка по одному за раз, накапливая градиенты в конце, будет выполнять тот же шаг оптимизации, что и полное пакетное обновление (которое только суммирует все градиенты как Что ж). В то время как полностью пакетное обновление может быть распараллелено больше и, таким образом, в целом является предпочтительным, существуют случаи, когда пакетное вычисление либо очень, очень сложно реализовать, либо просто невозможно. Однако, используя это накопление, мы все еще можем положиться на некоторые хорошие стабилизирующие свойства, которые дает дозирование. (Если не на прирост производительности)