Esempio n. 1
0
int ftest_data(void)
{
    //	sar_start_epoch=0;
    //  printf("\r\n\r\n--------------------------------------------------------------------------------");

    double val_2[10];
    fann_type *calc_out2;
    unsigned calc2;
    int curi=0;
    unsigned fails=0,success=0;
    double perc=0;
    double minv=9,maxv=-1;
    int i;
    int minat=0,maxat=0;

    test_mse=fann_test_data(ann,test_data);

    for (curi=0;curi<fann_length_train_data(train_data);curi++)
    {

        calc2=curi;//rand()%(fann_length_train_data(train_data)-1);
        //printf("\r\ntesting %u %u ",calc1,calc2);
        //fann_scale_input(ann, test_data->input[calc1]);
        //fann_scale_input(ann, train_data->input[calc2]);
        //	fann_scale_output(ann, test_data->input[calc1]);

        //fann_scale_input(ann, train_data->input[calc2]);
        calc_out2 = fann_run(ann, train_data->input[calc2]);
        //	fann_descale_output(ann,calc_out2);

        memcpy(&val_2,  calc_out2, sizeof(double)*3);





        minv=9;
        maxv=-1;
        for (i=0;i<train_data->num_output;i++)
        {
            if ((double)calc_out2[i]<minv)
            {
                minv=val_2[i];
                minat=i;
            }
            if ((double)calc_out2[i]>maxv)
            {
                maxv=val_2[i];
                maxat=i;
            }
        }

        int ok=0;
        ok=0;
        for (i=0;i<train_data->num_output;i++)
            if (train_data->output[calc2][i]==1&&maxat==i)
                ok=1;

        if (ok)success++;
        else
            fails++;

    }
    train_perc=((double)success/(double)fann_length_train_data(train_data))*100.0f;
    /*   printf(" fails %5u success %5u (%5.2f%%) ",
             fails,success,train_perc
            ); */

    fails=0;
    success=0;
    unsigned failed_classes[10];

    for (curi=0;curi<test_data->num_output;curi++)
        failed_classes[curi]=0;

    int nfunc=0;
    double train_thr_mse=0;


    nfunc=fann_get_activation_function(ann, 3, 0);
    int stpns;
    stpns=fann_get_activation_steepness(ann,1,0);
    //	printf("\r\n%f",diff_mse*0.1f);
    //fann_set_activation_steepness_layer(ann, 0.3f, 1);
    //fann_set_activation_function_layer(ann,FANN_THRESHOLD_SYMMETRIC,3);




    for (curi=0;curi<fann_length_train_data(test_data);curi++)
    {

        calc2=curi;//rand()%(fann_length_train_data(train_data)-1);
        //printf("\r\ntesting %u %u ",calc1,calc2);
        //fann_scale_input(ann, test_data->input[calc1]);
        //fann_scale_input(ann, train_data->input[calc2]);
        //	fann_scale_output(ann, test_data->input[calc1]);

        //fann_scale_input(ann, train_data->input[calc2]);
        calc_out2 = fann_run(ann, test_data->input[calc2]);
        //	fann_descale_output(ann,calc_out2);

        memcpy(&val_2,  calc_out2, sizeof(double)*3);





        minv=9;
        maxv=-1;
        for (i=0;i<test_data->num_output;i++)
        {
            if (val_2[i]<minv)
            {
                minv=val_2[i];
                minat=i;
            }
            if (val_2[i]>maxv)
            {
                maxv=val_2[i];
                maxat=i;
            }
        }

        int ok=0;
        ok=0;
        for (i=0;i<test_data->num_output;i++)
        {
            if (test_data->output[calc2][i]==1&&maxat==i)
                ok=1;
            else if (test_data->output[calc2][i]==1&&maxat!=i)
                failed_classes[i]++;
        }

        if (ok)success++;
        else
            fails++;

    }
    test_perc=((double)success/(double)fann_length_train_data(test_data))*100.0f;
    /*   printf(" fails %5u success %5u (%5.2f%%) [fails: ",
             fails,success,test_perc
            );
      for (curi=0;curi<test_data->num_output;curi++)
          printf("%4u ",failed_classes[curi]);
      printf("] "); */
    // fann_set_activation_function_hidden ( ann,  rand()*0.81);
    // printf("\r\n rpropfact dec/inc r %.5f %.5f lr %.5f mom %.5f",fann_get_rprop_decrease_factor(ann),fann_get_rprop_increase_factor(ann), fann_get_learning_rate ( ann),
    //       fann_get_learning_momentum(ann));

    //	rebuild_functions();

    fann_set_activation_function_layer(ann,nfunc,3);
    fann_set_activation_steepness_layer(ann,stpns, 1);
}
Esempio n. 2
0
int main (int argc, char * argv[]) {
  int i, epoch, k, num_bits_failing, num_correct;
  int max_epochs = 10000, exit_code = 0, batch_items = -1;
  int flag_cups = 0, flag_last = 0, flag_mse = 0, flag_verbose = 0,
    flag_bit_fail = 0, flag_ignore_limits = 0, flag_percent_correct = 0;
  int mse_reporting_period = 1, bit_fail_reporting_period = 1,
    percent_correct_reporting_period = 1;
  float bit_fail_limit = 0.05, mse_fail_limit = -1.0;
  double learning_rate = 0.7;
  char id[100] = "0";
  char * file_video_string = NULL;
  FILE * file_video = NULL;
  struct fann * ann = NULL;
  struct fann_train_data * data = NULL;
  fann_type * calc_out;
  enum fann_train_enum type_training = FANN_TRAIN_BATCH;

  char * file_nn = NULL, * file_train = NULL;
  int c;
  while (1) {
    static struct option long_options[] = {
      {"video-data",           required_argument, 0, 'b'},
      {"stat-cups",            no_argument,       0, 'c'},
      {"num-batch-items",      required_argument, 0, 'd'},
      {"max-epochs",           required_argument, 0, 'e'},
      {"bit-fail-limit",       required_argument, 0, 'f'},
      {"mse-fail-limit",       required_argument, 0, 'g'},
      {"help",                 no_argument,       0, 'h'},
      {"id",                   required_argument, 0, 'i'},
      {"stat-last",            no_argument,       0, 'l'},
      {"stat-mse",             optional_argument, 0, 'm'},
      {"nn-config",            required_argument, 0, 'n'},
      {"stat-bit-fail",        optional_argument, 0, 'o'},
      {"stat-percent-correct", optional_argument, 0, 'q'},
      {"learning-rate",        required_argument, 0, 'r'},
      {"train-file",           required_argument, 0, 't'},
      {"verbose",              no_argument,       0, 'v'},
      {"incremental",          optional_argument, 0, 'x'},
      {"ignore-limits",        no_argument,       0, 'z'}
    };
    int option_index = 0;
     c = getopt_long (argc, argv, "b:cd:e:f:g:hi:lm::n:o::q::r:t:vx::z",
                     long_options, &option_index);
    if (c == -1)
      break;
    switch (c) {
    case 'b': file_video_string = optarg; break;
    case 'c': flag_cups = 1; break;
    case 'd': batch_items = atoi(optarg); break;
    case 'e': max_epochs = atoi(optarg); break;
    case 'f': bit_fail_limit = atof(optarg); break;
    case 'g': mse_fail_limit = atof(optarg); break;
    case 'h': usage(); exit_code = 0; goto bail;
    case 'i': strcpy(id, optarg); break;
    case 'l': flag_last = 1; break;
    case 'm':
      if (optarg)
        mse_reporting_period = atoi(optarg);
      flag_mse = 1;
      break;
    case 'n': file_nn = optarg; break;
    case 'o':
      if (optarg)
        bit_fail_reporting_period = atoi(optarg);
      flag_bit_fail = 1;
      break;
    case 'q':
      if (optarg)
        percent_correct_reporting_period = atoi(optarg);
      flag_percent_correct = 1;
      break;
    case 'r': learning_rate = atof(optarg); break;
    case 't': file_train = optarg; break;
    case 'v': flag_verbose = 1; break;
    case 'x': type_training=(optarg)?atoi(optarg):FANN_TRAIN_INCREMENTAL; break;
    case 'z': flag_ignore_limits = 1; break;
    }
  };

  // Make sure there aren't any arguments left over
  if (optind != argc) {
    fprintf(stderr, "[ERROR] Bad argument\n\n");
    usage();
    exit_code = -1;
    goto bail;
  }

  // Make sure we have all required inputs
  if (file_nn == NULL || file_train == NULL) {
    fprintf(stderr, "[ERROR] Missing required input argument\n\n");
    usage();
    exit_code = -1;
    goto bail;
  }

  // The training type needs to make sense
  if (type_training > FANN_TRAIN_SARPROP) {
    fprintf(stderr, "[ERROR] Training type %d outside of enumerated range (max: %d)\n",
            type_training, FANN_TRAIN_SARPROP);
    exit_code = -1;
    goto bail;
  }

  ann = fann_create_from_file(file_nn);
  data = fann_read_train_from_file(file_train);
  if (batch_items != -1 && batch_items < data->num_data)
    data->num_data = batch_items;
  enum fann_activationfunc_enum af =
    fann_get_activation_function(ann, ann->last_layer - ann->first_layer -1, 0);

  ann->training_algorithm = type_training;
  ann->learning_rate = learning_rate;
  printf("[INFO] Using training type %d\n", type_training);

  if (file_video_string != NULL)
    file_video = fopen(file_video_string, "w");

  double mse;
  for (epoch = 0; epoch < max_epochs; epoch++) {
    fann_train_epoch(ann, data);
    num_bits_failing = 0;
    num_correct = 0;
    fann_reset_MSE(ann);
    for (i = 0; i < fann_length_train_data(data); i++) {
      calc_out = fann_test(ann, data->input[i], data->output[i]);
      if (flag_verbose) {
        printf("[INFO] ");
        for (k = 0; k < data->num_input; k++) {
          printf("%8.5f ", data->input[i][k]);
        }
      }
      int correct = 1;
      for (k = 0; k < data->num_output; k++) {
        if (flag_verbose)
          printf("%8.5f ", calc_out[k]);
        num_bits_failing +=
          fabs(calc_out[k] - data->output[i][k]) > bit_fail_limit;
        if (fabs(calc_out[k] - data->output[i][k]) > bit_fail_limit)
          correct = 0;
        if (file_video)
          fprintf(file_video, "%f ", calc_out[k]);
      }
      if (file_video)
        fprintf(file_video, "\n");
      num_correct += correct;
      if (flag_verbose) {
        if (i < fann_length_train_data(data) - 1)
          printf("\n");
      }
    }
    if (flag_verbose)
      printf("%5d\n\n", epoch);
    if (flag_mse  && (epoch % mse_reporting_period == 0)) {
      mse = fann_get_MSE(ann);
      switch(af) {
      case FANN_LINEAR_PIECE_SYMMETRIC:
      case FANN_THRESHOLD_SYMMETRIC:
      case FANN_SIGMOID_SYMMETRIC:
      case FANN_SIGMOID_SYMMETRIC_STEPWISE:
      case FANN_ELLIOT_SYMMETRIC:
      case FANN_GAUSSIAN_SYMMETRIC:
      case FANN_SIN_SYMMETRIC:
      case FANN_COS_SYMMETRIC:
        mse *= 4.0;
      default:
        break;
      }
      printf("[STAT] epoch %d id %s mse %8.8f\n", epoch, id, mse);
    }
    if (flag_bit_fail && (epoch % bit_fail_reporting_period == 0))
      printf("[STAT] epoch %d id %s bfp %8.8f\n", epoch, id,
             1 - (double) num_bits_failing / data->num_output /
             fann_length_train_data(data));
    if (flag_percent_correct && (epoch % percent_correct_reporting_period == 0))
      printf("[STAT] epoch %d id %s perc %8.8f\n", epoch, id,
             (double) num_correct / fann_length_train_data(data));
    if (!flag_ignore_limits && (num_bits_failing == 0 || mse < mse_fail_limit))
      goto finish;
    // printf("%8.5f\n\n", fann_get_MSE(ann));
  }

 finish:
  if (flag_last)
    printf("[STAT] x 0 id %s epoch %d\n", id, epoch);
  if (flag_cups)
    printf("[STAT] x 0 id %s cups %d / ?\n", id,
           epoch * fann_get_total_connections(ann));

 bail:
  if (ann != NULL)
    fann_destroy(ann);
  if (data != NULL)
    fann_destroy_train(data);
  if (file_video != NULL)
    fclose(file_video);

  return exit_code;
}