есть ли способ отслеживать грады с помощью метода self.put_variable во льне?

Я хотел бы отслеживать оценки через переменную self.put_variable. Есть ли способ сделать это возможным? Или другой способ обновить параметр, предоставленный отслеживаемому модулю?

      import jax 
from jax import numpy as jnp 
from jax import grad,random,jit,vmap 
import flax 
from flax import linen as nn 


class network(nn.Module):
    input_size : int 
    output_size : int 
    @nn.compact
    def __call__(self,x):
        W = self.param('W',nn.initializers.normal(),(self.input_size,self.output_size))
        b = self.param('b',nn.initializers.normal(),(self.output_size,))

      
        self.put_variable("params","b",(x@W+b).reshape(5,))  
    
        return jnp.sum(x+b)


if __name__ == "__main__":
    key = random.PRNGKey(0)
    key_x,key_param,key = random.split(key,3)
    x = random.normal(key_x,(1,5))

    module = network(5,5)
    param = module.init(key_param,x)
    print(param)
    #x,param = module.apply(param,x,mutable=["params"])
    #print(param)
    print(grad(module.apply,has_aux=True)(param,x,mutable=["params"]))

мои выходные оценки:

      FrozenDict({
    params: {
        W: DeviceArray([[0., 0., 0., 0., 0.],
                     [0., 0., 0., 0., 0.],
                     [0., 0., 0., 0., 0.],
                     [0., 0., 0., 0., 0.],
                     [0., 0., 0., 0., 0.]], dtype=float32),
        b: DeviceArray([1., 1., 1., 1., 1.], dtype=float32),
    },

Что показывает, что он не отслеживает градации через метод self.variable_put, поскольку все градации до W равны нулю, а b явно зависит от W.

0 ответов

Другие вопросы по тегам