Получение ошибки на шаге svi из-за мультиклассового распределения в образце с использованием pyro и pytorch
Я работаю над каузальным вариационным автоэнкодером, который работает с масками сегментации классов, метками классов и причинностью (0 или 1) в качестве входных данных.
Я получаю сообщение об ошибке при работе с пакетами размером больше 1 из-за шага svi. Я использую функцию бернуллинга, потому что хочу, чтобы она изучила распределение вероятностей для нескольких классов изображения. Я думаю, что категориальное распределение здесь тоже подходит, но я тоже получаю ту же ошибку.
Когда я попытался сузить количество строк кода, которые создают проблему, я подумал, что это в функции модели:
one_vec2 = torch.ones([batch_size, self.lbl_shape[0]], **options)
class_labels = pyro.sample('class_labels', dist.Bernoulli(one_vec2*0.5), obs = lbls)
Ошибка:
ValueError Traceback (most recent call last)
<ipython-input-19-8cbc046dd2c1> in <module>()
6 vae = Vae_Model1(lbl_sz, ch, img_sz).to(device)
7 svi = SVI(vae.model, vae.guide, optimizer, loss = Trace_ELBO())
----> 8 train(svi, train_loader, USE_CUDA)
6 frames
/usr/local/lib/python3.6/dist-packages/pyro/util.py in check_site_shape(site, max_plate_nesting)
320 '- enclose the batched tensor in a with plate(...): context',
321 '- .to_event(...) the distribution being sampled',
--> 322 '- .permute() data dimensions']))
323
324 # Check parallel dimensions on the left of max_plate_nesting.
ValueError: at site "class_labels", invalid log_prob shape
Expected [-1], actual [32, 21]
Try one of the following fixes:
- enclose the batched tensor in a with plate(...): context
- .to_event(...) the distribution being sampled
- .permute() data dimensions
В настоящее время размер пакета составляет 32, а lbl_shape[0] - 21 (набор данных VOC (фон и другие метки))
Может ли кто-нибудь помочь мне с этим? Мы будем очень признательны. Спасибо