void train_nn(struct net nwk, struct training t) { unsigned epoch = 0; unsigned freq = 500; unsigned limit = 5000; do { unsigned error = 0; for (size_t i = 0; i < t.n_set; ++i) { net_compute(nwk, get_in(&t, i)); net_error(nwk, get_out(&t, i)); if (epoch % freq == 0) { char c = convert_output(get_out(&t, i), 52); char rst = convert_output(net_output(nwk), 52); if (c != rst) { printf("In: %c\t", c); printf("Out: %c\n", rst); ++error; } } } if (epoch % freq == 0) { printf("[%d] ERROR: %d\n", epoch, error); printf("******************\n"); } ++epoch; if (epoch == limit) { printf("Continue? "); int b = 0; scanf("%d", &b); if (b) epoch = 0; } } while(epoch < limit); }
char ask_nn(struct net nwk, t_img_desc *img) { assert(img->comp == 1); if (img->x != 20 || img->y != 20) { uchar *ptr = malloc(sizeof(char) * 20 * 20); stbir_resize_uint8(img->data, img->x, img->y, 0, ptr, 20, 20, 0, 1); img->data = ptr; img->x = 20; img->y = 20; img->comp = 1; } double in[400]; gen_input(img, in); net_compute(nwk, in); double *out = net_output(nwk); return convert_output(out, 52); }
int main (int argc, char **argv) { network_t *net; int no_of_pairs; int no_of_inputs; int no_of_outputs; float input[MAX_SIZE]; float target[MAX_SIZE]; float output[MAX_SIZE]; float error, total_error; int t; int i; srand (time (0)); parse_options (argc, argv); read_specification (spec_filename, &no_of_inputs, &no_of_outputs, &no_of_pairs, input, target); if (strlen (input_filename) == 0) { if (hidden_nodes == 0) hidden_nodes = no_of_inputs; net = net_allocate (3, no_of_inputs, hidden_nodes, no_of_outputs); } else { net = net_load (input_filename); } if (!use_bias) { net_use_bias(net, 0); } /* See spec.c for the way the input/target pairs are stored * in the input[] and target[] arrays. */ #define inputs(i) (input + i * no_of_inputs) #define targets(i) (target + i* no_of_outputs) t = 0; total_error = 0; while ((t < max_trainings) && ((total_error >= max_error) || (t <= 10))) { /* choose one of the input/target pairs: inputs(i), targets(i) */ i = rand () % no_of_pairs; /* compute the outputs for inputs(i) */ net_compute (net, inputs (i), output); /* find the error with respect to targets(i) */ error = net_compute_output_error (net, targets (i)); /* train the network one step */ net_train (net); /* keep track of (moving) average of the error */ if (t == 0) { total_error = error; } else { total_error = 0.9 * total_error + 0.1 * error; } /* next */ t++; } if (strlen (output_filename) == 0) { net_print (net); } else { net_save (output_filename, net); } printf ("Number of training performed: %i (max %i)\n", t, max_trainings); printf ("Average output error: %f (max %f)\n", total_error, max_error); return 0; }