NTT с умножением Монтгомери
Последние несколько дней я пытался помочь мистеру Спектре, который из-за проблем совместимости должен был написать свое собственное Теоретическое преобразование чисел для БПФ-умножения.
Модульная арифметика и NTT (конечное поле DFT) оптимизации
У него есть тот, который работает просто отлично, но он задавался вопросом, есть ли способы ускорить его. Одна мысль, которая пришла в голову, состояла в том, чтобы использовать Умножение Монтгомери, чтобы избежать чрезмерного разделения. Я использовал это в прошлом, но по некоторым причинам я не могу заставить это работать здесь, и я не уверен, является ли это проблемой с Умножением Монтгомери или NTT.
Он использует 32-битный размер слова, поэтому сокращение также составляет 2^32, а основной модуль - 3221225473. Использование Ext. Евклидов алгоритм, я нашел обратное:
2^32 * 2415919104 = (3221225473 * 3221225471) + 1
Ниже приведен код, над которым я работаю, с основной функцией, которая его вызывает.
ПРИМЕЧАНИЕ. В настоящее время я не беспокоюсь об обратном преобразовании, поскольку нет смысла, если обычное преобразование вообще не работает.
#include <string.h>
#ifndef uint32
#define uint32 unsigned long int
#endif
#ifndef uint64
#define uint64 unsigned long long int
#endif
class montgom_ntt // number theoretic transform
{
public:
montgom_ntt()
{
r = 0; L = 0;
W = 0, N = 0;
}
// main interface
void NTT(uint32 *dst, uint32 *src, uint32 n = 0); // uint32 dst[n] = fast NTT(uint32 src[n])
private:
bool init(uint32 n); // init r,L,p,W,iW,rN
void NTT_calc(uint32 *dst, uint32 *src, uint32 n, uint32 w); // uint32 dst[n] = fast NTT(uint32 src[n])
void NTT_fast(uint32 *dst, uint32 *src, uint32 n, uint32 w); // uint32 dst[n] = fast NTT(uint32 src[n])
void NTT_fast(uint32 *dst, const uint32 *src, uint32 n, uint32 w);
// uint32 arithmetics
public:
uint32 montgom_in(uint32 n);
uint32 montgom_out(uint32 n);
void montgom_in_arr(uint32* dst, const uint32* src, uint32 n);
void montgom_out_arr(uint32* dst, const uint32* src, uint32 n);
private:
// modular arithmetics
inline uint32 modadd(uint32 a, uint32 b);
inline uint32 modsub(uint32 a, uint32 b);
inline uint32 modmul(uint32 a, uint32 b);
inline uint32 modpow(uint32 a, uint32 b);
uint32 r, L, N, W;
const uint32 p = 0xC0000001;
const uint64 px = 0xC0000001;
};
//---------------------------------------------------------------------------
bool montgom_ntt::init(uint32 n)
{
// (max(src[])^2)*n < p else NTT overflow can ocur !!!
r = 2;
if ((n < 2) || (n > 0x10000000))
{
r = 0; L = 0; W = 0; // p = 0;
iW = 0; rN = 0; N = 0;
return false;
}
L = 0x30000000 / n; // 32:30 bit best for unsigned 32 bit
N = n; // size of vectors [uint32s]
W = modpow(r, L); // Wn for NTT
W = montgom_in(W);
return true;
}
//---------------------------------------------------------------------------
void montgom_ntt::NTT(uint32 *dst, uint32 *src, uint32 n)
{
if (n > 0)
{
init(n);
}
NTT_fast(dst, src, N, W);
}
//---------------------------------------------------------------------------
void montgom_ntt::NTT_fast(uint32 *dst, uint32 *src, uint32 n, uint32 w)
{
if (n > 1)
{
if (dst != src)
{
NTT_calc(dst, src, n, w);
}
else
{
uint32* temp = new uint32[n];
memcpy(temp, src, sizeof(uint32) * n);
NTT_calc(dst, temp, n, w);
delete[] temp;
}
}
else if (n == 1)
{
dst[0] = src[0];
}
}
void montgom_ntt::NTT_calc(uint32 *dst, uint32 *src, uint32 n, uint32 w)
{
if (n > 1)
{
uint32 i, j, a0, a1,
n2 = n >> 1,
w2 = modmul(w, w);
// reorder even,odd
for (i = 0, j = 0; i < n2; i++, j += 2)
{
dst[i] = src[j];
}
for (j = 1; i < n; i++, j += 2)
{
dst[i] = src[j];
}
// recursion
if (n2 > 1)
{
NTT_calc(src, dst, n2, w2); // even
NTT_calc(src + n2, dst + n2, n2, w2); // odd
}
else if (n2 == 1)
{
src[0] = dst[0];
src[1] = dst[1];
}
// restore results
w2 = 1, i = 0, j = n2;
a0 = src[i];
a1 = src[j];
dst[i] = modadd(a0, a1);
dst[j] = modsub(a0, a1);
while (++i < n2)
{
w2 = modmul(w2, w);
j++;
a0 = src[i];
a1 = modmul(src[j], w2);
dst[i] = modadd(a0, a1);
dst[j] = modsub(a0, a1);
}
}
}
//---------------------------------------------------------------------------
void montgom_ntt::NTT_slow(uint32 *dst, uint32 *src, uint32 n, uint32 w)
{
uint32 i, j, wj, wi, a,
n2 = n >> 1;
for (wj = 1, j = 0; j < n; j++)
{
a = 0;
for (wi = 1, i = 0; i < n; i++)
{
a = modadd(a, modmul(wi, src[i]));
wi = modmul(wi, wj);
}
dst[j] = a;
wj = modmul(wj, w);
}
}
//---------------------------------------------------------------------------
uint32 montgom_ntt::montgom_in(uint32 n)
{
uint64 N = n;
N = (N << 32) % px;
return N;
}
//---------------------------------------------------------------------------
uint32 montgom_ntt::montgom_out(uint32 n)
{
const uint64 C = 0x90000000;
uint64 N = n;
N *= C;
N %= px;
return N;
}
//---------------------------------------------------------------------------
void montgom_ntt::montgom_in_arr(uint32* dst, const uint32* src, uint32 n)
{
uint32 I = 0;
do
{
dst[I] = montgom_in(src[I]);
} while (++I < n);
}
//---------------------------------------------------------------------------
void montgom_ntt::montgom_out_arr(uint32* dst, const uint32* src, uint32 n)
{
uint32 I = 0;
do
{
dst[I] = montgom_out(src[I]);
} while (++I < n);
}
//---------------------------------------------------------------------------
uint32 montgom_ntt::modadd(uint32 a, uint32 b)
{
uint32 n = a + b;
if (n < a)
{
n -= p;
}
else if (n >= p)
{
n -= p;
}
return n;
}
//---------------------------------------------------------------------------
uint32 montgom_ntt::modsub(uint32 a, uint32 b)
{
uint32 d;
d = a - b;
if(a < b)
{
d += p;
d = (a + p) - b;
}
return d;
}
//---------------------------------------------------------------------------
uint32 montgom_ntt::modmul(uint32 a, uint32 b)
{
uint64 A(a), B(b), C;
uint32 R;
A *= B;
C = A & 0xFFFFFFFF;
C *= 0xBFFFFFFF;
C = (C & 0xFFFFFFFF) * px;
C += A;
R = (C >> 32);
if(C < A)
{
R -= p;
}
if(R >= p)
{
R -= p;
}
return R;
}
//---------------------------------------------------------------------------
uint32 montgom_ntt::modpow(uint32 a, uint32 b)
{
//*
uint64 D, M, A;
P = p; A = a;
M = 0llu - (b & 1);
D = (M & A) | ((~M) & 1);
while ((b >>= 1) != 0)
{
A = (A * A) % P;
if ((b & 1) == 1)
{
D = (D * A) % P;
}
}
return (uint32)D;
}
и вот главное
void main()
{
montgom_ntt F;
uint32 Tran[8];
uint32 Arr[8] =
{
0x2923, 0xbe84,
0xe16c, 0xd6ae,
0, 0, 0, 0
};
F.montgom_in_arr(Arr1, Arr1, Len);
F.NTT(Tran, Arr, Len);
F.montgom_out_arr(Tran, Tran, Len);
}
У меня такое ощущение, что это что-то действительно простое, но я не могу понять, что это такое. Спасибо за любую помощь, которую вы, ребята, можете предоставить!
[Править] Итак, чтобы исключить это, я изменил функцию модуля так, чтобы она преобразовывала свои данные из формы Монтгомери в обычную форму, выполнил стандартный (A * B) % p, а затем преобразовал его обратно в форму Монтгомери и Я все еще получил тот же, неправильный ответ. Это заставляет меня думать, что проблема заключается в преобразовании в форму Монтгомери и обратно, но я понятия не имею, что я сделал не так.
uint32 montgom_ntt::modmul(uint32 a, uint32 b)
{
uint64 A, B, C;
A = montgom_out(a);
B = montgom_out(b);
C = (A * B) % px;
return montgom_in(C);
/*
uint64 A(a), B(b), C;
uint32 R;
A *= B;
C = A & 0xFFFFFFFF;
C *= 0xBFFFFFFF;
C = (C & 0xFFFFFFFF) * px;
C += A;
R = (C >> 32);
if(C < A)
{
R -= p;
}
if(R >= p)
{
R -= p;
}
return R;
*/
}