Многослойная сигмоидальная нейронная сеть Java, сходящаяся на 0,5

Я смотрел серию о нейронных сетях 3blue1brown на YouTube и пытался закодировать такую ​​нейронную сеть. Кажется, что, насколько я понял, сеть выполняет математику так, как хотела, но, тем не менее, не учится вообще. Я пытался научить отношения XOR, но они сходятся на 0,5. Я видел другую ветку о, я думаю, той же проблеме, но поскольку я не могу прочитать синтаксис Python, это не помогло мне.

MultiLayerSigmoid.java:

private Neuron[][] neurons;

private ReadWriteLock lock = new ReentrantReadWriteLock();

public MultiLayerSigmoid(int... neurons) {
    this.neurons = new Neuron[neurons.length][];
    for (int i = 0; i < neurons.length; i++) {
        this.neurons[i] = new Neuron[neurons[i]];
        for (int o = 0; o < this.neurons[i].length; o++) {
            this.neurons[i][o] = i == 0 ? new Neuron() : new Neuron(neurons[i - 1]);
        }
    }
}

@Override
public void learn(double[][] inputs, double[][] outputs) {
    Preconditions.checkArgument(inputs.length == outputs.length, "Invalid training sample sizes");

    lock.writeLock().lock();

    // Iterate samples

    Training[][] trainings = new Training[inputs.length][];

    for (int i = 0; i < inputs.length; i++) {

        double[] input = inputs[i];
        double[] output = outputs[i];

        double[] actual = compute(input);

        // Calculate output layer
        double[] activations = new double[output.length];

        for (int o = 0; o < activations.length; o++) {
            activations[o] = 2 * (actual[o] - output[o]);
        }

        System.out.println("Error derivatives: " + Arrays.toString(activations));

        // Backpropagation (without input layer)

        trainings[i] = new Training[this.neurons.length - 1];
        for (int layer = this.neurons.length - 1; layer > 0; layer--) {
            Training training = calculateLayer(layer, activations);
            //System.out.println(i + " " + layer + " " + new Gson().toJson(training));
            trainings[i][layer - 1] = training;
            activations = training.prevActivationDerivatives;
        }

    }

    Training[] summed = new Training[this.neurons.length - 1];
    // Initialize
    for (int i = 0; i < summed.length; i++) {
        summed[i] = new Training(new double[this.neurons[i + 1].length][this.neurons[i].length], new double[this.neurons[i + 1].length], new double[this.neurons[i].length]);
    }

    for (int i = 0; i < trainings.length; i++) {
        Training[] sample = trainings[i];
        for (int o = 0; o < sample.length; o++) {
            Training training = sample[o];

            MathUtils.addUp(summed[o].weightDerivatives, training.weightDerivatives);
            MathUtils.addUp(summed[o].biasDerivatives, training.biasDerivatives);
            MathUtils.addUp(summed[o].prevActivationDerivatives, training.prevActivationDerivatives);

        }
    }

    // Adjust weights and biases

    double learningRate = 0.1;

    for (int i = 1; i < neurons.length; i++) {
        Neuron[] layer = neurons[i];
        Training training = summed[i - 1];

        for (int o = 0; o < layer.length; o++) {
            Neuron neuron = layer[o];

            for (int p = 0; p < neuron.weights.length; p++) {
                neuron.weights[p] -= training.weightDerivatives[o][p] * learningRate;
            }

            neuron.bias -= training.biasDerivatives[o] * learningRate;

        }
    }

    lock.writeLock().unlock();
}

private Training calculateLayer(int layer, double[] activationDerivatives) {
    Neuron[] neurons = this.neurons[layer];

    // dC0/dw(L)    = dz(L)/dw(L) * da(L)/dz(L) * dC0/da(L)
    //              = a(L-1) * sigm'(z(L)) * 2(a(L) - y)

    double[][] weightDerivatives = new double[neurons.length][];
    double[][] previousActivationDerivativesRaw = new double[neurons.length][];
    double[] biasDerivative = new double[neurons.length];

    // Iterate over neurons for weights and bias
    for (int i = 0; i < neurons.length; i++) {

        System.out.println("Layer " + layer + ", neuron " + i);

        // Activation derivative
        double ca = activationDerivatives[i];
        System.out.println("\tCa: " + ca);
        // Activation over z
        double az = MathUtils.derivativeSigmoid(neurons[i].z);
        System.out.println("\tZ: " + neurons[i].z);
        System.out.println("\tAz: " + az);

        // z over bias
        double zb = 1;

        weightDerivatives[i] = new double[neurons[i].weights.length];
        biasDerivative[i] = ca * az * zb;
        System.out.println("\tBiasDerivative: " + biasDerivative[i]);
        previousActivationDerivativesRaw[i] = new double[this.neurons[layer - 1].length];

        // Iterate over connections
        for (int o = 0; o < neurons[i].weights.length; o++) {

            System.out.println("\tConnection to " + o);

            // z over weight
            double zw = this.neurons[layer - 1][o].activation;
            System.out.println("\t\tZw: " + zw);

            // z over previousActivation
            double za = neurons[i].weights[o];
            System.out.println("\t\tZa: " + za);

            // FillArray
            weightDerivatives[i][o] = ca * az * zw;
            System.out.println("\t\tWeight: " + weightDerivatives[i][o]);
            previousActivationDerivativesRaw[i][o] = ca * az * za;
            System.out.println("\t\tPreviousActivationRaw: " + previousActivationDerivativesRaw[i][o]);
        }
    }

    // Add up previousActivation
    double[] previousActivationDerivative = new double[this.neurons[layer - 1].length];
    for (int i = 0; i < previousActivationDerivative.length; i++) {
        double derivative = 0;
        for (int o = 0; o < previousActivationDerivativesRaw.length; o++) {
            derivative += previousActivationDerivativesRaw[o][i];
        }
        previousActivationDerivative[i] = derivative;
        System.out.println("PreviousActivation: " + i + ": " + derivative);
    }

    return new Training(weightDerivatives, biasDerivative, previousActivationDerivative);
}

