np.roll vs scipy.interpolation.shift- расхождение для значений целочисленного сдвига
Я написал некоторый код для смещения массива и пытался обобщить его для обработки нецелых сдвигов, используя функцию "shift" в scipy.ndimage
, Данные округлые, поэтому результат должен быть точно таким же, как np.roll
Команда делает это.
Тем не мение, scipy.ndimage.shift
не похоже, чтобы правильно переносить целочисленные сдвиги. Следующий фрагмент кода показывает несоответствие:
import numpy as np
import scipy.ndimage as sciim
import matplotlib.pyplot as plt
def shiftfunc(data, amt):
return sciim.interpolation.shift(data, amt, mode='wrap', order = 3)
if __name__ == "__main__":
xvals = np.arange(100)*1.0
yvals = np.sin(xvals*0.1)
rollshift = np.roll(yvals, 2)
interpshift = shiftfunc(yvals, 2)
plt.plot(xvals, rollshift, label = 'np.roll', alpha = 0.5)
plt.plot(xvals, interpshift, label = 'interpolation.shift', alpha = 0.5)
plt.legend()
plt.show()
Видно, что первая пара значений сильно не совпадает, а с остальными все в порядке. Я подозреваю, что это ошибка реализации операции предварительной фильтрации и интерполяции при использовании wrap
вариант. Обойти это можно было бы изменить shiftfunc
вернуться к np.roll, когда значение сдвига является целым числом, но это неудовлетворительно.
Я что-то упускаю здесь очевидное?
Есть ли способ сделать ndimage.shift
совпадают с np.roll
?
1 ответ
Я не думаю, что с функцией сдвига что-то не так. когда вы используете бросок, вам нужно нарезать дополнительный элемент для честного сравнения. пожалуйста, смотрите код ниже.
import numpy as np
import scipy.ndimage as sciim
import matplotlib.pyplot as plt
def shiftfunc(data, amt):
return sciim.interpolation.shift(data, amt, mode='wrap', order = 3)
def rollfunc(data,amt):
rollshift = np.roll(yvals, amt)
# Here I remove one element (first one before rollshift) from the array
return np.concatenate((rollshift[:amt], rollshift[amt+1:]))
if __name__ == "__main__":
shift_by = 5
xvals = np.linspace(0,2*np.pi,20)
yvals = np.sin(xvals)
rollshift = rollfunc(yvals, shift_by)
interpshift = shiftfunc(yvals,shift_by)
plt.plot(xvals, yvals, label = 'original', alpha = 0.5)
plt.plot(xvals[1:], rollshift, label = 'np.roll', alpha = 0.5,marker='s')
plt.plot(xvals, interpshift, label = 'interpolation.shift', alpha = 0.5,marker='o')
plt.legend()
plt.show()
результаты в