Ускорение scipy griddata для нескольких интерполяций между двумя нерегулярными сетками

У меня есть несколько значений, которые определены на одной и той же нерегулярной сетке (x, y, z) что я хочу, чтобы интерполировать на новую сетку (x1, y1, z1), то есть у меня есть f(x, y, z), g(x, y, z), h(x, y, z) и я хочу рассчитать f(x1, y1, z1), g(x1, y1, z1), h(x1, y1, z1),

На данный момент я делаю это с помощью scipy.interpolate.griddata и это работает хорошо. Тем не менее, поскольку мне приходится выполнять каждую интерполяцию отдельно и есть много точек, она довольно медленная, с большим количеством дублирования в вычислениях (то есть нахождение точек, которые расположены ближе всего, настройка сеток и т. Д.).

Есть ли способ ускорить расчеты и сократить дублирующиеся расчеты? то есть что-то вроде определения двух сеток, а затем изменения значений для интерполяции?

3 ответа

Решение

Каждый раз, когда вы звоните scipy.interpolate.griddata:

  1. Во-первых, звонок в sp.spatial.qhull.Delaunay сделан для триангуляции нерегулярных координат сетки.
  2. Затем для каждой точки в новой сетке выполняется поиск триангуляции, чтобы определить, в каком треугольнике (собственно, в каком симплексе, в вашем трехмерном случае и в каком тетраэдре) он лежит.
  3. Вычисляются барицентрические координаты каждой новой точки сетки относительно вершин вмещающего симплекса.
  4. Для этой точки сетки вычисляются интерполированные значения с использованием барицентрических координат и значений функции в вершинах вмещающего симплекса.

Первые три шага идентичны для всех ваших интерполяций, поэтому, если бы вы могли хранить для каждой новой точки сетки индексы вершин вмещающего симплекса и веса для интерполяции, вы бы значительно минимизировали количество вычислений. К сожалению, это нелегко сделать напрямую с доступными функциями, хотя это действительно возможно:

import scipy.interpolate as spint
import scipy.spatial.qhull as qhull
import itertools

def interp_weights(xyz, uvw):
    tri = qhull.Delaunay(xyz)
    simplex = tri.find_simplex(uvw)
    vertices = np.take(tri.simplices, simplex, axis=0)
    temp = np.take(tri.transform, simplex, axis=0)
    delta = uvw - temp[:, d]
    bary = np.einsum('njk,nk->nj', temp[:, :d, :], delta)
    return vertices, np.hstack((bary, 1 - bary.sum(axis=1, keepdims=True)))

def interpolate(values, vtx, wts):
    return np.einsum('nj,nj->n', np.take(values, vtx), wts)

Функция interp_weights выполняет расчеты за первые три шага, которые я перечислил выше. Тогда функция interpolate использует эти рассчитанные значения, чтобы выполнить шаг 4 очень быстро:

m, n, d = 3.5e4, 3e3, 3
# make sure no new grid point is extrapolated
bounding_cube = np.array(list(itertools.product([0, 1], repeat=d)))
xyz = np.vstack((bounding_cube,
                 np.random.rand(m - len(bounding_cube), d)))
f = np.random.rand(m)
g = np.random.rand(m)
uvw = np.random.rand(n, d)

In [2]: vtx, wts = interp_weights(xyz, uvw)

In [3]: np.allclose(interpolate(f, vtx, wts), spint.griddata(xyz, f, uvw))
Out[3]: True

In [4]: %timeit spint.griddata(xyz, f, uvw)
1 loops, best of 3: 2.81 s per loop

In [5]: %timeit interp_weights(xyz, uvw)
1 loops, best of 3: 2.79 s per loop

In [6]: %timeit interpolate(f, vtx, wts)
10000 loops, best of 3: 66.4 us per loop

In [7]: %timeit interpolate(g, vtx, wts)
10000 loops, best of 3: 67 us per loop

Итак, во-первых, он делает то же самое, что и griddata, и это хорошо. Во-вторых, настройка интерполяции, то есть вычисления vtx а также wts занимает примерно так же, как вызов griddata, Но в-третьих, теперь вы можете интерполировать различные значения в одной и той же сетке практически за короткое время.

Единственное что griddata делает это не предусмотрено здесь присваивание fill_value к точкам, которые должны быть экстраполированы. Вы можете сделать это, проверив точки, для которых хотя бы один из весов отрицательный, например:

def interpolate(values, vtx, wts, fill_value=np.nan):
    ret = np.einsum('nj,nj->n', np.take(values, vtx), wts)
    ret[np.any(wts < 0, axis=1)] = fill_value
    return ret

Большое спасибо Хайме за его решение (даже если я не совсем понимаю, как выполняется барицентрическое вычисление...)

Здесь вы найдете пример, адаптированный из его случая в 2D:

import scipy.interpolate as spint
import scipy.spatial.qhull as qhull
import numpy as np

def interp_weights(xy, uv,d=2):
    tri = qhull.Delaunay(xy)
    simplex = tri.find_simplex(uv)
    vertices = np.take(tri.simplices, simplex, axis=0)
    temp = np.take(tri.transform, simplex, axis=0)
    delta = uv - temp[:, d]
    bary = np.einsum('njk,nk->nj', temp[:, :d, :], delta)
    return vertices, np.hstack((bary, 1 - bary.sum(axis=1, keepdims=True)))

