Exemplo n.º 1
0
Factor Inference::FactorMarginalization ( Factor A, vec ele)
{
    //function B = FactorMarginalization(A,V)
    // this function is only for hmm, ele has only one var
    // var......
    Factor B;
    // eliminate te first var
    B.var << A.var(1);
    B.card << A.card(0);
    mat Val = A.val.t();
    Val = reshape(Val,B.card(0),B.card(0));

    // eliminate the second var
    if (ele(0) == A.var(1))
    {
        Val = Val.t();
        B.var << A.var(0);
    }

    //B.val = log(sum(exp(bsxfun(@minus, Val, max(Val)))))+max(Val);
    B.val = zeros<vec>(Val.n_cols);
    for (int step = 0; step != Val.n_cols; step ++)
    {
        B.val(step) = Utils::logsumexp( Val.col(step) );
    }

    return B;
}
Exemplo n.º 2
0
CliqueTree Inference::CreateCliqueTreeHMMFast( vector<Factor> factors )
{

    int maxVar = 0;
	// get the max var id
	for ( vector<Factor>::iterator iter = factors.begin();
		iter != factors.end(); iter ++ )
	{
		double max_now = iter->var.max();
		if ( max_now > maxVar )
		{
			maxVar = max_now;
		}
	}
    int numNodes = maxVar - 1;
    int card = factors[0].card(0);

    CliqueTree P(numNodes);
    //P.cliqueList = repmat(struct('var', [], 'card', [], 'val', []), numNodes, 1);
    //P.edges = zeros(numNodes);

	for ( int i = 0; i != numNodes; i++ )
	{
        P.cliqueList[i].var << i+1 << i+2;
        P.cliqueList[i].card << card << card;
        P.cliqueList[i].val = ones<vec>(card * card);
	    
        if (i > 0)
	    {
	        P.edges(i, i-1) = 1;
	        P.edges(i-1, i) = 1;
        }
	}

    // the name of the variable starts from 1 !!!!
    for ( int i = 0; i != factors.size(); i ++ )
	{
        Factor f = factors[i];
        int cliqueIdx = 0;
        if (f.var.n_rows == 1)
	    {
            if (f.var(0) > 1)
	        {
                cliqueIdx = f.var(0) - 1;
	        }else{
                cliqueIdx = 1;
	        }
	        
            vec updateIdxs;
            mat assignments = zeros<mat>(card, 2);
            vec cards;
            cards << card << card;
	        for ( int assignment = 0; assignment != card; assignment ++ )
	        {
               if (f.var(0) == cliqueIdx)
	           {
                   assignments.col(0) = linspace<vec>(0, card - 1, card);
                   assignments.col(1) = assignment*ones<vec>(card);
                   updateIdxs = Utils::AssignmentToIndex(assignments, cards);
	           }else{
                   assignments.col(1) = linspace<vec>(0, card - 1, card);
                   assignments.col(0) = assignment*ones<vec>(card);
                   updateIdxs = Utils::AssignmentToIndex(assignments, cards);
               }
               for (int step = 0; step != updateIdxs.n_rows; step ++)
               {
                   P.cliqueList[cliqueIdx - 1].val(updateIdxs(step)) += f.val(step);
               }
	        }        
	    }else{
            if ( f.var.n_rows != 2 )
            {
                std::cout << "ERROR: var more than 2!" << std::endl;
            }
            cliqueIdx = min(f.var);
            if (f.var(0) > f.var(1))
	        {
	            // % sort the factor val so it's in increasing var order
                uvec order = sort_index(f.var); //%#ok
                mat oldAssignments = Utils::IndexToAssignment(linspace<vec>(0,f.val.n_rows-1,f.val.n_rows), f.card);
                mat newAssignments = oldAssignments;
                // (:, order)
                for (int step = 0; step != oldAssignments.n_cols; step ++)
                {
                    newAssignments.col(step) = oldAssignments.col(order(step));
                }
                vec new_card = f.card;
                for (int step = 0; step != f.card.n_rows; step ++)
                {
                    new_card(step) = f.card(order(step));
                }
                f.card = new_card;
                f.var = sort(f.var);
                vec new_index = Utils::AssignmentToIndex(newAssignments, f.card);
                vec new_val = f.val;
                for ( int step = 0; step != f.val.n_rows; step ++ )
                {
                    new_val(step) = f.val( new_index(step) );
                }
                f.val = new_val;
            }

            P.cliqueList[cliqueIdx - 1].val += f.val;
	    }
	}
    return P;
}