есть ли способ отслеживать грады с помощью метода self.put_variable во льне?
Я хотел бы отслеживать оценки через переменную self.put_variable. Есть ли способ сделать это возможным? Или другой способ обновить параметр, предоставленный отслеживаемому модулю?
import jax
from jax import numpy as jnp
from jax import grad,random,jit,vmap
import flax
from flax import linen as nn
class network(nn.Module):
input_size : int
output_size : int
@nn.compact
def __call__(self,x):
W = self.param('W',nn.initializers.normal(),(self.input_size,self.output_size))
b = self.param('b',nn.initializers.normal(),(self.output_size,))
self.put_variable("params","b",(x@W+b).reshape(5,))
return jnp.sum(x+b)
if __name__ == "__main__":
key = random.PRNGKey(0)
key_x,key_param,key = random.split(key,3)
x = random.normal(key_x,(1,5))
module = network(5,5)
param = module.init(key_param,x)
print(param)
#x,param = module.apply(param,x,mutable=["params"])
#print(param)
print(grad(module.apply,has_aux=True)(param,x,mutable=["params"]))
мои выходные оценки:
FrozenDict({
params: {
W: DeviceArray([[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.]], dtype=float32),
b: DeviceArray([1., 1., 1., 1., 1.], dtype=float32),
},
Что показывает, что он не отслеживает градации через метод self.variable_put, поскольку все градации до W равны нулю, а b явно зависит от W.