Неожиданный апостериор при подгонке гауссовского процесса с использованием sklearn
Я новичок в гауссовских процессах (GP) и в настоящее время пытаюсь подогнать модель к некоторым зависящим от времени данным, которые я использую sklearn.GaussianProcessRegressor
. После подгонки предыдущий выглядит разумным (насколько я понимаю) там, где заданы точки данных, но резко идет к 0 везде, как показано здесь (извините, все еще не разрешено публиковать изображения).
В заголовке графика показаны параметры ядра до и после подгонки. Мой код основан на примерах, как этот, и этот, и другие с аналогичными структурами.
Я несколько раз пробовал подгонку с разными начальными параметрами, с добавлением шумового ядра и без него, и с разными данными одного и того же характера, но результат всегда был нереалистичным скачком к нулю (и, соответственно, большими стандартными отклонениями). Это сделало бы рисунки из апостериорного распределения не подходящими для этих данных.
Правильно ли я использую георадар? Есть ли рекомендуемый способ выбора параметров ядра? Кроме того, есть ли какая-то математическая интуиция относительно того, почему среднее значение стремится к нулю, когда в данных есть большие "пробелы", вместо более плавного перехода к следующей точке данных (как можно увидеть в ссылках с примерами)?
Изменить: после настройки normalize_y=True
чтобы установить среднее значение распределения равным среднему значению данных, подобранный GP просто покрывает все данные (см. график), что не дает информативных прогнозов. Таким образом, похоже, что это проблема ядра / параметров.
Код, который я использую,
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from sklearn.gaussian_process import GaussianProcessRegressor
from sklearn.gaussian_process.kernels import RBF,\
WhiteKernel as White, ConstantKernel as Constant
def plot_gp(x, y, xs, y_mean, gp):
"""
Plot a GP model
:param x: given x-axis points used for fitting
:param y: given y-axis points used for fitting
:param xs: x-axis points to predict
:param y_mean: mean of fitted GP
:param gp: the fitted GaussianProcessRegressor
"""
plt.figure(figsize=(10, 5))
plt.plot(xs, y_mean, 'k', lw=3) # plot posterior mean
plt.fill_between(xs, # plot posterior std/confidence
y_mean - 1.96 * std,
y_mean + 1.96 * std,
alpha=0.4, color='k', zorder=9)
plt.scatter(x, y, c='r', s=40, # plot given data points
zorder=10,
edgecolors=(0, 0, 0))
plt.title(f"prior: {gp.kernel}\n" # print prior and posterior kernels
f"posterior: {gp.kernel_}\n"
f"Log-Marginal-Likelihood: {gp.log_marginal_likelihood(gp.kernel_.theta)}",
fontsize=14)
plt.show()
kernel = Constant(1.0) *\
RBF(length_scale=1, length_scale_bounds=(1e-2, 1e3)) +\
White(noise_level=1e-5, noise_level_bounds=(1e-10, 1e+1))
gp = GaussianProcessRegressor(kernel=kernel, alpha=0.0).fit(X, y)
mean, std = gp.predict(Xstar[:, np.newaxis], return_std=True)
plot_gp(X, y, Xstar, mean, gp)
Входные данные, используемые для прилагаемого графика:
X = [[ 1] [ 2] [ 3] [ 4] [ 5] [ 6] [ 7] [12] [15] [16] [17] [18] [19] [20] [26] [27] [36] [44] [53] [54] [55] [56] [57] [61] [62] [63] [64] [65] [66] [67] [68] [69] [70] [71] [72] [73]]
y = [220.33333333 226. 217. 209. 243.
214.5 195. 219.33333333 210. 218.66666667
221. 194. 209. 197. 214.
244. 222. 224. 217. 243.25
206. 222.5 231. 220. 207.
225. 190. 219. 232. 219.5
200. 223. 236. 228.5 239.
216.5 ]
Xstar = [ 0. 0.50326797 1.00653595 1.50980392 2.0130719 2.51633987
3.01960784 3.52287582 4.02614379 4.52941176 5.03267974 5.53594771
6.03921569 6.54248366 7.04575163 7.54901961 8.05228758 8.55555556
9.05882353 9.5620915 10.06535948 10.56862745 11.07189542 11.5751634
12.07843137 12.58169935 13.08496732 13.58823529 14.09150327 14.59477124
15.09803922 15.60130719 16.10457516 16.60784314 17.11111111 17.61437908
18.11764706 18.62091503 19.12418301 19.62745098 20.13071895 20.63398693
21.1372549 21.64052288 22.14379085 22.64705882 23.1503268 23.65359477
24.15686275 24.66013072 25.16339869 25.66666667 26.16993464 26.67320261
27.17647059 27.67973856 28.18300654 28.68627451 29.18954248 29.69281046
30.19607843 30.69934641 31.20261438 31.70588235 32.20915033 32.7124183
33.21568627 33.71895425 34.22222222 34.7254902 35.22875817 35.73202614
36.23529412 36.73856209 37.24183007 37.74509804 38.24836601 38.75163399
39.25490196 39.75816993 40.26143791 40.76470588 41.26797386 41.77124183
42.2745098 42.77777778 43.28104575 43.78431373 44.2875817 44.79084967
45.29411765 45.79738562 46.30065359 46.80392157 47.30718954 47.81045752
48.31372549 48.81699346 49.32026144 49.82352941 50.32679739 50.83006536
51.33333333 51.83660131 52.33986928 52.84313725 53.34640523 53.8496732
54.35294118 54.85620915 55.35947712 55.8627451 56.36601307 56.86928105
57.37254902 57.87581699 58.37908497 58.88235294 59.38562092 59.88888889
60.39215686 60.89542484 61.39869281 61.90196078 62.40522876 62.90849673
63.41176471 63.91503268 64.41830065 64.92156863 65.4248366 65.92810458
66.43137255 66.93464052 67.4379085 67.94117647 68.44444444 68.94771242
69.45098039 69.95424837 70.45751634 70.96078431 71.46405229 71.96732026
72.47058824 72.97385621 73.47712418 73.98039216 74.48366013 74.9869281
75.49019608 75.99346405 76.49673203 77. ]