np.roll vs scipy.interpolation.shift- расхождение для значений целочисленного сдвига

Я написал некоторый код для смещения массива и пытался обобщить его для обработки нецелых сдвигов, используя функцию "shift" в scipy.ndimage, Данные округлые, поэтому результат должен быть точно таким же, как np.roll Команда делает это.

Тем не мение, scipy.ndimage.shift не похоже, чтобы правильно переносить целочисленные сдвиги. Следующий фрагмент кода показывает несоответствие:

import numpy as np
import scipy.ndimage as sciim
import matplotlib.pyplot as plt 

def shiftfunc(data, amt):
    return sciim.interpolation.shift(data, amt, mode='wrap', order = 3)

if __name__ == "__main__":
    xvals = np.arange(100)*1.0

    yvals = np.sin(xvals*0.1)

    rollshift   = np.roll(yvals, 2)

    interpshift = shiftfunc(yvals, 2)

    plt.plot(xvals, rollshift, label = 'np.roll', alpha = 0.5)
    plt.plot(xvals, interpshift, label = 'interpolation.shift', alpha = 0.5)
    plt.legend()
    plt.show()

бросок против сдвига

Видно, что первая пара значений сильно не совпадает, а с остальными все в порядке. Я подозреваю, что это ошибка реализации операции предварительной фильтрации и интерполяции при использовании wrap вариант. Обойти это можно было бы изменить shiftfunc вернуться к np.roll, когда значение сдвига является целым числом, но это неудовлетворительно.

Я что-то упускаю здесь очевидное?

Есть ли способ сделать ndimage.shift совпадают с np.roll?

1 ответ

Решение

Я не думаю, что с функцией сдвига что-то не так. когда вы используете бросок, вам нужно нарезать дополнительный элемент для честного сравнения. пожалуйста, смотрите код ниже.

import numpy as np
import scipy.ndimage as sciim
import matplotlib.pyplot as plt 


def shiftfunc(data, amt):
    return sciim.interpolation.shift(data, amt, mode='wrap', order = 3)

def rollfunc(data,amt):
    rollshift   = np.roll(yvals, amt)
    # Here I remove one element (first one before rollshift) from the array 
    return np.concatenate((rollshift[:amt], rollshift[amt+1:]))

if __name__ == "__main__":
    shift_by = 5
    xvals = np.linspace(0,2*np.pi,20)
    yvals = np.sin(xvals)
    rollshift   = rollfunc(yvals, shift_by)
    interpshift = shiftfunc(yvals,shift_by)
    plt.plot(xvals, yvals, label = 'original', alpha = 0.5)
    plt.plot(xvals[1:], rollshift, label = 'np.roll', alpha = 0.5,marker='s')
    plt.plot(xvals, interpshift, label = 'interpolation.shift', alpha = 0.5,marker='o') 
    plt.legend()
    plt.show()

результаты в

Другие вопросы по тегам