Попытка распараллелить NTT с потоками cpp

Я впервые задаю вопросы здесь, так что извините, если что-то не так.
Я пытаюсь распараллелить NTT, используя потоки cpp, но я здесь просто потерялся. Я основал код на статье, объясняющей CUDA-распараллеливание NTT, и адаптировал его так, чтобы он имел больше смысла для процессора (меньше потоков), но я уперся в стену и не могу прогрессировать. По сути, создал класс для сопоставления каждого потока с той парой элементов в массиве, на которой он должен будет сделать бабочку, результаты неверны, и psis, похоже, вычисляется правильно (порядок битов в обратном порядке).
Я новичок в параллельных вычислениях и NTT, приветствую любую помощь.

      class threadInfo {
    public:
        std::vector<long long *> psi, u, v;
        void clear(){
            psi.clear();
            u.clear();
            v.clear();
        }
};

void threadButterfly(threadInfo info, long long mod, bool invert){
    for (long long i = 0; i < info.u.size(); i++)
    {
        long long u = * info.u[i], v = * info.v[i];
        if(!invert)     v = modulo(v* * info.psi[i],mod);
        * info.u[i] = modulo(u+v,mod);
        * info.v[i] = modulo(u-v,mod);
        if(invert)      * info.v[i] = modulo((u-v)* * info.psi[i],mod);
        else            * info.v[i] = modulo(u-v,mod);
    }
}

void threadSched(vector<long long> &a, long long mod, long long len, std::vector<long long> &psi, vector<threadInfo> info, bool invert){
    vector<thread> threads(THREAD_NUM);
    long long n = a.size();
    for (long long id = 0; id < n>>1; id++)                                 //puts each u and v pairs in each thread object
    {
        long long step = (a.size()/len)/2;                                  // step counts the distance between u and v
        long long psi_step = id/step;                                       // what k in psi**k to use relative to the first in the group
        long long target = (psi_step * step * 2) + (id % step);             // what u and v we want
        long long group = len + psi_step;                                   // what k in psi**k to use relative to all psis
        long long arrayid = floor((2*id*THREAD_NUM)/n);                     // what thread will the par go to
        info[arrayid].psi.push_back( & psi[group]);
        info[arrayid].u.push_back( & a[target]);
        info[arrayid].v.push_back( & a[target+step]);
    }
    for ( size_t id=0; id<THREAD_NUM; id++ )        threads[id] = thread(threadButterfly,info[id],mod,invert);
    for ( size_t id=0; id<THREAD_NUM; id++ )        threads[id].join(); 
    for ( long long i = 0; i<info.size(); i++ )     info[i].clear();
    if(invert)  for ( long long j = 0; j < n; j++ ) a[j]=modulo(a[j]*mod_in(n,mod),mod);
}

void ntt(vector<long long> &a, long long mod, vector<long long> &psi){
    vector<threadInfo> fwd(THREAD_NUM);
    for (long long len = 1; len < a.size(); len = 2 * len)
    {
        threadSched(a,mod,len,psi,fwd,false);
    }
}

void intt(vector<long long> &a, long long mod, vector<long long> &psi){
    vector<threadInfo> rev(THREAD_NUM);
    for (long long len = 1; len < a.size(); len = 2 * len)
    {
        threadSched(a,mod,len,psi,rev,true);
    }
}

Спасибо за внимание.

0 ответов

Другие вопросы по тегам