PyTorch: неявные градиенты возвращают (нет) метаградиент
Я пытаюсь реализовать алгоритм неявных градиентов [1, 2, 3] для оптимизации некоторых метапараметров (в моем случае параметров функции потерь). Однако создаваемые (мета-)градиенты всегда равны None. Могу ли я получить некоторую помощь в определении проблемы, и как я могу решить эту проблему?
Ниже я прикрепил упрощенный код, который воспроизводит ошибку.
from sklearn.datasets import make_regression
import torch
# Creating a meta-network for representing the loss function.
class MetaNetwork(torch.nn.Module):
def __init__(self):
super(MetaNetwork, self).__init__()
self.model = torch.nn.Sequential(
torch.nn.Linear(2, 10),
torch.nn.ReLU(),
torch.nn.Linear(10, 1),
torch.nn.Softplus()
)
def forward(self, y_pred, y_target):
return self.model(torch.cat((y_pred, y_target), dim=1)).mean()
# Creating a base-network for learning the model of the data.
class BaseNetwork(torch.nn.Module):
def __init__(self):
super(BaseNetwork, self).__init__()
self.model = torch.nn.Sequential(
torch.nn.Linear(1, 10),
torch.nn.ReLU(),
torch.nn.Linear(10, 1)
)
def forward(self, x):
return self.model(x)
# Generating some synthetic training and validation data.
X_train, y_train = make_regression(n_samples=100, n_features=1, n_informative=1, noise=0.1, random_state=1)
X_valid, y_valid = make_regression(n_samples=100, n_features=1, n_informative=1, noise=0.1, random_state=2)
# Converting data into the correct format.
X_train, y_train = torch.tensor(X_train).float(), torch.unsqueeze(torch.tensor(y_train).float(), 1)
X_valid, y_valid = torch.tensor(X_valid).float(), torch.unsqueeze(torch.tensor(y_valid).float(), 1)
# Creating our base and meta models, as well as the base optimizer.
meta_network, base_network = MetaNetwork(), BaseNetwork()
base_optimizer = torch.optim.SGD(base_network.parameters(), lr=0.01)
# Training the model using the meta-network as the loss function.
for i in range(10):
base_optimizer.zero_grad()
yp = base_network(X_train)
base_loss = meta_network(yp, y_train)
base_loss.backward()
base_optimizer.step()
meta_loss_fn = torch.nn.MSELoss()
# Computing the training and validation (meta) loss.
train_loss = meta_loss_fn(base_network(X_train), y_train)
validation_loss = meta_loss_fn(base_network(X_valid), y_valid)
# Gradient of the validation loss with respect to the base model weights.
dloss_val_dparams = torch.autograd.grad(validation_loss, base_network.parameters(),
retain_graph=True, allow_unused=True)
# Gradient of the training loss with respect to the base model weights.
dloss_train_dparams = torch.autograd.grad(train_loss, base_network.parameters(),
create_graph=True, allow_unused=True)
p = v = dloss_val_dparams
for _ in range(10):
grad = torch.autograd.grad(dloss_train_dparams, base_network.parameters(),
grad_outputs=v, retain_graph=True, allow_unused=True)
grad = [g * 0.01 for g in grad]
v = [curr_v - curr_g for (curr_v, curr_g) in zip(v, grad)]
p = [curr_p + curr_v for (curr_p, curr_v) in zip(p, v)]
v2 = list(0.01 * pp for pp in p)
v3 = torch.autograd.grad(dloss_train_dparams, meta_network.parameters(), grad_outputs=v2, allow_unused=True)
print("Meta Gradient", v3)
[1] Раджесваран, А., Финн, К., Какаде, С.М., и Левин, С. (2019). Мета-обучение с неявными градиентами.
[2] Лоррейн Дж., Викол П. и Дювено Д. (2020 г., июнь). Оптимизация миллионов гиперпараметров путем неявного дифференцирования.
[3] Гао, Б., Гоук, Х., Ян, Ю., и Хоспедалес, Т. (2021). Обучение функции потерь для обобщения предметной области с помощью неявного градиента.