float *get_network_output(network net) { #ifdef GPU return get_network_output_gpu(net); #endif int i; for(i = net.n-1; i > 0; --i) if(net.layers[i].type != COST) break; return net.layers[i].output; }
void train_writing(char *cfgfile, char *weightfile) { data_seed = time(0); srand(time(0)); float avg_loss = -1; char *base = basecfg(cfgfile); printf("%s\n", base); network net = parse_network_cfg(cfgfile); if(weightfile){ load_weights(&net, weightfile); } printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay); int imgs = 256; int i = net.seen/imgs; list *plist = get_paths("data/train.list"); char **paths = (char **)list_to_array(plist); printf("%d\n", plist->size); clock_t time; while(1){ ++i; time=clock(); data train = load_data_writing(paths, imgs, plist->size, 256, 256, 4); float loss = train_network(net, train); #ifdef GPU float *out = get_network_output_gpu(net); #else float *out = get_network_output(net); #endif // image pred = float_to_image(32, 32, 1, out); // print_image(pred); net.seen += imgs; if(avg_loss == -1) avg_loss = loss; avg_loss = avg_loss*.9 + loss*.1; printf("%d: %f, %f avg, %lf seconds, %d images\n", i, loss, avg_loss, sec(clock()-time), net.seen); free_data(train); if((i % 20000) == 0) net.learning_rate *= .1; //if(i%100 == 0 && net.learning_rate > .00001) net.learning_rate *= .97; if(i%250==0){ char buff[256]; sprintf(buff, "/home/pjreddie/writing_backup/%s_%d.weights", base, i); save_weights(net, buff); } } }