Слишком медленный первый запуск модели TorchScript и ее реализация во Flask
Я пытаюсь развернуть модель с факелами в Python и Flask. Как я понял (по крайней мере, как упоминалось здесь), скриптовые модели необходимо "разогреть" перед использованием, поэтому первый запуск таких моделей занимает гораздо больше времени, чем последующие. Мой вопрос: есть ли способ загрузить модели со сценариями torchscripted в Flask route и прогнозировать без потери времени на "червячок"? Можно ли хранить где-нибудь "прогретую" модель, чтобы не прогреваться при каждой просьбе? Я написал простой код, воспроизводящий этап "разминки":
import torchvision, torch, time
model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
model = torch.jit.script(model)
model.eval()
x = [torch.randn((3,224,224))]
for i in range(3):
start = time.time()
model(x)
print(‘Time elapsed: {}’.format(time.time()-start))
Выход:
Time elapsed: 38.29<br>
Time elapsed: 6.65<br>
Time elapsed: 6.65<br>
И код Flask:
import torch, torchvision, os, time
from flask import Flask
app = Flask(__name__)
@app.route('/')
def test_scripted_model(path='/tmp/scripted_model.pth'):
if os.path.exists(path):
model = torch.jit.load(path, map_location='cpu')
else:
model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True)
model = torch.jit.script(model)
torch.jit.save(model, path)
model.eval()
x = [torch.randn((3, 224, 224))]
out = ''
for i in range(3):
start = time.time()
model(x)
out += 'Run {} time: {};\t'.format(i+1, round((time.time() - start), 2))
return out
if __name__ == '__main__':
app.run(host='0.0.0.0', port=5000, debug=False)
Выход:
Run 1 time: 46.01; Run 2 time: 8.76; Run 3 time: 8.55;
ОС: Ubuntu 18.04 и Windows10
Версия Python: 3.6.9
Flask: 1.1.1
Torch: 1.4.0
Torchvision: 0.5.0
Обновить:
Решенная проблема "разминки" как:
with torch.jit.optimized_execution(False):
model(x)
Обновление 2: решена проблема Flask (как указано ниже) с созданием глобального объекта модели Python перед запуском сервера и его разогревом. Затем в каждом запросе модель готова к использованию.
model = torch.jit.load(path, map_location='cpu').eval()
model(x)
app = Flask(__name__)
а затем в @app.route:
@app.route('/')
def test_scripted_model():
global model
...
...
1 ответ
Могу ли я хранить где-нибудь "прогретую" модель, чтобы не прогреваться при каждой просьбе?
Да, просто создайте экземпляр своей модели вне test_scripted_model
функция и обращайтесь к ней изнутри функции.