Сколько 64-битных умножений необходимо для вычисления младших 128-бит 64-битного 128-битного продукта?
Учтите, что вы хотите вычислить младшие 128-битный результат умножения 64-битного и 128-битного числа без знака, и что самое большое умножение, которое у вас есть, - это C-подобное 64-битное умножение, которое занимает два 64-битных входы без знака и возвращают младшие 64 бита результата.
Сколько умножений нужно?
Конечно, вы можете сделать это с восемью: разбить все входы на 32-битные порции и использовать 64-битное умножение, чтобы выполнить 4 * 2 = 8 требуемых умножений на всю ширину 32*32->64, но можно ли сделать лучше?
Конечно, алгоритм должен делать только "разумное" количество сложений или другую базовую арифметику поверх умножений (меня не интересуют решения, которые заново изобретают умножение в виде цикла сложения и, следовательно, требуют "умножения на ноль").
1 ответ
Четыре, но это становится немного сложнее.
Пусть a и b будут числами, которые должны быть умножены, где 0 и a 1 являются младшим и старшим 32 битами a соответственно, а b 0, b 1, b 2, b 3 являются 32-битными группами b, из от низкого до высокого соответственно.
Желаемым результатом является остаток от (a 0 + a 1 • 2 32) • (b 0 + b 1 • 2 32 + b 2 • 2 64 + b 3 • 2 96) по модулю 2 128.
Мы можем переписать это как (a 0 + a 1 • 2 32) • (b 0 + b 1 • 2 32) + (a 0 + a 1 • 2 32) • (b 2 • 2 64 + b 3 • 2 96) по модулю 2 128.
Оставшаяся часть последнего слагаемого по модулю 2 128 может быть вычислена как одиночное 64-разрядное на 64-разрядное умножение (результат которого неявно умножается на 2 64).
Затем первый член можно вычислить с помощью трех умножений, используя тщательно выполненный шаг Карацубы. Простая версия будет включать в себя 33-битный на 33-битный или 66-битный продукт, который недоступен, но есть более хитрая версия, которая этого избегает:
z0 = a0 * b0
z2 = a1 * b1
z1 = abs(a0 - a1) * abs(b0 - b1) * sgn(a0 - a1) * sgn(b1 - b0) + z0 + z2
Последняя строка содержит только одно умножение; два других псевдомультипликации - просто условные отрицания. Абсолютная разница и условное отрицание раздражают в реализации на чистом С, но это можно сделать.
Конечно, без Карацубы 5 умножений.
Карацуба прекрасна, но в наши дни умножение 64 x 64 может быть выполнено за 3 такта, а новое можно запланировать каждые часы. Таким образом, накладные расходы на работу со знаками и тем, что не может быть, могут быть значительно больше, чем экономия одного умножения.
Для простого умножения 64 x 64 нужно:
r0 = a0*b0
r1 = a0*b1
r2 = a1*b0
r3 = a1*b1
where need to add r0 = r0 + (r1 << 32) + (r2 << 32)
and add r3 = r3 + (r1 >> 32) + (r2 >> 32) + carry
where the carry is the carry from the additions to r0, and result is r3:r0.
typedef struct { uint64_t w0, w1 ; } uint64x2_t ;
uint64x2_t
mulu64x2(uint64_t x, uint64_t m)
{
uint64x2_t r ;
uint64_t r1, r2, rx, ry ;
uint32_t x1, x0 ;
uint32_t m1, m0 ;
x1 = (uint32_t)(x >> 32) ;
x0 = (uint32_t)x ;
m1 = (uint32_t)(m >> 32) ;
m0 = (uint32_t)m ;
r1 = (uint64_t)x1 * m0 ;
r2 = (uint64_t)x0 * m1 ;
r.w0 = (uint64_t)x0 * m0 ;
r.w1 = (uint64_t)x1 * m1 ;
rx = (uint32_t)r1 ;
rx = rx + (uint32_t)r2 ; // add the ls halves, collecting carry
ry = r.w0 >> 32 ; // pick up ms of r0
r.w0 += (rx << 32) ; // complete r0
rx += ry ; // complete addition, rx >> 32 == carry !
r.w1 += (r1 >> 32) + (r2 >> 32) + (rx >> 32) ;
return r ;
}
Для Карацубы предложено:
z1 = abs(a0 - a1) * abs(b0 - b1) * sgn(a0 - a1) * sgn(b1 - b0) + z0 + z2
сложнее, чем кажется... для начала, если z1
составляет 64 бита, тогда необходимо каким-то образом собрать перенос, который может сгенерировать это добавление... и это осложняется проблемами со знаком.
z0 = a0*b0
z1 = ax*bx -- ax = (a1 - a0), bx = (b0 - b1)
z2 = a1*b1
where need to add r0 = z0 + (z1 << 32) + (z0 << 32) + (z2 << 32)
and add r1 = z2 + (z1 >> 32) + (z0 >> 32) + (z2 >> 32) + carry
where the carry is the carry from the additions to create r0, and result is r1:r0.
where must take into account the signed-ness of ax, bx and z1.
uint64x2_t
mulu64x2_karatsuba(uint64_t a, uint64_t b)
{
uint64_t a0, a1, b0, b1 ;
uint64_t ax, bx, zx, zy ;
uint as, bs, xs ;
uint64_t z0, z2 ;
uint64x2_t r ;
a0 = (uint32_t)a ; a1 = a >> 32 ;
b0 = (uint32_t)b ; b1 = b >> 32 ;
z0 = a0 * b0 ;
z2 = a1 * b1 ;
ax = (uint64_t)(a1 - a0) ;
bx = (uint64_t)(b0 - b1) ;
as = (uint)(ax > a1) ; // sign of magic middle, a
bs = (uint)(bx > b0) ; // sign of magic middle, b
xs = (uint)(as ^ bs) ; // sign of magic middle, x = a * b
ax = (uint64_t)((ax ^ -(uint64_t)as) + as) ; // abs magic middle a
bx = (uint64_t)((bx ^ -(uint64_t)bs) + bs) ; // abs magic middle b
zx = (uint64_t)(((ax * bx) ^ -(uint64_t)xs) + xs) ;
xs = xs & (uint)(zx != 0) ; // discard sign if z1 == 0 !
zy = (uint32_t)zx ; // start ls half of z1
zy = zy + (uint32_t)z0 + (uint32_t)z2 ;
r.w0 = z0 + (zy << 32) ; // complete ls word of result.
zy = zy + (z0 >> 32) ; // complete carry
zx = (zx >> 32) - ((uint64_t)xs << 32) ; // start ms half of z1
r.w1 = z2 + zx + (z0 >> 32) + (z2 >> 32) + (zy >> 32) ;
return r ;
}
Я сделал несколько очень простых таймингов (используя times()
, работающий на Ryzen 7 1800X):
- с использованием gcc __int128................... ~780 'единиц'
- с использованием mulu64x2()..................... ~895
- используя mulu64x2_karatsuba()... ~1,095
... так что да, вы можете сэкономить умножение с помощью Карацубы, но стоит ли это делать, скорее, зависит.