Сколько 64-битных умножений необходимо для вычисления младших 128-бит 64-битного 128-битного продукта?

Учтите, что вы хотите вычислить младшие 128-битный результат умножения 64-битного и 128-битного числа без знака, и что самое большое умножение, которое у вас есть, - это C-подобное 64-битное умножение, которое занимает два 64-битных входы без знака и возвращают младшие 64 бита результата.

Сколько умножений нужно?

Конечно, вы можете сделать это с восемью: разбить все входы на 32-битные порции и использовать 64-битное умножение, чтобы выполнить 4 * 2 = 8 требуемых умножений на всю ширину 32*32->64, но можно ли сделать лучше?

Конечно, алгоритм должен делать только "разумное" количество сложений или другую базовую арифметику поверх умножений (меня не интересуют решения, которые заново изобретают умножение в виде цикла сложения и, следовательно, требуют "умножения на ноль").

1 ответ

Четыре, но это становится немного сложнее.

Пусть a и b будут числами, которые должны быть умножены, где 0 и a 1 являются младшим и старшим 32 битами a соответственно, а b 0, b 1, b 2, b 3 являются 32-битными группами b, из от низкого до высокого соответственно.

Желаемым результатом является остаток от (a 0 + a 1 • 2 32) • (b 0 + b 1 • 2 32 + b 2 • 2 64 + b 3 • 2 96) по модулю 2 128.

Мы можем переписать это как (a 0 + a 1 • 2 32) • (b 0 + b 1 • 2 32) + (a 0 + a 1 • 2 32) • (b 2 • 2 64 + b 3 • 2 96) по модулю 2 128.

Оставшаяся часть последнего слагаемого по модулю 2 128 может быть вычислена как одиночное 64-разрядное на 64-разрядное умножение (результат которого неявно умножается на 2 64).

Затем первый член можно вычислить с помощью трех умножений, используя тщательно выполненный шаг Карацубы. Простая версия будет включать в себя 33-битный на 33-битный или 66-битный продукт, который недоступен, но есть более хитрая версия, которая этого избегает:

z0 = a0 * b0
z2 = a1 * b1
z1 = abs(a0 - a1) * abs(b0 - b1) * sgn(a0 - a1) * sgn(b1 - b0) + z0 + z2

Последняя строка содержит только одно умножение; два других псевдомультипликации - просто условные отрицания. Абсолютная разница и условное отрицание раздражают в реализации на чистом С, но это можно сделать.

Конечно, без Карацубы 5 умножений.

Карацуба прекрасна, но в наши дни умножение 64 x 64 может быть выполнено за 3 такта, а новое можно запланировать каждые часы. Таким образом, накладные расходы на работу со знаками и тем, что не может быть, могут быть значительно больше, чем экономия одного умножения.

Для простого умножения 64 x 64 нужно:

    r0 =       a0*b0
    r1 =    a0*b1
    r2 =    a1*b0
    r3 = a1*b1

  where need to add r0 = r0 + (r1 << 32) + (r2 << 32)
            and add r3 = r3 + (r1 >> 32) + (r2 >> 32) + carry

  where the carry is the carry from the additions to r0, and result is r3:r0.

typedef struct { uint64_t w0, w1 ; } uint64x2_t ;

uint64x2_t
mulu64x2(uint64_t x, uint64_t m)
{
  uint64x2_t r ;
  uint64_t r1, r2, rx, ry ;
  uint32_t x1, x0 ;
  uint32_t m1, m0 ;

  x1    = (uint32_t)(x >> 32) ;
  x0    = (uint32_t)x ;
  m1    = (uint32_t)(m >> 32) ;
  m0    = (uint32_t)m ;

  r1    = (uint64_t)x1 * m0 ;
  r2    = (uint64_t)x0 * m1 ;
  r.w0  = (uint64_t)x0 * m0 ;
  r.w1  = (uint64_t)x1 * m1 ;

  rx    = (uint32_t)r1 ;
  rx    = rx + (uint32_t)r2 ;    // add the ls halves, collecting carry
  ry    = r.w0 >> 32 ;           // pick up ms of r0
  r.w0 += (rx << 32) ;           // complete r0
  rx   += ry ;                   // complete addition, rx >> 32 == carry !

  r.w1 += (r1 >> 32) + (r2 >> 32) + (rx >> 32) ;

  return r ;
}

Для Карацубы предложено:

z1 = abs(a0 - a1) * abs(b0 - b1) * sgn(a0 - a1) * sgn(b1 - b0) + z0 + z2

сложнее, чем кажется... для начала, если z1 составляет 64 бита, тогда необходимо каким-то образом собрать перенос, который может сгенерировать это добавление... и это осложняется проблемами со знаком.

    z0 =       a0*b0
    z1 =    ax*bx        -- ax = (a1 - a0), bx = (b0 - b1)
    z2 = a1*b1

  where need to add r0 = z0 + (z1 << 32) + (z0 << 32) + (z2 << 32)
            and add r1 = z2 + (z1 >> 32) + (z0 >> 32) + (z2 >> 32) + carry

  where the carry is the carry from the additions to create r0, and result is r1:r0.

  where must take into account the signed-ness of ax, bx and z1. 

uint64x2_t
mulu64x2_karatsuba(uint64_t a, uint64_t b)
{
  uint64_t a0, a1, b0, b1 ;
  uint64_t ax, bx, zx, zy ;
  uint     as, bs, xs ;
  uint64_t z0, z2 ;
  uint64x2_t r ;

  a0 = (uint32_t)a ; a1 = a >> 32 ;
  b0 = (uint32_t)b ; b1 = b >> 32 ;

  z0 = a0 * b0 ;
  z2 = a1 * b1 ;

  ax = (uint64_t)(a1 - a0) ;
  bx = (uint64_t)(b0 - b1) ;

  as = (uint)(ax > a1) ;                // sign of magic middle, a
  bs = (uint)(bx > b0) ;                // sign of magic middle, b
  xs = (uint)(as ^ bs) ;                // sign of magic middle, x = a * b

  ax = (uint64_t)((ax ^ -(uint64_t)as) + as) ;  // abs magic middle a
  bx = (uint64_t)((bx ^ -(uint64_t)bs) + bs) ;  // abs magic middle b

  zx = (uint64_t)(((ax * bx) ^ -(uint64_t)xs) + xs) ;
  xs = xs & (uint)(zx != 0) ;           // discard sign if z1 == 0 !

  zy = (uint32_t)zx ;                   // start ls half of z1
  zy = zy + (uint32_t)z0 + (uint32_t)z2 ;

  r.w0 = z0 + (zy << 32) ;              // complete ls word of result.
  zy   = zy + (z0 >> 32) ;              // complete carry

  zx   = (zx >> 32) - ((uint64_t)xs << 32) ;   // start ms half of z1
  r.w1 = z2 + zx + (z0 >> 32) + (z2 >> 32) + (zy >> 32) ;

  return r ;
}

Я сделал несколько очень простых таймингов (используя times(), работающий на Ryzen 7 1800X):

  • с использованием gcc __int128................... ~780 'единиц'
  • с использованием mulu64x2()..................... ~895
  • используя mulu64x2_karatsuba()... ~1,095

... так что да, вы можете сэкономить умножение с помощью Карацубы, но стоит ли это делать, скорее, зависит.

Другие вопросы по тегам