hzxdark 发表于 2013-1-13 19:08:19

BP网络JAVA版源代码

 
神经网络课程的作业,一个简单的BP网络。准确率有点低,可能是我算法有点问题,100个训练数据,测试50个数据,只得80%正确。
 
 
BP类  //封装bp算法
package arithmetic;
public class BP {
 private double[] P;
 private double[] T;
 private double[][] W1;
 private double[][] W2;
 private int n_a0;
 private int n_a1;
 private int n_a2;
 private double[] B1;
 private double[] B2;
 private double[] a1;
 private double[] a2;
 private double[] q;
 private double[] db1;
 private double[] db2;
 private double[][] dw1;
 private double[][] dw2;
 private double e;
 private double r;
 private double e0;
 public BP(double[][] W1, double[][] W2, double[] B1, double[] B2) {
  this.W1 = W1;
  this.W2 = W2;
  this.B1 = B1;
  this.B2 = B2;
  n_a0 = W1.length;
  n_a1 = W1.length;
  n_a2 = W2.length;
  init();
 }
 public void setP(double[] P) {
  this.P = P;
 }
 public void setT(double[] T) {
  this.T = T;
 }
 private void init() {
  a1 = new double;
  a2 = new double;
  r = 0.4;
  e0 = 0.02;
  q = new double;
  db2 = new double;
  dw2 = new double;
  db1 = new double;
  dw1 = new double;
 }
 public void calA1() {
  double temp = 0;
  for (int i = 0; i < n_a1; i++) {
   for (int j = 0; j < n_a0; j++) {
    temp += W1 * P;
   }
   temp += B1;
   a1 = F.f1(temp);
  }
 }
 public double[] getA1() {
  return a1;
 }
 public double[] getA2() {
  return a2;
 }
 public void calA2() {
  double temp = 0;
  for (int k = 0; k < n_a2; k++) {
   for (int i = 0; i < n_a1; i++) {
    temp += W2 * a1;
   }
   temp += B2;
   a2 = F.f2(temp);
  }
 }
 public void calE() {
  e = 0;
  for (int k = 0; k < n_a2; k++) {
   double ek = T - a2;
   e += ek * ek;
   e /= 2;
  }
 }
 public void calDb2() {
  for (int k = 0; k < n_a2; k++) {
   q = (T - a2) * F.f2_1(a2);
   db2 = q * r;
  }
 }
 public void calDw2() {
  for (int k = 0; k < n_a2; k++) {
   for (int i = 0; i < n_a1; i++) {
    dw2 = db2 * a1;
   }
  }
 }
 public void calDb1() {
  for (int i = 0; i < n_a1; i++) {
   db1 = 0;
   for (int k = 0; k < n_a2; k++) {
    db1 += q * W2;
   }
   db1 *= r * F.f1_1(a1);
  }
 }
 public void calDw1() {
  for (int i = 0; i < n_a1; i++) {
   for (int j = 0; j < n_a0; j++) {
    dw1 = db1 * P;
   }
  }
 }
 public void changeDb2() {
  for (int i = 0; i < n_a2; i++) {
   B2 += db2;
  }
 }
 public void changeDw2() {
  for (int i = 0; i < n_a2; i++) {
   for (int j = 0; j < n_a1; j++) {
    W2 += dw2;
   }
  }
 }
 public void changeDb1() {
  for (int i = 0; i < n_a1; i++) {
   B1 += db1;
  }
 }
 public void changeDw1() {
  for (int i = 0; i < n_a1; i++) {
   for (int j = 0; j < n_a0; j++) {
    W1 += dw1;
   }
  }
 }
 public void train(double[][] P, double[][] T) {
  while(true){
   boolean isChange = false;
   for (int n = 0; n < P.length; n++) {
    setP(P);
    setT(T);
    this.calA1();
    this.calA2();
    this.calE();
    if (e < e0)
     continue;
    this.calDb2();
    this.calDw2();
    this.calDb1();
    this.calDw1();
    this.changeDb2();
    this.changeDw2();
    this.changeDb1();
    this.changeDw1();
    isChange = true;
   // break;
   }
   if (!isChange) {
    System.out.println("train succeed");
    break;
   }
  }
 }
 public double[] divide(double[] p, double[] t) {
  setP(p);
  setT(t);
  this.calA1();
  this.calA2();
  return a2;
 }
 
