Проблема при совмещении репараметризации и автоматического перебора
Я пытаюсь повторить в пиро на annotators.py например , в numpyro.
Более подробно, повторная параметризация создает проблемы в сочетании с автоматическим перечислением дискретных переменных в pyro.
Код ниже точно такой же, как в примере numpyro, за исключением очевидных переводов с jax на torch.
def hierarchical_dawid_skene(positions: torch.Tensor, annotations: torch.Tensor) -> None:
"""
This model corresponds to the plate diagram in Figure 4 of reference [1].
"""
num_annotators = positions.unique().numel()
num_classes = annotations.unique().numel()
num_items, num_positions = annotations.shape
# debugging
print(f"{num_classes=}, {num_annotators=}, {num_items=}, {num_positions=}")
with pyro.plate("class", num_classes):
# NB: we define `beta` as the `logits` of `y` likelihood; but `logits` is
# invariant up to a constant, so we'll follow [1]: fix the last term of `beta`
# to 0 and only define hyperpriors for the first `num_classes - 1` terms.
zeta = pyro.sample("zeta", dist.Normal(0, 1).expand([num_classes - 1]).to_event(1))
omega = pyro.sample("Omega", dist.HalfNormal(1).expand([num_classes - 1]).to_event(1))
with pyro.plate("annotator", num_annotators, dim=-2):
with pyro.plate("class_abilities", num_classes):
# non-centered parameterization
with reparam(config={"beta": LocScaleReparam(centered=0.)}): # <- with this it does not work, beta is reshaped
beta = pyro.sample("beta", dist.Normal(zeta, omega).to_event(1)).
# pad 0 last dimension
beta = F.pad(beta, [0, 1] + [0, 0] * (beta.dim() - 1))
pi = pyro.sample("pi", dist.Dirichlet(torch.ones(num_classes)))
with pyro.plate("item", num_items, dim=-2):
c = pyro.sample("c", dist.Categorical(probs=pi))
# debugging
print(f"{c.shape=}, {beta.shape=}")
with pyro.plate("position", num_positions):
logits = Vindex(beta)[positions, c, :]
pyro.sample("y", dist.Categorical(logits=logits), obs=annotations)
В первой итерации MCMC с NUTS (следуя примеру numpyro) я получаю следующие отладочные отпечатки
num_classes=4, num_annotators=5, num_items=10, num_positions=7
c.shape=torch.Size([10, 1]), beta.shape=torch.Size([5, 4, 4])
Во второй итерации, когда
c
перечислено, я получаю следующие отладочные отпечатки
num_classes=4, num_annotators=5, num_items=10, num_positions=7
c.shape=torch.Size([4, 1, 1]), beta.shape=torch.Size([4, 4])
Меня озадачивает тот факт, что
beta
размер изменен с
(5, 4, 4)
к
(4, 4)
. Этого не происходит, когда я снимаю репараметризацию.
Есть ли предложения о том, где искать, чтобы понять, что происходит?
Заранее большое спасибо за ваше время.
С уважением, Пьетро
На пирофоруме тоже размещено: вопрос .