Компиляция TVM создает несовпадающие матрицы

Я определил следующую модель на основе BERT, используя PyTorch:

       class BERTGRUSentiment(nn.Module):
   13     def __init__(self,
   14                  bert,
   15                  hidden_dim,
   16                  output_dim,
   17                  n_layers,
   18                  bidirectional,
   19                  dropout=0):
S> 20         
   21         super().__init__()
S> 22         
   23         self.bert = bert
S> 24         
   25         embedding_dim = bert.config.to_dict()['hidden_size']
S> 26         
   27         self.rnn = nn.GRU(embedding_dim,
   28                           hidden_dim,
S> 29                           num_layers = n_layers,
S> 30                           bidirectional = True,
S> 31                           batch_first = True,
S> 32                           dropout = 0 if (n_layers < 2)  else dropout)
S> 33         
S> 34         self.out = nn.Linear(hidden_dim * 2 if bidirectional else hidden_dim, output_dim)
S> 35         
   36         self.dropout = nn.Dropout(dropout)
S> 37         
   38     def forward(self, text):
S> 39         
   40         # text = [batch size, sent len]
S> 41                 
   42         with torch.no_grad():
   43             embedded = self.bert(text)[0]
S> 44                 
S> 45         #embedded = [batch size, sent len, emb dim]
S> 46         
   47         _, hidden = self.rnn(embedded)
S> 48         
S> 49         #hidden = [n layers * n directions, batch size, emb dim]
S> 50         hidden = torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim = 1)
S> 51         
   52         hidden = self.dropout(hidden)
S> 53                     
S> 54         #hidden = [batch size, hid dim]
S> 55         
   56         output = self.out(hidden)
S> 57         
S> 58         #output = [batch size, out dim]
S> 59         
   60         return output

Я могу без проблем обучать и запускать эту модель.

Однако, когда я пытаюсь использовать ApacheTVMскомпилировать модель следующим образом:

      

  1 import model
  2 import torch
  3 from transformers import BertModel
  4 import tvm
  5 from tvm import relay
  6 
  7 import sys 
  8 sys.setrecursionlimit(1000000)
  9 
 10 bert = BertModel.from_pretrained('bert-base-uncased')
 11 embedding_dim = bert.config.to_dict()['max_position_embeddings']
 12 
 13 device = "cuda"
 14 
 15 HIDDEN_DIM = 256   
 16 OUTPUT_DIM = 1     
 17 N_LAYERS = 2       
 18 BIDIRECTIONAL = True
 19 DROPOUT = 0        
 20 
 21 bert = model.BERTGRUSentiment(bert, HIDDEN_DIM, OUTPUT_DIM, N_LAYERS, BIDIRECTIONAL, DROPOUT)
 22 bert.eval()
 23 bert.to(device)
 24 
 25 print(f"cuda: {next(bert.parameters()).is_cuda}")
 26 print(f"training: {bert.training}")
 27 
 28 input_name = "text"         
 29 input_shape = [1, embedding_dim]
 30 
 31 shape_list = [(input_name, input_shape)]
 32 
 33 example = model.preprocess(model.tokenizer, "it was ok").unsqueeze(0)
 34 
 35 scripted_model = torch.jit.trace(bert, example).eval()
 36 # scripted_model = torch.jit.script(bert)
 37 
 38 print(scripted_model.graph)
 39 
 40 print("starting")
 41 
 42 mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)
 43 
 44 print(mod)        
 45 
 46 target = tvm.target.cuda()
 47 
 48 with tvm.transform.PassContext(opt_level=3):
 49     lib = relay.build(mod, target=target, params=params)
 50 
 51 lib.export_library("compiled.so")
 52 

Я получаю следующую ошибку:

      Traceback (most recent call last):
  File "compile.py", line 42, in <module>
    mod, params = relay.frontend.from_pytorch(scripted_model, shape_list)
  File "/home/moe/tvm/python/tvm/relay/frontend/pytorch.py", line 5008, in from_pytorch
    outputs = converter.convert_operators(operator_nodes, outputs, ret_name)
  File "/home/moe/tvm/python/tvm/relay/frontend/pytorch.py", line 4272, in convert_operators
    _get_input_types(op_node, outputs, default_dtype=self.default_dtype),
  File "/home/moe/tvm/python/tvm/relay/frontend/pytorch.py", line 1783, in linear
    [inputs[0], _op.transpose(inputs[1], axes=(1, 0))], input_types[:2]
  File "/home/moe/tvm/python/tvm/relay/frontend/pytorch.py", line 1976, in matmul
    raise AssertionError(msg)
AssertionError: Tensors being multiplied do not have compatible shapes.

Я добавил ведение журнала в файлы tvm, и кажется, что рассматриваемые размеры(2,2, 256)и(1, 512). Также кажется, что эта ошибка возникает и вблизи уровня GRU. Могу ли я предпринять какие-либо дополнительные действия для устранения или устранения этой проблемы?

0 ответов

Другие вопросы по тегам