Получение ошибки на шаге svi из-за мультиклассового распределения в образце с использованием pyro и pytorch

Я работаю над каузальным вариационным автоэнкодером, который работает с масками сегментации классов, метками классов и причинностью (0 или 1) в качестве входных данных.

Я получаю сообщение об ошибке при работе с пакетами размером больше 1 из-за шага svi. Я использую функцию бернуллинга, потому что хочу, чтобы она изучила распределение вероятностей для нескольких классов изображения. Я думаю, что категориальное распределение здесь тоже подходит, но я тоже получаю ту же ошибку.

Когда я попытался сузить количество строк кода, которые создают проблему, я подумал, что это в функции модели:

one_vec2 = torch.ones([batch_size, self.lbl_shape[0]], **options)
class_labels = pyro.sample('class_labels', dist.Bernoulli(one_vec2*0.5), obs = lbls)      

Ошибка:

ValueError                                Traceback (most recent call last)
<ipython-input-19-8cbc046dd2c1> in <module>()
      6 vae = Vae_Model1(lbl_sz, ch, img_sz).to(device)
      7 svi = SVI(vae.model, vae.guide, optimizer, loss = Trace_ELBO())
----> 8 train(svi, train_loader, USE_CUDA)

6 frames
/usr/local/lib/python3.6/dist-packages/pyro/util.py in check_site_shape(site, max_plate_nesting)
    320                 '- enclose the batched tensor in a with plate(...): context',
    321                 '- .to_event(...) the distribution being sampled',
--> 322                 '- .permute() data dimensions']))
    323 
    324     # Check parallel dimensions on the left of max_plate_nesting.

ValueError: at site "class_labels", invalid log_prob shape
  Expected [-1], actual [32, 21]
  Try one of the following fixes:
  - enclose the batched tensor in a with plate(...): context
  - .to_event(...) the distribution being sampled
  - .permute() data dimensions

В настоящее время размер пакета составляет 32, а lbl_shape[0] - 21 (набор данных VOC (фон и другие метки))

Может ли кто-нибудь помочь мне с этим? Мы будем очень признательны. Спасибо

0 ответов

Другие вопросы по тегам