Indexing a multi-dimensional tensor with a tensor in PyTorch
У меня есть следующий код:
a = torch.randint(0,10,[3,3,3,3])
b = torch.LongTensor([1,1,1,1])
I have a multi-dimensional index b
and want to use it to select a single cell in a
, If b wasn't a tensor, I could do:
a[1,1,1,1]
Which returns the correct cell, but:
a[b]
Doesn't work, because it just selects a[1]
четыре раза.
Как я могу это сделать? Спасибо
2 ответа
Решение
Более элегантное (и более простое) решение может быть просто b
как кортеж:
a[tuple(b)]
Out[10]: tensor(5.)
Мне было любопытно посмотреть, как это работает с "обычным" numpy, и нашел соответствующую статью, объясняющую это довольно хорошо здесь.
Вы можете разделить b
в 4 с помощью chunk
, а затем использовать куски b
Индексировать определенный элемент, который вы хотите:
>> a = torch.arange(3*3*3*3).view(3,3,3,3)
>> b = torch.LongTensor([[1,1,1,1], [2,2,2,2], [0, 0, 0, 0]]).t()
>> a[b.chunk(chunks=4, dim=0)] # here's the trick!
Out[24]: tensor([[40, 80, 0]])
Что приятно, так это то, что он может быть легко обобщен для любого измерения a
нужно просто сделать количество патронов равным размеру a
,