Почему я получаю неправильные результаты в матричной мультипликации с jcuda?

Я хотел сделать параллельное матричное умножение с jcuda. Я скачал и попробовал. Это работает для matix-сложения, но когда я пытаюсь умножить матрицу, я получаю неправильные результаты. Примерно 1e-7 - это ошибка, когда я инициализирую две матрицы случайным значением от 0 до 1. Когда я делаю случайное значение выше, я также получаю более высокую ошибку.

package JCuda; 
import static jcuda.driver.JCudaDriver.cuCtxCreate;
import static jcuda.driver.JCudaDriver.cuCtxSynchronize;
import static jcuda.driver.JCudaDriver.cuDeviceGet;
import static jcuda.driver.JCudaDriver.cuInit;
import static jcuda.driver.JCudaDriver.cuLaunchKernel;
import static jcuda.driver.JCudaDriver.cuMemAlloc;
import static jcuda.driver.JCudaDriver.cuMemFree;
import static jcuda.driver.JCudaDriver.cuMemcpyDtoH;
import static jcuda.driver.JCudaDriver.cuMemcpyHtoD;
import static jcuda.driver.JCudaDriver.cuModuleGetFunction;
import static jcuda.driver.JCudaDriver.cuModuleLoad;

import java.io.IOException;
import java.security.SecureRandom;

import jcuda.Pointer;
import jcuda.Sizeof;
import jcuda.driver.CUcontext;
import jcuda.driver.CUdevice;
import jcuda.driver.CUdeviceptr;
import jcuda.driver.CUfunction;
import jcuda.driver.CUmodule;
import jcuda.driver.JCudaDriver;

public class Main {

public static void main(String args[]) throws IOException {
    int N = 8;
    JCudaDriver.setExceptionsEnabled(true);

    cuInit(0);
    CUdevice device = new CUdevice();
    cuDeviceGet(device, 0);
    CUcontext context = new CUcontext();
    cuCtxCreate(context, 0, device);

    CUmodule module = new CUmodule();
    cuModuleLoad(module, "src/resources/MatMul.ptx");

    CUfunction function = new CUfunction();
    cuModuleGetFunction(function, module, "matmul");

    int numElements = N * N;

    for (int t = 0; t < 20; t++) {

        float hostInputA[] = new float[numElements];
        float hostInputB[] = new float[numElements];

        SecureRandom ran = new SecureRandom();
        for (int i = 0; i < numElements; i++) {
            hostInputA[i] = ran.nextFloat();
            hostInputB[i] = ran.nextFloat();
        }

        float hostOutput[] = new float[numElements];

        long l = System.currentTimeMillis();

        CUdeviceptr deviceInputA = new CUdeviceptr();
        cuMemAlloc(deviceInputA, numElements * Sizeof.FLOAT);
        cuMemcpyHtoD(deviceInputA, Pointer.to(hostInputA), numElements * Sizeof.FLOAT);

        CUdeviceptr deviceInputB = new CUdeviceptr();
        cuMemAlloc(deviceInputB, numElements * Sizeof.FLOAT);
        cuMemcpyHtoD(deviceInputB, Pointer.to(hostInputB), numElements * Sizeof.FLOAT);

        CUdeviceptr deviceOutput = new CUdeviceptr();
        cuMemAlloc(deviceOutput, numElements * Sizeof.FLOAT);

        Pointer kernelParameters = Pointer.to(Pointer.to(new int[] { numElements }), Pointer.to(deviceInputA),
                Pointer.to(deviceInputB), Pointer.to(deviceOutput), Pointer.to(new int[] { N }),
                Pointer.to(new int[] { N }));

        int blockSizeX = (N<256) ? N : 256;
        int gridSizeX = (int) Math.ceil(numElements / blockSizeX);
        cuLaunchKernel(function, gridSizeX, 1, 1, // Grid dimension
                blockSizeX, 1, 1, // Block dimension
                0, null, // Shared memory size and stream
                kernelParameters, null // Kernel- and extra parameters
        );

        cuCtxSynchronize();

        cuMemcpyDtoH(Pointer.to(hostOutput), deviceOutput, numElements * Sizeof.FLOAT);

        System.out.println(System.currentTimeMillis() - l);

        boolean test = true;
        for (int i = 0; i < numElements; i++) {
            int x = i % N;
            int y = i / N;
            float sum = 0;
            for (int k = 0; k < N; k++) {
                sum += hostInputA[y * N + k] * hostInputB[k * N + x];
            }

            if (Math.abs(hostOutput[i] - sum) > 0) {
                System.out.println("(" + i + ") Not passed, expected: " + sum + " but actualy out: " + hostOutput[i]
                        + ", wrong by: " + (hostOutput[i] - sum));
                test = false;
                break;
            }

        }
        System.out.println("Passed: " + (test ? "True" : "False"));

        cuMemFree(deviceInputA);
        cuMemFree(deviceInputB);
        cuMemFree(deviceOutput);
    }

}

}

`

И это код MatMul.cu (позже он преобразуется в MatMul.ptx):

extern "C"
__global__ void matmul(int n, float *mt, float *m, float *out, int mtcols, int mcols)
{
    int i = blockIdx.x * blockDim.x + threadIdx.x;
    if (i<n)
    {
        int x = i % mcols;
        int y = i / mcols;
        float sum = 0;
        for (int k = 0; k < mtcols; k++) {
            sum += mt[y * mtcols + k] * m[k * mcols + x];
        }
        out[i] = sum;
    }
}

0 ответов

Другие вопросы по тегам