Можно ли запустить scatter matmul в pytorch?

У меня есть несколько типов вложений, и каждому нужна своя линейная проекция. Я могу решить проблему с помощью цикла for типа: for ntype in ntypes: emb_out = self.lin_layer[ntype](emb[ntype]), но в идеале я хотел сделать какую-то операцию разброса, чтобы запустить ее параллельно. Что-то типа: pytorch_scatter(lin_layers, embeddings, layer_map, reduce='matmul'), где:

      lin_layers.weights.shape = (emb.size, output.size)
embeddings.shape = (batch_size , emb_size)
layer_map.shape = batch_size 

Если у меня есть 2 типа линейных слоев и batch_size = 5, то layer_map будет что-то вроде [1,0,1,1,0]. Было бы возможно?

Ниже приведен пример, показывающий ускорение примерно в 20 раз при удалении цикла for из процесса.

      import torch
import random

device = 'cuda'
ntypes = [str(i) for i in range(20)]
emb_size = 32
seed = 42
torch.manual_seed(seed)
random.seed(seed)

#Create linear layers
lin_layers = torch.nn.ModuleDict()
for ntype in ntypes:
    lin_layers[ntype] = torch.nn.Linear(emb_size, emb_size).to(device)

single_lin_layer = torch.nn.Linear(emb_size, emb_size).to(device)

#create embedding layer
emb_layer = torch.nn.Embedding(num_embeddings = 10000, embedding_dim = emb_size).to(device)

#generate random embedding indices
emb_sep = dict()
emb_cat = []

for ntype in ntypes:
    emb_sep[ntype] = emb_layer(torch.randint(low = 0, high = 10000-1, size=(random.randint(0,5000), ), device=device))
    emb_cat.append(emb_sep [ntype])

emb_cat = torch.cat(emb_cat)

def lin_layers_with_loop():

    output = dict()
    for ntype in ntypes:
        output[ntype] = lin_layers[ntype](emb_sep[ntype])

    return output

def lin_layers_without_loop():
    output = single_lin_layer(emb_cat)
    return output

#18.4X speed when device = 'cuda' with ntypes = 20

%timeit lin_layers_with_loop()
#>>>> 993 µs ± 3.88 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

%timeit lin_layers_without_loop()
#>>>> 54.1 µs ± 392 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)

Связанный с этим вопрос stackoverflow: как векторизовать операцию scatter-matmul

Связанная проблема в pytorch: https://github.com/pytorch/pytorch/issues/31942

1 ответ

Только что узнал, что DGL уже работает над этой функцией: https://github.com/dmlc/dgl/pull/3641 .

Другие вопросы по тегам