Избегайте циклов for при вычислении точек сетки, которые являются суммой векторов
Я хотел бы ускорить свой код Matlab. Обычно я нахожу способы избежать циклов for, чтобы сэкономить время вычислений, но в этом случае я наталкиваюсь на кирпичную стену. Мне нужно вычислить значения в сетке точек, но для вычисления значения требуется логическая операция и сумма по вектору, и это усложняет реализацию. Этот код выполняется на моей машине примерно за 8 секунд:
clear all
% Grid
dLimsX=[-100 +100];
dLimsY=[-100 +100];
dStep=1;
[x_map, y_map]=meshgrid((dLimsX(1):dStep:dLimsX(2)),(dLimsY(1):dStep:dLimsY(2)));
nPoints_map=numel(x_map);
% Inputs
smallDistance=1e-3;
N=10e3;
scaleFactor=10;
x_input = sin(linspace(0,1,N));
y_input = cos(linspace(0,1,N));
z_input = linspace(0,1,N);
tic
A=zeros(size(x_map));
for r=1:size(x_map,1)
y0=y_map(r,1);
for c=1:size(x_map,2)
x0=x_map(1,c);
idxTemp = find((x0-x_input).^2+(y0-y_input).^2>smallDistance); % do not consider in the calculation the inputs too close to the point
A(r,c) = sum( scaleFactor * z_input(idxTemp) .* (y0-y_input(idxTemp)) ./ ((x0-x_input(idxTemp)).^2 +(y0-y_input(idxTemp)).^2+eps) );
end
end
toc
2 ответа
Ускорение кода не связано с удалением циклов for. Мне кажется, много случаев, когда векторизованный код работает медленнее, чем эквивалент цикла. Циклы MATLAB становились все быстрее и быстрее за последние 20 лет, и они больше не являются значительным источником замедления. Например, следующее только в 4 раза медленнее, чемsum(x)
для суммирования 1 миллиона элементов:
y = 0;
for ii = 1:numel(x)
y = y+x(ii);
end
Если вычисления внутри цикла более дорогие, накладные расходы цикла полностью исчезают.
Причина, по которой вы все еще можете извлечь выгоду из векторизации, заключается в том, что в коде цикла часто извлекается строка или столбец матрицы. Это связано с дорогостоящим копированием данных. С другой стороны, если векторизованному коду требуется большая промежуточная матрица, то сохранение этой матрицы в памяти будет узким местом, которое значительно замедлит векторизованный код. Доступ к памяти обычно является проблемой.
Чтобы сделать код быстрее, вы должны в первую очередь сосредоточиться на том, чтобы избежать дублирования вычислений. Например,(y0-y_input).^2
вычисляется 3*size(x_map,2)
раз! (1/3 времени для подмножества данных, но количество точек, удаленных индексированием, невелико).
Кроме того, вы должны использовать логическое индексирование и избегать использования find
. A(find(condition))
такой же как A(condition)
, но медленнее.
На моем компьютере ваш цикл занимает ~10,5 с, эта версия - ~5,1 с:
tic
A = zeros(size(x_map));
for r = 1:size(x_map,1)
y0 = y_map(r,1);
dy2 = (y0-y_input).^2;
for c = 1:size(x_map,2)
x0 = x_map(1,c);
dx2 = (x0-x_input).^2;
idxTemp = dx2 + dy2 > smallDistance; % do not consider in the calculation the inputs too close to the point
A(r,c) = sum(scaleFactor * z_input(idxTemp) .* (y0-y_input(idxTemp)) ./ (dx2(idxTemp) + dy2(idxTemp) + eps));
end
end
toc
Могут быть дальнейшие улучшения, например, отказ от повторного вычисления y0-y_input
во внутреннем цикле.
Ответ Криса Луенго дал мне огромный намек и заставил задуматься о том, какие вычисления я могу избежать повторения. Избегайте пересчетаx0-x_input
а также y0-y_input
, как предлагает Cris, уже сокращает время вычислений на 50-60%.
Кроме того, при использовании кода в другом цикле это помогло мне разделить, какие изменения и что можно вычислить только один раз. В моем случае,x0
, x_input
, y0
, y_input
оставаться всегда таким же;,z_input
меняется с каждой новой итерацией. Поэтому я разделил расчет на две части: во-первых, вычисляю один раз все, что не меняется; во-вторых, рассчитайте требуемые значения.
Чтобы иметь возможность выполнить второе вычисление с помощью простого умножения матриц, я преобразовал значения xy из матрицы (сетки) в вектор. Вот что я сделал:
% Arrange x-y values in two vectors
x = x_map(:);
y = y_map(:);
nPoints = numel(x);
z_input = linspace(0,1,N)';
tic
a = zeros(nPoints,N);
for p = 1:nPoints
x0 = x(p);
y0 = y(p);
dx2 = (x0-x_input).^2;
dy2 = (y0-y_input).^2;
idxTemp = dx2 + dy2 > smallDistance; % do not consider in the calculation the inputs too close to the point
a(p,:) = (scaleFactor * 1 .* dy2 ./ (dx2 + dy2 + eps)) .*idxTemp;
end
toc
tic
A2 = a*z_input;
toc
% Check that the values calculated with the alternative method are correct
mean(mean(abs((A(:)-A2)./A(:))))
Некоторые комментарии к результатам:
Исходный код: ~122 с на моем домашнем ПК (намного медленнее, чем мой офисный ПК из исходного сообщения)
Код, предложенный Крисом: ~50 с
Код выше: ~75 с для первой части. ~0,6 с для второй части.
Альтернативный расчет имеет относительную погрешность около 1E-17.
Поэтому в случаях, когда расчет повторяется несколько раз, имеет смысл предварительно рассчитать то, что не меняется. Недостаток - большее использование памяти, необходимое для хранения переменной.a
.
Буду признателен за отзыв.