def interpolate(values, vtx, wts):
    return np.einsum('nj,nj->n', np.take(values, vtx), wts)

m, n = 101,201
mi, ni = 1001,2001

[Y,X]=np.meshgrid(np.linspace(0,1,n),np.linspace(0,2,m))
[Yi,Xi]=np.meshgrid(np.linspace(0,1,ni),np.linspace(0,2,mi))

xy=np.zeros([X.shape[0]*X.shape[1],2])
xy[:,0]=Y.flatten()
xy[:,1]=X.flatten()
uv=np.zeros([Xi.shape[0]*Xi.shape[1],2])
uv[:,0]=Yi.flatten()
uv[:,1]=Xi.flatten()

values=np.cos(2*X)*np.cos(2*Y)

#Computed once and for all !
vtx, wts = interp_weights(xy, uv)
valuesi=interpolate(values.flatten(), vtx, wts)
valuesi=valuesi.reshape(Xi.shape[0],Xi.shape[1])
print "interpolation error: ",np.mean(valuesi-np.cos(2*Xi)*np.cos(2*Yi))  
print "interpolation uncertainty: ",np.std(valuesi-np.cos(2*Xi)*np.cos(2*Yi))  

Возможно применить преобразование изображения, такое как отображение изображения с ускорением ускорения

Вы не можете использовать то же определение функции, поскольку новые координаты будут меняться на каждой итерации, но вы можете вычислить триангуляцию раз и навсегда.

import scipy.interpolate as spint
import scipy.spatial.qhull as qhull
import numpy as np
import time

# Definition of the fast  interpolation process. May be the Tirangulation process can be removed !!
def interp_tri(xy):
    tri = qhull.Delaunay(xy)
    return tri


def interpolate(values, tri,uv,d=2):
    simplex = tri.find_simplex(uv)
    vertices = np.take(tri.simplices, simplex, axis=0)
    temp = np.take(tri.transform, simplex, axis=0)
    delta = uv- temp[:, d]
    bary = np.einsum('njk,nk->nj', temp[:, :d, :], delta)  
    return np.einsum('nj,nj->n', np.take(values, vertices),  np.hstack((bary, 1.0 - bary.sum(axis=1, keepdims=True))))

m, n = 101,201
mi, ni = 101,201

[Y,X]=np.meshgrid(np.linspace(0,1,n),np.linspace(0,2,m))
[Yi,Xi]=np.meshgrid(np.linspace(0,1,ni),np.linspace(0,2,mi))

xy=np.zeros([X.shape[0]*X.shape[1],2])
xy[:,1]=Y.flatten()
xy[:,0]=X.flatten()
uv=np.zeros([Xi.shape[0]*Xi.shape[1],2])
# creation of a displacement field
uv[:,1]=0.5*Yi.flatten()+0.4
uv[:,0]=1.5*Xi.flatten()-0.7
values=np.zeros_like(X)
values[50:70,90:150]=100.

#Computed once and for all !
tri = interp_tri(xy)
t0=time.time()
for i in range(0,100):
  values_interp_Qhull=interpolate(values.flatten(),tri,uv,2).reshape(Xi.shape[0],Xi.shape[1])
t_q=(time.time()-t0)/100

t0=time.time()
values_interp_griddata=spint.griddata(xy,values.flatten(),uv,fill_value=0).reshape(values.shape[0],values.shape[1])
t_g=time.time()-t0

print "Speed-up:", t_g/t_q
print "Mean error: ",(values_interp_Qhull-values_interp_griddata).mean()
print "Standard deviation: ",(values_interp_Qhull-values_interp_griddata).std()

На моем ноутбуке ускорение составляет от 20 до 40 раз!

Надеюсь, что это может помочь кому-то

У меня была та же проблема (griddata очень медленная, сетка остается неизменной для многих интерполяций), и мне больше всего понравилось решение, описанное здесь, в основном потому, что его очень легко понять и применить.

Он использует LinearNDInterpolator, где можно пройти триангуляцию Делоне, которую нужно вычислить только один раз. Скопируйте и вставьте из этого сообщения (все кредиты на xdze2):

from scipy.spatial import Delaunay
from scipy.interpolate import LinearNDInterpolator

tri = Delaunay(mesh1)  # Compute the triangulation

# Perform the interpolation with the given values:
interpolator = LinearNDInterpolator(tri, values_mesh1)
values_mesh2 = interpolator(mesh2)

Это ускоряет мои вычисления примерно в 2 раза.

Вы можете попробовать использовать Pandas, так как он обеспечивает высокопроизводительные структуры данных.

Это правда, что метод интерполяции является оберткой для скучной интерполяции, НО, может быть, с улучшенными структурами вы получите лучшую скорость.

import pandas as pd;
wp = pd.Panel(randn(2, 5, 4));
wp.interpolate();

interpolate() заполняет значения NaN в наборе данных Panel разными способами. Надеюсь, это быстрее, чем Сципи.

Если это не работает, есть один способ повысить производительность (вместо использования распараллеленной версии вашего кода): используйте Cython и внедрите небольшую подпрограмму в C для использования внутри кода Python. Вот вам пример по этому поводу.

Другие вопросы по тегам