Получение неправильного вывода из вызова инициализации льняной модели
Я пытаюсь создать простую нейронную сеть с использованием льна, как показано ниже.
Однакоparams
замороженный дикт я получаю в качестве выводаmodel.init
пуст вместо того, чтобы иметь параметры нейронной сети. Такжеtype(predictions)
являетсяflax.linen.combinators.Sequential
объект вместо того, чтобы бытьDeviceArray
.
Может ли кто-нибудь помочь мне понять, что не так с этим фрагментом кода?
import jax
import jax.numpy as jnp
import flax.linen as nn
import matplotlib.pyplot as plt
class MLP(nn.Module):
@nn.compact
def __call__(self, x):
return nn.Sequential(
[
nn.Dense(40),
nn.relu,
nn.Dense(40),
nn.Dense(1),
]
)
model = MLP()
dummy_input = jnp.ones((40, 40, 1))
params = model.init(jax.random.PRNGKey(0), dummy_input)
jax.tree_util.tree_map(lambda x: x.shape, params)
n = 100
x_inputs = jnp.linspace(-10, 10, n).reshape(1, -1)
y_targets = jnp.sin(x_inputs)
predictions = model.apply(params, x_inputs)
plt.plot(x_inputs.reshape(-1), y_targets.reshape(-1))
plt.plot(x_inputs.reshape(-1), predictions.reshape(-1))
1 ответ
Проблема в том, чтоnn.Sequential
возвращает функцию, которую необходимо вызвать с входными данными. Замена
class MLP(nn.Module):
@nn.compact
def __call__(self, x):
return nn.Sequential(
[
nn.Dense(40),
nn.relu,
nn.Dense(40),
nn.Dense(1),
]
)
с
class MLP(nn.Module):
@nn.compact
def __call__(self, x):
return nn.Sequential(
[
nn.Dense(40),
nn.relu,
nn.Dense(40),
nn.Dense(1),
]
)(x)
Решает проблему.