 public double getE(){
  return e;
 }
}
F类 //封装神经元函数
package arithmetic;
public class F {
 public static double f1(double x){
  return 1/(1+Math.exp(-1*x));
 }
 public static double f1_1(double y){
  return y*(1-y);
}
 public static double f2(double x){
  return x;
 }
 public static double f2_1(double y){
  return 1;
 }
}

Controller类 读入训练数据和测试数据,并创建BP实例进行训练测试
package arithmetic;
import java.io.*;
import java.util.ArrayList;
import javax.swing.JFrame;
public class Controler {
 private double[][] p_test;
 private double[][] t_test;
 private double[][] p_train;
 private double[][] t_train; 
 
 private BP bp;
 private JFrame viwer;
 
 
 public Controler() throws IOException{
  getTestData();
  getTrainData();
  double[][] w1 = new double[][] { { 0.2, 0.3, 0.4, 0.1 },
    { 0.3, 0.4, 0.2, 0.4 }, { 0.4, 0.8, 0.9, 0.3 }};
  double[][] w2 = new double[][] { { 0.3, 0.6, 0.7 }, { 0.1, 0.3, 0.7 } };
  double[] b1 = new double[] { 0.2, 0.4, 0.5 };
  double[] b2 = new double[] { 0.1, 0.5 };
  bp = new BP(w1,w2,b1,b2);
  bp.train(p_train, t_train);
  int a = 0;
  for(int i =0;i<p_test.length;i++){
   double[] a2 = bp.divide(p_test, t_test);  
   int t0 = (int)(t_test)*2+(int)(t_test);
   int t1 = (int)(a2)*2+(int)(a2);
   boolean equals = t1==t0;
   if(equals)a++;
   System.out.println("expected:"+t0+"\t"+"output:"+t1+"\t"+equals);
  }
  a = (int)(a/50.0*100);
  System.out.println(a);
 }
 
 private void getTestData() throws IOException{
  String fileName = "testData.txt";
  BufferedReader br = null;
  try {
   br = new BufferedReader(new FileReader(fileName));
  } catch (FileNotFoundException e) {
   br.close();
   e.printStackTrace();
  }
  ArrayList al = new ArrayList();
  String s = null;
  while((s=br.readLine())!=null){
   al.add(s);
  }
  br.close();
  p_test = new double;
  t_test = new double;
  double[] maxData = new double[]{0,0,0,0};
  for(int i =0;i<al.size();i++){
   String[] temp = al.get(i).toString().split(" ");;
   for(int j =0;j<4;j++){
    p_test = Double.parseDouble(temp);
    if(p_test>maxData)maxData = p_test;
   }
   int d = Integer.parseInt(temp);
   switch (d){
   case 0:
    t_test = 0;
    t_test = 0;
    break;
   case 1:
    t_test = 0;
    t_test = 1;
    break;
   case 2:
    t_test = 1;
    t_test = 0;
    break;
   default:
    t_test = 1;
    t_test = 1;
    break;
   }
  }
  for(int i =0;i<p_test.length;i++){
   for(int j =0;j<4;j++){
    p_test /= maxData;
   }
  }
 }
 private void getTrainData() throws IOException{
  String fileName = "trainData.txt";
  BufferedReader br = null;
  try {
   br = new BufferedReader(new FileReader(fileName));
  } catch (FileNotFoundException e) {
   br.close();
   e.printStackTrace();
  }
  ArrayList al = new ArrayList();
  String s = null;
  while((s=br.readLine())!=null){
   al.add(s);
  }
  p_train = new double;
  t_train = new double;
  double[] maxData = new double[]{0,0,0,0};
  for(int i =0;i<al.size();i++){
   String[] temp = al.get(i).toString().split(" ");;
   for(int j =0;j<temp.length-1;j++){
    p_train = Double.parseDouble(temp);
    if(p_train>maxData)maxData = p_train;
   }
   int d = Integer.parseInt(temp);
   switch (d){
   case 0:
    t_train = 0;
    t_train = 0;
    break;
   case 1:
    t_train = 0;
    t_train = 1;
    break;
   case 2:
    t_train = 1;
    t_train = 0;
    break;
   default:
    t_train = 1;
    t_train = 1;
    break;
   }
  }
  for(int i =0;i<p_train.length;i++){
   for(int j =0;j<4;j++){
    p_train /= maxData;
   }
  }
 }
 public static void main(String[] args) throws Exception{
  Controler ctr =new Controler();
  
 }
}

 
页: [1]
查看完整版本: BP网络JAVA版源代码