예제 #1
0
파일: ratefree.cpp 프로젝트: bqminh/IQ-TREE
double RateFree::optimizeWithEM() {
    size_t ptn, c;
    size_t nptn = phylo_tree->aln->getNPattern();
    size_t nmix = ncategory;
    const double MIN_PROP = 1e-4;
    
//    double *lk_ptn = aligned_alloc<double>(nptn);
    double *new_prop = aligned_alloc<double>(nmix);
    PhyloTree *tree = new PhyloTree;

    // attach memory to save space
//    tree->central_partial_lh = phylo_tree->central_partial_lh;
//    tree->central_scale_num = phylo_tree->central_scale_num;
//    tree->central_partial_pars = phylo_tree->central_partial_pars;

    tree->copyPhyloTree(phylo_tree);
    tree->optimize_by_newton = phylo_tree->optimize_by_newton;
    tree->setParams(phylo_tree->params);
    tree->setLikelihoodKernel(phylo_tree->sse);
    tree->setNumThreads(phylo_tree->num_threads);

    // initialize model
    ModelFactory *model_fac = new ModelFactory();
    model_fac->joint_optimize = phylo_tree->params->optimize_model_rate_joint;
//    model_fac->unobserved_ptns = phylo_tree->getModelFactory()->unobserved_ptns;

    RateHeterogeneity *site_rate = new RateHeterogeneity; 
    tree->setRate(site_rate);
    site_rate->setTree(tree);
            
    model_fac->site_rate = site_rate;
    tree->model_factory = model_fac;
    tree->setParams(phylo_tree->params);
    double old_score = 0.0;
    // EM algorithm loop described in Wang, Li, Susko, and Roger (2008)
    for (int step = 0; step < ncategory; step++) {
        // first compute _pattern_lh_cat
        double score;
        score = phylo_tree->computePatternLhCat(WSL_RATECAT);
        if (score > 0.0) {
            phylo_tree->printTree(cout, WT_BR_LEN+WT_NEWLINE);
            writeInfo(cout);
        }
        ASSERT(score < 0);
        
        if (step > 0) {
            if (score <= old_score-0.1) {
                phylo_tree->printTree(cout, WT_BR_LEN+WT_NEWLINE);
                writeInfo(cout);
                cout << "Partition " << phylo_tree->aln->name << endl;
                cout << "score: " << score << "  old_score: " << old_score << endl;
            }
            ASSERT(score > old_score-0.1);
        }
            
        old_score = score;
        
        memset(new_prop, 0, nmix*sizeof(double));
                
        // E-step
        // decoupled weights (prop) from _pattern_lh_cat to obtain L_ci and compute pattern likelihood L_i
        for (ptn = 0; ptn < nptn; ptn++) {
            double *this_lk_cat = phylo_tree->_pattern_lh_cat + ptn*nmix;
            double lk_ptn = phylo_tree->ptn_invar[ptn];
            for (c = 0; c < nmix; c++) {
                lk_ptn += this_lk_cat[c];
            }
            ASSERT(lk_ptn != 0.0);
            lk_ptn = phylo_tree->ptn_freq[ptn] / lk_ptn;
            
            // transform _pattern_lh_cat into posterior probabilities of each category
            for (c = 0; c < nmix; c++) {
                this_lk_cat[c] *= lk_ptn;
                new_prop[c] += this_lk_cat[c];
            }
            
        } 
        
        // M-step, update weights according to (*)
        int maxpropid = 0;
        double new_pinvar = 0.0;    
        for (c = 0; c < nmix; c++) {
            new_prop[c] = new_prop[c] / phylo_tree->getAlnNSite();
            if (new_prop[c] > new_prop[maxpropid])
                maxpropid = c;
        }
        // regularize prop
        bool zero_prop = false;
        for (c = 0; c < nmix; c++) {
            if (new_prop[c] < MIN_PROP) {
                new_prop[maxpropid] -= (MIN_PROP - new_prop[c]);
                new_prop[c] = MIN_PROP;
                zero_prop = true;
            }
        }
        // break if some probabilities too small
        if (zero_prop) break;

        bool converged = true;
        double sum_prop = 0.0;
        for (c = 0; c < nmix; c++) {
//            new_prop[c] = new_prop[c] / phylo_tree->getAlnNSite();
            // check for convergence
            sum_prop += new_prop[c];
            converged = converged && (fabs(prop[c]-new_prop[c]) < 1e-4);
            prop[c] = new_prop[c];
            new_pinvar += new_prop[c];
        }

        new_pinvar = 1.0 - new_pinvar;

        if (new_pinvar > 1e-4 && getPInvar() != 0.0) {
            converged = converged && (fabs(getPInvar()-new_pinvar) < 1e-4);
            if (isFixPInvar())
                outError("Fixed given p-invar is not supported");
            setPInvar(new_pinvar);
//            setOptimizePInvar(false);
            phylo_tree->computePtnInvar();
        }
        
        ASSERT(fabs(sum_prop+new_pinvar-1.0) < MIN_PROP);
        
        // now optimize rates one by one
        double sum = 0.0;
        for (c = 0; c < nmix; c++) {
            tree->copyPhyloTree(phylo_tree);
            ModelMarkov *subst_model;
            if (phylo_tree->getModel()->isMixture() && phylo_tree->getModelFactory()->fused_mix_rate)
                subst_model = (ModelMarkov*)phylo_tree->getModel()->getMixtureClass(c);
            else
                subst_model = (ModelMarkov*)phylo_tree->getModel();
            tree->setModel(subst_model);
            subst_model->setTree(tree);
            model_fac->model = subst_model;
            if (subst_model->isMixture() || subst_model->isSiteSpecificModel() || !subst_model->isReversible())
                tree->setLikelihoodKernel(phylo_tree->sse);

                        
            // initialize likelihood
            tree->initializeAllPartialLh();
            // copy posterior probability into ptn_freq
            tree->computePtnFreq();
            double *this_lk_cat = phylo_tree->_pattern_lh_cat+c;
            for (ptn = 0; ptn < nptn; ptn++)
                tree->ptn_freq[ptn] = this_lk_cat[ptn*nmix];
            double scaling = rates[c];
            tree->scaleLength(scaling);
            tree->optimizeTreeLengthScaling(MIN_PROP, scaling, 1.0/prop[c], 0.001);
            converged = converged && (fabs(rates[c] - scaling) < 1e-4);
            rates[c] = scaling;
            sum += prop[c] * rates[c];
            // reset subst model
            tree->setModel(NULL);
            subst_model->setTree(phylo_tree);
            
        }
        
        phylo_tree->clearAllPartialLH();
        if (converged) break;
    }
    
        // sort the rates in increasing order
    if (sorted_rates)
        quicksort(rates, 0, ncategory-1, prop);
    
    // deattach memory
//    tree->central_partial_lh = NULL;
//    tree->central_scale_num = NULL;
//    tree->central_partial_pars = NULL;

    delete tree;
    aligned_free(new_prop);
    return phylo_tree->computeLikelihood();
}