Ошибка с семплером pymc3 в pypesto: theano.graph.fg MissingInputError

Я занимаюсь проблемой байесовского вывода, и у меня возникли проблемы с использованием сэмплера pymc3, предоставленного pypesto, на моем ноутбуке с Windows. Чтобы убедиться, что я могу работать с семплером, я создаю простую фиктивную цель для использования.

Я устанавливаю среду create a conda (я пробовал как 3.7, так и 3.8) и устанавливаю модули pymc3 и theano с помощью pip3 / pip. Я пробовал несколько разных версий pymc3/theano, и мне удалось успешно их импортировать. Однако появляется сообщение об ошибке. Я не могу понять, как ее обойти. Я попытался найти решение в Интернете, но тоже не смог его найти. В настоящее время у меня установлены последние версии pymc3 и theano (3.11.0 и 1.0.5 соответственно). Это последняя строка сообщения

      theano.graph.fg.MissingInputError: Input 0 of the graph (indices start from 0), used to compute sigmoid(x2_interval__), was not provided and not given a value. Use the Theano flag exception_verbosity='high', for more information on this error.

Вот полное сообщение:

      Sampling 1 chain for 1_000 tune and 100 draw iterations (1_000 + 100 draws total) took 7 seconds.
Traceback (most recent call last):
  File "samplingPymc3.py", line 70, in <module>
    result2 = sample.sample(problem1, 100, sampler2, x0=np.array([0,0]))
  File "C:\Users\germa\anaconda3\envs\sampling\lib\site-packages\pypesto\sample\sample.py", line 68, in sample
    sampler.sample(n_samples=n_samples)
  File "C:\Users\germa\anaconda3\envs\sampling\lib\site-packages\pypesto\sample\pymc3.py", line 102, in sample
    **self.options)
  File "C:\Users\germa\anaconda3\envs\sampling\lib\site-packages\pymc3\sampling.py", line 637, in sample
    idata = arviz.from_pymc3(trace, **ikwargs)
  File "C:\Users\germa\anaconda3\envs\sampling\lib\site-packages\arviz\data\io_pymc3.py", line 559, in from_pymc3
    density_dist_obs=density_dist_obs,
  File "C:\Users\germa\anaconda3\envs\sampling\lib\site-packages\arviz\data\io_pymc3.py", line 163, in __init__
    self.observations, self.multi_observations = self.find_observations()
  File "C:\Users\germa\anaconda3\envs\sampling\lib\site-packages\arviz\data\io_pymc3.py", line 176, in find_observations
    multi_observations[key] = val.eval() if hasattr(val, "eval") else val
  File "C:\Users\germa\anaconda3\envs\sampling\lib\site-packages\theano\graph\basic.py", line 554, in eval
    self._fn_cache[inputs] = theano.function(inputs, self)
  File "C:\Users\germa\anaconda3\envs\sampling\lib\site-packages\theano\compile\function\__init__.py", line 350, in function
    output_keys=output_keys,
  File "C:\Users\germa\anaconda3\envs\sampling\lib\site-packages\theano\compile\function\pfunc.py", line 532, in pfunc
    output_keys=output_keys,
  File "C:\Users\germa\anaconda3\envs\sampling\lib\site-packages\theano\compile\function\types.py", line 1978, in orig_function
    name=name,
  File "C:\Users\germa\anaconda3\envs\sampling\lib\site-packages\theano\compile\function\types.py", line 1584, in __init__
    fgraph, additional_outputs = std_fgraph(inputs, outputs, accept_inplace)
  File "C:\Users\germa\anaconda3\envs\sampling\lib\site-packages\theano\compile\function\types.py", line 188, in std_fgraph
    fgraph = FunctionGraph(orig_inputs, orig_outputs, update_mapping=update_mapping)
  File "C:\Users\germa\anaconda3\envs\sampling\lib\site-packages\theano\graph\fg.py", line 162, in __init__
    self.import_var(output, reason="init")
  File "C:\Users\germa\anaconda3\envs\sampling\lib\site-packages\theano\graph\fg.py", line 330, in import_var
    self.import_node(var.owner, reason=reason)
  File "C:\Users\germa\anaconda3\envs\sampling\lib\site-packages\theano\graph\fg.py", line 383, in import_node
    raise MissingInputError(error_msg, variable=var)
theano.graph.fg.MissingInputError: Input 0 of the graph (indices start from 0), used to compute sigmoid(x2_interval__), was not provided and not given a value. Use the Theano flag exception_verbosity='high', for more information on this error.

Я где-то читал, что проблема может заключаться в используемой версии arviz, но в моем случае это не проблема. Я хотел включить сценарий, который я запускаю. Вот код скрипта:

      import numpy as np
import scipy as sp
import scipy.optimize as so
from scipy.stats import multivariate_normal
import pypesto
import pypesto.sample as sample
from pypesto import Objective

A = np.array([[2.0, 0.0], [0.0, 1.0]])
b = np.array([2.0, 1.0])
x_init = np.array([3.4302, 2.915])
x_true = np.array([1.0, 1.0])
temp = lambda x: A.dot(x) - b
f = lambda x: .5 * np.linalg.norm(temp(x))
A_t = A.transpose()
K = np.dot(A_t, A)
df = lambda x: K.dot(x) - A_t.dot(b)


def obj1(x):
    # f_val = f(x)
    # grad = df(x)
    return (f(x), df(x))


objfun = lambda x: obj1(x)
dim_full = 2
lb = -10 * np.ones((dim_full, 1))
ub = 10 * np.ones((dim_full, 1))
x_names = ['x1', 'x2']
# step_fcn = pymc3.step_methods.hmc.hmc.HamiltonianMC
objective = pypesto.Objective(fun=objfun, grad=True, hess=False)
problem1 = pypesto.Problem(objective=objective, lb=lb, ub=ub, x_names=x_names)
sampler = sample.AdaptiveMetropolisSampler()
print('function val: ', objfun(x_init))
sampler2 = sample.Pymc3Sampler()
result2 = sample.sample(problem1, 100, sampler2, x0=np.array([0, 0]))
print('Done sampling!')

Спасибо заранее за любую помощь!

1 ответ

pymc3 поддержка pypesto на данный момент ограничена, так как она была реализована в то время, когда theano был прекращен в пользу aesara в pymc3. Таким образом, pypesto поддерживает только определенные версии задействованных инструментов, в частности

      arviz >= 0.8.1, < 0.9.0
theano >= 1.0.4
packaging >= 20.0
pymc3 >= 3.8, < 3.9.2

(см. https://github.com/ICB-DCM/pyPESTO/blob/main/setup.cfg#L111). Переход на полную поддержку aesara и более поздних версий pymc3 находится в стадии реализации, но еще не вышел.

Другие вопросы по тегам