Как заставить Pyro подчиняться только апостериорной части для подмножества образцов сайтов?
Недавно я узнал, как использовать MCMC для определения входных параметров нейронной сети. После запуска кода я заметил, что MCMC очень долго завершает работу. Под очень длинным я имею в виду, что запуск MCMC с 40 шагами прогрева и 100 выборками может занять 1-3 часа (или даже больше) в зависимости от параметров ядра. Для меня это очень неприемлемо, особенно потому, что модель нейронной сети, которую я использую, может обрабатывать тысячи изображений в минуту. Затем я узнал, что это, вероятно, связано с тем, что MCMC фактически оценивает распределение для всех моих образцов сайтов, которые имеют очень большие размеры (см. Размеры ниже в коде). Однако меня интересует только выведение апостериорного распределения для одной из переменных после согласования вывода модели с наблюдением.
x_q
(изображение). Поэтому мне интересно, есть ли способ заставить MCMC использовать только нижнее апостериорное распределение для
v_q
а остальные переменные как-то игнорирует?
Вот моя попытка: я установил z на
z=torch.distributions.Normal(mu_q, std_q).rsample()
вывод MCMC выполняется очень быстро, но очень близок к 0 (например, 0,01) и продолжает снижаться (он упал до 1e-140 после 200 шагов разогрева). Так что я считаю, что это неправильно.
Кроме того, вывод будет намного быстрее, если я установлю
adapt_step_size
к
False
но результаты действительно плохие (
acc. prob
составляет около 0,1), когда я вручную установил
step_size
к значению, с которым выполнялся предыдущий вывод
adap_step_size
нашел.
Вот код
class model(nn.Module):
def __init__(self):
super(model, self).__init__()
.
. # define model layers
.
def infer_v_q(self, x, v, x_q, dataset):
.
. # do some other operations here to get c_e, h_e, h_g, u and some other variables
.
# x.shape = (batch_dim, 3, 64, 64), batch_dim=36
with pyro.plate("data", x.shape[0]):
# v_q.shape = (batch_dim, 7), batch_dim=36
v_q = pyro.sample('v_q', pyro.distributions.Uniform(v_q_min, v_q_max).to_event(1))
for l in range(self.L):
# note that the followings are done L times so I have L number of z sample sites
c_e, h_e = self.inference_network(x_q, v_q, r, c_e, h_e, h_g, u)
mu_q, logvar_q = torch.split(self.hidden_state_network(h_e), 1, dim=1)
std_q = torch.exp(0.5*logvar_q)
# z.shape = (batch_dim, 1, 16, 16)
z = pyro.sample("z"+str(l), pyro.distributions.Normal(mu_q, std_q).to_event(3))
c_g, h_g, u = self.generation_network(v_q, r, c_g, h_g, u, z)
# x.shape = (batch_dim, 3, 64, 64)
return pyro.sample("x", pyro.distributions.Normal(self.image_generator_network(u), 0.001).to_event(3), obs=x_q)
# Here's how I run MCMC inference
nuts_kernel = pyro.infer.NUTS(model.infer_v_q, adapt_step_size=True, step_size=1e-9, jit_compile=True, ignore_jit_warnings=True)
mcmc = pyro.infer.MCMC(nuts_kernel, num_samples=200, warmup_steps=100, num_chains=1)
mcmc.run(x, v, x_q)
# get samples
mcmc.get_samples()["v_q"] # I only need this
# I don't want MCMC to get a posterior for the "z" variables:
mcmc.get_samples()["z1"]
.
.
.
mcmc.get_samples()["z7"]
А вот и мои образцы сайтов:
Sample Sites:
data dist |
value 36 |
v_q dist 36 | 7
value 36 | 7
z0 dist 36 | 1 16 16
value 36 | 1 16 16
z1 dist 36 | 1 16 16
value 36 | 1 16 16
z2 dist 36 | 1 16 16
value 36 | 1 16 16
z3 dist 36 | 1 16 16
value 36 | 1 16 16
z4 dist 36 | 1 16 16
value 36 | 1 16 16
z5 dist 36 | 1 16 16
value 36 | 1 16 16
z6 dist 36 | 1 16 16
value 36 | 1 16 16
z7 dist 36 | 1 16 16
value 36 | 1 16 16