Простая нейронная сеть Lasange не работает

Я использую пакет Lasagne для создания простой трехслойной нейронной сети и тестирую ее на очень простом наборе данных (всего 4 примера).

X = np.array([[0,0,1],
              [0,1,1],
              [1,0,1],
              [1,1,1]])         

y = np.array([[0, 0],[1, 0],[1, 1],[0, 1]])

Однако это не в состоянии изучить это, и приводит к предсказанию:

pred = theano.function([input_var], [prediction])
np.round(pred(X), 2)
array([[[ 0.5 ,  0.5 ],
        [ 0.98,  0.02],
        [ 0.25,  0.75],
        [ 0.25,  0.75]]])

Полный код:

def build_mlp(input_var=None):
    l_in = lasagne.layers.InputLayer(shape=(None, 3), input_var=input_var)
    l_hid1 = lasagne.layers.DenseLayer(
        l_in, num_units=4,
        nonlinearity=lasagne.nonlinearities.rectify,
        W=lasagne.init.GlorotUniform())
    l_hid2 = lasagne.layers.DenseLayer(
        l_hid1, num_units=4,
        nonlinearity=lasagne.nonlinearities.rectify,
        W=lasagne.init.GlorotUniform())
    l_out = lasagne.layers.DenseLayer(
        l_hid2, num_units=2,
        nonlinearity=lasagne.nonlinearities.softmax)
    return l_out

input_var = T.lmatrix('inputs')
target_var = T.lmatrix('targets')

network = build_mlp(input_var)

prediction = lasagne.layers.get_output(network, deterministic=True)
loss = lasagne.objectives.squared_error(prediction, target_var)
loss = loss.mean()

params = lasagne.layers.get_all_params(network, trainable=True)
updates = lasagne.updates.nesterov_momentum(
    loss, params, learning_rate=0.01, momentum=0.9)

train_fn = theano.function([input_var, target_var], loss, updates=updates)
val_fn = theano.function([input_var, target_var], [loss])

Повышение квалификации:

num_epochs = 1000
for epoch in range(num_epochs):
    inputs, targets = (X, y)
    train_fn(inputs, targets)   

Я предполагаю, что может быть проблема с нелинейными функциями, используемыми в скрытых слоях, или с методом обучения.

2 ответа

Это мое предположение о проблеме,

Во-первых, я не знаю, почему есть выход [0,0]? это значит, что выборка не классифицируется во всех классах?

Во-вторых, вы используете Softmax в последнем слое, который обычно используется для классификации, вы строите эту сеть для классификации? если вы запутались в выводе, на самом деле вывод - это вероятность каждого класса. Поэтому я думаю, что вывод правильный:

  • предсказание второго образца [0.98 0.02] так что это означает, что второй образец принадлежит первому классу, как ваша цель [1 0]

  • предсказание третьего образца [0.25 0.75] так что это означает, что третий образец относится ко второму классу, как ваша цель [1 1] (независимо от вашего первого значения класса, это классификация, поэтому она будет считаться правильной классификацией по системе)

  • Прогноз четвертого образца [0.25 0.75] так что это означает, что четвертый образец принадлежит второму классу, как ваша цель [0 1]

  • предсказание первой выборки [0.5 0.5] этот мне кажется немного смущающим, поэтому я предполагаю, что Лазанье предскажет первый образец, который имеет одинаковую вероятность в каждом классе, не будучи членом какого-либо класса

Я чувствую, что вы не можете судить, правильно ли модель учится на основе вышеизложенного.

  1. Количество тренировочных экземпляров. У вас есть 4 тренировочных экземпляра. Созданная вами нейронная сеть содержит 3*4 + 4*4 + 4*2 = 36 весов, которые она должна выучить. Не говоря уже о том, что у вас есть 4 различных типа выходов. Сеть определенно не подходит, что может объяснить неожиданные результаты.

  2. Как проверить, работает ли модель Если бы я хотел проверить, правильно ли обучается нейронная сеть, я бы протестировал рабочий набор данных (например, MNIST) и убедился, что моя модель обучается с высокой вероятностью. Вы также можете попробовать сравнить с другой библиотекой нейронной сети, которую вы уже написали, или с литературой. Если бы я действительно хотел использовать микро, я бы использовал бустинг с линейно разделяемым набором данных.

Если ваша модель по-прежнему не учится должным образом, я был бы обеспокоен.

Другие вопросы по тегам