Самая быстрая реализация экспоненциальной функции с использованием SSE
Я ищу приближение функции экспоненты, работающей на элементе SSE. А именно - __m128 exp( __m128 x )
,
У меня есть реализация, которая является быстрой, но, похоже, очень низкой по точности:
static inline __m128 FastExpSse(__m128 x)
{
__m128 a = _mm_set1_ps(12102203.2f); // (1 << 23) / ln(2)
__m128i b = _mm_set1_epi32(127 * (1 << 23) - 486411);
__m128 m87 = _mm_set1_ps(-87);
// fast exponential function, x should be in [-87, 87]
__m128 mask = _mm_cmpge_ps(x, m87);
__m128i tmp = _mm_add_epi32(_mm_cvtps_epi32(_mm_mul_ps(a, x)), b);
return _mm_and_ps(_mm_castsi128_ps(tmp), mask);
}
Может ли кто-нибудь иметь реализацию с большей точностью, но так же быстро (или быстрее)?
Я был бы счастлив, если бы я написал в стиле Си.
Благодарю вас.
6 ответов
Код C, приведенный ниже, является переводом в SSE-алгоритма алгоритма, который я использовал в предыдущем ответе на аналогичный вопрос.
Основная идея состоит в том, чтобы преобразовать вычисление стандартной экспоненциальной функции в вычисление степени 2: expf (x) = exp2f (x / logf (2.0f)) = exp2f (x * 1.44269504)
, Мы разделились t = x * 1.44269504
в целое число i
и фракция f
такой, что t = i + f
а также 0 <= f <= 1
, Теперь мы можем вычислить 2f с полиномиальным приближением, а затем масштабировать результат на 2i, добавив i
в поле экспоненты результата с плавающей запятой одинарной точности.
Одна проблема, которая существует с реализацией SSE, состоит в том, что мы хотим вычислить i = floorf (t)
, но нет быстрого способа вычислить floor()
функция. Тем не менее, мы видим, что для положительных чисел, floor(x) == trunc(x)
и что для отрицательных чисел, floor(x) == trunc(x) - 1
кроме случаев, когда x
является отрицательным целым числом. Тем не менее, поскольку основное приближение может обрабатывать f
ценность 1.0f
Использование аппроксимации для отрицательных аргументов безвредно. SSE предоставляет инструкцию для преобразования операндов с плавающей запятой одинарной точности в целые числа с усечением, поэтому это решение является эффективным.
Peter Cordes отмечает, что SSE4.1 поддерживает функцию быстрого пола _mm_floor_ps()
, поэтому вариант с использованием SSE4.1 также показан ниже. Не все наборы инструментов автоматически предопределяют макрос __SSE4_1__
когда включена генерация кода SSE 4.1, а gcc -
Compiler Explorer (Godbolt) показывает, что gcc 7.2 компилирует приведенный ниже код в шестнадцать инструкций для простого SSE и двенадцать инструкций для SSE 4.1.
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <math.h>
#include <emmintrin.h>
#ifdef __SSE4_1__
#include <smmintrin.h>
#endif
/* max. rel. error = 1.72863156e-3 on [-87.33654, 88.72283] */
__m128 fast_exp_sse (__m128 x)
{
__m128 t, f, e, p, r;
__m128i i, j;
__m128 l2e = _mm_set1_ps (1.442695041f); /* log2(e) */
__m128 c0 = _mm_set1_ps (0.3371894346f);
__m128 c1 = _mm_set1_ps (0.657636276f);
__m128 c2 = _mm_set1_ps (1.00172476f);
/* exp(x) = 2^i * 2^f; i = floor (log2(e) * x), 0 <= f <= 1 */
t = _mm_mul_ps (x, l2e); /* t = log2(e) * x */
#ifdef __SSE4_1__
e = _mm_floor_ps (t); /* floor(t) */
i = _mm_cvtps_epi32 (e); /* (int)floor(t) */
#else /* __SSE4_1__*/
i = _mm_cvttps_epi32 (t); /* i = (int)t */
j = _mm_srli_epi32 (_mm_castps_si128 (x), 31); /* signbit(t) */
i = _mm_sub_epi32 (i, j); /* (int)t - signbit(t) */
e = _mm_cvtepi32_ps (i); /* floor(t) ~= (int)t - signbit(t) */
#endif /* __SSE4_1__*/
f = _mm_sub_ps (t, e); /* f = t - floor(t) */
p = c0; /* c0 */
p = _mm_mul_ps (p, f); /* c0 * f */
p = _mm_add_ps (p, c1); /* c0 * f + c1 */
p = _mm_mul_ps (p, f); /* (c0 * f + c1) * f */
p = _mm_add_ps (p, c2); /* p = (c0 * f + c1) * f + c2 ~= 2^f */
j = _mm_slli_epi32 (i, 23); /* i << 23 */
r = _mm_castsi128_ps (_mm_add_epi32 (j, _mm_castps_si128 (p))); /* r = p * 2^i*/
return r;
}
int main (void)
{
union {
float f[4];
unsigned int i[4];
} arg, res;
double relerr, maxrelerr = 0.0;
int i, j;
__m128 x, y;
float start[2] = {-0.0f, 0.0f};
float finish[2] = {-87.33654f, 88.72283f};
for (i = 0; i < 2; i++) {
arg.f[0] = start[i];
arg.i[1] = arg.i[0] + 1;
arg.i[2] = arg.i[0] + 2;
arg.i[3] = arg.i[0] + 3;
do {
memcpy (&x, &arg, sizeof(x));
y = fast_exp_sse (x);
memcpy (&res, &y, sizeof(y));
for (j = 0; j < 4; j++) {
double ref = exp ((double)arg.f[j]);
relerr = fabs ((res.f[j] - ref) / ref);
if (relerr > maxrelerr) {
printf ("arg=% 15.8e res=%15.8e ref=%15.8e err=%15.8e\n",
arg.f[j], res.f[j], ref, relerr);
maxrelerr = relerr;
}
}
arg.i[0] += 4;
arg.i[1] += 4;
arg.i[2] += 4;
arg.i[3] += 4;
} while (fabsf (arg.f[3]) < fabsf (finish[i]));
}
printf ("maximum relative errror = %15.8e\n", maxrelerr);
return EXIT_SUCCESS;
}
Альтернативный дизайн для fast_sse_exp()
извлекает целочисленную часть скорректированного аргумента x / log(2)
в режиме округления до ближайшего, используя хорошо известную технику сложения "магической" константы преобразования 1,5 * 223 для принудительного округления в правильной битовой позиции, а затем вычитания того же числа снова. Это требует, чтобы режим округления SSE, действующий во время добавления, был "округлен до ближайшего или даже", что является значением по умолчанию. wim указал в комментариях, что некоторые компиляторы могут оптимизировать сложение и вычитание константы преобразования cvt
как избыточный, когда используется агрессивная оптимизация, мешающая функционированию этой кодовой последовательности, поэтому рекомендуется проверять сгенерированный машинный код. Интервал аппроксимации для вычисления 2f теперь сосредоточен вокруг нуля, так как -0.5 <= f <= 0.5
, требующий другого основного приближения.
/* max. rel. error <= 1.72860465e-3 on [-87.33654, 88.72283] */
__m128 fast_exp_sse (__m128 x)
{
__m128 t, f, p, r;
__m128i i, j;
const __m128 l2e = _mm_set1_ps (1.442695041f); /* log2(e) */
const __m128 cvt = _mm_set1_ps (12582912.0f); /* 1.5 * (1 << 23) */
const __m128 c0 = _mm_set1_ps (0.238428936f);
const __m128 c1 = _mm_set1_ps (0.703448006f);
const __m128 c2 = _mm_set1_ps (1.000443142f);
/* exp(x) = 2^i * 2^f; i = rint (log2(e) * x), -0.5 <= f <= 0.5 */
t = _mm_mul_ps (x, l2e); /* t = log2(e) * x */
r = _mm_sub_ps (_mm_add_ps (t, cvt), cvt); /* r = rint (t) */
f = _mm_sub_ps (t, r); /* f = t - rint (t) */
i = _mm_cvtps_epi32 (t); /* i = (int)t */
p = c0; /* c0 */
p = _mm_mul_ps (p, f); /* c0 * f */
p = _mm_add_ps (p, c1); /* c0 * f + c1 */
p = _mm_mul_ps (p, f); /* (c0 * f + c1) * f */
p = _mm_add_ps (p, c2); /* p = (c0 * f + c1) * f + c2 ~= exp2(f) */
j = _mm_slli_epi32 (i, 23); /* i << 23 */
r = _mm_castsi128_ps (_mm_add_epi32 (j, _mm_castps_si128 (p))); /* r = p * 2^i*/
return r;
}
Алгоритм кода в вопросе, по-видимому, взят из работы Никола Шраудольфа, в которой ловко используется полулогарифмическая природа двоичных форматов IEEE-754 с плавающей запятой:
Н.Н. Шраудольф. "Быстрое, компактное приближение экспоненциальной функции". Нейронные вычисления, 11(4), май 1999, с.853-862.
После удаления кода ограничения аргументов он сокращается до трех инструкций SSE. "Магическая" поправочная константа 486411
не является оптимальным для минимизации максимальной относительной ошибки во всей области ввода. Основываясь на простом бинарном поиске, значение 298765
кажется лучше, уменьшая максимальную относительную погрешность для FastExpSse()
до 3,56e-2 против максимальной относительной погрешности 1,73e-3 для fast_exp_sse()
,
/* max. rel. error = 3.55959567e-2 on [-87.33654, 88.72283] */
__m128 FastExpSse (__m128 x)
{
__m128 a = _mm_set1_ps (12102203.0f); /* (1 << 23) / log(2) */
__m128i b = _mm_set1_epi32 (127 * (1 << 23) - 298765);
__m128i t = _mm_add_epi32 (_mm_cvtps_epi32 (_mm_mul_ps (a, x)), b);
return _mm_castsi128_ps (t);
}
Алгоритм Шраудольфа в основном использует линейное приближение 2f ~ = 1.0 + f
за f
в [0,1], и его точность может быть улучшена путем добавления квадратичного члена. Умная часть подхода Шраудольфа заключается в вычислении 2i * 2f без явного разделения целочисленной части i = floor(x * 1.44269504)
от фракции. Я не вижу возможности расширить этот трюк до квадратичного приближения, но можно, конечно, объединить floor()
Вычисление Шраудольфа с использованием квадратичного приближения, использованного выше:
/* max. rel. error <= 1.72886892e-3 on [-87.33654, 88.72283] */
__m128 fast_exp_sse (__m128 x)
{
__m128 f, p, r;
__m128i t, j;
const __m128 a = _mm_set1_ps (12102203.0f); /* (1 << 23) / log(2) */
const __m128i m = _mm_set1_epi32 (0xff800000); /* mask for integer bits */
const __m128 ttm23 = _mm_set1_ps (1.1920929e-7f); /* exp2(-23) */
const __m128 c0 = _mm_set1_ps (0.3371894346f);
const __m128 c1 = _mm_set1_ps (0.657636276f);
const __m128 c2 = _mm_set1_ps (1.00172476f);
t = _mm_cvtps_epi32 (_mm_mul_ps (a, x));
j = _mm_and_si128 (t, m); /* j = (int)(floor (x/log(2))) << 23 */
t = _mm_sub_epi32 (t, j);
f = _mm_mul_ps (ttm23, _mm_cvtepi32_ps (t)); /* f = (x/log(2)) - floor (x/log(2)) */
p = c0; /* c0 */
p = _mm_mul_ps (p, f); /* c0 * f */
p = _mm_add_ps (p, c1); /* c0 * f + c1 */
p = _mm_mul_ps (p, f); /* (c0 * f + c1) * f */
p = _mm_add_ps (p, c2); /* p = (c0 * f + c1) * f + c2 ~= 2^f */
r = _mm_castsi128_ps (_mm_add_epi32 (j, _mm_castps_si128 (p))); /* r = p * 2^i*/
return r;
}
Хорошее увеличение точности в моем алгоритме (реализация FastExpSse в ответе выше) может быть достигнуто за счет целочисленного вычитания и деления с плавающей запятой, используя FastExpSse(x/2)/FastExpSse(-x/2) вместо FastExpSse (Икс). Хитрость здесь в том, чтобы установить параметр сдвига (298765 выше) на ноль, чтобы кусочно-линейные аппроксимации в числителе и знаменателе выстраивались в линию, чтобы дать вам существенное устранение ошибок. Сверните это в одну функцию:
__m128 BetterFastExpSse (__m128 x)
{
const __m128 a = _mm_set1_ps ((1 << 22) / float(M_LN2)); // to get exp(x/2)
const __m128i b = _mm_set1_epi32 (127 * (1 << 23)); // NB: zero shift!
__m128i r = _mm_cvtps_epi32 (_mm_mul_ps (a, x));
__m128i s = _mm_add_epi32 (b, r);
__m128i t = _mm_sub_epi32 (b, r);
return _mm_div_ps (_mm_castsi128_ps (s), _mm_castsi128_ps (t));
}
(Я не специалист по аппаратному обеспечению - насколько плохо здесь убийца производительности?)
Если вам нужно exp (x) только для того, чтобы получить y = tanh (x) (например, для нейронных сетей), используйте FastExpSse с нулевым сдвигом следующим образом:
a = FastExpSse(x);
b = FastExpSse(-x);
y = (a - b)/(a + b);
чтобы получить такой же тип отмены ошибки. Логистическая функция работает аналогично, используя FastExpSse(x/2)/(FastExpSse(x/2) + FastExpSse(-x/2)) с нулевым сдвигом. (Это просто для демонстрации принципа - вы, очевидно, не хотите оценивать FastExpSse несколько раз здесь, но сверните его в одну функцию по аналогии с BetterFastExpSse выше.)
Я разработал серию приближений высшего порядка из этого, еще более точного, но и более медленного. Неопубликованные, но рады сотрудничеству, если кто-то хочет дать им вращение.
И, наконец, для развлечения: используйте задний ход, чтобы получить FastLogSse. Объединение в цепочку с FastExpSse дает вам как оператор, так и отмену ошибок, и выдает невероятно быструю функцию питания...
Возвращаясь к своим записям с тех пор, я исследовал способы повышения точности без использования деления. Я использовал тот же трюк с реинтерпретацией, как с плавающей точкой, но применил к мантиссе полиномиальную коррекцию, которая была по существу рассчитана в 16-битной арифметике с фиксированной точкой (единственный способ сделать это быстро тогда).
Куб. Соотв. Quartic версии дают вам 4 соотв. 5 значащих цифр точности. Не было никакого смысла увеличивать порядок после этого, так как шум арифметики низкой точности затем начинает заглушать ошибку полиномиального приближения. Вот простые версии C:
#include <stdint.h>
float fastExp3(register float x) // cubic spline approximation
{
union { float f; int32_t i; } reinterpreter;
reinterpreter.i = (int32_t)(12102203.0f*x) + 127*(1 << 23);
int32_t m = (reinterpreter.i >> 7) & 0xFFFF; // copy mantissa
// empirical values for small maximum relative error (8.34e-5):
reinterpreter.i +=
((((((((1277*m) >> 14) + 14825)*m) >> 14) - 79749)*m) >> 11) - 626;
return reinterpreter.f;
}
float fastExp4(register float x) // quartic spline approximation
{
union { float f; int32_t i; } reinterpreter;
reinterpreter.i = (int32_t)(12102203.0f*x) + 127*(1 << 23);
int32_t m = (reinterpreter.i >> 7) & 0xFFFF; // copy mantissa
// empirical values for small maximum relative error (1.21e-5):
reinterpreter.i += (((((((((((3537*m) >> 16)
+ 13668)*m) >> 18) + 15817)*m) >> 14) - 80470)*m) >> 11);
return reinterpreter.f;
}
Квартиру подчиняется (fastExp4(0f) == 1f), что может быть важно для алгоритмов итерации с фиксированной точкой.
Насколько эффективны эти целочисленные последовательности умножения-сдвига в SSE? На архитектурах, где арифметика с плавающей точкой такая же быстрая, можно использовать ее, уменьшая арифметический шум. По сути, это приведет к кубическим и квартичным расширениям ответа @njuffa выше.
Есть статья о создании быстрых версий этих уравнений (tanh, cosh, artanh, sinh и т. Д.):
http://ijeais.org/wp-content/uploads/2018/07/IJAER180702.pdf"Создание оптимизированной для компилятора встроенной реализации Intel Svml Simd Intrinsics"
их уравнение 6 на странице 9 очень похоже на ответ @NicSchraudolph
Для использования softmax я представляю поток как:
auto a = _mm_mul_ps(x, _mm_set1_ps(12102203.2f));
auto b = _mm_castsi128_ps(_mm_cvtps_epi32(a)); // so far as in other variants
// copy 9 MSB from 0x3f800000 over 'b' so that 1 <= c < 2
// - also 1 <= poly_eval(...) < 2
auto c = replace_exponent(b, _mm_set1_ps(1.0f));
auto d = poly_eval(c, kA, kB, kC); // 2nd degree polynomial
auto e = replace_exponent(d, b); // restore exponent : 2^i * 2^f
Копирование экспоненты может быть выполнено поразрядным выбором с использованием соответствующей маски (AVX-512 имеет
vpternlogd
, и я использую на самом деле Arm Neon
vbsl
).
Все входные значения
x
должен быть отрицательным и ограничиваться между -17-f(N) <= x <= -f(N), чтобы при масштабировании на (1<<23)/log(2) максимальная сумма N результирующих чисел с плавающей запятой значения не достигают бесконечности и обратная величина не становится денормальной. Для N = 3, f(N) = 4. Чем больше f(N), тем меньше точность ввода.
Коэффициенты многозначности генерируются, например, с помощью
polyfit([1 1.5 2],[1 sqrt(2) 2])
, с kA = 0,343146, kB = -0,029437, kC = 0,68292, что дает строго значения меньше 2 и предотвращает разрывы. Максимальную среднюю ошибку можно уменьшить, оценив полином при x = [1 + max_err 1,5-eps 2], y = [1 2^(.5-eps) 2-max_err].
Для строго SSE/AVX замену экспоненты для 1.0f можно выполнить с помощью
(x & 0x007fffff) | 0x3f800000)
. Последовательность двух инструкций для замены последней экспоненты можно найти, убедившись, что poly_eval(x) оценивается как диапазон, который может быть напрямую обработан с помощью
b & 0xff800000
.
Я разработал для своих целей следующую функцию, которая быстро и точно вычисляет натуральный показатель с одинарной точностью. Функция работает во всем диапазоне значений float. Код написан под Visual Studio (x86). Вместо SSE используется AVX, но это не должно быть проблемой. Точность этой функции почти такая же, как у стандартной функции expf, но значительно быстрее. Используемая аппроксимация основана на разложении в ряд Чебышева функции f(t)=t/(2^(t/2)-1)+t/2 при t из [-1; 1]. Я благодарю Питера Кордеса за его хороший совет.
_declspec(naked) float _vectorcall fexp(float x)
{
static const float ct[7] = // Constants table
{
1.44269502f, // lb(e)
1.92596299E-8f, // Correction to the value lb(e)
-9.21120925E-4f, // 16*b2
0.115524396f, // 4*b1
2.88539004f, // b0
2.0f, // 2
4.65661287E-10f // 2^-31
};
_asm
{
mov ecx,offset ct // ecx contains the address of constants tables
vmulss xmm1,xmm0,[ecx] // xmm1 = x*lb(e)
vcvtss2si eax,xmm1 // eax = round(x*lb(e)) = k
cdq // edx=-1, if x<0 or overflow, otherwise edx=0
vmovss xmm3,[ecx+8] // Initialize the sum with highest coefficient 16*b2
and edx,4 // edx=4, if x<0 or overflow, otherwise edx=0
vcvtsi2ss xmm1,xmm1,eax // xmm1 = k
lea eax,[eax+8*edx] // Add 32 to exponent, if x<0
vfmsub231ss xmm1,xmm0,[ecx] // xmm1 = x*lb(e)-k = t/2 in the range from -0,5 to 0,5
add eax,126 // The exponent of 2^(k-1) or 2^(k+31) with bias 127
jle exp_low // Jump if x<<0 or overflow (|x| too large or x=NaN)
vfmadd132ss xmm0,xmm1,[ecx+4] // xmm0 = t/2 (corrected value)
cmp eax,254 // Check that the exponent is not too large
jg exp_inf // Jump to set Inf if overflow
vmulss xmm2,xmm0,xmm0 // xmm2 = t^2/4 - the argument of the polynomial
shl eax,23 // The bits of the float value 2^(k-1) or 2^(k+31)
vfmadd213ss xmm3,xmm2,[ecx+12] // xmm3 = 4*b1+4*b2*t^2
vmovd xmm1,eax // xmm1 = 2^(k-1) или 2^(k+31)
vfmsub213ss xmm3,xmm2,xmm0 // xmm3 = -t/2+b1*t^2+b2*t^4
vaddss xmm0,xmm0,xmm0 // xmm0 = t
vaddss xmm3,xmm3,[ecx+16] // xmm3 = b0-t/2+b1*t^2+b2*t^4 = f(t)-t/2
vdivss xmm0,xmm0,xmm3 // xmm0 = t/(f(t)-t/2)
vfmadd213ss xmm0,xmm1,xmm1 // xmm0 = e^x with shifted exponent of -1 or 31
vmulss xmm0,xmm0,[ecx+edx+20] // xmm0 = e^x
ret // Return
exp_low: // Handling the case of x<<0 or overflow
vucomiss xmm0,[ecx] // Check the sign of x and a condition x=NaN
jp exp_end // Complete with NaN result, if x=NaN
exp_inf: // Entry point for processing large x
vxorps xmm0,xmm0,xmm0 // xmm0 = 0
jc exp_end // Ready, if x<<0
vrcpss xmm0,xmm0,xmm0 // xmm0 = Inf in case x>>0
exp_end: // The result at xmm0 is ready
ret // Return
}
}
Ниже выкладываю упрощенный алгоритм. Здесь удалена поддержка денормализованных чисел в результате.
_declspec(naked) float _vectorcall fexp(float x)
{
static const float ct[5] = // Constants table
{
1.44269502f, // lb(e)
1.92596299E-8f, // Correction to the value lb(e)
-9.21120925E-4f, // 16*b2
0.115524396f, // 4*b1
2.88539004f // b0
};
_asm
{
mov edx,offset ct // edx contains the address of constants tables
vmulss xmm1,xmm0,[edx] // xmm1 = x*lb(e)
vcvtss2si eax,xmm1 // eax = round(x*lb(e)) = k
vmovss xmm3,[edx+8] // Initialize the sum with highest coefficient 16*b2
vcvtsi2ss xmm1,xmm1,eax // xmm1 = k
cmp eax,127 // Check that the exponent is not too large
jg exp_break // Jump to set Inf if overflow
vfmsub231ss xmm1,xmm0,[edx] // xmm1 = x*lb(e)-k = t/2 in the range from -0,5 to 0,5
add eax,127 // Receive the exponent of 2^k with the bias 127
jle exp_break // The result is 0, if x<<0
vfmadd132ss xmm0,xmm1,[edx+4] // xmm0 = t/2 (corrected value)
vmulss xmm2,xmm0,xmm0 // xmm2 = t^2/4 - the argument of polynomial
shl eax,23 // eax contains the bits of 2^k
vfmadd213ss xmm3,xmm2,[edx+12] // xmm3 = 4*b1+4*b2*t^2
vmovd xmm1,eax // xmm1 = 2^k
vfmsub213ss xmm3,xmm2,xmm0 // xmm3 = -t/2+b1*t^2+b2*t^4
vaddss xmm0,xmm0,xmm0 // xmm0 = t
vaddss xmm3,xmm3,[edx+16] // xmm3 = b0-t/2+b1*t^2+b2*t^4 = f(t)-t/2
vdivss xmm0,xmm0,xmm3 // xmm0 = t/(f(t)-t/2)
vfmadd213ss xmm0,xmm1,xmm1 // xmm0 = 2^k*(t/(f(t)-t/2)+1) = e^x
ret // Return
exp_break: // Get 0 for x<0 or Inf for x>>0
vucomiss xmm0,[edx] // Check the sign of x and a condition x=NaN
jp exp_end // Complete with NaN result, if x=NaN
vxorps xmm0,xmm0,xmm0 // xmm0 = 0
jc exp_end // Ready, if x<<0
vrcpss xmm0,xmm0,xmm0 // xmm0 = Inf, if x>>0
exp_end: // The result at xmm0 is ready
ret // Return
}
}