Как рассчитать перекрестное внимание между трехмерным тензором и четырехмерным тензором? [закрыто]
В настоящее время я реализую одну из своих идей, которая включает в себя вычисление перекрестного внимания между трехмерным тензором и четырехмерным тензором. Я хочу добавить механизм внимания в установленный модуль абстракции PointNet++
добавьте пример:
import torch
# a batch of point clouds and their features,
# 4 is the batch size, 80000 is the number of points for each scan
xyz = torch.rand(4, 80000, 3)
features = torch.rand(4,80000,128)
# sample some reference point from raw point clouds
new_xyz = torch.rand(4, 1024, 3)
new_features = torch.rand(4, 1024, 128)
# For each reference point, ball query is adopted to find its 32 neighbor points
grouped_xyz = torch.rand(4, 1024, 32, 3)
grouped_features = torch.rand(4, 1024, 32, 128)
# Then I want to calculate the cross attention between each reference
# point and its corresponding neighbor points
# (between new_features and grouped_features)
# Now I am trying to do something as follows:
new_features = new_features.view(4*1024, 1, 128)
grouped_features = grouped_features.view(4*1024, 32, 128)
# Then the cross attention can be calculate directly using the pytorch function
1 ответ
Не уверен, правильно ли я понял вопрос, но вот минимальный пример того, чего, я думаю, вы хотите достичь:
def cross_attention(input_3d, input_4d, attention_size, batch_first=True, mask=None, return_scores=False):
"""
:param input_3d: [batch_size, sents_len, hidden_size]
:param input_4d: [batch_size, hidden_size, sents_len, 1]
:param attention_size: attention_size
:param batch_first: batch first or not
:param mask: mask
:param return_scores: return attention scores or not
:return: cross attention output and attention scores if return_scores is True
"""
batch_size, sents_len, hidden_size = input_3d.size()
hidden_size = input_4d.size(1)
# [batch_size, sents_len, hidden_size] * [batch_size, hidden_size, sents_len, 1] -> [batch_size, sents_len, sents_len, 1]
attn = torch.matmul(input_3d, input_4d).unsqueeze(dim=-1)
# [batch_size, sents_len, sents_len, 1] -> [batch_size, sents_len, sents_len]
attn = attn.squeeze(dim=-1).masked_fill(mask, -np.inf)
# [batch_size, sents_len, sents_len] -> [batch_size, sents_len, sents_len]
attn = F.softmax(attn, dim=-1)
# [batch_size, sents_len, sents_len] -> [batch_size, sents_len, sents_len]
attn = F.dropout(attn, p=dropout, training=training)
# [batch_size, sents_len, sents_len] * [batch_size, sents_len, hidden_size] -> [batch_size, sents_len, hidden_size]
attn_output = torch.matmul(attn, input_3d)
# [batch_size, sents_len, hidden_size] -> [batch_size, sents_len, 1, hidden_size]
attn_output = attn_output.unsqueeze(dim=2)
# [batch_size, sents_len, 1, hidden_size] -> [batch_size, sents_len, hidden_size]
attn_output = attn_output.squeeze(dim=2)
if return_scores:
return attn_output, attn
else:
return attn_output