Работа с Google JAX на сервере gunicorn/flask
Я хочу обслуживать приложение, которое обрабатывает данные во фреймворке Googles JAX с помощью flask и gunicorn.
Если запустить внутри колбы, все работает нормально. Как только я запускаю приложение в gunicorn, каждая часть, связанная с jax, приводит к тому, что рабочий процесс умирает без возникновения каких-либо исключений. Я пробовал использовать как синхронизацию, так и gthreads в качестве рабочих, но с тем же результатом.
Я попытался проверить, может ли JAX обрабатывать многопроцессорность и многопоточность, заключая одни и те же вызовы в ThreadPoolExecutor и ProcessPoolExecutor, и это работает безупречно.
import jax
import logging
logging.basicConfig(format="%(asctime)s | %(name)12.12s | %(message)s")
logger = logging.getLogger("Main")
logger.setLevel(logging.DEBUG)
from concurrent.futures import ThreadPoolExecutor, ProcessPoolExecutor, as_completed
from fit.optimization.vectorize import BatchNumpyInterface, batch_calculate_fit
def warmup():
logger.debug("Warmup")
data = BatchNumpyInterface.generate_dummy()
batch_calculate_fit(data)
logger.debug("Warmed up")
def run_fn():
logger.debug("Creating data")
data = BatchNumpyInterface.generate_dummy(100)
logger.debug("Predicting %s in batches", 100)
result = batch_calculate_fit(data)
logger.debug("Done")
return float(result[0][0]), float(result[1][0])
#with ThreadPoolExecutor(max_workers=4) as executor:
with ProcessPoolExecutor(max_workers=4) as executor:
results = []
for i in range(4):
results.append(executor.submit(warmup))
for res in as_completed(results):
continue
results = []
for i in range(10):
future = executor.submit(run_fn)
results.append(future)
for res in as_completed(results):
print(res.result())
Во время отладки каждый раз, когда я проверяю JAX DeviceArray, приложение вылетает. То же самое касается перехода через первый расчет с JAX.
Любая помощь приветствуется!