https://www.cnblogs.com/hesi/p/7218602.html
public class NeuralNetwork {
private Logger logger = LoggerFactory.getLogger(NeuralNetwork.class);
private int[] layerShape;//{2, 2, 1} 三层结构
private double[] weights;//6个权重
private double[] outputs;//
private int weightNum;//权重个数
public NeuralNetwork(double[] weights,int[] layerShape) {
this.layerShape = layerShape;
this.weights = weights;
logger.info("输入:{}",layerShape[0]);
logger.info("输出:{}",layerShape[layerShape.length-1]);
for (int i = 1; i < layerShape.length; i++) {
weightNum +=layerShape[i]*layerShape[i-1];
weightNum +=1;
}
logger.info("权重数:{}",weightNum);
if(weightNum!=weights.length){
logger.error("权重数不匹配!!!");
}
}
//三层网络
public void run(double[] inputs){
//第二层
//每个输出 = 前一层输入*权重
outputs = new double[layerShape[1]];
for (int i = 0; i < layerShape[1] ; i++) {
for (int j = 0; j < layerShape[0]; j++) {
outputs[i] +=inputs[j]*weights[i*layerShape[0]+j];
}
double b = weights[layerShape[1]*layerShape[0]];
logger.info("b:{}",b);
outputs[i] = sigmod(outputs[i]+b);
}
logger.info("outputs:{}",outputs);
//第三层
inputs = new double[layerShape[1]];
System.arraycopy(outputs, 0, inputs, 0, outputs.length);
outputs = new double[layerShape[2]];
for (int i = 0; i < layerShape[2]; i++) {
for (int j = 0; j < layerShape[1]; j++) {
outputs[i] +=inputs[j]*weights[i*layerShape[1]+j + layerShape[1]*layerShape[0] + 1];
}
double b = weights[layerShape[2]*layerShape[1] + layerShape[1]*layerShape[0] + 1];
logger.info("b:{}",b);
outputs[i] = sigmod(outputs[i]+b);
}
logger.info("outputs:{}",outputs);
for (int l = 1; l < layerShape.length; l++) {
outputs = new double[layerShape[l]];
for (int i = 0; i < layerShape[l]; i++) {
for (int j = 0; j < layerShape[l-1]; j++) {
outputs[i] +=inputs[j]*weights[i*layerShape[l-1]+j];
}
double b = weights[layerShape[l]*layerShape[l-1]];
logger.info("b:{}",b);
outputs[i] = sigmod(outputs[i]+b);
}
logger.info("outputs:{}",outputs);
}
}
public void update(double[] weights){
this.weights = weights;
}
private double sum(double[] inputs, double[] weights) {
double sum = 1;
for (int i = 0; i < inputs.length; i++) {
sum += inputs[i] * weights[i];
}
return sum;
}
private double sigmod(double x){
return 1/(1+Math.exp(-x));
}
@Override
public String toString() {
return "NeuralNetwork{" +
"weights=" + Arrays.toString(weights) +
", outputs=" + Arrays.toString(outputs) +
'}';
}
}