Пользовательская функция JAX VJP для нескольких входных переменных не работает для NumPyro/HMC-NUTS

Я пытаюсь использовать пользовательскую функцию VJP (вектор-якобианский продукт) в качестве модели для HMC-NUTS в numpyro. Мне удалось создать функцию с одной переменной, которая работает для HMC-NUTS, следующим образом:

      import jax.numpy as jnp
from jax import custom_vjp

@custom_vjp
def h(x):
    return jnp.sin(x)

def h_fwd(x):
    return h(x), jnp.cos(x)

def h_bwd(res, u):
    cos_x  = res 
    return (cos_x * u,)

h.defvjp(h_fwd, h_bwd)

Здесь я вручную определил h(x) = sin(x). Затем я сделал тестовые данные как

      import numpy as np
np.random.seed(32)
sigin=0.3
N=20
x=np.sort(np.random.rand(N))*4*np.pi
data=hv(x)+np.random.normal(0,sigin,size=N)

данные испытаний

В этом случае мне удалось выполнить HMC-NUTS в NumPyro как

      import numpyro
import numpyro.distributions as dist

def model(x,y):
    sigma = numpyro.sample('sigma', dist.Exponential(1.))
    x0 = numpyro.sample('x0', dist.Uniform(-1.,1.))
    #mu=jnp.sin(x-x0)
    #mu=hv(x-x0)
    mu=h(x-x0)
    numpyro.sample('y', dist.Normal(mu, sigma), obs=y)

from jax import random
from numpyro.infer import MCMC, NUTS

rng_key = random.PRNGKey(0)
rng_key, rng_key_ = random.split(rng_key)
num_warmup, num_samples = 1000, 2000
kernel = NUTS(model)
mcmc = MCMC(kernel, num_warmup, num_samples)
mcmc.run(rng_key_, x=x, y=data)
mcmc.print_summary()

Оно работает.

      sample: 100%|██████████| 3000/3000 [00:15<00:00, 193.84it/s, 3 steps of size 7.67e-01. acc. prob=0.92]

                mean       std    median      5.0%     95.0%     n_eff     r_hat
     sigma      0.35      0.06      0.34      0.26      0.45   1178.07      1.00
        x0      0.07      0.11      0.07     -0.11      0.26   1243.73      1.00

Number of divergences: 0

Однако, если я определю функцию с несколькими переменными как,

      @custom_vjp
def h(x,A):
    return A*jnp.sin(x)

def h_fwd(x, A):
    res = (A*jnp.cos(x), jnp.sin(x))
    return h(x,A), res

def h_bwd(res, u):
    A_cos_x, sin_x = res
    return (A_cos_x * u, sin_x * u)

h.defvjp(h_fwd, h_bwd)

затем выполните HMC-NUTS как

      def model(x,y):
    sigma = numpyro.sample('sigma', dist.Exponential(1.))
    x0 = numpyro.sample('x0', dist.Uniform(-1.,1.))
    A = numpyro.sample('A', dist.Exponential(1.))
    mu=h(x-x0,A)
    numpyro.sample('y', dist.Normal(mu, sigma), obs=y)

rng_key = random.PRNGKey(0)
rng_key, rng_key_ = random.split(rng_key)
num_warmup, num_samples = 1000, 2000
kernel = NUTS(model)
mcmc = MCMC(kernel, num_warmup, num_samples)
mcmc.run(rng_key_, x=x, y=data)
mcmc.print_summary()

тогда я получил ошибку как

      TypeError: mul got incompatible shapes for broadcasting: (3,), (22,).

Я подозреваю, что форма вывода в моей функции была неправильной. Но я не мог понять, в чем дело, после различных попыток изменения формы.

1 ответ

      def model(x,y):
sigma = numpyro.sample('sigma', dist.Exponential(1.))
x0 = numpyro.sample('x0', dist.Uniform(-1.,1.))
A = numpyro.sample('A', dist.Exponential(1.))
hv=vmap(h,(0,None),0)
mu=hv(x-x0,A)
numpyro.sample('y', dist.Normal(mu, sigma), obs=y)

vmap решил эту проблему.

Другие вопросы по тегам