NTT с умножением Монтгомери

Последние несколько дней я пытался помочь мистеру Спектре, который из-за проблем совместимости должен был написать свое собственное Теоретическое преобразование чисел для БПФ-умножения.
Модульная арифметика и NTT (конечное поле DFT) оптимизации

У него есть тот, который работает просто отлично, но он задавался вопросом, есть ли способы ускорить его. Одна мысль, которая пришла в голову, состояла в том, чтобы использовать Умножение Монтгомери, чтобы избежать чрезмерного разделения. Я использовал это в прошлом, но по некоторым причинам я не могу заставить это работать здесь, и я не уверен, является ли это проблемой с Умножением Монтгомери или NTT.

Он использует 32-битный размер слова, поэтому сокращение также составляет 2^32, а основной модуль - 3221225473. Использование Ext. Евклидов алгоритм, я нашел обратное:
2^32 * 2415919104 = (3221225473 * 3221225471) + 1

Ниже приведен код, над которым я работаю, с основной функцией, которая его вызывает.
ПРИМЕЧАНИЕ. В настоящее время я не беспокоюсь об обратном преобразовании, поскольку нет смысла, если обычное преобразование вообще не работает.

#include <string.h>


#ifndef uint32
#define uint32 unsigned long int
#endif

#ifndef uint64
#define uint64 unsigned long long int
#endif


class montgom_ntt                                   // number theoretic transform
{
public:
    montgom_ntt()
    {
        r = 0; L = 0; 
        W = 0, N = 0;
    }
    // main interface
    void  NTT(uint32 *dst, uint32 *src, uint32 n = 0);             // uint32 dst[n] = fast  NTT(uint32 src[n])

private:
    bool init(uint32 n);                                     // init r,L,p,W,iW,rN
    void NTT_calc(uint32 *dst, uint32 *src, uint32 n, uint32 w);  // uint32 dst[n] = fast  NTT(uint32 src[n])

    void  NTT_fast(uint32 *dst, uint32 *src, uint32 n, uint32 w);  // uint32 dst[n] = fast  NTT(uint32 src[n])
    void NTT_fast(uint32 *dst, const uint32 *src, uint32 n, uint32 w);

    // uint32 arithmetics
    public:
    uint32 montgom_in(uint32 n);
    uint32 montgom_out(uint32 n);
    void montgom_in_arr(uint32* dst, const uint32* src, uint32 n);
    void montgom_out_arr(uint32* dst, const uint32* src, uint32 n);

    private:

    // modular arithmetics
    inline uint32 modadd(uint32 a, uint32 b);
    inline uint32 modsub(uint32 a, uint32 b);
    inline uint32 modmul(uint32 a, uint32 b);
    inline uint32 modpow(uint32 a, uint32 b);

    uint32 r, L, N, W;

    const uint32 p = 0xC0000001;
    const uint64 px = 0xC0000001;
};

//---------------------------------------------------------------------------
bool montgom_ntt::init(uint32 n)
{
    // (max(src[])^2)*n < p else NTT overflow can ocur !!!
    r = 2;

    if ((n < 2) || (n > 0x10000000))
    {
        r = 0; L = 0; W = 0; // p = 0;
        iW = 0; rN = 0; N = 0;
        return false;
    }
    L = 0x30000000 / n; // 32:30 bit best for unsigned 32 bit

    N = n;               // size of vectors [uint32s]
    W = modpow(r, L); // Wn for NTT
    W = montgom_in(W);

    return true;
}

//---------------------------------------------------------------------------
void montgom_ntt::NTT(uint32 *dst, uint32 *src, uint32 n)
{
    if (n > 0)
    {
        init(n);
    }
    NTT_fast(dst, src, N, W);
}


//---------------------------------------------------------------------------
void montgom_ntt::NTT_fast(uint32 *dst, uint32 *src, uint32 n, uint32 w)
{
    if (n > 1)
    {
        if (dst != src)
        {
            NTT_calc(dst, src, n, w);
        }
        else
        {
            uint32* temp = new uint32[n];
            memcpy(temp, src, sizeof(uint32) * n);
            NTT_calc(dst, temp, n, w);

            delete[] temp;
        }
    }
    else if (n == 1)
    {
        dst[0] = src[0];
    }
}


void montgom_ntt::NTT_calc(uint32 *dst, uint32 *src, uint32 n, uint32 w)
{
    if (n > 1)
    {
        uint32 i, j, a0, a1,
            n2 = n >> 1,
            w2 = modmul(w, w);

        // reorder even,odd
        for (i = 0, j = 0; i < n2; i++, j += 2)
        {
            dst[i] = src[j];
        }
        for (j = 1; i < n; i++, j += 2)
        {
            dst[i] = src[j];
        }
        // recursion
        if (n2 > 1)
        {
            NTT_calc(src, dst, n2, w2);  // even
            NTT_calc(src + n2, dst + n2, n2, w2);  // odd
        }
        else if (n2 == 1)
        {
            src[0] = dst[0];
            src[1] = dst[1];
        }

        // restore results

        w2 = 1, i = 0, j = n2;
        a0 = src[i];
        a1 = src[j];

        dst[i] = modadd(a0, a1);
        dst[j] = modsub(a0, a1);
        while (++i < n2)
        {
            w2 = modmul(w2, w);
            j++;
            a0 = src[i];
            a1 = modmul(src[j], w2);
            dst[i] = modadd(a0, a1);
            dst[j] = modsub(a0, a1);
        }
    }
}

