Exemple #1
0
void train_char_rnn(char *cfgfile, char *weightfile, char *filename, int clear, int tokenized)
{
    srand(time(0));
    data_seed = time(0);
    unsigned char *text = 0;
    int *tokens = 0;
    size_t size;
    if(tokenized){
        tokens = read_tokenized_data(filename, &size);
    } else {
        FILE *fp = fopen(filename, "rb");

        fseek(fp, 0, SEEK_END); 
        size = ftell(fp);
        fseek(fp, 0, SEEK_SET); 

        text = calloc(size+1, sizeof(char));
        fread(text, 1, size, fp);
        fclose(fp);
    }

    char *backup_directory = "/home/pjreddie/backup/";
    char *base = basecfg(cfgfile);
    fprintf(stderr, "%s\n", base);
    float avg_loss = -1;
    network net = parse_network_cfg(cfgfile);
    if(weightfile){
        load_weights(&net, weightfile);
    }

    int inputs = get_network_input_size(net);
    fprintf(stderr, "Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
    int batch = net.batch;
    int steps = net.time_steps;
    if(clear) *net.seen = 0;
    int i = (*net.seen)/net.batch;

    int streams = batch/steps;
    size_t *offsets = calloc(streams, sizeof(size_t));
    int j;
    for(j = 0; j < streams; ++j){
        offsets[j] = rand_size_t()%size;
    }

    clock_t time;
    while(get_current_batch(net) < net.max_batches){
        i += 1;
        time=clock();
        float_pair p;
        if(tokenized){
            p = get_rnn_token_data(tokens, offsets, inputs, size, streams, steps);
        }else{
            p = get_rnn_data(text, offsets, inputs, size, streams, steps);
        }

        float loss = train_network_datum(net, p.x, p.y) / (batch);
        free(p.x);
        free(p.y);
        if (avg_loss < 0) avg_loss = loss;
        avg_loss = avg_loss*.9 + loss*.1;

        int chars = get_current_batch(net)*batch;
        fprintf(stderr, "%d: %f, %f avg, %f rate, %lf seconds, %f epochs\n", i, loss, avg_loss, get_current_rate(net), sec(clock()-time), (float) chars/size);

        for(j = 0; j < streams; ++j){
            //printf("%d\n", j);
            if(rand()%10 == 0){
                //fprintf(stderr, "Reset\n");
                offsets[j] = rand_size_t()%size;
                reset_rnn_state(net, j);
            }
        }

        if(i%1000==0){
            char buff[256];
            sprintf(buff, "%s/%s_%d.weights", backup_directory, base, i);
            save_weights(net, buff);
        }
        if(i%10==0){
            char buff[256];
            sprintf(buff, "%s/%s.backup", backup_directory, base);
            save_weights(net, buff);
        }
    }
    char buff[256];
    sprintf(buff, "%s/%s_final.weights", backup_directory, base);
    save_weights(net, buff);
}
void train_char_rnn(char *cfgfile, char *weightfile, char *filename, int clear,
		int tokenized) {
	srand(time(0));
	unsigned char *text = 0;
	int *tokens = 0;
	size_t size;
	if (tokenized) {
		tokens = read_tokenized_data(filename, &size);
	} else {
		text = read_file(filename);
		size = strlen((const char*) text);
	}

	char *backup_directory = "/home/pjreddie/backup/";
	char *base = basecfg(cfgfile);
	fprintf(stderr, "%s\n", base);
	real_t avg_loss = -1;
	network *net = load_network(cfgfile, weightfile, clear);

	int inputs = net->inputs;
	fprintf(stderr,
			"Learning Rate: %g, Momentum: %g, Decay: %g, Inputs: %d %d %d\n",
			net->learning_rate, net->momentum, net->decay, inputs, net->batch,
			net->time_steps);
	int batch = net->batch;
	int steps = net->time_steps;
	if (clear)
		*net->seen = 0;
	int i = (*net->seen) / net->batch;

	int streams = batch / steps;
	size_t *offsets = calloc(streams, sizeof(size_t));
	int j;
	for (j = 0; j < streams; ++j) {
		offsets[j] = rand_size_t() % size;
	}

	clock_t time;
	while (get_current_batch(net) < net->max_batches) {
		i += 1;
		time = clock();
		real_t_pair p;
		if (tokenized) {
			p = get_rnn_token_data(tokens, offsets, inputs, size, streams,
					steps);
		} else {
			p = get_rnn_data(text, offsets, inputs, size, streams, steps);
		}

		copy_cpu(net->inputs * net->batch, p.x, 1, net->input, 1);
		copy_cpu(net->truths * net->batch, p.y, 1, net->truth, 1);
		real_t loss = train_network_datum(net) / (batch);
		free(p.x);
		free(p.y);
		if (avg_loss < 0)
			avg_loss = loss;
		avg_loss = avg_loss * .9 + loss * .1;

		size_t chars = get_current_batch(net) * batch;
		fprintf(stderr, "%d: %f, %f avg, %f rate, %lf seconds, %f epochs\n", i,
				loss, avg_loss, get_current_rate(net), sec(clock() - time),
				(real_t) chars / size);

		for (j = 0; j < streams; ++j) {
			//printf("%d\n", j);
			if (rand() % 64 == 0) {
				//fprintf(stderr, "Reset\n");
				offsets[j] = rand_size_t() % size;
				reset_network_state(net, j);
			}
		}

		if (i % 10000 == 0) {
			char buff[256];
			sprintf(buff, "%s/%s_%d.weights", backup_directory, base, i);
			save_weights(net, buff);
		}
		if (i % 100 == 0) {
			char buff[256];
			sprintf(buff, "%s/%s.backup", backup_directory, base);
			save_weights(net, buff);
		}
	}
	char buff[256];
	sprintf(buff, "%s/%s_final.weights", backup_directory, base);
	save_weights(net, buff);
}