void train_char_rnn(char *cfgfile, char *weightfile, char *filename, int clear, int tokenized) { srand(time(0)); data_seed = time(0); unsigned char *text = 0; int *tokens = 0; size_t size; if(tokenized){ tokens = read_tokenized_data(filename, &size); } else { FILE *fp = fopen(filename, "rb"); fseek(fp, 0, SEEK_END); size = ftell(fp); fseek(fp, 0, SEEK_SET); text = calloc(size+1, sizeof(char)); fread(text, 1, size, fp); fclose(fp); } char *backup_directory = "/home/pjreddie/backup/"; char *base = basecfg(cfgfile); fprintf(stderr, "%s\n", base); float avg_loss = -1; network net = parse_network_cfg(cfgfile); if(weightfile){ load_weights(&net, weightfile); } int inputs = get_network_input_size(net); fprintf(stderr, "Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay); int batch = net.batch; int steps = net.time_steps; if(clear) *net.seen = 0; int i = (*net.seen)/net.batch; int streams = batch/steps; size_t *offsets = calloc(streams, sizeof(size_t)); int j; for(j = 0; j < streams; ++j){ offsets[j] = rand_size_t()%size; } clock_t time; while(get_current_batch(net) < net.max_batches){ i += 1; time=clock(); float_pair p; if(tokenized){ p = get_rnn_token_data(tokens, offsets, inputs, size, streams, steps); }else{ p = get_rnn_data(text, offsets, inputs, size, streams, steps); } float loss = train_network_datum(net, p.x, p.y) / (batch); free(p.x); free(p.y); if (avg_loss < 0) avg_loss = loss; avg_loss = avg_loss*.9 + loss*.1; int chars = get_current_batch(net)*batch; fprintf(stderr, "%d: %f, %f avg, %f rate, %lf seconds, %f epochs\n", i, loss, avg_loss, get_current_rate(net), sec(clock()-time), (float) chars/size); for(j = 0; j < streams; ++j){ //printf("%d\n", j); if(rand()%10 == 0){ //fprintf(stderr, "Reset\n"); offsets[j] = rand_size_t()%size; reset_rnn_state(net, j); } } if(i%1000==0){ char buff[256]; sprintf(buff, "%s/%s_%d.weights", backup_directory, base, i); save_weights(net, buff); } if(i%10==0){ char buff[256]; sprintf(buff, "%s/%s.backup", backup_directory, base); save_weights(net, buff); } } char buff[256]; sprintf(buff, "%s/%s_final.weights", backup_directory, base); save_weights(net, buff); }
void train_char_rnn(char *cfgfile, char *weightfile, char *filename, int clear, int tokenized) { srand(time(0)); unsigned char *text = 0; int *tokens = 0; size_t size; if (tokenized) { tokens = read_tokenized_data(filename, &size); } else { text = read_file(filename); size = strlen((const char*) text); } char *backup_directory = "/home/pjreddie/backup/"; char *base = basecfg(cfgfile); fprintf(stderr, "%s\n", base); real_t avg_loss = -1; network *net = load_network(cfgfile, weightfile, clear); int inputs = net->inputs; fprintf(stderr, "Learning Rate: %g, Momentum: %g, Decay: %g, Inputs: %d %d %d\n", net->learning_rate, net->momentum, net->decay, inputs, net->batch, net->time_steps); int batch = net->batch; int steps = net->time_steps; if (clear) *net->seen = 0; int i = (*net->seen) / net->batch; int streams = batch / steps; size_t *offsets = calloc(streams, sizeof(size_t)); int j; for (j = 0; j < streams; ++j) { offsets[j] = rand_size_t() % size; } clock_t time; while (get_current_batch(net) < net->max_batches) { i += 1; time = clock(); real_t_pair p; if (tokenized) { p = get_rnn_token_data(tokens, offsets, inputs, size, streams, steps); } else { p = get_rnn_data(text, offsets, inputs, size, streams, steps); } copy_cpu(net->inputs * net->batch, p.x, 1, net->input, 1); copy_cpu(net->truths * net->batch, p.y, 1, net->truth, 1); real_t loss = train_network_datum(net) / (batch); free(p.x); free(p.y); if (avg_loss < 0) avg_loss = loss; avg_loss = avg_loss * .9 + loss * .1; size_t chars = get_current_batch(net) * batch; fprintf(stderr, "%d: %f, %f avg, %f rate, %lf seconds, %f epochs\n", i, loss, avg_loss, get_current_rate(net), sec(clock() - time), (real_t) chars / size); for (j = 0; j < streams; ++j) { //printf("%d\n", j); if (rand() % 64 == 0) { //fprintf(stderr, "Reset\n"); offsets[j] = rand_size_t() % size; reset_network_state(net, j); } } if (i % 10000 == 0) { char buff[256]; sprintf(buff, "%s/%s_%d.weights", backup_directory, base, i); save_weights(net, buff); } if (i % 100 == 0) { char buff[256]; sprintf(buff, "%s/%s.backup", backup_directory, base); save_weights(net, buff); } } char buff[256]; sprintf(buff, "%s/%s_final.weights", backup_directory, base); save_weights(net, buff); }