Quantcast
Channel: Active questions tagged feed-forward+neural-network+backpropagation - Stack Overflow
Viewing all articles
Browse latest Browse all 35

Neural Network Feed Forward Back Propagation based learning not working for AND table

$
0
0

My Neural network is able to learn if the sample is OR or XOR table. However, if i want it to learn AND table, it refuses.

NN Config:

1 input layer with 2 input neurons

1 hidden layer with 3 neurons

1 output layer with 1 result neuron.

in total 9 weights, 6 assigned between input and hidden and 3 assigned between hidden and output.

Using sigmoid as activation function.

With OR or XOR, the values tends to approach the right value like shown below:

OR

0 0 - 0.0232535024 // ~0

0 1 - 0.9882075648 // ~1

1 0 - 0.9881840412 // ~1

1 1 - 0.9932447649 // ~1

XOR

0 0 - 0.0419020172 // ~0

0 1 - 0.9742653893 // ~1

1 0 - 0.9742622526 // ~1

1 1 - 0.0096044003 // ~0

But when I try to train it for AND, for the first 3 rows it tends to approach towards 0 but the last row (1,1) it tends to approach .5 and not beyond that.

AND

0 0 - 0.0007202012 // ~0

0 1 - 0.0151875796 // ~0

1 0 - 0.0128653577 // ~0

1 1 - 0.4987960208 // NOT EVENT CLOSE, tends to approach .5, half of 1

If code is required, please let me know i shall post it. I am wondering if my approach is right.

Can i apply the same activation function for all cases? Why is it approaching to 1's half .5? Is there anything wrong in my understanding this conceptually.

I have followed https://stevenmiller888.github.io/mind-how-to-build-a-neural-network/ and few others in understanding NN and how it could be implemented. Using Java

Below is my class:

package com.example;


import org.apache.commons.math3.linear.MatrixUtils;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;

