Example #1
0
double sample_hmm_posterior(
    int blocklen, const LocalTree *tree, const States &states,
    const TransMatrix *matrix, const double *const *fw, int *path)
{
    // NOTE: path[n-1] must already be sampled

    const int nstates = max(states.size(), (size_t)1);
    double A[nstates];
    double trans[nstates];
    int last_k = -1;
    double lnl = 0.0;

    // recurse
    for (int i=blocklen-2; i>=0; i--) {
        int k = path[i+1];

        // recompute transition probabilities if state (k) changes
        if (k != last_k) {
            for (int j=0; j<nstates; j++)
                trans[j] = matrix->get(tree, states, j, k);
            last_k = k;
        }

        for (int j=0; j<nstates; j++)
            A[j] = fw[i][j] * trans[j];
        path[i] = sample(A, nstates);
        //lnl += log(A[path[i]]);

        // DEBUG
        assert(trans[path[i]] != 0.0);
    }

    return lnl;
}
Example #2
0
static void merge_tags(const States& states,
		const std::map<size_t, Tag>& from,
		int s, std::map<size_t, Tag> *to) {
	size_t size = states.size();
	for (States::const_iterator it = states.begin();
			it != states.end(); ++it) {
		assert(*it >= 0);
		std::map<size_t, Tag>::const_iterator fit = from.find(*it);
		if (fit != from.end()) {
			Tag& tag = (*to)[s];
			tag.insert(fit->second.begin(), fit->second.end());
		}
	}
}
Example #3
0
// 所有能从from通过EPSILON能达到的NFA状态(包括from)
static States fill(const std::vector<NFATran>& trans, const States& last, const States& from, bool* is_last) {
	std::queue<int> q;
	for (States::const_iterator it = from.begin();
			it != from.end(); ++it) {
		q.push(*it);
	}

	// ends表示终点(即最终状态),要判断这次转移是否只有-1
	States ends;
	States to;
	while (!q.empty()) {
		int s = q.front();
		q.pop();

		to.insert(s);
		if (last.find(s) != last.end()) {
			*is_last = true;
		}

		if (s == -1) {
			ends.insert(-1);
			continue;
		}

		const NFATran& tran = trans[s];
		NFATran::const_iterator it = tran.find(EPSILON);
		if (it == tran.end()) {
			ends.insert(s);
			continue;
		}

		const States& next = it->second;
		for (States::const_iterator nit = next.begin();
				nit != next.end(); ++nit) {
			if (to.find(*nit) == to.end()) {
				q.push(*nit);
			}
		}
	}

	if (ends.find(-1) == ends.end() || ends.size() > 1) {
		to.erase(-1);
	} else {
		to.clear();
		to.insert(-1);
	}

	return to;
}
Example #4
0
// compute one block of forward algorithm with compressed transition matrices
// NOTE: first column of forward table should be pre-populated
// This can be used for testing
void arghmm_forward_block_slow(const LocalTree *tree, const int ntimes,
                               const int blocklen, const States &states,
                               const LineageCounts &lineages,
                               const TransMatrix *matrix,
                               const double* const *emit, double **fw)
{
    const int nstates = states.size();

    // get transition matrix
    double **transmat = new_matrix<double>(nstates, nstates);
    for (int k=0; k<nstates; k++)
        for (int j=0; j<nstates; j++)
            transmat[j][k] = matrix->get(tree, states, j, k);

    // fill in forward table
    for (int i=1; i<blocklen; i++) {
        const double *col1 = fw[i-1];
        double *col2 = fw[i];
        double norm = 0.0;

        for (int k=0; k<nstates; k++) {
            double sum = 0.0;
            for (int j=0; j<nstates; j++)
                sum += col1[j] * transmat[j][k];
            col2[k] = sum * emit[i][k];
            norm += col2[k];
        }

        // normalize column for numerical stability
        for (int k=0; k<nstates; k++)
            col2[k] /= norm;
    }

    // cleanup
    delete_matrix<double>(transmat, nstates);
}
Example #5
0
void sample_recombinations(
    const LocalTrees *trees, const ArgModel *model,
    ArgHmmMatrixIter *matrix_iter,
    int *thread_path, vector<int> &recomb_pos, vector<NodePoint> &recombs,
    bool internal)
{
    States states;
    LineageCounts lineages(model->ntimes);
    const int new_node = -1;
    vector <NodePoint> candidates;
    vector <double> probs;

    // loop through local blocks
    for (matrix_iter->begin(); matrix_iter->more(); matrix_iter->next()) {

        // get local block information
        ArgHmmMatrices &matrices = matrix_iter->ref_matrices();
        LocalTree *tree = matrix_iter->get_tree_spr()->tree;
        lineages.count(tree, internal);
        matrices.states_model.get_coal_states(tree, states);
        int next_recomb = -1;

        // don't sample recombination if there is no state space
        if (internal && states.size() == 0)
            continue;

        int start = matrix_iter->get_block_start();
        int end = matrix_iter->get_block_end();
        if (matrices.transmat_switch || start == trees->start_coord) {
            // don't allow new recomb at start if we are switching blocks
            start++;
        }

        //int start = end + 1;  // don't allow new recomb at start
        //end = start - 1 + matrices.blocklen;

        // loop through positions in block
        for (int i=start; i<end; i++) {

            if (thread_path[i] == thread_path[i-1]) {
                // no change in state, recombination is optional

                if (i > next_recomb) {
                    // sample the next recomb pos
                    int last_state = thread_path[i-1];
                    TransMatrix *m = matrices.transmat;
                    int a = states[last_state].time;
                    double self_trans = m->get(
                        tree, states, last_state, last_state);
                    double rate = 1.0 - (m->norecombs[a] / self_trans);

                    // NOTE: the min prevents large floats from overflowing
                    // when cast to int
                    next_recomb = int(min(double(end), i + expovariate(rate)));
                }

                if (i < next_recomb)
                    continue;
            }


            // sample recombination
            next_recomb = -1;
            State state = states[thread_path[i]];
            State last_state = states[thread_path[i-1]];

            // there must be a recombination
            // either because state changed or we choose to recombine
            // find candidates
            candidates.clear();
            int end_time = min(state.time, last_state.time);
            if (state.node == last_state.node) {
                // y = v, k in [0, min(timei, last_timei)]
                // y = node, k in Sr(node)
                for (int k=tree->nodes[state.node].age; k<=end_time; k++)
                    candidates.push_back(NodePoint(state.node, k));
            }

            if (internal) {
                const int subtree_root = tree->nodes[tree->root].child[0];
                const int subtree_root_age = tree->nodes[subtree_root].age;
                for (int k=subtree_root_age; k<=end_time; k++)
                    candidates.push_back(NodePoint(subtree_root, k));
            } else {
                for (int k=0; k<=end_time; k++)
                    candidates.push_back(NodePoint(new_node, k));
            }

            // compute probability of each candidate
            probs.clear();
            for (vector<NodePoint>::iterator it=candidates.begin();
                 it != candidates.end(); ++it) {
                probs.push_back(recomb_prob_unnormalized(
                    model, tree, lineages, last_state, state, *it));
            }

            // sample recombination
            recomb_pos.push_back(i);
            recombs.push_back(candidates[sample(&probs[0], probs.size())]);

            assert(recombs[recombs.size()-1].time <= min(state.time,
                                                         last_state.time));
        }
    }
}
Example #6
0
// sample the thread of the last chromosome, conditioned on a given
// start and end state
void cond_sample_arg_thread_internal(
    const ArgModel *model, const Sequences *sequences, LocalTrees *trees,
    const State start_state, const State end_state)
{
    // allocate temp variables
    ArgHmmForwardTable forward(trees->start_coord, trees->length());
    States states;
    double **fw = forward.get_table();
    int *thread_path_alloc = new int [trees->length()];
    int *thread_path = &thread_path_alloc[-trees->start_coord];
    const bool internal = true;
    bool prior_given = true;
    bool last_state_given = true;

    // build matrices
    ArgHmmMatrixIter matrix_iter(model, sequences, trees);
    matrix_iter.set_internal(internal);

    // fill in first column of forward table
    matrix_iter.begin();
    matrix_iter.get_coal_states(states);
    forward.new_block(matrix_iter.get_block_start(),
                      matrix_iter.get_block_end(), states.size());

    if (states.size() > 0) {
        if (!start_state.is_null()) {
            // find start state
            int j = find_vector(states, start_state);
            assert(j != -1);
            double *col = fw[trees->start_coord];
            fill(col, col + states.size(), 0.0);
            col[j] = 1.0;
        } else {
            // open ended, sample start state
            prior_given = false;
        }
    } else {
        // fully specified tree
        fw[trees->start_coord][0] = 1.0;
    }

    // compute forward table
    Timer time;
    arghmm_forward_alg(trees, model, sequences, &matrix_iter, &forward, NULL,
                       prior_given, internal);
    int nstates = get_num_coal_states_internal(
        trees->front().tree, model->ntimes);
    printTimerLog(time, LOG_LOW,
                  "forward (%3d states, %6d blocks):",
                  nstates, trees->get_num_trees());

    // fill in last state of traceback
    matrix_iter.rbegin();
    matrix_iter.get_coal_states(states);
    if (states.size() > 0) {
        if (!end_state.is_null()) {
            thread_path[trees->end_coord-1] = find_vector(states, end_state);
            assert(thread_path[trees->end_coord-1] != -1);
        } else {
            // sample end start
            last_state_given = false;
        }
    } else {
        // fully specified tree
        thread_path[trees->end_coord-1] = 0;
    }

    // traceback
    time.start();
    ArgHmmMatrixIter matrix_iter2(model, NULL, trees);
    matrix_iter2.set_internal(internal);
    stochastic_traceback(trees, model, &matrix_iter2, fw, thread_path,
                         last_state_given, internal);
    printTimerLog(time, LOG_LOW,
                  "trace:                              ");
    if (!start_state.is_null())
        assert(fw[trees->start_coord][thread_path[trees->start_coord]] == 1.0);

    // sample recombination points
    time.start();
    vector<int> recomb_pos;
    vector<NodePoint> recombs;
    sample_recombinations(trees, model, &matrix_iter2,
                          thread_path, recomb_pos, recombs, internal);

    // add thread to ARG
    add_arg_thread_path(trees, matrix_iter.states_model,
                        model->ntimes, thread_path,
                        recomb_pos, recombs);
    printTimerLog(time, LOG_LOW,
                  "add thread:                         ");

    // clean up
    delete [] thread_path_alloc;
}
Example #7
0
// sample the thread of the last chromosome, conditioned on a given
// start and end state
void cond_sample_arg_thread(const ArgModel *model, const Sequences *sequences,
                            LocalTrees *trees, int new_chrom,
                            State start_state, State end_state)
{
    // allocate temp variables
    ArgHmmForwardTable forward(trees->start_coord, trees->length());
    States states;
    double **fw = forward.get_table();
    int *thread_path_alloc = new int [trees->length()];
    int *thread_path = &thread_path_alloc[-trees->start_coord];

    // build matrices
    Timer time;
    ArgHmmMatrixList matrix_list(model, sequences, trees, new_chrom);
    matrix_list.setup();
    printf("matrix calc: %e s\n", time.time());

    // fill in first column of forward table
    matrix_list.begin();
    matrix_list.get_coal_states(states);
    forward.new_block(matrix_list.get_block_start(),
                      matrix_list.get_block_end(), states.size());
    int j = find_vector(states, start_state);
    assert(j != -1);
    double *col = fw[trees->start_coord];
    fill(col, col + states.size(), 0.0);
    col[j] = 1.0;

    // compute forward table
    time.start();
    arghmm_forward_alg(trees, model, sequences, &matrix_list, &forward, NULL, 
		       true);
    int nstates = get_num_coal_states(trees->front().tree, model->ntimes);
    printf("forward:     %e s  (%d states, %d blocks)\n", time.time(),
           nstates, trees->get_num_trees());

    // fill in last state of traceback
    matrix_list.rbegin();
    matrix_list.get_coal_states(states);
    thread_path[trees->end_coord-1] = find_vector(states, end_state);
    assert(thread_path[trees->end_coord-1] != -1);

    // traceback
    time.start();
    stochastic_traceback(trees, model, &matrix_list, fw, thread_path, true);
    printf("trace:       %e s\n", time.time());
    assert(fw[trees->start_coord][thread_path[trees->start_coord]] == 1.0);


    // sample recombination points
    time.start();
    vector<int> recomb_pos;
    vector<NodePoint> recombs;
    sample_recombinations(trees, model, &matrix_list,
                          thread_path, recomb_pos, recombs);

    // add thread to ARG
    add_arg_thread(trees, matrix_list.states_model,
                   model->ntimes, thread_path, new_chrom,
                   recomb_pos, recombs);

    printf("add thread:  %e s\n", time.time());

    // clean up
    delete [] thread_path_alloc;
}
Example #8
0
// compute one block of forward algorithm with compressed transition matrices
// NOTE: first column of forward table should be pre-populated
void arghmm_forward_block(const LocalTree *tree, const int ntimes,
                          const int blocklen, const States &states,
                          const LineageCounts &lineages,
                          const TransMatrix *matrix,
                          const double* const *emit, double **fw)
{
    const int nstates = states.size();
    const LocalNode *nodes = tree->nodes;

    //  handle internal branch resampling special cases
    int minage = matrix->minage;
    int maintree_root = 0;
    if (matrix->internal) {
        maintree_root = nodes[tree->root].child[1];

        if (nstates == 0) {
            // handle fully given case
            for (int i=1; i<blocklen; i++)
                fw[i][0] = fw[i-1][0];
            return;
        }
    }

    // compute ntimes*ntimes and ntime*nstates temp matrices
    double tmatrix[ntimes][ntimes];
    double tmatrix2[ntimes][nstates];
    for (int a=0; a<ntimes-1; a++) {
        for (int b=0; b<ntimes-1; b++) {
            tmatrix[a][b] = matrix->get_time(a, b, 0, minage, false);
            assert(!isnan(tmatrix[a][b]));
        }

        for (int k=0; k<nstates; k++) {
            const int b = states[k].time;
            const int node2 = states[k].node;
            const int c = nodes[node2].age;
            assert(b >= minage);
            tmatrix2[a][k] = matrix->get_time(a, b, c, minage, true) -
                             matrix->get_time(a, b, 0, minage, false);
        }
    }

    // get max time
    int maxtime = 0;
    for (int k=0; k<nstates; k++)
        if (maxtime < states[k].time)
            maxtime = states[k].time;

    // get branch ages
    NodeStateLookup state_lookup(states, tree->nnodes);
    int ages1[tree->nnodes];
    int ages2[tree->nnodes];
    int indexes[tree->nnodes];
    for (int i=0; i<tree->nnodes; i++) {
        ages1[i] = max(nodes[i].age, minage);
        indexes[i] = state_lookup.lookup(i, ages1[i]);
        if (matrix->internal)
            ages2[i] = (i == maintree_root || i == tree->root) ?
                maxtime : nodes[nodes[i].parent].age;
        else
            ages2[i] = (i == tree->root) ? maxtime : nodes[nodes[i].parent].age;
    }


    double tmatrix_fgroups[ntimes];
    double fgroups[ntimes];
    for (int i=1; i<blocklen; i++) {
        const double *col1 = fw[i-1];
        double *col2 = fw[i];
        const double *emit2 = emit[i];

        // precompute the fgroup sums
        fill(fgroups, fgroups+ntimes, 0.0);
        for (int j=0; j<nstates; j++) {
            const int a = states[j].time;
            fgroups[a] += col1[j];
        }

        // multiply tmatrix and fgroups together
        for (int b=0; b<ntimes-1; b++) {
            double sum = 0.0;
            for (int a=0; a<ntimes-1; a++)
                sum += tmatrix[a][b] * fgroups[a];
            tmatrix_fgroups[b] = sum;
        }

        // fill in one column of forward table
        double norm = 0.0;
        for (int k=0; k<nstates; k++) {
            const int b = states[k].time;
            const int node2 = states[k].node;
            const int age1 = ages1[node2];
            const int age2 = ages2[node2];

            assert(!isnan(col1[k]));

            // same branch case
            double sum = tmatrix_fgroups[b];
            const int j1 = indexes[node2];
            for (int j=j1, a=age1; a<=age2; j++, a++)
                sum += tmatrix2[a][k] * col1[j];

            col2[k] = sum * emit2[k];
            norm += col2[k];
        }

        // normalize column for numerical stability
        for (int k=0; k<nstates; k++)
            col2[k] /= norm;
    }
}
Example #9
0
 int size() {
     return states.size();
 }