예제 #1
0
파일: embedded_da.cpp 프로젝트: nowlab/DALM
void EmbeddedDA::make_da(std::string &pathtotreefile, ValueArrayIndex *value_array_index, Vocabulary &vocab)
{
	TreeFile tf(pathtotreefile, vocab);
	size_t unigram_type = vocab.size();
	resize_array(unigram_type*20);
	da_array[0].base.base_val = 0;

	size_t total = tf.get_totalsize();
	logger << "EmbeddedDA[" << daid << "] total=" << total << Logger::endi;
	size_t order = tf.get_ngramorder();

	int *words = new int[unigram_type+1];
	float *values = new float[unigram_type+1];
	int *history = new int[order-1];
	size_t wordssize = 0;
	size_t historysize = 0;
	size_t terminal_pos=(size_t)-1;

	memset(history, 0, sizeof(int)*(order-1));

	size_t tenpercent = total / 10;

	for(size_t i = 0; i < total; i++){
		if((i+1) % tenpercent == 0){
			logger << "EmbeddedDA[" << daid << "] " << (i+1)/tenpercent << "0% done." << Logger::endi;
		}

		unsigned short n;
		VocabId *ngram;
		float value;
		tf.get_ngram(n,ngram,value);

		if(n==2 && ngram[0]==1 && ngram[1]%datotal!=daid){
			delete [] ngram;
			continue;
		}else if(ngram[0]%datotal!=daid && ngram[0]!=1){
			delete [] ngram;
			continue;
		}

		if(historysize != (size_t)n-1
				|| memcmp(history, ngram, sizeof(int)*(n-1))!=0){

			unsigned now=0;
			for(size_t j = 0; j < historysize; j++){
				now = get_pos(history[j], now);
			}
			if(historysize!=0 && history[historysize-1]==1){ // context ends for <#>.
				det_base(words, values, wordssize, now);
			}else{
				det_base(words, NULL, wordssize, now);
				if(historysize!=0 && terminal_pos!=(size_t)-1){
					unsigned terminal=get_terminal(now);
					value_id[terminal] = value_array_index->lookup(values[terminal_pos]);
				}
			}

			memcpy(history, ngram, sizeof(int)*(n-1));
			historysize = n-1;
			wordssize=0;
			terminal_pos=(size_t)-1;
		}

		if(ngram[n-1]==1){
			terminal_pos=wordssize;
		}

		words[wordssize]=ngram[n-1];
		values[wordssize]=value;
		wordssize++;

		delete [] ngram;
	}
	unsigned now=0;
	for(size_t j = 0; j < historysize; j++){
		now = get_pos(history[j], now);
	}
	if(historysize!=0 && history[historysize-1]==1){
		det_base(words, values, wordssize, now);
	}else{
		det_base(words, NULL, wordssize, now);
		if(historysize!=0 && terminal_pos!=(size_t)-1){
			unsigned terminal=get_terminal(now);
			value_id[terminal] = value_array_index->lookup(values[terminal_pos]);
		}
	}

	replace_value();
	delete [] history;
	delete [] words;
	delete [] values;
}
예제 #2
0
int main(int argc, char **argv) {
    uint64_t hidden_layer_size = 100;
    int min_count = 5;
    TrainPara train_para;
    string save_vocab_file;
    string read_vocab_file;
    string train_file;
    string vector_file;

    if (argc < 3) {
        cerr << usage << endl;
        return -1;
    }
    train_file = argv[argc - 2];
    vector_file = argv[argc - 1];
    
    for (int i = 1; i < argc - 2; i += 2) {
        string arg = argv[i];
        const char* val = argv[i + 1];

        if (arg == "-size") {
            hidden_layer_size = atoi(val);
        }
        else if (arg == "-type") {
            if (string(val) == "cbow") {
                train_para.type = CBOW;
            }
            else if (string(val) == "skip-gram") {
                train_para.type = SKIP_GRAM;
            }
            else {
                cerr << "unknown -type: " << val << endl;;
                return -1;
            }
        }
        else if (arg == "-algo") {
            if (string(val) == "ns") {
                train_para.algo = NEG_SAMPLING;
            }
            else if (string(val) == "hs") {
                train_para.algo = HIER_SOFTMAX;
            }
            else {
                cerr << "unknown -algo: " << val << endl;;
                return -1;
            }
        }
        else if (arg == "-neg-sample") {
            train_para.neg_sample_cnt = atoi(val);
        }
        else if (arg == "-window") {
            train_para.window_size = atoi(val);
        }
        else if (arg == "-subsample") {
            train_para.subsample_thres = atof(val);
        }
        else if (arg == "-thread") {
            train_para.thread_cnt = atoi(val);
        }
        else if (arg == "-iter") {
            train_para.iter_cnt = atoi(val);
        }
        else if (arg == "-min-count") {
            min_count = atoi(val);
        }
        else if (arg == "-alpha") {
            train_para.alpha = atof(val);
        }
        else if (arg == "-save-vocab") {
            save_vocab_file = val;
        }
        else if (arg == "-read-vocab") {
            read_vocab_file = val;
        }
        else {
            cerr << "unknow argument: '" << arg << "'" << endl;
            return -1;
        }
    }

    if (train_para.alpha < 0) {
        if (train_para.type == CBOW) {
            train_para.alpha = 0.05;
        }
        else {
            train_para.alpha = 0.025;
        }
    }

    cerr << "parameters:" << endl
         << "size = " << hidden_layer_size << endl
         << "type = " << ((train_para.type==CBOW)?"cbow":"skip-gram") << endl
         << "algo = " << ((train_para.algo==HIER_SOFTMAX)?"hs":"neg sampling") << endl
         << "neg sampling cnt = " << train_para.neg_sample_cnt << endl
         << "window = " << train_para.window_size << endl
         << "subsample thres = " << train_para.subsample_thres << endl
         << "thread = " << train_para.thread_cnt << endl
         << "iter = " << train_para.iter_cnt << endl
         << "min count = " << min_count << endl
         << "alpha = " << train_para.alpha << endl
         << "save vocab = " << save_vocab_file << endl
         << "read vocab = " << read_vocab_file << endl
         << "training file = " << train_file << endl
         << "word vector file = " << vector_file << endl
         << endl;
    print_log("start ...");

    ifstream ifs_train(train_file.c_str());
    if (!ifs_train) {
        cerr << "can't open: " << train_file << endl;
        return -1;
    }
    
    Vocabulary vocab;
    HuffmanTree* huffman_tree = NULL;
    vocab.parse(ifs_train, min_count);
    cerr << "vocab size = " << vocab.size() << ", total words count = " << vocab.total_cnt() << endl;
    print_log("calc vocab finished ...");
    ifs_train.close();

    if (!save_vocab_file.empty()) {
        ofstream ofs_vocab(save_vocab_file.c_str());
        if (!ofs_vocab) {
            cerr << "can't write to " << save_vocab_file << endl;
            return -1;
        }
        vocab.save(ofs_vocab);
        print_log("save vocab finished ...");
    }

    if (train_para.algo == NEG_SAMPLING) {
        vocab.init_sampling_table();
        print_log("init sampling table finished ...");
    }
    else if (train_para.algo == HIER_SOFTMAX) {
        huffman_tree = new HuffmanTree(vocab.vocab());
        print_log("grow huffman tree finished ...");
    }


    Net net(vocab.size(), hidden_layer_size);
    print_log("net init finished ...");

    if (!train(train_file, vocab, *huffman_tree, net, train_para)) {
        cerr << "training failed" << endl;
        return -1;
    }
    print_log("training finished ...");


    ofstream ofs_result(vector_file.c_str());
    if (!ofs_result) {
        cerr << "can't write to " << vector_file << endl;
        return -1;
    }
    save_word_vec(ofs_result, net, vocab);
    ofs_result.close();
    print_log("saving word vector finished ...");

    delete huffman_tree;
}