Matlab отнимает слишком много времени на вычисление функции mex для большого массива
Я написал скрипт MATLAB, в котором я передаю несколько скаляров и один вектор-строку в качестве входных аргументов функции mex, и после выполнения некоторых вычислений он возвращает скаляр в качестве вывода. Этот процесс должен быть выполнен для всех элементов массива, размер которого составляет 1 X 1638400. Ниже приведен соответствующий код:
ans=0;
for i=0:1638400-1
temp = sub_imed(r,i,diff);
ans = ans + temp*diff(i+1);
end
где r,i - скаляры, diff - вектор размера 1 X 1638400, а sub_imed - функция MEX, которая выполняет следующую работу:
void sub_imed(double r,mwSize base, double* diff, mwSize dim, double* ans)
{
mwSize i,k,l,k1,l1;
double d,g,temp;
for(i=0; i<dim; i++)
{
k = (base/200) + 1;
l = (base%200) + 1;
k1 = (i/200) + 1;
l1 = (i%200) + 1;
d = sqrt(pow((k-k1),2) + pow((l-l1),2));
g=(1/(2*pi*pow(r,2)))*exp(-(pow(d,2))/(2*(pow(r,2))));
temp = temp + diff[i]*g;
}
*ans = temp;
}
void mexFunction(int nlhs,mxArray *plhs[],int nrhs,const mxArray *prhs[])
{
double *diff; /* Input data vectors */
double r; /* Value of r (input) */
double* ans; /* Output ImED distance */
size_t base,ncols; /* For storing the size of input vector and base */
/* Checking for proper number of arguments */
if(nrhs!=3)
mexErrMsgTxt("Error..Three inputs required.");
if(nlhs!=1)
mexErrMsgTxt("Error..Only one output required.");
/* make sure the first input argument(value of r) is scalar */
if( !mxIsDouble(prhs[0]) || mxIsComplex(prhs[0]) || mxGetNumberOfElements(prhs[0])!=1)
mexErrMsgTxt("Error..Value of r must be a scalar.");
/* make sure that the input value of base is a scalar */
if( !mxIsDouble(prhs[1]) || mxIsComplex(prhs[1]) || mxGetNumberOfElements(prhs[1])!=1)
mexErrMsgTxt("Error..Value of base must be a scalar.");
/* make sure that the input vector diff is of type double */
if(!mxIsDouble(prhs[2]) || mxIsComplex(prhs[2]))
mexErrMsgTxt("Error..Input vector must be of type double.");
/* check that number of rows in input arguments is 1 */
if(mxGetM(prhs[2])!=1)
mexErrMsgTxt("Error..Inputs must be row vectors.");
/* Get the value of r */
r = mxGetScalar(prhs[0]);
base = mxGetScalar(prhs[1]);
/* Getting the input vectors */
diff = mxGetPr(prhs[2]);
ncols = mxGetN(prhs[2]);
/* Creating link for the scalar output */
plhs[0] = mxCreateDoubleMatrix(1,1,mxREAL);
ans = mxGetPr(plhs[0]);
sub_imed(r,base,diff,(mwSize)ncols,ans);
}
Для получения более подробной информации о проблеме и алгоритме подчеркивания, пожалуйста, следуйте евклидову дистанцию между изображениями.
Я сделал профилирование моего сценария MATLAB и узнал, что это занимает 63 секунды. только для 387 вызовов функции sub_imed() mex. Таким образом, для 1638400 вызовов sub_imed, в идеале это займет около 74 часов, что слишком много.
Может кто-нибудь, пожалуйста, помогите мне оптимизировать код, предложив несколько альтернативных способов сократить время, затрачиваемое на вычисления.
Заранее спасибо.
2 ответа
Я перенес ваш код обратно в MATLAB и сделал несколько небольших корректировок, а результаты должны остаться прежними. Я ввел следующие константы:
N = 8192;
step = 0.005;
Обратите внимание, что N / step = 1638400
, С этим вы можете переписать вашу переменную k
(и переименуйте его в baseDiv
):
baseDiv = 1 + (0 : step : (N-step)).';
то есть 1:8193
по шагам 0.005
, Так же, l
является 1:200
(=1:(1/0.005)
), повторяется 8192 раза подряд, что называется (сейчас называется baseMod
):
baseMod = (repmat(1:1:(1/step), 1, N)).';
Ваши переменные k1
а также l1
просто i
й элемент k
а также l
т.е. baseDiv(i)
а также baseMod(i)
,
С векторами baseDiv
, а также baseMod
можно рассчитать d
, g
и ваша временная переменная tmp
с
d = sqrt((baseDiv(k)-baseDiv).^2 + (baseMod(k)-baseMod).^2);
g = 1/(2*pi*r^2) * exp(-(d.^2) / (2*r^2));
tmp = sum(diffVec .* g);
Мы можем поместить это в ваш цикл MATLAB for, чтобы вся программа стала
% Constants
N = 8192;
step = 0.005;
% Some example data
r = 2;
diffVec = rand(N/step,1);
base = (0:(numel(diffVec)-1)).';
baseDiv = (1:step:1+N-step).';
baseMod = (repmat(1:1:(1/step), 1, N)).';
res = 0;
for k=1:(N/step)
d = sqrt((baseDiv(k)-baseDiv).^2 + (baseMod(k)-baseMod).^2);
g = 1/(2*pi*r^2) * exp(-(d.^2) / (2*r^2));
tmp = sum(diffVec .* g);
res = res + tmp * diffVec(k);
end
Устраняя внутренний цикл for и вычисляя его в векторизованном виде, вам все равно нужно 11 секунд на 1000 итераций, в результате чего общее время выполнения составляет 5 часов. Тем не менее - ускорение более чем в 10 раз. Чтобы получить еще большее ускорение, у вас есть две возможности:
1) Полная векторизация: Вы можете легко векторизовать оставшийся цикл for, используя bsxfun(@minus, baseDiv, baseDiv.')
а также sum
над столбцами, чтобы рассчитать все значения одновременно. К сожалению, у нас есть небольшая проблема: двойная матрица 1638400 x 1638400 будет занимать 20 ТБ ОЗУ, чего, я полагаю, у вас нет в ноутбуке;-)
2) Меньше образцов: вы выполняете математическое преобразование с разрешением step=0.005
, Проверьте, действительно ли вам нужна эта точность! Если вы берете 1/10 точности: step=0.05
, вы в 100 раз быстрее и закончите за 3 минуты!
- Вы скомпилировали с флагом оптимизации? (См флаг
-O
) - Вместо того, чтобы использовать
pow(x,2)
в C/C++ вы должны написатьx*x
- Я почти уверен, что причина, почему это медленно, не из-за кода, который вы показали, а из-за того, что вы делаете в
mexFunction
, Если бы мне пришлось угадывать, я бы сказал, что вы бессмысленно копируете память вdiff
, но мы должны видеть всю функцию Mex, чтобы быть уверенным.
Попробуйте следующий код C++, используя mex -O myfile.cpp
:
void sub_imed( double r, size_t base, const double *diff, size_t dim, double& ans)
{
double d, g;
// these need to be double to avoid underflow
double k = base / 200;
double l = base % 200;
r = 2*r*r;
for(; dim; --dim, ++diff )
{
d = k - i/200;
g = l - i%200;
ans += (*diff) * exp( - (d*d + g*g)/r ) / (pi*r);
}
}
void mexFunction(int nlhs,mxArray *plhs[],int nrhs,const mxArray *prhs[])
{
/* Checking for proper number of arguments */
if(nrhs!=3)
mexErrMsgTxt("Error..Three inputs required.");
if(nlhs!=1)
mexErrMsgTxt("Error..Only one output required.");
/* make sure the first input argument(value of r) is scalar */
if( !mxIsDouble(prhs[0]) || mxIsComplex(prhs[0]) || mxGetNumberOfElements(prhs[0])!=1 )
mexErrMsgTxt("Error..Value of r must be a scalar.");
/* make sure that the input value of base is a scalar */
if( !mxIsDouble(prhs[1]) || mxIsComplex(prhs[1]) || mxGetNumberOfElements(prhs[1])!=1 )
mexErrMsgTxt("Error..Value of base must be a scalar.");
/* make sure that the input vector diff is of type double */
if( !mxIsDouble(prhs[2]) || mxIsComplex(prhs[2]) )
mexErrMsgTxt("Error..Input vector must be of type double.");
/* check that number of rows in input arguments is 1 */
if( mxGetM(prhs[2])!=1 )
mexErrMsgTxt("Error..Inputs must be row vectors.");
/* Get the value of r */
double r = mxGetScalar(prhs[0]);
size_t base = static_cast<size_t>(mxGetScalar(prhs[1]);
/* Getting the input vectors */
const double *diff = mxGetPr(prhs[2]);
size_t nrows = static_cast<size_t>(mxGetN(prhs[2]));
/* Creating link for the scalar output */
plhs[0] = mxCreateDoubleScalar(0.0);
sub_imed( r, base, diff, nrows, *mxGetPr(plhs[0]) );
}