Exemplo n.º 1
0
Arquivo: gibbs.cpp Projeto: DylanV/lda
void gibbs::train(size_t numTopics) {

    this->numTopics = numTopics;

    zero_init_counts();
    random_assign_topics();

    this->beta = 1/numTerms;

    double max_iter = MAX_ITER;

    for(int iter=0; iter<max_iter; ++iter){
        for(int d=0; d<numDocs; ++d){

            document curr_doc = corpus.docs[d];
            int doc_word_index = 0;

            for(auto const& word_count : curr_doc.wordCounts){

                int w = word_count.first;
                int count = word_count.second;

                for(int i=0; i<count; ++i){
                    // get current assignment
                    int z = topic_assignments[d][doc_word_index];
                    // decrement counts
                    n_dk[d][z] -= 1;
                    n_kw[z][w] -= 1;
                    n_k[z] -= 1;
                    n_d[d] -= 1;
                    // get conditional distribution
                    std::vector<double> pz = get_pz(d, w);
                    // update assignment
                    z = sample_multinomial(pz);
                    topic_assignments[d][doc_word_index] = z;
                    doc_word_index++;
                    // update counts
                    n_dk[d][z] += 1;
                    n_kw[z][w] += 1;
                    n_k[z] += 1;
                    n_d[d] += 1;
                }
            }
        }

        std::cout << "<";
        int count = 0;
        int length = 50;
        for(int x=0; x<(iter/max_iter)*length; x++){
            std::cout << "=";
            count++;
        }
        for(int x=0; x<length-count; x++){
            std::cout << "-";
        }
        std::cout << "> " << iter+1 << "/" << max_iter << '\r' << std::flush;
    }
    // update phi and theta
    estimate_parameters();
}
void detect_breakpoints(std::string read_filename, IPrinter *& printer) {
	estimate_parameters(read_filename);
	BamParser * mapped_file = 0;
	RefVector ref;
	if (read_filename.find("bam") != string::npos) {
		mapped_file = new BamParser(read_filename);
		ref = mapped_file->get_refInfo();
	} else {
		cerr << "File Format not recognized. File must be a sorted .bam file!" << endl;
		exit(0);
	}
//Using PlaneSweep to comp coverage and iterate through reads:
//PlaneSweep * sweep = new PlaneSweep();
	std::cout << "Start parsing..." << std::endl;
//Using Interval tree to store and manage breakpoints:

	//IntervallList final;
	//IntervallList bst;

	IntervallTree final;
	IntervallTree bst;

	TNode * root_final = NULL;
	int current_RefID = 0;

	TNode *root = NULL;
//FILE * alt_allel_reads;
	FILE * ref_allel_reads;
	if (Parameter::Instance()->genotype) {

		std::string output = Parameter::Instance()->tmp_file.c_str();
		output += "ref_allele";
		ref_allel_reads = fopen(output.c_str(), "wb");

//	output = Parameter::Instance()->tmp_file.c_str();
		//output += "alt_allele";
		//alt_allel_reads = fopen(output.c_str(), "wb");
	}
	Alignment * tmp_aln = mapped_file->parseRead(Parameter::Instance()->min_mq);
	long ref_space = get_ref_lengths(tmp_aln->getRefID(), ref);
	std::string prev = "test";
	std::string curr = "curr";
	long num_reads = 0;
	while (!tmp_aln->getQueryBases().empty()) {
		if ((tmp_aln->getAlignment()->IsPrimaryAlignment()) && (!(tmp_aln->getAlignment()->AlignmentFlag & 0x800) && tmp_aln->get_is_save())) {
			//change CHR:
			if (current_RefID != tmp_aln->getRefID()) {
				current_RefID = tmp_aln->getRefID();
				ref_space = get_ref_lengths(tmp_aln->getRefID(), ref);
				std::cout << "\tSwitch Chr " << ref[tmp_aln->getRefID()].RefName << " " << ref[tmp_aln->getRefID()].RefLength << std::endl;
				std::vector<Breakpoint *> points;
				//	clarify(points);
				bst.get_breakpoints(root, points);
				polish_points(points, ref);
				for (int i = 0; i < points.size(); i++) {
					if (should_be_stored(points[i])) {
						if (points[i]->get_SVtype() & TRA) {
							final.insert(points[i], root_final);
						} else {
							printer->printSV(points[i]);
						}
					}
				}
				bst.clear(root);
			}
Exemplo n.º 3
0
Arquivo: lda.c Projeto: zzy7896321/ppp
int main()
{
#ifdef ENABLE_MEM_PROFILE
    mem_profile_init();
#endif
	set_sample_method("Metropolis-hastings");
	set_sample_iterations(200);
	set_mh_burn_in(200);
	set_mh_lag(50);
	set_mh_max_initial_round(2000);

    /* use pointers to structs because the client doesn't need to know the struct sizes */
    struct pp_state_t* state;
    struct pp_instance_t* instance;
    struct pp_query_t* query;
    struct pp_trace_store_t* traces;
    float result;

    state = pp_new_state();
    printf("> state created\n");

    pp_load_file(state, "parse/models/lda.model");
    printf("> file loaded\n");
	
	ModelNode* model = model_map_find(state->model_map, state->symbol_table, "latent_dirichlet_allocation");
	printf(dump_model(model));

    //query = pp_compile_string_query("");
    //printf("> condition compiled\n");

	pp_variable_t** param = malloc(sizeof(pp_variable_t*) * 4);
	param[0] = new_pp_int(2);
	param[1] = new_pp_int(2);
	param[2] = new_pp_vector(2);
	PP_VARIABLE_VECTOR_LENGTH(param[2]) = 2;
	for (int i = 0; i < 2; ++i) {
		PP_VARIABLE_VECTOR_VALUE(param[2])[i] = new_pp_int(2);
	}
	param[3] = new_pp_int(3);	

	int X[2][2] = {
		{0, 0},
		{1, 1},
	};
	query = pp_query_observe_int_array_2D(state, "X", &X[0][0], 2, 2);
	if (!query) return 1;

    traces = pp_sample_v(state, "latent_dirichlet_allocation", param, query, 1, "topic");
    printf("> traces sampled\n");

    if (!traces) {
        printf("ERROR encountered!!\n");
        return 1;
    }

	char buffer[8096];

    size_t max_index = 0;
    for (size_t i = 1; i < traces->n; ++i) {
    	if (traces->trace[i]->logprob > traces->trace[max_index]->logprob) {
    		max_index = i;
    	}
    }
    printf("\nsample with max logprob:\n");
    pp_trace_dump(buffer, 8096, traces->trace[max_index]);
    printf(buffer);

	printf("\nlast sample:\n");
	pp_trace_dump(buffer, 8096, traces->trace[traces->n - 1]);
	printf(buffer); 

    FILE* trace_dump_file = fopen("trace_dump.txt", "w");
    for (size_t i = 0; i != traces->n; ++i) {
        pp_trace_dump(buffer, 8096, traces->trace[i]);
        fprintf(trace_dump_file, "[trace %u]\n", i);
        fprintf(trace_dump_file, buffer);
    }
    fclose(trace_dump_file);

	int nwords[] = {2, 2};

	/* parameter estimation */
	estimate_parameters(traces, 1.0, 1.0, 2, 2, nwords, 3, &X[0][0], 2);

	// pp_free is broken
    pp_free(state);  /* free memory, associated models, instances, queries, and trace stores are deallocated */

    pp_trace_store_destroy(traces);
	
	pp_query_destroy(query);

    for (int i = 0; i < 4; ++i)
        pp_variable_destroy(param[i]);
	free(param);

#ifdef ENABLE_MEM_PROFILE
    mem_profile_print();
    mem_profile_destroy();
#endif

    return 0;
}