public class MyNeuralNet {

float learningRate = .7f;
int iterations = 10000;
int hiddenUnits = 3;

double[][] input = new double[4][2];//asdf
double[][] knownOp = new double[4][1];//asdf
double[][] errorOutputLayer = new double[4][1];
double[][] inputHiddenArray = new double[2][3];
double[][] hiddenOutputArray = new double[3][1];

RealMatrix inputHidden;
RealMatrix hiddenOutput;
RealMatrix outputSum;
RealMatrix outPutResult;
RealMatrix hiddenSum;
RealMatrix hidderResult;

public static void main(String[] args){
    System.out.println("Hello! NeuralNet");
    MyNeuralNet mind = new MyNeuralNet();

    init(mind);

    for(int i = 0; i < mind.iterations; i++) {

        //run forward
        RealMatrix afterFwd = mind.forward();

        //calculate the difference between expected and obtained
        loggResult("afterFwd:: "+afterFwd.toString());
        RealMatrix m = MatrixUtils.createRealMatrix(mind.knownOp);
        RealMatrix n = afterFwd;//MatrixUtils.createRealMatrix(hiddenWeightsOne);
        // Now subtract m by n
        RealMatrix errorOutputLayer = m.subtract(n);
        RealVector errorOutputLayerVector = errorOutputLayer.getColumnVector(0);
        loggCost("errorOutputLayer: " + errorOutputLayer.toString());
        logg(errorOutputLayer.getRowDimension() + " * " + errorOutputLayer.getColumnDimension());

        //back propagate
        mind.backPropagate(errorOutputLayer);
    }


}

private static void init(MyNeuralNet mind) {
    mind.input[0] = new double[]{0,0};
    mind.input[1] = new double[]{0,1};
    mind.input[2] = new double[]{1,0};
    mind.input[3] = new double[]{1,1};

    mind.knownOp[0] = new double[]{0};
    mind.knownOp[1] = new double[]{0};
    mind.knownOp[2] = new double[]{0};
    mind.knownOp[3] = new double[]{1};

    mind.inputHiddenArray[0] = new double[]{.8f,.4f,.3f};
    mind.inputHiddenArray[1] = new double[]{.2f,.9f,.5f};

    mind.inputHidden = MatrixUtils.createRealMatrix(mind.inputHiddenArray);

    mind.hiddenOutputArray[0] = new double[]{.3f};
    mind.hiddenOutputArray[1] = new double[]{.5f};
    mind.hiddenOutputArray[2] = new double[]{.9f};

    mind.hiddenOutput = MatrixUtils.createRealMatrix(mind.hiddenOutputArray);
}

private RealMatrix forward() {

    RealMatrix m = MatrixUtils.createRealMatrix(input);
    // Now multiply m by n
    hiddenSum = m.multiply(inputHidden);
    logg("hiddenSum: "+hiddenSum);
    logg(hiddenSum.getRowDimension() + " * " + hiddenSum.getColumnDimension());

    hidderResult = activate(hiddenSum);//MatrixUtils.createRealMatrix(activatedArray);
    logg("hidderResult: "+ hidderResult);
    logg(hidderResult.getRowDimension() + " * " + hidderResult.getColumnDimension());

    // Now multiply m by n
    outputSum = hidderResult.multiply(hiddenOutput);
    logg("outputSum: "+ outputSum);
    logg(outputSum.getRowDimension() + " * " + outputSum.getColumnDimension());

    outPutResult = activate(outputSum);//MatrixUtils.createRealMatrix(activatedArrayOut);
    logg("outPutResult:: "+outPutResult.toString());
    logg(outPutResult.getRowDimension() + " * " + outPutResult.getColumnDimension());

    return outPutResult;


}
private void backPropagate(RealMatrix errorOutputLayerVector) {

    RealMatrix sigmaPrimeMatrix = activatePrime(outputSum);
    RealVector sigmaPrimeVector = sigmaPrimeMatrix.getColumnVector(0);
    com.example.MatrixUtils.getElementWiseProduct(sigmaPrimeMatrix,errorOutputLayerVector);
    double deltaOutputLayer = sigmaPrimeVector.dotProduct(errorOutputLayerVector.getColumnVector(0));
    RealMatrix deltaOutputLayerMatrix = com.example.MatrixUtils.getElementWiseProduct(sigmaPrimeMatrix,errorOutputLayerVector);
    logg("deltaOutputLayer1: " + deltaOutputLayerMatrix);
    logg(deltaOutputLayerMatrix.getRowDimension() + " * " + deltaOutputLayerMatrix.getColumnDimension());

    RealMatrix hiddenOutputChanges = hidderResult.transpose().multiply(deltaOutputLayerMatrix).scalarMultiply((learningRate));
    logg("hiddenOutputChanges: " + hiddenOutputChanges);
    logg(hiddenOutputChanges.getRowDimension() + " * " + hiddenOutputChanges.getColumnDimension());

    RealMatrix sigmaPrime2Matrix = activatePrime(hiddenSum);
    RealMatrix p2 = deltaOutputLayerMatrix.multiply(hiddenOutput.transpose());
    RealMatrix deltaHiddenLayer = com.example.MatrixUtils.getElementWiseProduct(sigmaPrime2Matrix,p2);
    logg("deltaHiddenLayer: " + deltaHiddenLayer);
    logg(deltaHiddenLayer.getRowDimension() + " * " + deltaHiddenLayer.getColumnDimension());

    RealMatrix inputMatrix = MatrixUtils.createRealMatrix(input);
    RealMatrix inputHiddenChanges = inputMatrix.transpose().multiply(deltaHiddenLayer).scalarMultiply(learningRate);
    logg("inputHiddenChanges: " + inputHiddenChanges.toString());
    logg(inputHiddenChanges.getRowDimension() + " * " + inputHiddenChanges.getColumnDimension());

}




private RealMatrix activatePrime(RealMatrix hiddenSum) {
    double activatedPrimeArray[][] = new double[hiddenSum.getRowDimension()][hiddenSum.getColumnDimension()];
    for (int i = 0; i < hiddenSum.getRowDimension(); i++){
        for (int j = 0; j < hiddenSum.getColumnDimension(); j++){
            activatedPrimeArray[i][j] = sigmoidPrime(hiddenSum.getEntry(i,j));
        }
    }

    RealMatrix sigmaPrimeMatrix = MatrixUtils.createRealMatrix(activatedPrimeArray);
    return sigmaPrimeMatrix;
}

private RealMatrix activate(RealMatrix hiddenSum) {
    double activatedPrimeArray[][] = new double[hiddenSum.getRowDimension()][hiddenSum.getColumnDimension()];
    for (int i = 0; i < hiddenSum.getRowDimension(); i++){
        for (int j = 0; j < hiddenSum.getColumnDimension(); j++){
            activatedPrimeArray[i][j] = sigmoid(hiddenSum.getEntry(i,j));
        }
    }

    RealMatrix sigmaPrimeMatrix = MatrixUtils.createRealMatrix(activatedPrimeArray);
    return sigmaPrimeMatrix;
}

private static void loggCost(String str){
    System.out.println("COST:::::::::::::::::::::::" +str);    // 2
}

private static void loggWeights(String str){
    System.out.println("Weights:::::::::::::::::::::::" +str);
}

private static void loggResult(String s) {
    System.out.println(s);
}

public static double sigmoid(double x) {
    return (1/( 1 + Math.pow(Math.E,(-1*x))));
}

public static double sigmoidPrime(double x) {
    return sigmoid(x) * (1 - sigmoid(x));
}

private static void logg(String str){
    System.out.println(str);
}


}

Here is an update, if i randomly change out of 9 weights, it seems to be working fine

    //from .9 to .8
    mind.hiddenOutputArray[2] = new double[]{.8f};

Viewing all articles
Browse latest Browse all 35

Latest Images

Trending Articles



Latest Images