Проблема при совмещении репараметризации и автоматического перебора

Я пытаюсь повторить в пиро на annotators.py например , в numpyro.

Более подробно, повторная параметризация создает проблемы в сочетании с автоматическим перечислением дискретных переменных в pyro.

Код ниже точно такой же, как в примере numpyro, за исключением очевидных переводов с jax на torch.

      def hierarchical_dawid_skene(positions: torch.Tensor, annotations: torch.Tensor) -> None:
    """
    This model corresponds to the plate diagram in Figure 4 of reference [1].
    """
    num_annotators = positions.unique().numel()
    num_classes = annotations.unique().numel()
    num_items, num_positions = annotations.shape

    # debugging
    print(f"{num_classes=}, {num_annotators=}, {num_items=}, {num_positions=}")

    with pyro.plate("class", num_classes):
        # NB: we define `beta` as the `logits` of `y` likelihood; but `logits` is
        # invariant up to a constant, so we'll follow [1]: fix the last term of `beta`
        # to 0 and only define hyperpriors for the first `num_classes - 1` terms.
        zeta = pyro.sample("zeta", dist.Normal(0, 1).expand([num_classes - 1]).to_event(1))
        omega = pyro.sample("Omega", dist.HalfNormal(1).expand([num_classes - 1]).to_event(1))

    with pyro.plate("annotator", num_annotators, dim=-2):
        with pyro.plate("class_abilities", num_classes):
            # non-centered parameterization
            with reparam(config={"beta": LocScaleReparam(centered=0.)}):    # <- with this it does not work, beta is reshaped
                beta = pyro.sample("beta", dist.Normal(zeta, omega).to_event(1)).
            # pad 0 last dimension
            beta = F.pad(beta, [0, 1] + [0, 0] * (beta.dim() - 1))

    pi = pyro.sample("pi", dist.Dirichlet(torch.ones(num_classes)))

    with pyro.plate("item", num_items, dim=-2):
        c = pyro.sample("c", dist.Categorical(probs=pi))
        
        # debugging
        print(f"{c.shape=}, {beta.shape=}")

        with pyro.plate("position", num_positions):
            logits = Vindex(beta)[positions, c, :]
            pyro.sample("y", dist.Categorical(logits=logits), obs=annotations)

В первой итерации MCMC с NUTS (следуя примеру numpyro) я получаю следующие отладочные отпечатки

      num_classes=4, num_annotators=5, num_items=10, num_positions=7
c.shape=torch.Size([10, 1]), beta.shape=torch.Size([5, 4, 4])

Во второй итерации, когда c перечислено, я получаю следующие отладочные отпечатки

      num_classes=4, num_annotators=5, num_items=10, num_positions=7
c.shape=torch.Size([4, 1, 1]), beta.shape=torch.Size([4, 4])

Меня озадачивает тот факт, что beta размер изменен с (5, 4, 4) к (4, 4). Этого не происходит, когда я снимаю репараметризацию.

Есть ли предложения о том, где искать, чтобы понять, что происходит?

Заранее большое спасибо за ваше время.

С уважением, Пьетро

На пирофоруме тоже размещено: вопрос .

0 ответов

Другие вопросы по тегам