спектральная норма в модуле GCNConv
Я хочу вызвать функцию torch.nn.utilsspectral_norm на слое GCNConv.
gc1 = GCNConv(18, 16)
spectral_norm(gc1)
но я получаю следующую ошибку:
KeyError: 'weight'
это означает, что gc1._parameters не имеет веса (только смещение):
gc1._parameters
OrderedDict([('bias', Parameter containing:
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
requires_grad=True))])
Однако gc1.parameters() хранит два объекта, и один из них представляет собой матрицу 16 на 18 (весовая матрица).
for p in gc1.parameters():
print('P: ', p.shape)
P: torch.Size([16])
P: torch.Size([16, 18])
Как я могу заставить функцию спектра_нормы работать с модулем GCNConv?
1 ответ
Согласно исходному коду, параметр веса заключен в линейный модуль, содержащийся в объектах GCNConv, как
lin
.
Я предполагаю, что это должно работать:
gc1 = GCNConv(18, 16)
spectral_norm(gc1.lin)