Как реализовать модифицируемую функцию активации в классе нейронов в Java?

Я изучаю концепцию нейронных сетей. Я решил попробовать сделать нейронный класс самостоятельно. Каков наилучший способ реализации различных функций активации в моем коде? Теперь он использует только двоичную шаговую функцию. Это моя первая попытка кодирования нейронных сетей, так что если у вас есть какие-либо предложения по поводу моего кода, или он совершенно тупой, пожалуйста, дайте мне знать.

Вот мой код:

public class Neuron {

// properties
    private ArrayList<Neuron> input;
    private ArrayList<Float> weight;
    private float pot, bias, sense, out;
    private boolean checked;

// methods
    public float fire(){
        pot = 0f;
        if (input != null) {
            for (Neuron n : input){
                if (!n.getChecked()){
                    pot += n.fire()*weight.get(input.indexOf(n));
                } else {
                        pot += n.getOut()*weight.get(input.indexOf(n));
                } // end of condition (checked)
            } // end of loop (for input)
        } // end of condition (input exists)
        checked = true;
        pot -= bias;
        pot += sense;
        out = actFunc(pot);
        return out;
    } // end of fire()

    // getting properties
    public float getPot(){return pot;}
    public boolean getChecked(){return checked;}
    public float getOut(){return out;}

    // setting properties
    public void stimulate(float f){sense = f;}
    public void setBias(float b){bias = b;}
    public void setChecked(boolean c){checked = c;}
    public void setOut(float o){out = o;}

    // connection
    public void connect(Neuron n, float w){
        input.add(n);
        weight.add(w);
        }
    public void deconnect(Neuron n){
        weight.remove(input.indexOf(n));
        input.remove(n);
    }

    // activation function
        private float actFunc(float x){
            if (x < 0) {
                return 0f;
            } else {
                return 1f;
            }
        }

// constructor
    public Neuron(Neuron[] ns, float[] ws, float b, float o){
        if (ns != null){
            input = new ArrayList<Neuron>();
            weight = new ArrayList<Float>();
            for (Neuron n : ns) input.add(n);
            for (int i = 0; i < ws.length; i++) weight.add(ws[i]);
        } else {
            input = null;
            weight = null;
        }
        bias = b;
        out = o;
    }

    public Neuron(Neuron[] ns){
        if (ns != null){
            input = new ArrayList<Neuron>();
            weight = new ArrayList<Float>();
            for (Neuron n : ns) input.add(n);
            for (int i = 0; i < input.size(); i++) weight.add((float)Math.random()*2f-1f);
        } else {
            input = null;
            weight = null;
        }
        bias = (float)Math.random();
        out = (float)Math.random();
    }

}

1 ответ

Решение

Сначала определите интерфейс любой функции активации:

public interface ActivationFunction {
    float get(float f);
}

Затем напишите несколько реализаций:

public class StepFunction implements ActivationFunction {
    @Override
    public float get() {return (x < 0) ? 0f : 1f;}
}

public class SigmoidFunction implements ActivationFunction {
    @Override
    public float get() {return StrictMath.tanh(h);}
}

Наконец, установите некоторую реализацию для вашего Neuron:

public class Neuron {
    private final ActivationFunction actFunc;
    // other fields...

    public Neuron(ActivationFunction actFunc) {
        this.actFunc = actFunc;
    }

    public float fire(){
        // ...
        out = actFunc.get(pot);
        return out;
    } 
}

следующим образом:

Neuron n = new Neuron(new SigmoidFunction());

Обратите внимание, что нейронные сети используют распространение сигнала через нейроны, где производятся веса. Вычисление веса зависит также от первой производной функции активации. Поэтому я бы продлил ActivationFunction по методу, который вернет первую производную в указанной точке x:

public interface ActivationFunction {
    float get(float f);
    float firstDerivative(float x);
}

Таким образом, реализации будут выглядеть так:

public class StepFunction implements ActivationFunction {
    @Override
    public float get(float x) {return (x < 0) ? 0f : 1f;}

    @Override
    public float firstDerivative(float x) {return 1;}
}

public class SigmoidFunction implements ActivationFunction {
    @Override
    public float get(float x) {return StrictMath.tanh(x);}

    // derivative_of tanh(x) = (4*e^(2x))/(e^(2x) + 1)^2 == 1-tanh(x)^2 
    @Override
    public float firstDerivative(float x) {return 1 - Math.pow(StrictMath.tanh(x), 2);}
}

Затем используйте actFunction.firstDerivative(x); в fire() метод, где вычисляется вес.

Другие вопросы по тегам