//---------------------------------------------------------------------------
void montgom_ntt::NTT_slow(uint32 *dst, uint32 *src, uint32 n, uint32 w)
{
    uint32 i, j, wj, wi, a,
        n2 = n >> 1;
    for (wj = 1, j = 0; j < n; j++)
    {
        a = 0;
        for (wi = 1, i = 0; i < n; i++)
        {
            a = modadd(a, modmul(wi, src[i]));
            wi = modmul(wi, wj);
        }
        dst[j] = a;
        wj = modmul(wj, w);
    }
}


//---------------------------------------------------------------------------
uint32 montgom_ntt::montgom_in(uint32 n)
{
    uint64 N = n;
    N = (N << 32) % px;
    return N;
}

//---------------------------------------------------------------------------
uint32 montgom_ntt::montgom_out(uint32 n)
{
    const uint64 C = 0x90000000;
    uint64 N = n;
    N *= C;
    N %= px;

    return N;
}

//---------------------------------------------------------------------------
void montgom_ntt::montgom_in_arr(uint32* dst, const uint32* src, uint32 n)
{
    uint32 I = 0;

    do
    {
        dst[I] = montgom_in(src[I]);
    } while (++I < n);
}

//---------------------------------------------------------------------------
void montgom_ntt::montgom_out_arr(uint32* dst, const uint32* src, uint32 n)
{
    uint32 I = 0;

    do
    {
        dst[I] = montgom_out(src[I]);
    } while (++I < n);
}

//---------------------------------------------------------------------------
uint32 montgom_ntt::modadd(uint32 a, uint32 b)
{
    uint32 n = a + b;
    if (n < a)
    {
        n -= p;
    }
    else if (n >= p)
    {
        n -= p;
    }
    return n;
}

//---------------------------------------------------------------------------
uint32 montgom_ntt::modsub(uint32 a, uint32 b)
{
    uint32 d;

    d = a - b;
    if(a < b)
    {
        d += p;
        d = (a + p) - b;
    }
    return d;
}

//---------------------------------------------------------------------------
uint32 montgom_ntt::modmul(uint32 a, uint32 b)
{
    uint64 A(a), B(b), C;
    uint32 R;

    A *= B;
    C = A & 0xFFFFFFFF;
    C *= 0xBFFFFFFF;

    C = (C & 0xFFFFFFFF) * px;
    C += A;
    R = (C >> 32);
    if(C < A)
    {
        R -= p;
    }
    if(R >= p) 
    {
        R -= p;
    }   
    return R;

}

//---------------------------------------------------------------------------
uint32 montgom_ntt::modpow(uint32 a, uint32 b)
{
    //*
    uint64 D, M, A;

    P = p; A = a;
    M = 0llu - (b & 1);
    D = (M & A) | ((~M) & 1);

    while ((b >>= 1) != 0)
    {
        A = (A * A) % P;

        if ((b & 1) == 1)
        {
            D = (D * A) % P;
        }
    }
    return (uint32)D;
}

и вот главное

void main()
{
    montgom_ntt F;

    uint32 Tran[8];
    uint32 Arr[8] = 
    {
        0x2923, 0xbe84,
        0xe16c, 0xd6ae,
        0, 0, 0, 0
    };

    F.montgom_in_arr(Arr1, Arr1, Len);
    F.NTT(Tran, Arr, Len);
    F.montgom_out_arr(Tran, Tran, Len);

}

У меня такое ощущение, что это что-то действительно простое, но я не могу понять, что это такое. Спасибо за любую помощь, которую вы, ребята, можете предоставить!

[Править] Итак, чтобы исключить это, я изменил функцию модуля так, чтобы она преобразовывала свои данные из формы Монтгомери в обычную форму, выполнил стандартный (A * B) % p, а затем преобразовал его обратно в форму Монтгомери и Я все еще получил тот же, неправильный ответ. Это заставляет меня думать, что проблема заключается в преобразовании в форму Монтгомери и обратно, но я понятия не имею, что я сделал не так.

uint32 montgom_ntt::modmul(uint32 a, uint32 b)
{
    uint64 A, B, C;

    A = montgom_out(a);
    B = montgom_out(b);
    C = (A * B) % px;
    return montgom_in(C);        

    /*
    uint64 A(a), B(b), C;
    uint32 R;

    A *= B;
    C = A & 0xFFFFFFFF;
    C *= 0xBFFFFFFF;

    C = (C & 0xFFFFFFFF) * px;
    C += A;
    R = (C >> 32);
    if(C < A)
    {
        R -= p;
    }
    if(R >= p) 
    {
        R -= p;
    }   
    return R;
    */
}

0 ответов

Другие вопросы по тегам