Softmax вероятности выбора с категорией в PyMC3

Я пытаюсь выполнить оценку параметра для одного параметра функции выбора softmax в следующем сценарии:

В каждом испытании дается три значения параметров (например, [1,2,3]), и субъект делает выбор между вариантами (0, 1 или 2). Функция softmax преобразует значения параметров в вероятности выбора (вектор из 3 вероятностей, суммирующих до 1), в зависимости от параметра температуры (здесь он ограничен значениями от 0 до 10).

Предполагается, что выбор в каждом испытании моделируется как категориальное распределение с вероятностями выбора, рассчитанными по softmax. Обратите внимание, что вероятности выбора Категорийного зависят от значений параметров и поэтому различны в каждом испытании.

Вот что я придумал:

# Generate data
nTrials = 60 # number of trials (value triplets and choices)
np.random.seed(42)
# generate nTrials triplets of values
values = np.random.choice([1,2,3,4,5], size=(nTrials, 3)) 

choices = values.argmax(axis=1) # choose highest value option
# add some random variation, so that *not* always the highest value option is chosen
errors = np.random.rand(nTrials)>0.8 # determine trials with non-optimal choice
# randomly determine new choices for these trials
choices[errors] = np.random.choice([0,1,2], size=sum(errors==True))

# Model specification & estimation
import pymc3 as pm
from theano import tensor as t
with pm.Model():

    # prior over theta
    theta = pm.Uniform('theta', lower=0, upper=10)

    # softmax implementation
    enumerator  = pm.exp(theta*values) 
    denominator = t.reshape(pm.sum(pm.exp(theta*values), axis=1), (nTrials, 1))
    ps = enumerator/denominator

    # Likelihood (sampling model for the data)
    for trial in range(nTrials):
        yobs = pm.Categorical('yobs{}'.format(trial), p=ps[trial], observed=choices[trial])

    # draw 500 samples from posterior
    trace = pm.sample(500, pm.Metropolis())

Этот код не работает для nTrials больше чем 50 с очень длинным предупреждением / сообщением об ошибке:

Предупреждение:

INFO (theano.gof.compilelock): Refreshing lock /Users/felixmolter/.theano/compiledir_Darwin-14.4.0-x86_64-i386-64bit-i386-2.7.8-64/lock_dir/lock
INFO:theano.gof.compilelock:Refreshing lock /Users/felixmolter/.theano/compiledir_Darwin-14.4.0-x86_64-i386-64bit-i386-2.7.8-64/lock_dir/lock
00001   #include <Python.h>
00002   #include <iostream>
00003   #include <math.h>
00004   #include <numpy/arrayobject.h>
00005   #include <numpy/arrayscalars.h>
00006   #include <vector>
00007   #include <algorithm>
00008   //////////////////////
00009   ////  Support Code
00010   //////////////////////
00011   
00012   
00013       namespace {
00014       struct __struct_compiled_op_65734e56ae54d89bdcf84e36893358e6 {
00015           PyObject* __ERROR;
00016   
00017           PyObject* storage_V3;
00018   PyObject* storage_V5;
00019   PyObject* storage_V7;
00020   PyObject* storage_V9;
00021   PyObject* storage_V11;
00022   PyObject* storage_V13;
[...]

Ошибка:

Exception: ('The following error happened while compiling the node', Elemwise{Composite{((Switch(LE(Abs((i0 + i1)), i2), log(i3), i4) + Switch(LE(Abs((i0 + i5)), i2), log(i6), i4) + Switch(LE(Abs((i0 + i7)), i2), log(i8), i4) + Switch(LE(Abs((i0 + i9)), i2), log(i10), i4) + Switch(LE(Abs((i0 + i11)), [...]

Я довольно новичок в PyMC (и Theano), и я чувствую, что моя реализация действительно неуклюжа и неоптимальна. Любая помощь и советы очень ценятся!

Феликс

Изменить: я загрузил код в виде записной книжки с полным отображением предупреждений и сообщений об ошибках: http://nbviewer.ipython.org/github/moltaire/softmaxPyMC/blob/master/softmax_stackru.ipynb

1 ответ

Я снова нашел это дело. Просто как продолжение, сейчас не в состоянии воспроизвести это сейчас. Так что я думаю, что это исправлено. Мы исправили связанную проблему, которая могла вызвать в некоторых случаях эту ошибку.

Работает с g ++ 4.5.1. Если у вас возникла эта проблема, обновите Theano до версии для разработчиков. Если это не помогает, попробуйте использовать более позднюю версию g ++, это может быть связано с более старой версией g ++.

Другие вопросы по тегам