Подсчет 1 бита (подсчет населения) для больших данных с использованием AVX-512 или AVX-2
У меня большой кусок памяти, скажем, 256 КиБ или больше. Я хочу подсчитать количество 1 битов во всем этом фрагменте, или другими словами: сложить значения "количества населения" для всех байтов.
Я знаю, что AVX-512 имеет инструкцию VPOPCNTDQ, которая подсчитывает количество 1 бит в каждых последовательных 64 битах в 512-битном векторе, и IIANM должна быть возможность выдавать один из них каждый цикл (если соответствующий векторный регистр SIMD доступно) - но у меня нет никакого опыта написания SIMD-кода (я скорее парень из GPU). Кроме того, я не уверен на 100% в поддержке компилятора для целей AVX-512.
Тем не менее, на большинстве процессоров AVX-512 не поддерживается (полностью); но AVX-2 широко доступен. Я не смог найти менее 512-битную векторизованную инструкцию, похожую на VPOPCNTDQ, поэтому даже теоретически я не уверен, как быстро считать биты с процессорами с поддержкой AVX-2; может что-то подобное существует и я просто как-то пропустил?
В любом случае, я был бы признателен за короткую функцию C/C++ - либо с использованием некоторой библиотеки intristics-wrapper, либо со встроенной сборкой - для каждого из двух наборов команд. Подпись
uint64_t count_bits(void* ptr, size_t size);
Заметки:
- Связано с Как быстро посчитать биты в отдельные бункеры в серии целых чисел на Sandy Bridge? но не дурак.
- Мы можем предположить, что вход хорошо выровнен, если это имеет значение.
- Забудьте о нескольких ядрах или сокетах, я хочу код для одного (потока на одном) ядра.
2 ответа
AVX-2
@HadiBreis 'комментирует ссылки на статью о быстром подсчете населения с помощью SSSE3, написанную Войцехом Мулой; статья ссылается на этот репозиторий GitHub; и репозиторий имеет следующую реализацию AVX-2. Он основан на векторизованной инструкции поиска и использует таблицу поиска из 16 значений для количества битов в полубайтах.
# include <immintrin.h>
# include <x86intrin.h>
std::uint64_t popcnt_AVX2_lookup(const uint8_t* data, const size_t n) {
size_t i = 0;
const __m256i lookup = _mm256_setr_epi8(
/* 0 */ 0, /* 1 */ 1, /* 2 */ 1, /* 3 */ 2,
/* 4 */ 1, /* 5 */ 2, /* 6 */ 2, /* 7 */ 3,
/* 8 */ 1, /* 9 */ 2, /* a */ 2, /* b */ 3,
/* c */ 2, /* d */ 3, /* e */ 3, /* f */ 4,
/* 0 */ 0, /* 1 */ 1, /* 2 */ 1, /* 3 */ 2,
/* 4 */ 1, /* 5 */ 2, /* 6 */ 2, /* 7 */ 3,
/* 8 */ 1, /* 9 */ 2, /* a */ 2, /* b */ 3,
/* c */ 2, /* d */ 3, /* e */ 3, /* f */ 4
);
const __m256i low_mask = _mm256_set1_epi8(0x0f);
__m256i acc = _mm256_setzero_si256();
#define ITER { \
const __m256i vec = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(data + i)); \
const __m256i lo = _mm256_and_si256(vec, low_mask); \
const __m256i hi = _mm256_and_si256(_mm256_srli_epi16(vec, 4), low_mask); \
const __m256i popcnt1 = _mm256_shuffle_epi8(lookup, lo); \
const __m256i popcnt2 = _mm256_shuffle_epi8(lookup, hi); \
local = _mm256_add_epi8(local, popcnt1); \
local = _mm256_add_epi8(local, popcnt2); \
i += 32; \
}
while (i + 8*32 <= n) {
__m256i local = _mm256_setzero_si256();
ITER ITER ITER ITER
ITER ITER ITER ITER
acc = _mm256_add_epi64(acc, _mm256_sad_epu8(local, _mm256_setzero_si256()));
}
__m256i local = _mm256_setzero_si256();
while (i + 32 <= n) {
ITER;
}
acc = _mm256_add_epi64(acc, _mm256_sad_epu8(local, _mm256_setzero_si256()));
#undef ITER
uint64_t result = 0;
result += static_cast<uint64_t>(_mm256_extract_epi64(acc, 0));
result += static_cast<uint64_t>(_mm256_extract_epi64(acc, 1));
result += static_cast<uint64_t>(_mm256_extract_epi64(acc, 2));
result += static_cast<uint64_t>(_mm256_extract_epi64(acc, 3));
for (/**/; i < n; i++) {
result += lookup8bit[data[i]];
}
return result;
}
AVX-512
В этом же хранилище также реализована реализация AVX-512 на основе VPOPCNT:
# include <immintrin.h>
# include <x86intrin.h>
uint64_t avx512_vpopcnt(const uint8_t* data, const size_t size) {
const size_t chunks = size / 64;
uint8_t* ptr = const_cast<uint8_t*>(data);
const uint8_t* end = ptr + size;
// count using AVX512 registers
__m512i accumulator = _mm512_setzero_si512();
for (size_t i=0; i < chunks; i++, ptr += 64) {
// Note: a short chain of dependencies, likely unrolling will be needed.
const __m512i v = _mm512_loadu_si512((const __m512i*)ptr);
const __m512i p = _mm512_popcnt_epi64(v);
accumulator = _mm512_add_epi64(accumulator, p);
}
// horizontal sum of a register
uint64_t tmp[8] __attribute__((aligned(64)));
_mm512_store_si512((__m512i*)tmp, accumulator);
uint64_t total = 0;
for (size_t i=0; i < 8; i++) {
total += tmp[i];
}
// popcount the tail
while (ptr + 8 < end) {
total += _mm_popcnt_u64(*reinterpret_cast<const uint64_t*>(ptr));
ptr += 8;
}
while (ptr < end) {
total += lookup8bit[*ptr++];
}
return total;
}
lookup8bit
является таблицей поиска popcnt для байтов, а не битов, и определяется здесь. редактирование: как отмечают комментаторы, использование 8-битной таблицы поиска в конце не очень хорошая идея и может быть улучшено.
Функции popcnt большого массива Войцеха Мулы выглядят оптимально, за исключением скалярных циклов очистки. (См. Ответ @einpoklum для получения подробной информации об основных циклах).
LUT с 256 записями, которое вы используете только пару раз в конце, может привести к потере кэша и не является оптимальным для более чем 1 байта, даже если кэш был горячим. Я считаю, что все процессоры AVX2 имеют аппаратное обеспечение popcnt
и мы можем легко выделить последние до 8 байтов, которые еще не были подсчитаны, чтобы настроить нас для одного popcnt
,
Как обычно с алгоритмами SIMD, он часто хорошо работает для полной загрузки, которая заканчивается на последнем байте буфера. Но в отличие от векторного регистра, сдвиги с переменным счетом полного целочисленного регистра дешевы (особенно с BMI2). Popcnt не волнует, где находятся биты, поэтому мы можем просто использовать сдвиг вместо необходимости создавать маску AND или что-то еще.
// untested
// ptr points at the first byte that hasn't been counted yet
uint64_t final_bytes = reinterpret_cast<const uint64_t*>(end)[-1] >> (8*(end-ptr));
total += _mm_popcnt_u64( final_bytes );
// Careful, this could read outside a small buffer.
Или, что еще лучше, используйте более сложную логику, чтобы избежать пересечения страниц. Это может избежать пересечения страниц для 6-байтового буфера в начале страницы, например.