Как я могу инициализировать скрытое состояние (перенос) GRUCell (льняное полотно) в качестве обучаемого параметра (например, с помощью model.init)

Я создаю модель GRU в Jax с помощью Flax и инициализирую параметры модели с помощью model.init следующим образом:

      import jax.numpy as np
from jax import random
import flax.linen as nn
from jax.nn import initializers

class RNN(nn.Module):
    n_RNN_units: int

    @nn.compact
    def __call__(self, carry, inputs):
        
        carry, outputs = nn.GRUCell()(carry, inputs)
        
        return carry, outputs
    
    def init_state(self):
        
        return nn.GRUCell.initialize_carry((), (), self.n_RNN_units, init_fn = initializers.zeros)

# instantiate an RNN (GRU) model
n_RNN_units = 200
model = RNN(n_RNN_units = n_RNN_units)

# initialize the parameters of the model (weights and biases)
data_dim = 20
params = model.init(carry = np.empty((n_RNN_units,)), inputs = np.empty((data_dim,)), rngs = {'params': random.PRNGKey(1)})

К сожалению для меня, параметры FrozenDict, созданные model.init, содержат только вес и смещения GRU, а не начальное скрытое состояние (перенос). Есть ли способ, которым я могу сказать model.init 1), что я также хочу узнать начальное скрытое состояние и 2) указать функцию инициализации для начального скрытого состояния.

В качестве альтернативы, если есть лучший способ сделать это, который не требует использования model.init, не стесняйтесь предлагать его.

заранее спасибо

1 ответ

Вы можете использоватьself.paramзарегистрировать тензор в качестве параметров:

      @nn.compact
def __call__(self, inputs, carry=None):
    if carry is None:
        # Learnable initial carry
        carry = self.param('carry_init', lambda rng, shape: jnp.zeros(shape), (self.n_RNN_units,))
    carry, outputs = nn.GRUCell()(carry, inputs)
    return carry, outputs

Теперь находится в параметрах модели послеmodel.init(rng, inputs, None).

Что происходит сейчас, так это то, чтоmodel.applyпринимает параметры с ним, поэтому градиенты по отношению к нему будут вычисляться как обычно с помощьюgrad.

Точнее, когда вы делаете прогноз последовательности, вы должны начинать свои вызовы сcarry, outputs = model.apply(params, inputs). Он будет использовать вparamsзатем для следующих вызовов используйтеcarry, outputs = model.apply(params, inputs, carry). Он будет использоватьcarryсейчас иcarry_initнаходится косвенно на графе вычислений выходов и переносится в качестве начального переноса, поэтому вы можете распространять на нем градиент. Однако вам следует позаботиться о потенциально сильном исчезновении градиента, если у вас есть длинные последовательности, поэтому вы можете рассмотреть возможность использования всех значений (особенно первых) ваших последовательностей для вычисления потерь или адаптации специальной скорости обучения на основе длины последовательности.

Деталиlinen.Module.paramв документации Flax Управление параметрами и состоянием .

Другие вопросы по тегам