Будет ли создание «данных» градиента путем их отсоединения реализовать MAML первого порядка с использованием более высокой библиотеки PyTorch?
Я использовал более высокую библиотеку pytorch для MAML и хотел запустить MAML первого порядка. Я до сих пор не понимал, что
track_higher_grads
(вероятно, моя ошибка, потому что в прошлом я находил документы запутанными, например, см. Что означает документация copy_initial_weights в более высокой библиотеке для Pytorch?).
Но теперь я понял, что в моем коде может быть странная версия MAML, и хотел убедиться, что она правильная.
Официальные документы говорят, что вы можете просто установить . Что я думаю, хорошо (скорее всего).
Однако я сделал необработанное число градиента, отделив его от графа вычислений, например
if self.fo: # first-order
g = g.detach() # dissallows flow of higher order grad while still letting params track gradients.
Мне было интересно, будет ли это эквивалентно
track_higher_grads=False
. В частности, я отсоединяюсь, но оставляю... вот что меня смущает.
Я считаю, что они должны быть эквивалентны/давать одинаковые результаты, но тот, у которого
track_higher_grads=True
будет медленным... он сделает всю тяжелую работу по вычислению гессиана, но внезапно будет "убит"
.detach()
.
Мой код:
class NonDiffMAML(optim.Optimizer): # copy pasted from torch.optim.SGD
def __init__(self, params, lr=required, momentum=0, dampening=0,
weight_decay=0, nesterov=False):
if lr is not required and lr < 0.0:
raise ValueError("Invalid learning rate: {}".format(lr))
if momentum < 0.0:
raise ValueError("Invalid momentum value: {}".format(momentum))
if weight_decay < 0.0:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
weight_decay=weight_decay, nesterov=nesterov)
if nesterov and (momentum <= 0 or dampening != 0):
raise ValueError("Nesterov momentum requires a momentum and zero dampening")
super().__init__(params, defaults)
class MAML(DifferentiableOptimizer): # copy pasted from DifferentiableSGD but with the g.detach() line of code
def _update(self, grouped_grads: _GroupedGradsType, **kwargs) -> None:
zipped = zip(self.param_groups, grouped_grads)
for group_idx, (group, grads) in enumerate(zipped):
weight_decay = group['weight_decay']
momentum = group['momentum']
dampening = group['dampening']
nesterov = group['nesterov']
for p_idx, (p, g) in enumerate(zip(group['params'], grads)):
if g is None:
continue
if weight_decay != 0:
g = _add(g, weight_decay, p)
if momentum != 0:
param_state = self.state[group_idx][p_idx]
if 'momentum_buffer' not in param_state:
buf = param_state['momentum_buffer'] = g
else:
buf = param_state['momentum_buffer']
buf = _add(buf.mul(momentum), 1 - dampening, g)
param_state['momentum_buffer'] = buf
if nesterov:
g = _add(g, momentum, buf)
else:
g = buf
if self.fo: # first-order
g = g.detach() # dissallows flow of higher order grad while still letting params track gradients.
group['params'][p_idx] = _add(p, -group['lr'], g)
higher.register_optim(NonDiffMAML, MAML)
Примечание для полноты говорит:
track_higher_grads – if True, during unrolled optimization the graph be retained, and the fast weights will bear grad funcs, so as to permit backpropagation through the optimization process. Setting this to False allows the differentiable optimizer to be used in “test mode”, without potentially tracking higher order gradients. This can be useful when running the training loop at test time, e.g. in k-shot learning experiments, without incurring a significant memory overhead.
Связанный:
- официальный fo maml: https://github.com/facebookresearch/higher/issues/63
- документы: документовhttps://higher.readthedocs.io/en/latest/optim.html
- крест: https://github.com/facebookresearch/higher/issues/128
- https://www.reddit.com/r/pytorch/comments/si5xv1/would_making_the_gradient_data_by_detaching_them/