Проблемы с ограничениями JIT и Numpy Jax

Недавно я начал экспериментировать с интересной библиотекой Python Jax, которая содержит улучшенный Numpy, а также автоматический дифференциатор. Я хотел попытаться создать грубый "дифференцируемый рендерер", написав шейдер и функцию потерь на python, а затем используя Jax AD для поиска градиента. Затем мы должны иметь возможность инвертировать изображение, запустив градиентный спуск для этого градиента потерь. Я добился неплохой работы с простыми шейдерами, но столкнулся с проблемами при использовании логических выражений. Это код моего шейдера, который генерирует узор шахматной доски:

import jax.numpy as np

class CheckerShader:

    def __init__(self, scale: float, color1: np.ndarray, color2: np.ndarray):
        self.color1 = None
        self.color2 = None
        self.scale = None
        self.scale_min = 0
        self.scale_max = 20
        self.color1 = color1
        self.color2 = color2
        self.scale = scale * 20

    def checker(self, x: float, y: float) -> float:
        xi = np.abs(np.floor(x))
        yi = np.abs(np.floor(y))

        first_col = np.mod(xi, 2) == np.mod(yi, 2)
        return first_col

    def shade(self, x: float, y: float):
        x = x * self.scale
        y = y * self.scale

        first_col = self.checker(x, y)
        if first_col:
            return self.color1
        else:
            return self.color2

И это моя функция рендеринга, которая в первую очередь терпит неудачу при JIT:

import jax.numpy as np
import numpy as onp
import jax

def render(scale, c1, c2):
    img = onp.zeros((WIDTH, HEIGHT, CHANNELS))
    sh = CheckerShader(scale, c1, c2)
    jit_func = jax.jit(sh.shade)

    for y in range(HEIGHT):
        for x in range(WIDTH):
            val = jit_func(x / WIDTH, y / HEIGHT)
            img[y, x, :] = val

    return img

Я получаю следующее сообщение об ошибке:

TypeError: Abstract value passed to `bool`, which requires a concrete value. The function to be transformed can't be traced at the required level of abstraction. If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions instead.

и я предполагаю, что это потому, что вы не можете запустить JIT для функции с логическими значениями, значения которых зависят от чего-то, что не решено во время компиляции. Но как его переписать для работы с JIT? Без JIT это мучительно медленно.

Другой вопрос, который у меня есть: могу ли я что-то сделать, чтобы ускорить Jax's Numpy в целом? Рендеринг моего изображения (100x100 пикселей) с помощью обычного Numpy занимает несколько миллисекунд, но с Jax Numpy это занимает секунды! Спасибо

2 ответа

Заменить

if first_col:
    return self.color1
else:
    return self.color2

с участием

return np.where(first_col, self.color1, self.color2)

Но как его переписать для работы с JIT?

У Иво есть хороший ответ - просто используйте np.where.

Другой вопрос, который у меня есть: могу ли я что-то сделать, чтобы ускорить Jax's Numpy в целом?

Вероятно, это происходит по трем причинам.

Первый - это природа JITing. При первом запуске кода он будет медленным, но если вы запустите один и тот же код несколько раз, скорость должна увеличиться. Я бы также попытался выполнить JIT для всей функции рендеринга, если это возможно, если вы планируете запускать это несколько раз.

Вторая причина в том, что переключение между numpy и jax.numpy будет очень медленным. Ты пишешь

img = onp.zeros((WIDTH, HEIGHT, CHANNELS))

но это будет намного быстрее, если вы напишете

img = np.zeros((WIDTH, HEIGHT, CHANNELS))

В-третьих, вы выполняете цикл по ширине и высоте, а не используете векторизованные операции. Я не понимаю, почему вы не можете сделать это в полностью векторизованной форме.

Другие вопросы по тегам