void test_tactic_rnn_multi(char *cfgfile, char *weightfile, int num, real_t temp, int rseed, char *token_file) { char **tokens = 0; if (token_file) { size_t n; tokens = read_tokens(token_file, &n); } srand(rseed); char *base = basecfg(cfgfile); fprintf(stderr, "%s\n", base); network *net = load_network(cfgfile, weightfile, 0); int inputs = net->inputs; int i, j; for (i = 0; i < net->n; ++i) net->layers[i].temperature = temp; int c = 0; real_t *input = calloc(inputs, sizeof(real_t)); real_t *out = 0; while (1) { reset_network_state(net, 0); while ((c = getc(stdin)) != EOF && c != 0) { input[c] = 1; out = network_predict(net, input); input[c] = 0; } for (i = 0; i < num; ++i) { for (j = 0; j < inputs; ++j) { if (out[j] < .0001) out[j] = 0; } int next = sample_array(out, inputs); if (c == '.' && next == '\n') break; c = next; print_symbol(c, tokens); input[c] = 1; out = network_predict(net, input); input[c] = 0; } printf("\n"); } }
void vec_char_rnn(char *cfgfile, char *weightfile, char *seed) { char *base = basecfg(cfgfile); fprintf(stderr, "%s\n", base); network *net = load_network(cfgfile, weightfile, 0); int inputs = net->inputs; int c; int seed_len = strlen(seed); real_t *input = calloc(inputs, sizeof(real_t)); int i; char *line; while ((line = fgetl(stdin)) != 0) { reset_network_state(net, 0); for (i = 0; i < seed_len; ++i) { c = seed[i]; input[(int) c] = 1; network_predict(net, input); input[(int) c] = 0; } strip(line); int str_len = strlen(line); for (i = 0; i < str_len; ++i) { c = line[i]; input[(int) c] = 1; network_predict(net, input); input[(int) c] = 0; } c = ' '; input[(int) c] = 1; network_predict(net, input); input[(int) c] = 0; layer l = net->layers[0]; #ifdef GPU cuda_pull_array(l.output_gpu, l.output, l.outputs); #endif printf("%s", line); for (i = 0; i < l.outputs; ++i) { printf(",%g", l.output[i]); } printf("\n"); } }
void train_char_rnn(char *cfgfile, char *weightfile, char *filename, int clear, int tokenized) { srand(time(0)); unsigned char *text = 0; int *tokens = 0; size_t size; if (tokenized) { tokens = read_tokenized_data(filename, &size); } else { text = read_file(filename); size = strlen((const char*) text); } char *backup_directory = "/home/pjreddie/backup/"; char *base = basecfg(cfgfile); fprintf(stderr, "%s\n", base); real_t avg_loss = -1; network *net = load_network(cfgfile, weightfile, clear); int inputs = net->inputs; fprintf(stderr, "Learning Rate: %g, Momentum: %g, Decay: %g, Inputs: %d %d %d\n", net->learning_rate, net->momentum, net->decay, inputs, net->batch, net->time_steps); int batch = net->batch; int steps = net->time_steps; if (clear) *net->seen = 0; int i = (*net->seen) / net->batch; int streams = batch / steps; size_t *offsets = calloc(streams, sizeof(size_t)); int j; for (j = 0; j < streams; ++j) { offsets[j] = rand_size_t() % size; } clock_t time; while (get_current_batch(net) < net->max_batches) { i += 1; time = clock(); real_t_pair p; if (tokenized) { p = get_rnn_token_data(tokens, offsets, inputs, size, streams, steps); } else { p = get_rnn_data(text, offsets, inputs, size, streams, steps); } copy_cpu(net->inputs * net->batch, p.x, 1, net->input, 1); copy_cpu(net->truths * net->batch, p.y, 1, net->truth, 1); real_t loss = train_network_datum(net) / (batch); free(p.x); free(p.y); if (avg_loss < 0) avg_loss = loss; avg_loss = avg_loss * .9 + loss * .1; size_t chars = get_current_batch(net) * batch; fprintf(stderr, "%d: %f, %f avg, %f rate, %lf seconds, %f epochs\n", i, loss, avg_loss, get_current_rate(net), sec(clock() - time), (real_t) chars / size); for (j = 0; j < streams; ++j) { //printf("%d\n", j); if (rand() % 64 == 0) { //fprintf(stderr, "Reset\n"); offsets[j] = rand_size_t() % size; reset_network_state(net, j); } } if (i % 10000 == 0) { char buff[256]; sprintf(buff, "%s/%s_%d.weights", backup_directory, base, i); save_weights(net, buff); } if (i % 100 == 0) { char buff[256]; sprintf(buff, "%s/%s.backup", backup_directory, base); save_weights(net, buff); } } char buff[256]; sprintf(buff, "%s/%s_final.weights", backup_directory, base); save_weights(net, buff); }
void reset_rnn(network *net) { reset_network_state(net, 0); }