Функционализация дополнительных вычислений с прогнозированием
Я работаю с Jax и Stax. Базовый цикл прямой связи для сети stax выглядит примерно так:
def apply_fun(params, inputs, **kwargs):
for fun, param, rng in zip(apply_funs, params):
inputs = fun(param, inputs, **kwargs)
return inputs
Если вы не знакомы с Jax, "funs & params" соответствуют послойным вычислениям (например, активациям или умножению матриц). Jax является функциональным, поэтому важно сохранить код без состояния и ограничить использование операторов потока управления.
Я пытаюсь реализовать алгоритм, который требует дополнительных вычислений во время цикла с прямой связью (если вам интересно, это раздел 2 Кроче и др. 18'доказуемая надежность сетей relu...).
Есть даже код PyTorch из другой бумаги: ссылка GH
Моя проблема в том, что я немного не практиковался, и код Jax, который я пытался адаптировать из приведенной выше статьи, не очень хорош, вызывает большие замедления и не компилируется в xla.
def pwl_fun(params, configs, **kwargs):
lambdas = [np.diag(config) for config in configs]
js = [np.diag(-2 * config + 1) for config in configs]
# Compute Z_k = W_k * x + b_k for each layer
wks = [params[0][0].T] #
bks = [params[0][1].T]
i = 0
for p in params[1:]: # 2
current_wk = wks[-1]
current_bk = bks[-1]
current_lambda = lambdas[i]
c_len_p = (len(p) == 2)
i = lax.cond(c_len_p, i, lambda i: i+1, i, lambda i: i)
if c_len_p:
precompute = np.matmul(p[0].T,current_lambda)
wks.append(np.matmul(precompute,current_wk))
bks.append(np.matmul(precompute,current_bk) + p[1])
a_stack = []
b_stack = []
for j, wk, bk in zip(js, wks, bks):
a_stack.append(np.matmul(j,wk))
b_stack.append(-np.matmul(j,bk))
polytope_A = np.concatenate(a_stack)
polytope_b = np.concatenate(b_stack)
return polytope_A, polytope_b, wks[-1], bks[-1]
К вашему сведению: "конфигурации" - это список двоичных векторов разной длины.
Я выявил пару конкретных проблем:
if c_len_p:
Когда я перебираю слои / параметры сети, я хочу выполнить этот блок только в том случае, если параметр является матрицей весов (вы можете увидеть, что я пытался сделать с итератором 'i').- Использование списков и добавления: я не думаю, что могу использовать здесь vmap, поскольку размеры продуктов могут быть разными, но я не знаю, как я могу выполнить удаление списков / добавлений.
Я был бы очень признателен за любой совет или направление по этому поводу. Как я уже сказал, мое функциональное программирование немного не практикуется.:)