Исчезающие параметры в MAML JAX (метаобучение)
Я работаю над реализацией MAML (см. https://arxiv.org/pdf/1703.03400.pdf ) в Jax.
При обучении распределению простых задач линейной регрессии кажется, что все работает нормально (требуется некоторое время, чтобы сходиться, но в конечном итоге работает).
Однако при обучении на задачах, распределенных по типу A * sin(B + X), где A, B — случайные величины, все веса в сети сходятся к 0. Результаты обучения
Это явно неправильно. Заранее благодарим за любую оказанную помощь.
Полный код здесь https://colab.research.google.com/drive/1YoOkwo5tI42LeIbBOxpImkN55Kg9wScl?usp=sharing или см. ниже минимальный код.
Код генерации задачи:
class MAMLDataLoader:
def __init__(self, sample_task_fn, num_tasks, batch_size):
self.sample_task_fn = sample_task_fn
self.num_tasks = num_tasks
self.batch_size = batch_size
def sample_tasks(self, key):
XS = jnp.empty((self.num_tasks, 2 * self.batch_size, 1))
YS = jnp.empty((self.num_tasks, 2 * self.batch_size, 1))
for i in range(self.num_tasks):
key, subkey = random.split(key)
xs, ys = self.sample_task_fn(self.batch_size * 2, subkey)
XS = XS.at[i].set(xs)
YS = YS.at[i].set(ys)
x_train, x_test = XS[:, :self.batch_size], XS[:, self.batch_size:]
y_train, y_test = YS[:, :self.batch_size], YS[:, self.batch_size:]
return x_train, y_train, x_test, y_test
def dummy_input(self):
key = random.PRNGKey(0)
x = self.sample_task_fn(1, key)[0][0]
return x
def sample_sinusoidal_task(samples, key):
# y = a * sin(b + x)
xs_key, amplitude_key, phase_key = random.split(key, num=3)
amplitude = random.uniform(amplitude_key, (1, 1))
phase = random.uniform(phase_key, (1, 1)) * jnp.pi * 2
xs = (random.uniform(xs_key, (samples, 1)) * 4 - 2) * jnp.pi
ys = amplitude * jnp.sin(xs + phase)
return xs, ys
Вот основной код MAML:
class MAMLTrainer:
def __init__(self, model, alpha, optimiser, inner_steps=1):
self.model = model
self.alpha = alpha
self.optimiser = optimiser
self.inner_steps = inner_steps
self.jit_step = jit(self.step)
def loss(self, params, x, y):
preds = self.model.apply(params, x)
return jnp.mean(jnp.inner(y - preds, y - preds) / 2.0)
def update(self, params, x, y, inner_steps=None):
if inner_steps is None:
inner_steps = self.inner_steps
loss_grad = grad(self.loss)
def _update(i, params):
grads = loss_grad(params, x, y)
new_params = tree_map(lambda p, g: p - self.alpha * g, params, grads)
return new_params
return lax.fori_loop(0, inner_steps, _update, params)
def meta_loss(self, params, x1, y1, x2, y2):
return self.loss(self.update(params, x1, x2), x2, y2)
def batch_meta_loss(self, params, x1, y1, x2, y2):
return jnp.mean(vmap(partial(self.meta_loss, params))(x1, y1, x2, y2))
def step(self, params, optimiser, x1, y1, x2, y2):
loss, grads = value_and_grad(self.batch_meta_loss)(params, x1, y1, x2, y2)
updates, opt_state = self.optimiser.update(grads, optimiser, params)
params = optax.apply_updates(params, updates)
return params, loss
def train(self, dataloader, steps, key, params=None):
if params is None:
key, subkey = random.split(key)
params = self.model.init(subkey, dataloader.dummy_input())
optimiser = self.optimiser.init(params)
pbar, losses = tqdm(range(steps), desc='Training'), []
for epoch in pbar:
key, subkey = random.split(key)
params, loss = self.jit_step(params, optimiser, *dataloader.sample_tasks(subkey))
losses.append(loss)
if epoch % 100 == 0:
avg_loss = jnp.mean(jnp.array(losses[-100:]))
pbar.set_postfix_str(f'current_loss: {loss:.3f}, running_loss_100_epochs: {avg_loss:.3f}')
return params, jnp.array(losses)
def n_shot_learn(self, x_train, y_train, params, n):
return self.update(params, x_train, y_train, n)
Код тренировки:
class SimpleMLP(nn.Module):
features: Sequence[int]
@nn.compact
def __call__(self, inputs):
x = inputs
for i, feat in enumerate(self.features[:-1]):
x = nn.Dense(feat)(x)
x = nn.relu(x)
return nn.Dense(self.features[-1])(x)
model = SimpleMLP([64, 64, 1])
optimiser = optax.adam(1e-3)
trainer = MAMLTrainer(model, 0.1, optimiser, 1)
dataloader = MAMLDataLoader(sample_sinusoidal_task, 2, 100)
key = random.PRNGKey(0)
params, losses = trainer.train(dataloader, 10000, key)