@Override
public double[] compute(double[] input) {
    Preconditions.checkArgument(input.length == neurons[0].length, "Input array lenght does not match network input layer size");

    lock.readLock().lock();

    double[] output = new double[input.length];

    for (int i = 0; i < neurons.length; i++) {
        double[] newOutput = new double[neurons[i].length];
        for (int o = 0; o < neurons[i].length; o++) {
            if (i == 0) {
                newOutput[o] = neurons[i][o].compute(input[o]);
            } else {
                newOutput[o] = neurons[i][o].compute(output);
                System.out.println("Neuron [" + i + "][" + o + "]: Z: " + neurons[i][o].z + "; A: " + neurons[i][o].activation);
            }
        }
        output = newOutput;
        //System.out.println(Arrays.toString(output));
    }

    lock.readLock().unlock();

    return output;
}

private static class Neuron implements Serializable {

    private boolean input;

    private double[] weights;
    private double bias;

    private double z;
    private double activation;

    private Neuron() {
        this.input = true;
    }

    private Neuron(int size) {
        weights = new double[size];
        for (int i = 0; i < weights.length; i++) {
            //weights[i] = Math.random();
            weights[i] = 0.5;
        }
        //bias = Math.random();
        bias = 0.5;
    }

    public double compute(double input) {
        Preconditions.checkState(this.input, "Neuron is not an input neuron");
        this.activation = input;
        return input;
    }

    public double compute(double[] activations) {
        Preconditions.checkState(!input, "Neuron is an input neuron");

        double activation = 0;
        for (int i = 0; i < activations.length; i++) {
            activation += weights[i] * activations[i];
        }
        activation += bias;
        this.z = activation;
        activation = MathUtils.sigmoid(activation);
        this.activation = activation;
        return activation;
    }

}

@AllArgsConstructor(access = AccessLevel.PRIVATE)
private static class Training {

    private double[][] weightDerivatives;
    private double[] biasDerivatives;
    private double[] prevActivationDerivatives;

}

MathUtils.java:

public static double sigmoid(double a) {
    return 1 / (1 + Math.exp(-a));
}

public static double derivativeSigmoid(double a) {
    double sigmoid = sigmoid(a);
    return sigmoid * (1 - sigmoid);
}

public static void addUp(double[] base, double[] addition) {
    Preconditions.checkArgument(base.length == addition.length, "Arrays not of the same lenght");
    for (int i = 0; i < base.length; i++) {
        base[i] += addition[i];
    }
}

public static void addUp(double[][] base, double[][] addition) {
    Preconditions.checkArgument(base.length == addition.length, "Arrays not of the same lenght");
    for (int i = 0; i < base.length; i++) {
        addUp(base[i], addition[i]);
    }
}

Test.java:

public static void main(String[] args) {
    double[][] inputs = new double[][] {
            new double[] { 0, 0 },
            new double[] { 0, 1 },
            new double[] { 1, 0 },
            new double[] { 1, 1 }
    };
    double[][] outputs = new double[][] {
            new double[] { 0 },
            new double[] { 1 },
            new double[] { 1 },
            new double[] { 0 }
    };

    MultiLayerSigmoid network = new MultiLayerSigmoid(2, 2, 1);

    for (double[] input : inputs) {
        System.out.println(Arrays.toString(network.compute(input)));
    }

    System.out.println();
    System.out.println();
    System.out.println();

    for (int i = 0; i < 100; i++) {
        network.learn(inputs, outputs);
    }

    System.out.println();
    System.out.println();
    System.out.println();

    for (double[] input : inputs) {
        System.out.println(Arrays.toString(network.compute(input)));
    }
}

0 ответов

Другие вопросы по тегам