Сохранение данных из трассировки в PyMC3
Ниже приведен код простой байесовской линейной регрессии. После того, как я получу трассу и графики для параметров, есть ли способ сохранить данные, которые создали графики, в файле, так что если мне нужно будет построить их снова, я могу просто построить их из данных в файле вместо того, чтобы снова запустить всю симуляцию?
import pymc3 as pm
import matplotlib.pyplot as plt
import numpy as np
x = np.linspace(0,9,5)
y = 2*x + 5
yerr=np.random.rand(len(x))
def soln(x, p1, p2):
return p1+p2*x
with pm.Model() as model:
# Define priors
intercept = pm.Normal('Intercept', 15, sd=5)
slope = pm.Normal('Slope', 20, sd=5)
# Model solution
sol = soln(x, intercept, slope)
# Define likelihood
likelihood = pm.Normal('Y', mu=sol,
sd=yerr, observed=y)
# Sampling
trace = pm.sample(1000, nchains = 1)
pm.traceplot(trace)
print pm.summary(trace, ['Slope'])
print pm.summary(trace, ['Intercept'])
plt.show()
3 ответа
Есть два простых способа сделать это:
Используйте версию после 3.4.1 (в настоящее время это означает установку с мастера, с
pip install git+https://github.com/pymc-devs/pymc3
). Появилась новая функция, позволяющая эффективно сохранять и загружать трассы. Обратите внимание, что вам нужен доступ к модели, которая создала трассировку:... pm.save_trace(trace, 'linreg.trace') # later with model: trace = pm.load_trace('linreg.trace')
использование
cPickle
(или жеpickle
в питоне 3). Обратите внимание, чтоpickle
по крайней мере немного небезопасно, не извлекайте данные из ненадежных источников:import cPickle as pickle # just `import pickle` on python 3 ... with open('trace.pkl', 'wb') as buff: pickle.dump(trace, buff) #later with open('trace.pkl', 'rb') as buff: trace = pickle.load(buff)
Обновление для кого-то вроде меня, кто все еще подходит к этому вопросу:
Функции load_trace и save_trace были удалены. Начиная с версии 4.0, даже предупреждение об устаревании этих функций было удалено.
Способ сделать это теперь использовать arviz:
with model:
trace = pymc.sample(return_inferencedata=True)
trace.to_netcdf("filename.nc")
И его можно загрузить с помощью:
trace = arviz.from_netcdf("filename.nc")
У меня работает такой способ:
# saving trace
pm.save_trace(trace=trace_nb, directory=r"c:\Users\xxx\Documents\xxx\traces\trace_nb")
# loading saved traces
with model_nb:
t_nb = pm.load_trace(directory=r"c:\Users\xxx\Documents\xxx\traces\trace_nb")