Многослойная сигмоидальная нейронная сеть Java, сходящаяся на 0,5
Я смотрел серию о нейронных сетях 3blue1brown на YouTube и пытался закодировать такую нейронную сеть. Кажется, что, насколько я понял, сеть выполняет математику так, как хотела, но, тем не менее, не учится вообще. Я пытался научить отношения XOR, но они сходятся на 0,5. Я видел другую ветку о, я думаю, той же проблеме, но поскольку я не могу прочитать синтаксис Python, это не помогло мне.
MultiLayerSigmoid.java:
private Neuron[][] neurons;
private ReadWriteLock lock = new ReentrantReadWriteLock();
public MultiLayerSigmoid(int... neurons) {
this.neurons = new Neuron[neurons.length][];
for (int i = 0; i < neurons.length; i++) {
this.neurons[i] = new Neuron[neurons[i]];
for (int o = 0; o < this.neurons[i].length; o++) {
this.neurons[i][o] = i == 0 ? new Neuron() : new Neuron(neurons[i - 1]);
}
}
}
@Override
public void learn(double[][] inputs, double[][] outputs) {
Preconditions.checkArgument(inputs.length == outputs.length, "Invalid training sample sizes");
lock.writeLock().lock();
// Iterate samples
Training[][] trainings = new Training[inputs.length][];
for (int i = 0; i < inputs.length; i++) {
double[] input = inputs[i];
double[] output = outputs[i];
double[] actual = compute(input);
// Calculate output layer
double[] activations = new double[output.length];
for (int o = 0; o < activations.length; o++) {
activations[o] = 2 * (actual[o] - output[o]);
}
System.out.println("Error derivatives: " + Arrays.toString(activations));
// Backpropagation (without input layer)
trainings[i] = new Training[this.neurons.length - 1];
for (int layer = this.neurons.length - 1; layer > 0; layer--) {
Training training = calculateLayer(layer, activations);
//System.out.println(i + " " + layer + " " + new Gson().toJson(training));
trainings[i][layer - 1] = training;
activations = training.prevActivationDerivatives;
}
}
Training[] summed = new Training[this.neurons.length - 1];
// Initialize
for (int i = 0; i < summed.length; i++) {
summed[i] = new Training(new double[this.neurons[i + 1].length][this.neurons[i].length], new double[this.neurons[i + 1].length], new double[this.neurons[i].length]);
}
for (int i = 0; i < trainings.length; i++) {
Training[] sample = trainings[i];
for (int o = 0; o < sample.length; o++) {
Training training = sample[o];
MathUtils.addUp(summed[o].weightDerivatives, training.weightDerivatives);
MathUtils.addUp(summed[o].biasDerivatives, training.biasDerivatives);
MathUtils.addUp(summed[o].prevActivationDerivatives, training.prevActivationDerivatives);
}
}
// Adjust weights and biases
double learningRate = 0.1;
for (int i = 1; i < neurons.length; i++) {
Neuron[] layer = neurons[i];
Training training = summed[i - 1];
for (int o = 0; o < layer.length; o++) {
Neuron neuron = layer[o];
for (int p = 0; p < neuron.weights.length; p++) {
neuron.weights[p] -= training.weightDerivatives[o][p] * learningRate;
}
neuron.bias -= training.biasDerivatives[o] * learningRate;
}
}
lock.writeLock().unlock();
}
private Training calculateLayer(int layer, double[] activationDerivatives) {
Neuron[] neurons = this.neurons[layer];
// dC0/dw(L) = dz(L)/dw(L) * da(L)/dz(L) * dC0/da(L)
// = a(L-1) * sigm'(z(L)) * 2(a(L) - y)
double[][] weightDerivatives = new double[neurons.length][];
double[][] previousActivationDerivativesRaw = new double[neurons.length][];
double[] biasDerivative = new double[neurons.length];
// Iterate over neurons for weights and bias
for (int i = 0; i < neurons.length; i++) {
System.out.println("Layer " + layer + ", neuron " + i);
// Activation derivative
double ca = activationDerivatives[i];
System.out.println("\tCa: " + ca);
// Activation over z
double az = MathUtils.derivativeSigmoid(neurons[i].z);
System.out.println("\tZ: " + neurons[i].z);
System.out.println("\tAz: " + az);
// z over bias
double zb = 1;
weightDerivatives[i] = new double[neurons[i].weights.length];
biasDerivative[i] = ca * az * zb;
System.out.println("\tBiasDerivative: " + biasDerivative[i]);
previousActivationDerivativesRaw[i] = new double[this.neurons[layer - 1].length];
// Iterate over connections
for (int o = 0; o < neurons[i].weights.length; o++) {
System.out.println("\tConnection to " + o);
// z over weight
double zw = this.neurons[layer - 1][o].activation;
System.out.println("\t\tZw: " + zw);
// z over previousActivation
double za = neurons[i].weights[o];
System.out.println("\t\tZa: " + za);
// FillArray
weightDerivatives[i][o] = ca * az * zw;
System.out.println("\t\tWeight: " + weightDerivatives[i][o]);
previousActivationDerivativesRaw[i][o] = ca * az * za;
System.out.println("\t\tPreviousActivationRaw: " + previousActivationDerivativesRaw[i][o]);
}
}
// Add up previousActivation
double[] previousActivationDerivative = new double[this.neurons[layer - 1].length];
for (int i = 0; i < previousActivationDerivative.length; i++) {
double derivative = 0;
for (int o = 0; o < previousActivationDerivativesRaw.length; o++) {
derivative += previousActivationDerivativesRaw[o][i];
}
previousActivationDerivative[i] = derivative;
System.out.println("PreviousActivation: " + i + ": " + derivative);
}
return new Training(weightDerivatives, biasDerivative, previousActivationDerivative);
}
@Override
public double[] compute(double[] input) {
Preconditions.checkArgument(input.length == neurons[0].length, "Input array lenght does not match network input layer size");
lock.readLock().lock();
double[] output = new double[input.length];
for (int i = 0; i < neurons.length; i++) {
double[] newOutput = new double[neurons[i].length];
for (int o = 0; o < neurons[i].length; o++) {
if (i == 0) {
newOutput[o] = neurons[i][o].compute(input[o]);
} else {
newOutput[o] = neurons[i][o].compute(output);
System.out.println("Neuron [" + i + "][" + o + "]: Z: " + neurons[i][o].z + "; A: " + neurons[i][o].activation);
}
}
output = newOutput;
//System.out.println(Arrays.toString(output));
}
lock.readLock().unlock();
return output;
}
private static class Neuron implements Serializable {
private boolean input;
private double[] weights;
private double bias;
private double z;
private double activation;
private Neuron() {
this.input = true;
}
private Neuron(int size) {
weights = new double[size];
for (int i = 0; i < weights.length; i++) {
//weights[i] = Math.random();
weights[i] = 0.5;
}
//bias = Math.random();
bias = 0.5;
}
public double compute(double input) {
Preconditions.checkState(this.input, "Neuron is not an input neuron");
this.activation = input;
return input;
}
public double compute(double[] activations) {
Preconditions.checkState(!input, "Neuron is an input neuron");
double activation = 0;
for (int i = 0; i < activations.length; i++) {
activation += weights[i] * activations[i];
}
activation += bias;
this.z = activation;
activation = MathUtils.sigmoid(activation);
this.activation = activation;
return activation;
}
}
@AllArgsConstructor(access = AccessLevel.PRIVATE)
private static class Training {
private double[][] weightDerivatives;
private double[] biasDerivatives;
private double[] prevActivationDerivatives;
}
MathUtils.java:
public static double sigmoid(double a) {
return 1 / (1 + Math.exp(-a));
}
public static double derivativeSigmoid(double a) {
double sigmoid = sigmoid(a);
return sigmoid * (1 - sigmoid);
}
public static void addUp(double[] base, double[] addition) {
Preconditions.checkArgument(base.length == addition.length, "Arrays not of the same lenght");
for (int i = 0; i < base.length; i++) {
base[i] += addition[i];
}
}
public static void addUp(double[][] base, double[][] addition) {
Preconditions.checkArgument(base.length == addition.length, "Arrays not of the same lenght");
for (int i = 0; i < base.length; i++) {
addUp(base[i], addition[i]);
}
}
Test.java:
public static void main(String[] args) {
double[][] inputs = new double[][] {
new double[] { 0, 0 },
new double[] { 0, 1 },
new double[] { 1, 0 },
new double[] { 1, 1 }
};
double[][] outputs = new double[][] {
new double[] { 0 },
new double[] { 1 },
new double[] { 1 },
new double[] { 0 }
};
MultiLayerSigmoid network = new MultiLayerSigmoid(2, 2, 1);
for (double[] input : inputs) {
System.out.println(Arrays.toString(network.compute(input)));
}
System.out.println();
System.out.println();
System.out.println();
for (int i = 0; i < 100; i++) {
network.learn(inputs, outputs);
}
System.out.println();
System.out.println();
System.out.println();
for (double[] input : inputs) {
System.out.println(Arrays.toString(network.compute(input)));
}
}