спектральная норма в модуле GCNConv

Я хочу вызвать функцию torch.nn.utilsspectral_norm на слое GCNConv.

      gc1 = GCNConv(18, 16)
spectral_norm(gc1)

но я получаю следующую ошибку:

      KeyError: 'weight'

это означает, что gc1._parameters не имеет веса (только смещение):

      gc1._parameters
OrderedDict([('bias', Parameter containing:
              tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
                     requires_grad=True))])

Однако gc1.parameters() хранит два объекта, и один из них представляет собой матрицу 16 на 18 (весовая матрица).

      for p in gc1.parameters():
  print('P: ', p.shape)
P:  torch.Size([16])
P:  torch.Size([16, 18])

Как я могу заставить функцию спектра_нормы работать с модулем GCNConv?

1 ответ

Согласно исходному коду, параметр веса заключен в линейный модуль, содержащийся в объектах GCNConv, как lin.

Я предполагаю, что это должно работать:

      gc1 = GCNConv(18, 16)
spectral_norm(gc1.lin)
Другие вопросы по тегам