Ejemplo n.º 1
0
void test_predict(){
    crbm m;
    dataset_blas train_set;
    int i, j, k;
    int minibatch = 20;
    int batch_size, niter;
    double *x, *y = py;

    load_model(&m, "../data/20newsgroup/cross_valid/crbm.model");
    load_corpus("../data/20newsgroup/train.data.format", &train_set);
    load_corpus_label("../data/20newsgroup/train.label.format", &train_set);

    niter = ceil(train_set.N * 1.0 / minibatch);
    for(k = 0; k < niter; k++){
        if(k == niter - 1){
            batch_size = train_set.N - minibatch * (niter-1);
        }else{
            batch_size = minibatch;
        }
        x = train_set.input + m.nvisible * minibatch * k;
        get_y_given_x(&m, x, y, batch_size);
        for(j = 0; j < batch_size; j++){
            printf("%d : %d\n", train_set.output[k*minibatch+j] + 1, 
                    get_max_index(&y[j*m.ncat], m.ncat));
        }
    }

    free_crbm(&m);
    free_dataset_blas(&train_set);
}
Ejemplo n.º 2
0
int main() {

	std::clock_t start,end;
	start = std::clock();
	Node root = load_corpus();
	end = std::clock();
//	std::cout << "Time: " << (end-start)/(double)(CLOCKS_PER_SEC/1000) << "\n";
	std::string prefix;
	std::cout << "Enter a prefix: ";
	std::getline(std::cin,prefix);
	
	std::vector <std::string> suggestedWords = root.search(prefix);
	std::sort(suggestedWords.begin(),suggestedWords.end(),compareRanks);
	int num_words = suggestedWords.size();
	for(int i = 0; i < num_words; i++) {
		std::string suggestedWord = suggestedWords[i];
		std::cout << "Suggested: " << suggestedWord << ", #" << wordRank[suggestedWord] << std::endl;
		//std::cout << "Suggested: "<<root.search(prefix)[i] << std::endl;
	
	}
}
Ejemplo n.º 3
0
void test_crbm(){
    dataset_blas train_set, valid_set;

    int nhidden = 500;
    int epoch = 30;
    double lr = 0.1;
    int minibatch = 10;
    double momentum = 0;

    load_corpus("../data/20newsgroup/cross_valid/data.out", &train_set);
    load_corpus_label("../data/20newsgroup/cross_valid/label.out", &train_set);
    //load_corpus("../data/20newsgroup/train.data.format", &train_set);
    //load_corpus_label("../data/20newsgroup/train.label.format", &train_set);
    //load_corpus("../data/tcga/train.pm.data", &train_set);
    //load_corpus_label("../data/tcga/train.pm.label", &train_set);
    //load_mnist_dataset_blas(&train_set, &valid_set);
    train_crbm(&train_set, train_set.n_feature, nhidden, train_set.nlabel, epoch, lr,
               minibatch, momentum, "../data/20newsgroup/cross_valid/crbm.model");

    free_dataset_blas(&train_set);
}