Обратное распространение не работает: нейронная сеть Java
Я создал простую нейронную сеть с 3 слоями в соответствии с этим примером на python: Link (PS: вам нужно прокрутить вниз, пока не дойдете до части 2)
Это моя реализация Java-кода:
private void trainNet()
{
// INPUT is a 4*3 matrix
// SYNAPSES is a 3*4 matrix
// SYNAPSES2 is a 4*1 matrix
// 4*3 matrix DOT 3*4 matrix => 4*4 matrix: unrefined test results
double[][] layer1 = sigmoid(dot(inputs, synapses), false);
// 4*4 matrix DOT 4*1 matrix => 4*1 matrix: 4 final test results
double[][] layer2 = sigmoid(dot(layer1, synapses2), false);
// 4*1 matrix - 4*1 matrix => 4*1 matrix: error of 4 test results
double[][] layer2Error = subtract(outputs, layer2);
// 4*1 matrix DOT 4*1 matrix => 4*1 matrix: percentage of change of 4 test results
double[][] layer2Delta = dot(layer2Error, sigmoid(layer2, true));
// 4*1 matrix DOT 3*1 matrix => 4*1 matrix
double[][] layer1Error = dot(layer2Delta, synapses2);
// 4*1 matrix DOT 4*4 matrix => 4*4 matrix: percentage of change of 4 test results
double[][] layer1Delta = dot(layer1Error, sigmoid(layer1, true));
double[][] transposedInputs = transpose(inputs);
double[][] transposedLayer1 = transpose(layer1);
// 4*4 matrix DOT 4*1 matrix => 4*1 matrix: the updated weights
// Update the weights
synapses2 = sum(synapses2, dot(transposedLayer1, layer2Delta));
// 3*4 matrix DOT 4*4 matrix => 3*4 matrix: the updated weights
// Update the weights
synapses = sum(synapses, dot(transposedInputs, layer1Delta));
// Test each value of two 4*1 matrices with each other
testValue(layer2, outputs);
}
Функции "точка", "сумма", "вычитание" и "транспонирование" я создал сам и уверен, что они отлично справляются со своей задачей.
Первая партия входных данных дает мне ошибку около 0,4, что нормально, потому что веса имеют случайное значение. При втором запуске допустимая погрешность меньше, но только на очень большое количество пальцев (0,001)
После 500 000 пакетов (итого 2 000 000 тестов) сеть все еще не дала правильного значения! Поэтому я попытался использовать еще большее количество партий. Используя 1 000 000 пакетов (всего 4 000 000 тестов), сеть генерирует колоссальные 16 900 правильных результатов.
Может ли кто-нибудь сказать мне, что происходит?
Это были используемые веса:
Первый слой:
- 2.038829298171684 2.816232761170282 1.6740269469812146 1.634422766238497
- 1.5890997594993828 1.7909325329112222 2.101840236824494 1.063579126586681
- 3.761238407071311 3.757148454039234 3.7557450538398176 3.6715972104291605
Второй слой:
- -0,019603811941904248
- +218,38253323323553
- +53,70133275445734
-272,83589796861514
РЕДАКТИРОВАТЬ: Спасибо lsnare за указание на меня, использование библиотеки было бы намного проще!
Для тех, кто заинтересован, вот рабочий код, использующий библиотеку math.nist.gov/javanumerics:
private void trainNet()
{
// INPUT is a 4*3 matrix
// SYNAPSES is a 3*4 matrix
// SYNAPSES2 is a 4*1 matrix
// 4*3 matrix DOT 3*4 matrix => 4*4 matrix: unrefined test results
Matrix hiddenLayer = sigmoid(inputs.times(synapses), false);
// 4*4 matrix DOT 4*1 matrix => 4*1 matrix: 4 final test results
Matrix outputLayer = sigmoid(hiddenLayer.times(synapses2), false);
// 4*1 matrix - 4*1 matrix => 4*1 matrix: error of 4 test results
Matrix outputLayerError = outputs.minus(outputLayer);
// 4*1 matrix DOT 4*1 matrix => 4*1 matrix: percentage of change of 4 test results
Matrix outputLayerDelta = outputLayerError.arrayTimes(sigmoid(outputLayer, true));
// 4*1 matrix DOT 1*4 matrix => 4*4 matrix
Matrix hiddenLayerError = outputLayerDelta.times(synapses2.transpose());
// 4*4 matrix DOT 4*4 matrix => 4*4 matrix: percentage of change of 4 test results
Matrix hiddenLayerDelta = hiddenLayerError.arrayTimes(sigmoid(hiddenLayer, true));
// 4*4 matrix DOT 4*1 matrix => 4*1 matrix: the updated weights
// Update the weights
synapses2 = synapses2.plus(hiddenLayer.transpose().times(outputLayerDelta));
// 3*4 matrix DOT 4*4 matrix => 3*4 matrix: the updated weights
// Update the weights
synapses = synapses.plus(inputs.transpose().times(hiddenLayerDelta));
// Test each value of two 4*1 matrices with each other
testValue(outputLayer.getArrayCopy(), outputs.getArrayCopy());
}
1 ответ
В общем, при написании кода, который включает в себя сложные математические или численные вычисления (например, линейную алгебру), лучше использовать существующие библиотеки, написанные экспертами в данной области, а не писать свои собственные функции. Стандартные библиотеки будут давать более точные результаты и, скорее всего, будут более эффективными. Например, в блоге, на который вы ссылаетесь, автор использует цифровую библиотеку для вычисления точечных произведений и транспонирования матриц. Для Java вы можете использовать Java Matrix Package (JAMA), разработанный NIST: http://math.nist.gov/javanumerics/jama/
Например, чтобы транспонировать матрицу:
double[4][3] in = {{0,0,1},{0,1,1},{1,0,1},{1,1,1}};
Matrix input = new Matrix(in);
input = input.transpose();
Я не уверен, что это решит вашу проблему полностью, но, надеюсь, это поможет вам сэкономить на написании дополнительного кода в будущем.