示例#1
0
void train(void) {
  int i, j, k, l;
  for (i = 0; i < sizeof trset / sizeof trset[0]; i++) {
    printf("\nTrainning Round %d\n", i);

    // feed an input into the network
    
    input = trset[i];
    feed_forward();
    printe();

    // back-propagate

    Neuron cpy[LAYER_NUM][NEURON_PER_LAYER];
    memcpy(cpy, hlayers, sizeof cpy);

    for (j = 0; j < LAYER_NUM; j++) {
      for (k = 0; k < NEURON_PER_LAYER; k++) {
        if (j == 0) {
          cpy[0][k].weights[0] = hlayers[0][k].weights[0] -
            derv(&hlayers[0][k].weights[0]) * RATE;
        }else {
          for (l = 0; l < NEURON_PER_LAYER; l++) {
            cpy[j][k].weights[l] = hlayers[j][k].weights[l] -
              derv(&hlayers[j][k].weights[l]) * RATE;
          }
        }
      }
    }

    Neuron outcpy = output;

    for (j = 0; j < NEURON_PER_LAYER; j++) {
      outcpy.weights[j] = output.weights[j] -
        derv(&output.weights[j]) * RATE;
    }
    
    memcpy(hlayers, cpy, sizeof hlayers); // update weights
    output = outcpy;

    // printw();
  }
}
示例#2
0
void findresult(float o[M][N],float t[M][N],float u[M][2],float v[2][N])
{
	float error = 0;
	int ctr = 0;
	int val = 0;

	FILE *fp;
	fp = fopen("error.dat","w");
	for (ctr = 0; ctr < 500; ctr++) {

		multiply(t,u,v);
		print(M,N,t);
		error = findError(o,t);
		printf("error is %f\n",error);
		fprintf(fp,"%d\t%f\n",val++,error);
		printf("======================================================================================\n");
	 
		s();
		derv(o,t,u,v,0,0,0);
		print(M,2,u);
		print(2,N,v);
		multiply(t,u,v);
		print(M,N,t);
		error = findError(o,t);
		printf("error is %f\n",error);
		fprintf(fp,"%d\t%f\n",val++,error);
		printf("======================================================================================\n");
		
		s();
		derv(o,t,u,v,1,0,0);
		print(M,2,u);
		print(2,N,v);
		multiply(t,u,v);
		print(M,N,t);
		error = findError(o,t);
		printf("error is %f\n",error);
		fprintf(fp,"%d\t%f\n",val++,error);
		printf("======================================================================================\n");
	
		s();
		derv(o,t,u,v,0,2,0);
		print(M,2,u);
		print(2,N,v);
		multiply(t,u,v);
		print(M,N,t);
		error = findError(o,t);
		fprintf(fp,"%d\t%f\n",val++,error);
		printf("error is %f\n",error);
		printf("======================================================================================\n");
		
		s();
		derv(o,t,u,v,0,2,1);
		print(M,2,u);
		print(2,N,v);
		multiply(t,u,v);
		print(M,N,t);
		error = findError(o,t);
		fprintf(fp,"%d\t%f\n",val++,error);
		printf("error is %f\n",error);
		printf("======================================================================================\n");
		
		s();
		derv(o,t,u,v,1,1,2);
		print(M,2,u);
		print(2,N,v);
		multiply(t,u,v);
		print(M,N,t);
		error = findError(o,t);
		fprintf(fp,"%d\t%f\n",val++,error);
		printf("error is %f\n",error);
		printf("======================================================================================\n");
	
		s();
		derv(o,t,u,v,1,0,2);
		print(M,2,u);
		print(2,N,v);
		multiply(t,u,v);
		print(M,N,t);
		error = findError(o,t);
		fprintf(fp,"%d\t%f\n",val++,error);
		printf("error is %f\n",error);
		printf("======================================================================================\n");
		
		s();
		derv(o,t,u,v,0,1,1);
		print(M,2,u);
		print(2,N,v);
		multiply(t,u,v);
		print(M,N,t);
		error = findError(o,t);
		fprintf(fp,"%d\t%f\n",val++,error);
		printf("error is %f\n",error);
		printf("======================================================================================\n");
	
		s();
		derv(o,t,u,v,1,0,1);
		print(M,2,u);
		print(2,N,v);
		multiply(t,u,v);
		print(M,N,t);
		error = findError(o,t);
		fprintf(fp,"%d\t%f\n",val++,error);
		printf("error is %f\n",error);
		printf("======================================================================================\n");
		//from here
		s();
		derv(o,t,u,v,0,4,1);
		print(M,2,u);
		print(2,N,v);
		multiply(t,u,v);
		print(M,N,t);
		error = findError(o,t);
		fprintf(fp,"%d\t%f\n",val++,error);
		printf("error is %f\n",error);
		printf("======================================================================================\n");
		
		s();
		derv(o,t,u,v,1,1,4);
		print(M,2,u);
		print(2,N,v);
		multiply(t,u,v);
		print(M,N,t);
		error = findError(o,t);
		fprintf(fp,"%d\t%f\n",val++,error);
		printf("error is %f\n",error);
		printf("======================================================================================\n");
	
		s();
		derv(o,t,u,v,0,0,1);
		print(M,2,u);
		print(2,N,v);
		multiply(t,u,v);
		print(M,N,t);
		error = findError(o,t);
		fprintf(fp,"%d\t%f\n",val++,error);
		printf("error is %f\n",error);
		printf("======================================================================================\n");
	
		s();
		derv(o,t,u,v,1,0,3);
		print(M,2,u);
		print(2,N,v);
		multiply(t,u,v);
		print(M,N,t);
		error = findError(o,t);
		fprintf(fp,"%d\t%f\n",val++,error);
		printf("error is %f\n",error);
		printf("======================================================================================\n");
	
		s();
		derv(o,t,u,v,0,1,1);
		print(M,2,u);
		print(2,N,v);
		multiply(t,u,v);
		print(M,N,t);
		error = findError(o,t);
		fprintf(fp,"%d\t%f\n",val++,error);
		printf("error is %f\n",error);
		printf("======================================================================================\n");
	
	
	
		s();
		derv(o,t,u,v,1,1,0);
		print(M,2,u);
		print(2,N,v);
		multiply(t,u,v);
		print(M,N,t);
		error = findError(o,t);
		fprintf(fp,"%d\t%f\n",val++,error);
		printf("error is %f\n",error);
		printf("======================================================================================\n");
	
	
	
		s();
		derv(o,t,u,v,0,1,0);
		print(M,2,u);
		print(2,N,v);
		multiply(t,u,v);
		print(M,N,t);
		error = findError(o,t);
		fprintf(fp,"%d\t%f\n",val++,error);
		printf("error is %f\n",error);
		printf("======================================================================================\n");
	
	
	
		s();
		derv(o,t,u,v,1,1,1);
		print(M,2,u);
		print(2,N,v);
		multiply(t,u,v);
		print(M,N,t);
		error = findError(o,t);
		fprintf(fp,"%d\t%f\n",val++,error);
		printf("error is %f\n",error);
		printf("======================================================================================\n");
	
		s();
		derv(o,t,u,v,0,3,0);
		print(M,2,u);
		print(2,N,v);
		multiply(t,u,v);
		print(M,N,t);
		error = findError(o,t);
		fprintf(fp,"%d\t%f\n",val++,error);
		printf("error is %f\n",error);
		printf("======================================================================================\n");
	
	
	
		s();
		derv(o,t,u,v,1,0,4);
		print(M,2,u);
		print(2,N,v);
		multiply(t,u,v);
		print(M,N,t);
		error = findError(o,t);
		fprintf(fp,"%d\t%f\n",val++,error);
		printf("error is %f\n",error);
		printf("======================================================================================\n");
	
	
		s();
		derv(o,t,u,v,0,4,0);
		print(M,2,u);
		print(2,N,v);
		multiply(t,u,v);
		print(M,N,t);
		error = findError(o,t);
		fprintf(fp,"%d\t%f\n",val++,error);
		printf("error is %f\n",error);
		printf("======================================================================================\n");
	
	
		s();
		derv(o,t,u,v,1,1,3);
		print(M,2,u);
		print(2,N,v);
		multiply(t,u,v);
		print(M,N,t);
		error = findError(o,t);
		fprintf(fp,"%d\t%f\n",val++,error);
		printf("error is %f\n",error);
		printf("======================================================================================\n");
	
	
		s();
		derv(o,t,u,v,0,3,1);
		print(M,2,u);
		print(2,N,v);
		multiply(t,u,v);
		print(M,N,t);
		error = findError(o,t);
		fprintf(fp,"%d\t%f\n",val++,error);
		printf("error is %f\n",error);
		printf("======================================================================================\n");
	}
	fclose(fp);
	return;
}
示例#3
0
void findresult(float o[M][N],float t[M][N],float u[M][2],float v[2][N])
{
	int choice = 0;
	int i;
	float error = 0;
	int ctr = 0;
	int val = 0;
	int val_x;
	int val_y;

	FILE *fp;
	fp = fopen("error.dat","w");
	int a[M][2];
	int b[2][N];
	int detVal[2]; // for finding the i, j indices
	memset(detVal,0,sizeof(detVal));
	memset(a,0,sizeof(a));
	memset(b,0,sizeof(b));

	multiply(t,u,v);
	print(M,N,t);
	error = findError(o,t);
	printf("error is %f\n",error);
	fprintf(fp,"%d\t%f\n",val++,error);
	printf("======================================================================================\n");

	for ( i = 0; i < 100; i++) {
		while (noA(a,M) || noB(b,N)){
			switch(choice) {
				case 0:
					if (notA(a,M,detVal)) {
						s();
						val_x = detVal[0];
						val_y = detVal[1];
						derv(o,t,u,v,choice,val_x,val_y);
						choice = 1;
				//		print(M,2,u);
				//		print(2,N,v);
						multiply(t,u,v);
				//		print(M,N,t);
						error = findError(o,t);
				//		printf("error is %f\n",error);
				//		fprintf(fp,"%d\t%f\n",val++,error);
						printf("======================================================================================\n");
					} else {
						memset(a,0,sizeof(a));
					}
				break;
				case 1:
					if (notB(b,N,detVal)) {
						s();
						val_x = detVal[0];
						val_y = detVal[1];
						derv(o,t,u,v,choice,val_x,val_y);
						choice = 0;
				//		print(M,2,u);
				//		print(2,N,v);
						multiply(t,u,v);
				//		print(M,N,t);
						error = findError(o,t);
				//		printf("error is %f\n",error);
				//		fprintf(fp,"%d\t%f\n",val++,error);
						printf("======================================================================================\n");
					} else {
						memset(b,0,sizeof(b));
					}
				break;
			}
		}

/*		multiply(t,u,v);
		print(M,N,t);
		error = findError(o,t);
		printf("error is %f\n",error);
		fprintf(fp,"%d\t%f\n",val++,error);
		printf("======================================================================================\n");
	 
		s();
		derv(o,t,u,v,0,0,0);
		print(M,2,u);
		print(2,N,v);
		multiply(t,u,v);
		print(M,N,t);
		error = findError(o,t);
		printf("error is %f\n",error);
		fprintf(fp,"%d\t%f\n",val++,error);
		printf("======================================================================================\n");
		
		s();
		derv(o,t,u,v,1,0,0);
		print(M,2,u);
		print(2,N,v);
		multiply(t,u,v);
		print(M,N,t);
		error = findError(o,t);
		printf("error is %f\n",error);
		fprintf(fp,"%d\t%f\n",val++,error);
		printf("======================================================================================\n");
	
		s();
		derv(o,t,u,v,0,2,0);
		print(M,2,u);
		print(2,N,v);
		multiply(t,u,v);
		print(M,N,t);
		error = findError(o,t);
		fprintf(fp,"%d\t%f\n",val++,error);
		printf("error is %f\n",error);
		printf("======================================================================================\n");
		
		s();
		derv(o,t,u,v,0,2,1);
		print(M,2,u);
		print(2,N,v);
		multiply(t,u,v);
		print(M,N,t);
		error = findError(o,t);
		fprintf(fp,"%d\t%f\n",val++,error);
		printf("error is %f\n",error);
		printf("======================================================================================\n");
		
		s();
		derv(o,t,u,v,1,1,2);
		print(M,2,u);
		print(2,N,v);
		multiply(t,u,v);
		print(M,N,t);
		error = findError(o,t);
		fprintf(fp,"%d\t%f\n",val++,error);
		printf("error is %f\n",error);
		printf("======================================================================================\n");
	
		s();
		derv(o,t,u,v,1,0,2);
		print(M,2,u);
		print(2,N,v);
		multiply(t,u,v);
		print(M,N,t);
		error = findError(o,t);
		fprintf(fp,"%d\t%f\n",val++,error);
		printf("error is %f\n",error);
		printf("======================================================================================\n");
		
		s();
		derv(o,t,u,v,0,1,1);
		print(M,2,u);
		print(2,N,v);
		multiply(t,u,v);
		print(M,N,t);
		error = findError(o,t);
		fprintf(fp,"%d\t%f\n",val++,error);
		printf("error is %f\n",error);
		printf("======================================================================================\n");
	
		s();
		derv(o,t,u,v,1,0,1);
		print(M,2,u);
		print(2,N,v);
		multiply(t,u,v);
		print(M,N,t);
		error = findError(o,t);
		fprintf(fp,"%d\t%f\n",val++,error);
		printf("error is %f\n",error);
		printf("======================================================================================\n");
		//from here
		s();
		derv(o,t,u,v,0,4,1);
		print(M,2,u);
		print(2,N,v);
		multiply(t,u,v);
		print(M,N,t);
		error = findError(o,t);
		fprintf(fp,"%d\t%f\n",val++,error);
		printf("error is %f\n",error);
		printf("======================================================================================\n");
		
		s();
		derv(o,t,u,v,1,1,4);
		print(M,2,u);
		print(2,N,v);
		multiply(t,u,v);
		print(M,N,t);
		error = findError(o,t);
		fprintf(fp,"%d\t%f\n",val++,error);
		printf("error is %f\n",error);
		printf("======================================================================================\n");
	
		s();
		derv(o,t,u,v,0,0,1);
		print(M,2,u);
		print(2,N,v);
		multiply(t,u,v);
		print(M,N,t);
		error = findError(o,t);
		fprintf(fp,"%d\t%f\n",val++,error);
		printf("error is %f\n",error);
		printf("======================================================================================\n");
	
		s();
		derv(o,t,u,v,1,0,3);
		print(M,2,u);
		print(2,N,v);
		multiply(t,u,v);
		print(M,N,t);
		error = findError(o,t);
		fprintf(fp,"%d\t%f\n",val++,error);
		printf("error is %f\n",error);
		printf("======================================================================================\n");
	
		s();
		derv(o,t,u,v,0,1,1);
		print(M,2,u);
		print(2,N,v);
		multiply(t,u,v);
		print(M,N,t);
		error = findError(o,t);
		fprintf(fp,"%d\t%f\n",val++,error);
		printf("error is %f\n",error);
		printf("======================================================================================\n");
	
	
	
		s();
		derv(o,t,u,v,1,1,0);
		print(M,2,u);
		print(2,N,v);
		multiply(t,u,v);
		print(M,N,t);
		error = findError(o,t);
		fprintf(fp,"%d\t%f\n",val++,error);
		printf("error is %f\n",error);
		printf("======================================================================================\n");
	
	
	
		s();
		derv(o,t,u,v,0,1,0);
		print(M,2,u);
		print(2,N,v);
		multiply(t,u,v);
		print(M,N,t);
		error = findError(o,t);
		fprintf(fp,"%d\t%f\n",val++,error);
		printf("error is %f\n",error);
		printf("======================================================================================\n");
	
	
	
		s();
		derv(o,t,u,v,1,1,1);
		print(M,2,u);
		print(2,N,v);
		multiply(t,u,v);
		print(M,N,t);
		error = findError(o,t);
		fprintf(fp,"%d\t%f\n",val++,error);
		printf("error is %f\n",error);
		printf("======================================================================================\n");
	
		s();
		derv(o,t,u,v,0,3,0);
		print(M,2,u);
		print(2,N,v);
		multiply(t,u,v);
		print(M,N,t);
		error = findError(o,t);
		fprintf(fp,"%d\t%f\n",val++,error);
		printf("error is %f\n",error);
		printf("======================================================================================\n");
	
	
	
		s();
		derv(o,t,u,v,1,0,4);
		print(M,2,u);
		print(2,N,v);
		multiply(t,u,v);
		print(M,N,t);
		error = findError(o,t);
		fprintf(fp,"%d\t%f\n",val++,error);
		printf("error is %f\n",error);
		printf("======================================================================================\n");
	
	
		s();
		derv(o,t,u,v,0,4,0);
		print(M,2,u);
		print(2,N,v);
		multiply(t,u,v);
		print(M,N,t);
		error = findError(o,t);
		fprintf(fp,"%d\t%f\n",val++,error);
		printf("error is %f\n",error);
		printf("======================================================================================\n");
	
	
		s();
		derv(o,t,u,v,1,1,3);
		print(M,2,u);
		print(2,N,v);
		multiply(t,u,v);
		print(M,N,t);
		error = findError(o,t);
		fprintf(fp,"%d\t%f\n",val++,error);
		printf("error is %f\n",error);
		printf("======================================================================================\n");
	
	
		s();
		derv(o,t,u,v,0,3,1);
		print(M,2,u);
		print(2,N,v);
		multiply(t,u,v);
		print(M,N,t);
		error = findError(o,t);
		fprintf(fp,"%d\t%f\n",val++,error);
		printf("error is %f\n",error);
		printf("======================================================================================\n");
		*/
	}
	fclose(fp);
	return;
}