Python Redis Queue (rq) - как избежать предварительной загрузки модели ML для каждого задания?
Я хочу поставить в очередь свои прогнозы ml, используя rq. Пример кода (песо-иш):
predict.py
:
import tensorflow as tf
def predict_stuff(foo):
model = tf.load_model()
result = model.predict(foo)
return result
app.py
:
from rq import Queue
from redis import Redis
from predict import predict_stuff
q = Queue(connection=Redis())
for foo in baz:
job = q.enqueue(predict_stuff, foo)
worker.py
:
import sys
from rq import Connection, Worker
# Preload libraries
import tensorflow as tf
with Connection():
qs = sys.argv[1:] or ['default']
w = Worker(qs)
w.work()
Я прочитал rq docs, объясняющие, что вы можете предварительно загружать библиотеки, чтобы избежать их импорта при каждом запуске задания (поэтому в примере кода я импортирую тензорный поток в рабочий код). Однако я также хочу перенести загрузку модели из predict_stuff
чтобы не загружать модель каждый раз, когда рабочий запускает работу. Как я могу пойти по этому поводу?
1 ответ
Я не уверен, может ли это чем-то помочь, но, следуя примеру здесь:
https://github.com/rq/rq/issues/720
Вместо совместного использования пула соединений вы можете поделиться моделью.
псевдокод:
import tensorflow as tf
from rq import Worker as _Worker
from rq.local import LocalStack
_model_stack = LocalStack()
def get_model():
"""Get Model."""
m = _model_stack.top
try:
assert m
except AssertionError:
raise('Run outside of worker context')
return m
class Worker(_Worker):
"""Worker Class."""
def work(self, burst=False, logging_level='WARN'):
"""Work."""
_model_stack.push(tf.load_model())
return super().work(burst, logging_level)
def predict_stuff_job(foo):
model = get_model()
result = model.predict(foo)
return result
Я использую нечто похожее на это для "глобальной" программы для чтения файлов, которую я написал. Загрузите экземпляр в LocalStack и попросите рабочих считывать данные из стека.
В конце концов, я не понял, как это сделать с помощью python-rq. Я переехал в сельдерей, где я сделал это так:
app.py
from tasks import predict_stuff
for foo in baz:
task = transform_xml.delay(foo)
tasks.py
import tensorflow as tf
from celery import Celery
from celery.signals import worker_process_init
cel_app = Celery('tasks')
model = None
@worker_process_init.connect()
def on_worker_init(**_):
global model
model = tf.load_model()
@cel_app.task(name='predict_stuff')
def predict_stuff(foo):
result = model.predict(foo)
return result