示例#1
0
void test_tactic_rnn_multi(char *cfgfile, char *weightfile, int num,
		real_t temp, int rseed, char *token_file) {
	char **tokens = 0;
	if (token_file) {
		size_t n;
		tokens = read_tokens(token_file, &n);
	}

	srand(rseed);
	char *base = basecfg(cfgfile);
	fprintf(stderr, "%s\n", base);

	network *net = load_network(cfgfile, weightfile, 0);
	int inputs = net->inputs;

	int i, j;
	for (i = 0; i < net->n; ++i)
		net->layers[i].temperature = temp;
	int c = 0;
	real_t *input = calloc(inputs, sizeof(real_t));
	real_t *out = 0;

	while (1) {
		reset_network_state(net, 0);
		while ((c = getc(stdin)) != EOF && c != 0) {
			input[c] = 1;
			out = network_predict(net, input);
			input[c] = 0;
		}
		for (i = 0; i < num; ++i) {
			for (j = 0; j < inputs; ++j) {
				if (out[j] < .0001)
					out[j] = 0;
			}
			int next = sample_array(out, inputs);
			if (c == '.' && next == '\n')
				break;
			c = next;
			print_symbol(c, tokens);

			input[c] = 1;
			out = network_predict(net, input);
			input[c] = 0;
		}
		printf("\n");
	}
}
示例#2
0
void vec_char_rnn(char *cfgfile, char *weightfile, char *seed) {
	char *base = basecfg(cfgfile);
	fprintf(stderr, "%s\n", base);

	network *net = load_network(cfgfile, weightfile, 0);
	int inputs = net->inputs;

	int c;
	int seed_len = strlen(seed);
	real_t *input = calloc(inputs, sizeof(real_t));
	int i;
	char *line;
	while ((line = fgetl(stdin)) != 0) {
		reset_network_state(net, 0);
		for (i = 0; i < seed_len; ++i) {
			c = seed[i];
			input[(int) c] = 1;
			network_predict(net, input);
			input[(int) c] = 0;
		}
		strip(line);
		int str_len = strlen(line);
		for (i = 0; i < str_len; ++i) {
			c = line[i];
			input[(int) c] = 1;
			network_predict(net, input);
			input[(int) c] = 0;
		}
		c = ' ';
		input[(int) c] = 1;
		network_predict(net, input);
		input[(int) c] = 0;

		layer l = net->layers[0];
#ifdef GPU
		cuda_pull_array(l.output_gpu, l.output, l.outputs);
#endif
		printf("%s", line);
		for (i = 0; i < l.outputs; ++i) {
			printf(",%g", l.output[i]);
		}
		printf("\n");
	}
}
示例#3
0
void train_char_rnn(char *cfgfile, char *weightfile, char *filename, int clear,
		int tokenized) {
	srand(time(0));
	unsigned char *text = 0;
	int *tokens = 0;
	size_t size;
	if (tokenized) {
		tokens = read_tokenized_data(filename, &size);
	} else {
		text = read_file(filename);
		size = strlen((const char*) text);
	}

	char *backup_directory = "/home/pjreddie/backup/";
	char *base = basecfg(cfgfile);
	fprintf(stderr, "%s\n", base);
	real_t avg_loss = -1;
	network *net = load_network(cfgfile, weightfile, clear);

	int inputs = net->inputs;
	fprintf(stderr,
			"Learning Rate: %g, Momentum: %g, Decay: %g, Inputs: %d %d %d\n",
			net->learning_rate, net->momentum, net->decay, inputs, net->batch,
			net->time_steps);
	int batch = net->batch;
	int steps = net->time_steps;
	if (clear)
		*net->seen = 0;
	int i = (*net->seen) / net->batch;

	int streams = batch / steps;
	size_t *offsets = calloc(streams, sizeof(size_t));
	int j;
	for (j = 0; j < streams; ++j) {
		offsets[j] = rand_size_t() % size;
	}

	clock_t time;
	while (get_current_batch(net) < net->max_batches) {
		i += 1;
		time = clock();
		real_t_pair p;
		if (tokenized) {
			p = get_rnn_token_data(tokens, offsets, inputs, size, streams,
					steps);
		} else {
			p = get_rnn_data(text, offsets, inputs, size, streams, steps);
		}

		copy_cpu(net->inputs * net->batch, p.x, 1, net->input, 1);
		copy_cpu(net->truths * net->batch, p.y, 1, net->truth, 1);
		real_t loss = train_network_datum(net) / (batch);
		free(p.x);
		free(p.y);
		if (avg_loss < 0)
			avg_loss = loss;
		avg_loss = avg_loss * .9 + loss * .1;

		size_t chars = get_current_batch(net) * batch;
		fprintf(stderr, "%d: %f, %f avg, %f rate, %lf seconds, %f epochs\n", i,
				loss, avg_loss, get_current_rate(net), sec(clock() - time),
				(real_t) chars / size);

		for (j = 0; j < streams; ++j) {
			//printf("%d\n", j);
			if (rand() % 64 == 0) {
				//fprintf(stderr, "Reset\n");
				offsets[j] = rand_size_t() % size;
				reset_network_state(net, j);
			}
		}

		if (i % 10000 == 0) {
			char buff[256];
			sprintf(buff, "%s/%s_%d.weights", backup_directory, base, i);
			save_weights(net, buff);
		}
		if (i % 100 == 0) {
			char buff[256];
			sprintf(buff, "%s/%s.backup", backup_directory, base);
			save_weights(net, buff);
		}
	}
	char buff[256];
	sprintf(buff, "%s/%s_final.weights", backup_directory, base);
	save_weights(net, buff);
}
示例#4
0
void reset_rnn(network *net)
{
    reset_network_state(net, 0);
}