Почему я получаю неправильные результаты в матричной мультипликации с jcuda?
Я хотел сделать параллельное матричное умножение с jcuda. Я скачал и попробовал. Это работает для matix-сложения, но когда я пытаюсь умножить матрицу, я получаю неправильные результаты. Примерно 1e-7 - это ошибка, когда я инициализирую две матрицы случайным значением от 0 до 1. Когда я делаю случайное значение выше, я также получаю более высокую ошибку.
package JCuda;
import static jcuda.driver.JCudaDriver.cuCtxCreate;
import static jcuda.driver.JCudaDriver.cuCtxSynchronize;
import static jcuda.driver.JCudaDriver.cuDeviceGet;
import static jcuda.driver.JCudaDriver.cuInit;
import static jcuda.driver.JCudaDriver.cuLaunchKernel;
import static jcuda.driver.JCudaDriver.cuMemAlloc;
import static jcuda.driver.JCudaDriver.cuMemFree;
import static jcuda.driver.JCudaDriver.cuMemcpyDtoH;
import static jcuda.driver.JCudaDriver.cuMemcpyHtoD;
import static jcuda.driver.JCudaDriver.cuModuleGetFunction;
import static jcuda.driver.JCudaDriver.cuModuleLoad;
import java.io.IOException;
import java.security.SecureRandom;
import jcuda.Pointer;
import jcuda.Sizeof;
import jcuda.driver.CUcontext;
import jcuda.driver.CUdevice;
import jcuda.driver.CUdeviceptr;
import jcuda.driver.CUfunction;
import jcuda.driver.CUmodule;
import jcuda.driver.JCudaDriver;
public class Main {
public static void main(String args[]) throws IOException {
int N = 8;
JCudaDriver.setExceptionsEnabled(true);
cuInit(0);
CUdevice device = new CUdevice();
cuDeviceGet(device, 0);
CUcontext context = new CUcontext();
cuCtxCreate(context, 0, device);
CUmodule module = new CUmodule();
cuModuleLoad(module, "src/resources/MatMul.ptx");
CUfunction function = new CUfunction();
cuModuleGetFunction(function, module, "matmul");
int numElements = N * N;
for (int t = 0; t < 20; t++) {
float hostInputA[] = new float[numElements];
float hostInputB[] = new float[numElements];
SecureRandom ran = new SecureRandom();
for (int i = 0; i < numElements; i++) {
hostInputA[i] = ran.nextFloat();
hostInputB[i] = ran.nextFloat();
}
float hostOutput[] = new float[numElements];
long l = System.currentTimeMillis();
CUdeviceptr deviceInputA = new CUdeviceptr();
cuMemAlloc(deviceInputA, numElements * Sizeof.FLOAT);
cuMemcpyHtoD(deviceInputA, Pointer.to(hostInputA), numElements * Sizeof.FLOAT);
CUdeviceptr deviceInputB = new CUdeviceptr();
cuMemAlloc(deviceInputB, numElements * Sizeof.FLOAT);
cuMemcpyHtoD(deviceInputB, Pointer.to(hostInputB), numElements * Sizeof.FLOAT);
CUdeviceptr deviceOutput = new CUdeviceptr();
cuMemAlloc(deviceOutput, numElements * Sizeof.FLOAT);
Pointer kernelParameters = Pointer.to(Pointer.to(new int[] { numElements }), Pointer.to(deviceInputA),
Pointer.to(deviceInputB), Pointer.to(deviceOutput), Pointer.to(new int[] { N }),
Pointer.to(new int[] { N }));
int blockSizeX = (N<256) ? N : 256;
int gridSizeX = (int) Math.ceil(numElements / blockSizeX);
cuLaunchKernel(function, gridSizeX, 1, 1, // Grid dimension
blockSizeX, 1, 1, // Block dimension
0, null, // Shared memory size and stream
kernelParameters, null // Kernel- and extra parameters
);
cuCtxSynchronize();
cuMemcpyDtoH(Pointer.to(hostOutput), deviceOutput, numElements * Sizeof.FLOAT);
System.out.println(System.currentTimeMillis() - l);
boolean test = true;
for (int i = 0; i < numElements; i++) {
int x = i % N;
int y = i / N;
float sum = 0;
for (int k = 0; k < N; k++) {
sum += hostInputA[y * N + k] * hostInputB[k * N + x];
}
if (Math.abs(hostOutput[i] - sum) > 0) {
System.out.println("(" + i + ") Not passed, expected: " + sum + " but actualy out: " + hostOutput[i]
+ ", wrong by: " + (hostOutput[i] - sum));
test = false;
break;
}
}
System.out.println("Passed: " + (test ? "True" : "False"));
cuMemFree(deviceInputA);
cuMemFree(deviceInputB);
cuMemFree(deviceOutput);
}
}
}
`
И это код MatMul.cu (позже он преобразуется в MatMul.ptx):
extern "C"
__global__ void matmul(int n, float *mt, float *m, float *out, int mtcols, int mcols)
{
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i<n)
{
int x = i % mcols;
int y = i / mcols;
float sum = 0;
for (int k = 0; k < mtcols; k++) {
sum += mt[y * mtcols + k] * m[k * mcols + x];
}
out[i] = sum;
}
}