Умножение xmm регистра

У меня проблема с умножением двух регистров в ассемблере sse. Здесь есть мой код:

moltiplicazionePuntoPunto:
    mov edx,[esp+20]                 ; edx = fxx
    mov esi,[esp+4]                  ; esi = fx
    mov edi,[esp+8]                  ; edi = fy
    xor eax,eax                      ; i=0
 fori:   cmp eax,[esp+12]            ; confronta i con N
    jge endfori
    xor ebx,ebx                       ; j=0
 forj:   cmp ebx,[esp+16]             ; confronta j con M
    jge endforj   
    mov ecx,eax
    imul ecx,[esp+16]                 ; ecx = i*M
    add ecx,ebx                       ; ecx = i*M+j
    movss xmm5,[esi+ecx*4]            ; xmm5 = fx[i*M+j]
    movss xmm6,[edi+ecx*4]            ; xmm6 = fy[i*M+j]
    mulps xmm5,xmm6                   ; xmm7 = fx[i*M+j]*fx[i*M+j]
    movss [edx+ecx*4],xmm5            ; fxx[i*M+j] = fx*fx
    inc ebx
    jmp forj
 endforj:
    inc eax
    jmp fori
 endfori: 

Этот код модифицирует матрицу fxx, где элемент fxx[i*M+j] = fx[i*M+j] * fy[i*M+j]. Проблема в том, когда я делаю операцию mulps xmm5,xmm6 результат 0.

1 ответ

Проблема решена. Проблема заключалась в том, что я передал из C матрицу типа int. Вместо этого, если я передаю матрицу с плавающей запятой, код работает.

Например, упрощенный C++, он будет просто проходить все элементы матрицы, потому что это то, что ваш [i,j] вложенный цикл делает. Вам не нужно рассчитывать i*M+j, поскольку ваша формула не использует i/j каким-либо особым образом, она просто проходит все элементы один раз:

void muldata(float* fxx, const float* fx, const float* fy, const unsigned int M, const unsigned int N) {
    int ofs = 0;
    do {
        fxx[ofs] = fx[ofs] * fy[ofs];
        ++ofs;
    } while (ofs < M*N);
}

Сделаю clang -O3 -m32 (v4.0.0) произвести это:

muldata(float*, float const*, float const*, unsigned int, unsigned int):                   # @muldata(float*, float const*, float const*, unsigned int, unsigned int)
        push    ebp
        push    ebx
        push    edi
        push    esi
        sub     esp, 12
        mov     esi, dword ptr [esp + 48]
        mov     edi, dword ptr [esp + 40]
        mov     ecx, dword ptr [esp + 36]
        mov     edx, dword ptr [esp + 32]
        mov     eax, 1
        imul    esi, dword ptr [esp + 44]
        cmp     esi, 1
        cmova   eax, esi
        xor     ebp, ebp
        cmp     eax, 8
        jb      .LBB0_7
        mov     ebx, eax
        and     ebx, -8
        je      .LBB0_7
        mov     dword ptr [esp + 4], eax # 4-byte Spill
        cmp     esi, 1
        mov     eax, 1
        mov     dword ptr [esp], ebx    # 4-byte Spill
        cmova   eax, esi
        lea     ebx, [ecx + 4*eax]
        lea     edi, [edx + 4*eax]
        mov     dword ptr [esp + 8], ebx # 4-byte Spill
        mov     ebx, dword ptr [esp + 40]
        cmp     edx, dword ptr [esp + 8] # 4-byte Folded Reload
        lea     eax, [ebx + 4*eax]
        sbb     bl, bl
        cmp     ecx, edi
        sbb     bh, bh
        and     bh, bl
        cmp     edx, eax
        sbb     al, al
        cmp     dword ptr [esp + 40], edi
        mov     edi, dword ptr [esp + 40]
        sbb     ah, ah
        test    bh, 1
        jne     .LBB0_7
        and     al, ah
        and     al, 1
        jne     .LBB0_7
        mov     eax, dword ptr [esp]    # 4-byte Reload
        lea     ebx, [edi + 16]
        lea     ebp, [ecx + 16]
        lea     edi, [edx + 16]
.LBB0_5:                                # =>This Inner Loop Header: Depth=1
        movups  xmm0, xmmword ptr [ebp - 16]
        movups  xmm2, xmmword ptr [ebx - 16]
        movups  xmm1, xmmword ptr [ebp]
        movups  xmm3, xmmword ptr [ebx]
        add     ebp, 32
        add     ebx, 32
        mulps   xmm2, xmm0
        mulps   xmm3, xmm1
        movups  xmmword ptr [edi - 16], xmm2
        movups  xmmword ptr [edi], xmm3
        add     edi, 32
        add     eax, -8
        jne     .LBB0_5
        mov     eax, dword ptr [esp]    # 4-byte Reload
        mov     edi, dword ptr [esp + 40]
        cmp     dword ptr [esp + 4], eax # 4-byte Folded Reload
        mov     ebp, eax
        je      .LBB0_8
.LBB0_7:                                # =>This Inner Loop Header: Depth=1
        movss   xmm0, dword ptr [ecx + 4*ebp] # xmm0 = mem[0],zero,zero,zero
        mulss   xmm0, dword ptr [edi + 4*ebp]
        movss   dword ptr [edx + 4*ebp], xmm0
        inc     ebp
        cmp     ebp, esi
        jb      .LBB0_7
.LBB0_8:
        add     esp, 12
        pop     esi
        pop     edi
        pop     ebx
        pop     ebp
        ret

Что намного превосходит ваш код (по умолчанию включает векторизацию цикла).

И было бы, вероятно, даже что-то лучше, если бы вы указали выравнивание указателей и сделали бы постоянную времени компиляции M/N.


Я только что убедился, что вариант C++ работает, зайдя на сайт cpp.sh и расширив его до следующего:

#include <iostream>

void muldata(float* fxx, const float* fx, const float* fy, const unsigned int M, const unsigned int N) {
    unsigned int ofs = 0;
    do {
        fxx[ofs] = fx[ofs] * fy[ofs];
        ++ofs;
    } while (ofs < M*N);
}

int main()
{
    // constexpr unsigned int M = 1;
    // constexpr unsigned int N = 1;
    // const float fx[M*N] = { 2.2f };
    // const float fy[M*N] = { 3.3f };

    constexpr unsigned int M = 3;
    constexpr unsigned int N = 2;
    const float fx[M*N] = { 2.2f, 1.0f, 0.0f,
                            1.0f, 1.0f, 1e-24f };
    const float fy[M*N] = { 3.3f, 3.3f, 3.3f,
                            5.5f, 1e30f, 1e-24f };

    float fr[M*N];
    muldata(fr, fx, fy, M, N);
    for (unsigned int i = 0; i < N; ++i) {
        for (unsigned int j = 0; j < M; ++j) std::cout << fr[i*M+j] << " ";
        std::cout << std::endl;
    }
}

выход:

7.26 3.3 0 
5.5 1e+30 0 

Там также закомментированы входные данные 1x1, которые должны быть первыми, что нужно отладить в вашем случае. Попробуйте заставить этот пример работать в вашей любимой C++ IDE, затем замените muldata с вашим ассемблерным кодом, и отладка через него, чтобы увидеть, где он выходит.

Другие вопросы по тегам