Как рассчитать перекрестное внимание между трехмерным тензором и четырехмерным тензором? [закрыто]

В настоящее время я реализую одну из своих идей, которая включает в себя вычисление перекрестного внимания между трехмерным тензором и четырехмерным тензором. Я хочу добавить механизм внимания в установленный модуль абстракции PointNet++

добавьте пример:

      import torch

# a batch of point clouds and their features, 
# 4 is the batch size, 80000 is the number of points for each scan
xyz = torch.rand(4, 80000, 3)
features = torch.rand(4,80000,128)

# sample some reference point from raw point clouds
new_xyz = torch.rand(4, 1024, 3)
new_features = torch.rand(4, 1024, 128)

# For each reference point, ball query is adopted to find its 32 neighbor points
grouped_xyz = torch.rand(4, 1024, 32, 3)
grouped_features = torch.rand(4, 1024, 32, 128)

# Then I want to calculate the cross attention between each reference 
# point and its corresponding neighbor points
# (between new_features and grouped_features)
# Now I am trying to do something as follows:
new_features = new_features.view(4*1024, 1, 128)
grouped_features = grouped_features.view(4*1024, 32, 128)

# Then the cross attention can be calculate directly using the pytorch function



1 ответ

Не уверен, правильно ли я понял вопрос, но вот минимальный пример того, чего, я думаю, вы хотите достичь:

      def cross_attention(input_3d, input_4d, attention_size, batch_first=True, mask=None, return_scores=False):
    """
    :param input_3d: [batch_size, sents_len, hidden_size]
    :param input_4d: [batch_size, hidden_size, sents_len, 1]
    :param attention_size: attention_size
    :param batch_first: batch first or not
    :param mask: mask
    :param return_scores: return attention scores or not
    :return: cross attention output and attention scores if return_scores is True
    """
    batch_size, sents_len, hidden_size = input_3d.size()
    hidden_size = input_4d.size(1)
    # [batch_size, sents_len, hidden_size] * [batch_size, hidden_size, sents_len, 1] -> [batch_size, sents_len, sents_len, 1]
    attn = torch.matmul(input_3d, input_4d).unsqueeze(dim=-1)
    # [batch_size, sents_len, sents_len, 1] -> [batch_size, sents_len, sents_len]
    attn = attn.squeeze(dim=-1).masked_fill(mask, -np.inf)
    # [batch_size, sents_len, sents_len] -> [batch_size, sents_len, sents_len]
    attn = F.softmax(attn, dim=-1)
    # [batch_size, sents_len, sents_len] -> [batch_size, sents_len, sents_len]
    attn = F.dropout(attn, p=dropout, training=training)
    # [batch_size, sents_len, sents_len] * [batch_size, sents_len, hidden_size] -> [batch_size, sents_len, hidden_size]
    attn_output = torch.matmul(attn, input_3d)
    # [batch_size, sents_len, hidden_size] -> [batch_size, sents_len, 1, hidden_size]
    attn_output = attn_output.unsqueeze(dim=2)
    # [batch_size, sents_len, 1, hidden_size] -> [batch_size, sents_len, hidden_size]
    attn_output = attn_output.squeeze(dim=2)
    if return_scores:
        return attn_output, attn
    else:
        return attn_output
Другие вопросы по тегам