Как заставить Pyro подчиняться только апостериорной части для подмножества образцов сайтов?

Недавно я узнал, как использовать MCMC для определения входных параметров нейронной сети. После запуска кода я заметил, что MCMC очень долго завершает работу. Под очень длинным я имею в виду, что запуск MCMC с 40 шагами прогрева и 100 выборками может занять 1-3 часа (или даже больше) в зависимости от параметров ядра. Для меня это очень неприемлемо, особенно потому, что модель нейронной сети, которую я использую, может обрабатывать тысячи изображений в минуту. Затем я узнал, что это, вероятно, связано с тем, что MCMC фактически оценивает распределение для всех моих образцов сайтов, которые имеют очень большие размеры (см. Размеры ниже в коде). Однако меня интересует только выведение апостериорного распределения для одной из переменных после согласования вывода модели с наблюдением. x_q(изображение). Поэтому мне интересно, есть ли способ заставить MCMC использовать только нижнее апостериорное распределение для v_q а остальные переменные как-то игнорирует?

Вот моя попытка: я установил z на z=torch.distributions.Normal(mu_q, std_q).rsample()вывод MCMC выполняется очень быстро, но очень близок к 0 (например, 0,01) и продолжает снижаться (он упал до 1e-140 после 200 шагов разогрева). Так что я считаю, что это неправильно.

Кроме того, вывод будет намного быстрее, если я установлю adapt_step_size к False но результаты действительно плохие ( acc. prob составляет около 0,1), когда я вручную установил step_size к значению, с которым выполнялся предыдущий вывод adap_step_size нашел.

Вот код

      class model(nn.Module):
    def __init__(self):
        super(model, self).__init__()
        .
        . # define model layers
        .

    def infer_v_q(self, x, v, x_q, dataset):
        .
        . # do some other operations here to get c_e, h_e, h_g, u and some other variables
        .

        # x.shape = (batch_dim, 3, 64, 64), batch_dim=36
        with pyro.plate("data", x.shape[0]):
                # v_q.shape = (batch_dim, 7), batch_dim=36
                v_q = pyro.sample('v_q', pyro.distributions.Uniform(v_q_min, v_q_max).to_event(1))

                for l in range(self.L):
                    # note that the followings are done L times so I have L number of z sample sites
                    c_e, h_e = self.inference_network(x_q, v_q, r, c_e, h_e, h_g, u)
                    
                    mu_q, logvar_q = torch.split(self.hidden_state_network(h_e), 1, dim=1)
                    std_q = torch.exp(0.5*logvar_q)
                    # z.shape = (batch_dim, 1, 16, 16)
                    z = pyro.sample("z"+str(l), pyro.distributions.Normal(mu_q, std_q).to_event(3))

                    c_g, h_g, u = self.generation_network(v_q, r, c_g, h_g, u, z)
                # x.shape = (batch_dim, 3, 64, 64)
                return pyro.sample("x", pyro.distributions.Normal(self.image_generator_network(u), 0.001).to_event(3), obs=x_q)

# Here's how I run MCMC inference
nuts_kernel = pyro.infer.NUTS(model.infer_v_q, adapt_step_size=True, step_size=1e-9, jit_compile=True, ignore_jit_warnings=True)
mcmc = pyro.infer.MCMC(nuts_kernel, num_samples=200, warmup_steps=100, num_chains=1)
mcmc.run(x, v, x_q)

# get samples
mcmc.get_samples()["v_q"] # I only need this

# I don't want MCMC to get a posterior for the "z" variables:
mcmc.get_samples()["z1"]
.
.
.
mcmc.get_samples()["z7"]

А вот и мои образцы сайтов:

      Sample Sites:
data dist |
value 36 |
v_q dist 36 | 7
value 36 | 7
z0 dist 36 | 1 16 16
value 36 | 1 16 16
z1 dist 36 | 1 16 16
value 36 | 1 16 16
z2 dist 36 | 1 16 16
value 36 | 1 16 16
z3 dist 36 | 1 16 16
value 36 | 1 16 16
z4 dist 36 | 1 16 16
value 36 | 1 16 16
z5 dist 36 | 1 16 16
value 36 | 1 16 16
z6 dist 36 | 1 16 16
value 36 | 1 16 16
z7 dist 36 | 1 16 16
value 36 | 1 16 16

0 ответов

Другие вопросы по тегам