Как я могу инициализировать скрытое состояние (перенос) GRUCell (льняное полотно) в качестве обучаемого параметра (например, с помощью model.init)
Я создаю модель GRU в Jax с помощью Flax и инициализирую параметры модели с помощью model.init следующим образом:
import jax.numpy as np
from jax import random
import flax.linen as nn
from jax.nn import initializers
class RNN(nn.Module):
n_RNN_units: int
@nn.compact
def __call__(self, carry, inputs):
carry, outputs = nn.GRUCell()(carry, inputs)
return carry, outputs
def init_state(self):
return nn.GRUCell.initialize_carry((), (), self.n_RNN_units, init_fn = initializers.zeros)
# instantiate an RNN (GRU) model
n_RNN_units = 200
model = RNN(n_RNN_units = n_RNN_units)
# initialize the parameters of the model (weights and biases)
data_dim = 20
params = model.init(carry = np.empty((n_RNN_units,)), inputs = np.empty((data_dim,)), rngs = {'params': random.PRNGKey(1)})
К сожалению для меня, параметры FrozenDict, созданные model.init, содержат только вес и смещения GRU, а не начальное скрытое состояние (перенос). Есть ли способ, которым я могу сказать model.init 1), что я также хочу узнать начальное скрытое состояние и 2) указать функцию инициализации для начального скрытого состояния.
В качестве альтернативы, если есть лучший способ сделать это, который не требует использования model.init, не стесняйтесь предлагать его.
заранее спасибо
1 ответ
Вы можете использоватьself.param
зарегистрировать тензор в качестве параметров:
@nn.compact
def __call__(self, inputs, carry=None):
if carry is None:
# Learnable initial carry
carry = self.param('carry_init', lambda rng, shape: jnp.zeros(shape), (self.n_RNN_units,))
carry, outputs = nn.GRUCell()(carry, inputs)
return carry, outputs
Теперь находится в параметрах модели послеmodel.init(rng, inputs, None)
.
Что происходит сейчас, так это то, чтоmodel.apply
принимает параметры с ним, поэтому градиенты по отношению к нему будут вычисляться как обычно с помощьюgrad
.
Точнее, когда вы делаете прогноз последовательности, вы должны начинать свои вызовы сcarry, outputs = model.apply(params, inputs)
. Он будет использовать вparams
затем для следующих вызовов используйтеcarry, outputs = model.apply(params, inputs, carry)
. Он будет использоватьcarry
сейчас иcarry_init
находится косвенно на графе вычислений выходов и переносится в качестве начального переноса, поэтому вы можете распространять на нем градиент. Однако вам следует позаботиться о потенциально сильном исчезновении градиента, если у вас есть длинные последовательности, поэтому вы можете рассмотреть возможность использования всех значений (особенно первых) ваших последовательностей для вычисления потерь или адаптации специальной скорости обучения на основе длины последовательности.
Деталиlinen.Module.param
в документации Flax Управление параметрами и состоянием .