Можно ли запустить scatter matmul в pytorch?
У меня есть несколько типов вложений, и каждому нужна своя линейная проекция. Я могу решить проблему с помощью цикла for типа:
for ntype in ntypes: emb_out = self.lin_layer[ntype](emb[ntype])
, но в идеале я хотел сделать какую-то операцию разброса, чтобы запустить ее параллельно. Что-то типа:
pytorch_scatter(lin_layers, embeddings, layer_map, reduce='matmul')
, где:
lin_layers.weights.shape = (emb.size, output.size)
embeddings.shape = (batch_size , emb_size)
layer_map.shape = batch_size
Если у меня есть 2 типа линейных слоев и batch_size = 5, то layer_map будет что-то вроде [1,0,1,1,0]. Было бы возможно?
Ниже приведен пример, показывающий ускорение примерно в 20 раз при удалении цикла for из процесса.
import torch
import random
device = 'cuda'
ntypes = [str(i) for i in range(20)]
emb_size = 32
seed = 42
torch.manual_seed(seed)
random.seed(seed)
#Create linear layers
lin_layers = torch.nn.ModuleDict()
for ntype in ntypes:
lin_layers[ntype] = torch.nn.Linear(emb_size, emb_size).to(device)
single_lin_layer = torch.nn.Linear(emb_size, emb_size).to(device)
#create embedding layer
emb_layer = torch.nn.Embedding(num_embeddings = 10000, embedding_dim = emb_size).to(device)
#generate random embedding indices
emb_sep = dict()
emb_cat = []
for ntype in ntypes:
emb_sep[ntype] = emb_layer(torch.randint(low = 0, high = 10000-1, size=(random.randint(0,5000), ), device=device))
emb_cat.append(emb_sep [ntype])
emb_cat = torch.cat(emb_cat)
def lin_layers_with_loop():
output = dict()
for ntype in ntypes:
output[ntype] = lin_layers[ntype](emb_sep[ntype])
return output
def lin_layers_without_loop():
output = single_lin_layer(emb_cat)
return output
#18.4X speed when device = 'cuda' with ntypes = 20
%timeit lin_layers_with_loop()
#>>>> 993 µs ± 3.88 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
%timeit lin_layers_without_loop()
#>>>> 54.1 µs ± 392 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
Связанный с этим вопрос stackoverflow: как векторизовать операцию scatter-matmul
Связанная проблема в pytorch: https://github.com/pytorch/pytorch/issues/31942
1 ответ
Только что узнал, что DGL уже работает над этой функцией: https://github.com/dmlc/dgl/pull/3641 .