Функционализация дополнительных вычислений с прогнозированием

Я работаю с Jax и Stax. Базовый цикл прямой связи для сети stax выглядит примерно так:

def apply_fun(params, inputs, **kwargs):
    for fun, param, rng in zip(apply_funs, params):
        inputs = fun(param, inputs, **kwargs)
    return inputs

Если вы не знакомы с Jax, "funs & params" соответствуют послойным вычислениям (например, активациям или умножению матриц). Jax является функциональным, поэтому важно сохранить код без состояния и ограничить использование операторов потока управления.

Я пытаюсь реализовать алгоритм, который требует дополнительных вычислений во время цикла с прямой связью (если вам интересно, это раздел 2 Кроче и др. 18'доказуемая надежность сетей relu...).

Есть даже код PyTorch из другой бумаги: ссылка GH

Моя проблема в том, что я немного не практиковался, и код Jax, который я пытался адаптировать из приведенной выше статьи, не очень хорош, вызывает большие замедления и не компилируется в xla.

def pwl_fun(params, configs, **kwargs):
    lambdas = [np.diag(config) for config in configs]
    js = [np.diag(-2 * config + 1) for config in configs]

    # Compute Z_k = W_k * x + b_k for each layer
    wks = [params[0][0].T] #
    bks = [params[0][1].T]
    i = 0

    for p in params[1:]: # 2
        current_wk = wks[-1]
        current_bk = bks[-1]
        current_lambda = lambdas[i]
        c_len_p = (len(p) == 2)
        i = lax.cond(c_len_p, i, lambda i: i+1, i, lambda i: i)

        if c_len_p:
            precompute = np.matmul(p[0].T,current_lambda)
            wks.append(np.matmul(precompute,current_wk))
            bks.append(np.matmul(precompute,current_bk) + p[1])

    a_stack = []
    b_stack = []        
    for j, wk, bk in zip(js, wks, bks):
        a_stack.append(np.matmul(j,wk))
        b_stack.append(-np.matmul(j,bk))

    polytope_A = np.concatenate(a_stack)
    polytope_b = np.concatenate(b_stack)

    return polytope_A, polytope_b, wks[-1], bks[-1]

К вашему сведению: "конфигурации" - это список двоичных векторов разной длины.

Я выявил пару конкретных проблем:

  1. if c_len_p: Когда я перебираю слои / параметры сети, я хочу выполнить этот блок только в том случае, если параметр является матрицей весов (вы можете увидеть, что я пытался сделать с итератором 'i').
  2. Использование списков и добавления: я не думаю, что могу использовать здесь vmap, поскольку размеры продуктов могут быть разными, но я не знаю, как я могу выполнить удаление списков / добавлений.

Я был бы очень признателен за любой совет или направление по этому поводу. Как я уже сказал, мое функциональное программирование немного не практикуется.:)

0 ответов

Другие вопросы по тегам