Indexing a multi-dimensional tensor with a tensor in PyTorch

У меня есть следующий код:

a = torch.randint(0,10,[3,3,3,3])
b = torch.LongTensor([1,1,1,1])

I have a multi-dimensional index b and want to use it to select a single cell in a, If b wasn't a tensor, I could do:

a[1,1,1,1]

Which returns the correct cell, but:

a[b]

Doesn't work, because it just selects a[1] четыре раза.

Как я могу это сделать? Спасибо

2 ответа

Решение

Более элегантное (и более простое) решение может быть просто b как кортеж:

a[tuple(b)]
Out[10]: tensor(5.)

Мне было любопытно посмотреть, как это работает с "обычным" numpy, и нашел соответствующую статью, объясняющую это довольно хорошо здесь.

Вы можете разделить b в 4 с помощью chunk, а затем использовать куски b Индексировать определенный элемент, который вы хотите:

>> a = torch.arange(3*3*3*3).view(3,3,3,3)
>> b = torch.LongTensor([[1,1,1,1], [2,2,2,2], [0, 0, 0, 0]]).t()
>> a[b.chunk(chunks=4, dim=0)]   # here's the trick!
Out[24]: tensor([[40, 80,  0]])

Что приятно, так это то, что он может быть легко обобщен для любого измерения a нужно просто сделать количество патронов равным размеру a,

Другие вопросы по тегам