示例#1
0
文件: pattern.c 项目: Mougatine/ocr
void make_nn(char *path_pattern, char *saved_name)
{
    size_t desc_layers[] = { 400, 210, 52 };
    struct net nwk = net_init(3, desc_layers);
    struct training t = load_pattern(path_pattern);

    train_nn(nwk, t);
    net_save(nwk, saved_name);
    net_free(nwk);
}
int
main (int argc, char **argv)
{
  network_t *net;
  int no_of_pairs;
  int no_of_inputs;
  int no_of_outputs;
  float input[MAX_SIZE];
  float target[MAX_SIZE];
  float output[MAX_SIZE];
  float error, total_error;
  int t;
  int i;

  srand (time (0));

  parse_options (argc, argv);
  read_specification (spec_filename, &no_of_inputs, &no_of_outputs,
                      &no_of_pairs, input, target);

  if (strlen (input_filename) == 0) {
    if (hidden_nodes == 0)
      hidden_nodes = no_of_inputs;
    net = net_allocate (3, no_of_inputs, hidden_nodes, no_of_outputs);
  } else {
    net = net_load (input_filename);
  }

  if (!use_bias) {
    net_use_bias(net, 0);
  }

/* See spec.c for the way the input/target pairs are stored
 * in the input[] and target[] arrays. */
#define inputs(i) (input + i * no_of_inputs)
#define targets(i) (target + i* no_of_outputs)

  t = 0;
  total_error = 0;
  while ((t < max_trainings) && ((total_error >= max_error) || (t <= 10))) {
    /* choose one of the input/target pairs: inputs(i), targets(i) */
    i = rand () % no_of_pairs;

    /* compute the outputs for inputs(i) */
    net_compute (net, inputs (i), output);

    /* find the error with respect to targets(i) */
    error = net_compute_output_error (net, targets (i));

    /* train the network one step */
    net_train (net);

    /* keep track of (moving) average of the error */
    if (t == 0) {
      total_error = error;
    } else {
      total_error = 0.9 * total_error + 0.1 * error;
    }

    /* next */
    t++;
  }

  if (strlen (output_filename) == 0) {
    net_print (net);
  } else {
    net_save (output_filename, net);
  }

  printf ("Number of training performed: %i (max %i)\n", t, max_trainings);
  printf ("Average output error: %f (max %f)\n", total_error, max_error);

  return 0;
}