Проблемы с ограничениями JIT и Numpy Jax
Недавно я начал экспериментировать с интересной библиотекой Python Jax, которая содержит улучшенный Numpy, а также автоматический дифференциатор. Я хотел попытаться создать грубый "дифференцируемый рендерер", написав шейдер и функцию потерь на python, а затем используя Jax AD для поиска градиента. Затем мы должны иметь возможность инвертировать изображение, запустив градиентный спуск для этого градиента потерь. Я добился неплохой работы с простыми шейдерами, но столкнулся с проблемами при использовании логических выражений. Это код моего шейдера, который генерирует узор шахматной доски:
import jax.numpy as np
class CheckerShader:
def __init__(self, scale: float, color1: np.ndarray, color2: np.ndarray):
self.color1 = None
self.color2 = None
self.scale = None
self.scale_min = 0
self.scale_max = 20
self.color1 = color1
self.color2 = color2
self.scale = scale * 20
def checker(self, x: float, y: float) -> float:
xi = np.abs(np.floor(x))
yi = np.abs(np.floor(y))
first_col = np.mod(xi, 2) == np.mod(yi, 2)
return first_col
def shade(self, x: float, y: float):
x = x * self.scale
y = y * self.scale
first_col = self.checker(x, y)
if first_col:
return self.color1
else:
return self.color2
И это моя функция рендеринга, которая в первую очередь терпит неудачу при JIT:
import jax.numpy as np
import numpy as onp
import jax
def render(scale, c1, c2):
img = onp.zeros((WIDTH, HEIGHT, CHANNELS))
sh = CheckerShader(scale, c1, c2)
jit_func = jax.jit(sh.shade)
for y in range(HEIGHT):
for x in range(WIDTH):
val = jit_func(x / WIDTH, y / HEIGHT)
img[y, x, :] = val
return img
Я получаю следующее сообщение об ошибке:
TypeError: Abstract value passed to `bool`, which requires a concrete value. The function to be transformed can't be traced at the required level of abstraction. If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions instead.
и я предполагаю, что это потому, что вы не можете запустить JIT для функции с логическими значениями, значения которых зависят от чего-то, что не решено во время компиляции. Но как его переписать для работы с JIT? Без JIT это мучительно медленно.
Другой вопрос, который у меня есть: могу ли я что-то сделать, чтобы ускорить Jax's Numpy в целом? Рендеринг моего изображения (100x100 пикселей) с помощью обычного Numpy занимает несколько миллисекунд, но с Jax Numpy это занимает секунды! Спасибо
2 ответа
Заменить
if first_col:
return self.color1
else:
return self.color2
с участием
return np.where(first_col, self.color1, self.color2)
Но как его переписать для работы с JIT?
У Иво есть хороший ответ - просто используйте np.where
.
Другой вопрос, который у меня есть: могу ли я что-то сделать, чтобы ускорить Jax's Numpy в целом?
Вероятно, это происходит по трем причинам.
Первый - это природа JITing. При первом запуске кода он будет медленным, но если вы запустите один и тот же код несколько раз, скорость должна увеличиться. Я бы также попытался выполнить JIT для всей функции рендеринга, если это возможно, если вы планируете запускать это несколько раз.
Вторая причина в том, что переключение между numpy и jax.numpy будет очень медленным. Ты пишешь
img = onp.zeros((WIDTH, HEIGHT, CHANNELS))
но это будет намного быстрее, если вы напишете
img = np.zeros((WIDTH, HEIGHT, CHANNELS))
В-третьих, вы выполняете цикл по ширине и высоте, а не используете векторизованные операции. Я не понимаю, почему вы не можете сделать это в полностью векторизованной форме.