Разъяснение в понимании TorchScripts и JIT на PyTorch
Просто хотел прояснить мое понимание того, как работают JIT и TorchScripts, и прояснить конкретный пример.
Так что если я не ошибаюсь
torch.jit.script
преобразует мой метод или модуль в TorchScript. Я могу использовать свой скомпилированный модуль TorchScript в среде за пределами Python, но также могу просто использовать его в Python с предполагаемыми улучшениями и оптимизациями. Аналогичный случай с
torch.jit.trace
где вместо этого прослеживаются веса и операции, но следует примерно такой же идее.
В этом случае модуль TorchScripted, как правило, должен быть по крайней мере таким же быстрым, как типичное время вывода интерпретатора Python. Немного поэкспериментировав, я заметил, что это чаще всего медленнее, чем типичное время вывода интерпретатора, и, немного прочитав, обнаружил, что, очевидно, модуль TorchScripted необходимо немного "разогреть", чтобы достичь максимальной производительности. При этом я не заметил никаких изменений как таковых во времени вывода, оно улучшилось, но не настолько, чтобы назвать улучшение по сравнению с типичным способом работы (интерпретатор python). Кроме того, я использовал стороннюю библиотеку под названием
torch_tvm
, который при включении предположительно вдвое сокращает время вывода для любого способа изменения модуля.
Ничего из этого не произошло до сих пор, и я не могу сказать почему.
Ниже приведен мой пример кода на случай, если я сделал что-то не так:
class TrialC(nn.Module):
def __init__(self):
super().__init__()
self.l1 = nn.Linear(1024, 2048)
self.l2 = nn.Linear(2048, 4096)
self.l3 = nn.Linear(4096, 4096)
self.l4 = nn.Linear(4096, 2048)
self.l5 = nn.Linear(2048, 1024)
def forward(self, input):
out = self.l1(input)
out = self.l2(out)
out = self.l3(out)
out = self.l4(out)
out = self.l5(out)
return out
if __name__ == '__main__':
# Trial inference input
TrialC_input = torch.randn(1, 1024)
warmup = 10
# Record time for typical inference
model = TrialC()
start = time.time()
model_out = model(TrialC_input)
elapsed = time.time() - start
# Record the 10th inference time (10 warmup) for the optimized model in TorchScript
script_model = torch.jit.script(TrialC())
for i in range(warmup):
start_2 = time.time()
model_out_check_2 = script_model(TrialC_input)
elapsed_2 = time.time() - start_2
# Record the 10th inference time (10 warmup) for the optimized model in TorchScript + TVM optimization
torch_tvm.enable()
script_model_2 = torch.jit.trace(TrialC(), torch.randn(1, 1024))
for i in range(warmup):
start_3 = time.time()
model_out_check_3 = script_model_2(TrialC_input)
elapsed_3 = time.time() - start_3
print("Regular model inference time: {}s\nJIT compiler inference time: {}s\nJIT Compiler with TVM: {}s".format(elapsed, elapsed_2, elapsed_3))
И следующие результаты приведенного выше кода на моем процессоре -
Regular model inference time: 0.10335588455200195s
JIT compiler inference time: 0.11449170112609863s
JIT Compiler with TVM: 0.10834860801696777s
Любая помощь или разъяснение по этому поводу были бы очень признательны!