Beispiel #1
0
mfloat_t CKroneckerLMM::nLLeval(mfloat_t ldelta, const MatrixXdVec& A,const MatrixXdVec& X, const MatrixXd& Y, const VectorXd& S_C1, const VectorXd& S_R1, const VectorXd& S_C2, const VectorXd& S_R2)
{
//#define debugll
	muint_t R = (muint_t)Y.rows();
	muint_t C = (muint_t)Y.cols();
	assert(A.size() == X.size());
	assert(R == (muint_t)S_R1.rows());
	assert(C == (muint_t)S_C1.rows());
	assert(R == (muint_t)S_R2.rows());
	assert(C == (muint_t)S_C2.rows());
	muint_t nWeights = 0;
	for(muint_t term = 0; term < A.size();++term)
	{
		assert((muint_t)(X[term].rows())==R);
		assert((muint_t)(A[term].cols())==C);
		nWeights+=(muint_t)(A[term].rows()) * (muint_t)(X[term].cols());
	}
	mfloat_t delta = exp(ldelta);
	mfloat_t ldet = 0.0;//R * C * ldelta;

	//build D and compute the logDet of D
	MatrixXd D = MatrixXd(R,C);
	for (muint_t r=0; r<R;++r)
	{
		if(S_R2(r)>1e-10)
		{
			ldet += (mfloat_t)C * log(S_R2(r));//ldet
		}
		else
		{
			std::cout << "S_R2(" << r << ")="<< S_R2(r)<<"\n";
		}
	}
#ifdef debugll
	std::cout << ldet;
	std::cout << "\n";
#endif
	for (muint_t c=0; c<C;++c)
	{
		if(S_C2(c)>1e-10)
		{
			ldet += (mfloat_t)R * log(S_C2(c));//ldet
		}
		else
		{
			std::cout << "S_C2(" << c << ")="<< S_C2(c)<<"\n";
		}
	}
#ifdef debugll
	std::cout << ldet;
	std::cout << "\n";
#endif
	for (muint_t r=0; r<R;++r)
	{
		for (muint_t c=0; c<C;++c)
		{
			mfloat_t SSd = S_R1.data()[r]*S_C1.data()[c] + delta;
			ldet+=log(SSd);
			D(r,c) = 1.0/SSd;
		}
	}
#ifdef debugll
	std::cout << ldet;
	std::cout << "\n";
#endif
	MatrixXd DY = Y.array() * D.array();

	VectorXd XYA = VectorXd(nWeights);

	muint_t cumSumR = 0;

	MatrixXd covW = MatrixXd(nWeights,nWeights);
	for(muint_t termR = 0; termR < A.size();++termR){
		muint_t nW_AR = A[termR].rows();
		muint_t nW_XR = X[termR].cols();
		muint_t rowsBlock = nW_AR * nW_XR;
		MatrixXd XYAblock = X[termR].transpose() * DY * A[termR].transpose();
		XYAblock.resize(rowsBlock,1);
		XYA.block(cumSumR,0,rowsBlock,1) = XYAblock;

		muint_t cumSumC = 0;

		for(muint_t termC = 0; termC < A.size(); ++termC){
			muint_t nW_AC = A[termC].rows();
			muint_t nW_XC = X[termC].cols();
			muint_t colsBlock = nW_AC * nW_XC;
			MatrixXd block = MatrixXd::Zero(rowsBlock,colsBlock);
			if (R<C)
			{
				for(muint_t r=0; r<R; ++r)
				{
					MatrixXd AD = A[termR];
					AD.array().rowwise() *= D.row(r).array();
					MatrixXd AA = AD * A[termC].transpose();
					//sum up col matrices
					MatrixXd XX = X[termR].row(r).transpose() * X[termC].row(r);
					akron(block,AA,XX,true);
				}
			}
			else
			{//sum up col matrices
				for(muint_t c=0; c<C; ++c)
				{
					MatrixXd XD = X[termR];
					XD.array().colwise() *= D.col(c).array();
					MatrixXd XX = XD.transpose() * X[termC];
					//sum up col matrices
					MatrixXd AA = A[termR].col(c) * A[termC].col(c).transpose();
					akron(block,AA,XX,true);
				}
			}
			covW.block(cumSumR, cumSumC, rowsBlock, colsBlock) = block;
			cumSumC+=colsBlock;
		}
		cumSumR+=rowsBlock;
	}
	//std::cout << "covW = " << covW<<std::endl;
	MatrixXd W_vec = covW.colPivHouseholderQr().solve(XYA);
	//MatrixXd W_vec = covW * XYA;
	//std::cout << "W = " << W_vec<<std::endl;
	//std::cout << "XYA = " << XYA<<std::endl;

	mfloat_t res = (Y.array()*DY.array()).sum();
	mfloat_t varPred = (W_vec.array() * XYA.array()).sum();
	res-= varPred;

	mfloat_t sigma = res/(mfloat_t)(R*C);

	mfloat_t nLL = 0.5 * ( R * C * (L2pi + log(sigma) + 1.0) + ldet);
#ifdef returnW
	covW = covW.inverse();
	//std::cout << "covW.inverse() = " << covW<<std::endl;

	muint_t cumSum = 0;
	VectorXd F_vec = W_vec.array() * W_vec.array() /covW.diagonal().array() / sigma;
	for(muint_t term = 0; term < A.size();++term)
	{
		muint_t currSize = X[term].cols() * A[term].rows();
		//W[term] = MatrixXd(X[term].cols(),A[term].rows());
		W[term] = W_vec.block(cumSum,0,currSize,1);//
		W[term].resize(X[term].cols(),A[term].rows());
		//F_tests[term] = MatrixXd(X[term].cols(),A[term].rows());
		F_tests[term] = F_vec.block(cumSum,0,currSize,1);//
		F_tests[term].resize(X[term].cols(),A[term].rows());
		cumSum+=currSize;
	}
#endif
	return nLL;
}
Beispiel #2
0
extendTreeR_t ExtendTree_NHC_Sort_OnlyGS_TermCond_Heading(MatrixXd tree, VectorXd vecEndNode, worldLine_t world, setting_t P,
				double dKmax, double c4, double dMax, double dMaxBeta, sigma_t sigma, int iMaxIte, int iINDC)
{
	int iFlag, kk, iTemp, n, idx, flagRet;
	double p, r, theta, Sx, Sy;

	Vector3d vec3RandomPoint;
	VectorXd vecTempDiag, vecIdx, vecP1, vecP2, vecP3, vecNewNode;
	MatrixXd matTemp, matTempSq, matWP, newTree, matdKappa;
	VectorXd vecBeta;

	extendTreeR_t funcReturn;

	iFlag = 0;

	while (iFlag==0) {
		// select a biased random point
		p=Uniform01();

		if ( (iINDC==0 && p<0.1) || (iINDC==1 && p<0.05) ) {
			vec3RandomPoint << vecEndNode(0), vecEndNode(1), vecEndNode(2);
		} else {
			r = sigma.r*Uniform01();
			theta = sigma.theta0 + sigma.theta*quasi_normal_random();

			Sx = sigma.sp(0) + r*cos(theta);
			Sy = sigma.sp(1) + r*sin(theta);
			vec3RandomPoint << Sx, Sy, vecEndNode(2);
		}
		std::cout << "Random Point : " << vec3RandomPoint(0) << " " << vec3RandomPoint(1) << " " << vec3RandomPoint(2) << std::endl;

		// Find node that is closest to random point
		matTemp.resize(tree.rows(),3);

		for (iTemp=0; iTemp<tree.rows(); iTemp++)	{
			matTemp(iTemp,0) = tree(iTemp,0) - vec3RandomPoint(0);
			matTemp(iTemp,1) = tree(iTemp,1) - vec3RandomPoint(1);
			matTemp(iTemp,2) = tree(iTemp,2) - vec3RandomPoint(2);
		}
		matTempSq = matTemp*matTemp.transpose();

		vecTempDiag.resize(matTemp.rows());
		for (iTemp=0; iTemp<matTemp.rows(); iTemp++)
			vecTempDiag(iTemp) = matTempSq(iTemp, iTemp);

		SortVec(vecTempDiag, vecIdx); // vecTempDiag : sorted vector, vecIdx : index of vecTempDiag

		// Modification 3
		if ( vecIdx.rows() > iMaxIte )
			n = iMaxIte;
		else
			n = vecIdx.rows();

		/// Nonholonomic Length Decision
		kk=-1;

		// Modification 4
		if (tree.rows() == 2) {
			vecP1.resize(3); vecP2.resize(3); vecP3.resize(3);
			vecP1 << tree(0,0), tree(0,1), tree(0,2);
			vecP2 << tree(1,0), tree(1,1), tree(1,2);
			vecP3 = vec3RandomPoint;

			matWP.resize(3,3);
			matWP.row(0) = vecP1.segment(0,3).transpose();
			matWP.row(1) = vecP2.segment(0,3).transpose();
			matWP.row(2) = vecP3.segment(0,3).transpose();
			Kappa_Max_Calculation(matWP, P, matdKappa, vecBeta);

			if ( (vecBeta.array().abs()<= dMaxBeta).all() )
				// Method 2 : use the maximum length margin - when there is straight line, there is more margin1
				P.L = 2*dMax;
			else if ( (vecBeta.array().abs() > dMaxBeta).all() ) {
				funcReturn.flag = 0;
				funcReturn.newTree = tree;
				funcReturn.INDC = iINDC;
				funcReturn.sigma = sigma;
				funcReturn.maxIte = iMaxIte;

				return funcReturn;
			}
		} else if (tree.rows() > 2)	{
			for (iTemp=0; iTemp<n; iTemp++)	{
				kk = iTemp;

				if ( tree(vecIdx(iTemp), tree.cols()-1) == 0 )	{
					funcReturn.flag = 0;
					funcReturn.newTree = tree;
					funcReturn.INDC = iINDC;
					funcReturn.sigma = sigma;
					funcReturn.maxIte = iMaxIte;

					return funcReturn;
				}

				vecP2.resize(tree.cols());
				vecP2 = tree.row(vecIdx(iTemp)).transpose();
				vecP1.resize(tree.cols());
				vecP1 = tree.row( vecP2(vecP2.rows()-1)-1 ).transpose();
				vecP3.resize(vec3RandomPoint.rows());
				vecP3 = vec3RandomPoint;

				matWP.resize(3,3);

				matWP.row(0) = vecP1.segment(0,3).transpose();
				matWP.row(1) = vecP2.segment(0,3).transpose();
				matWP.row(2) = vecP3.segment(0,3).transpose();

				Kappa_Max_Calculation(matWP, P, matdKappa, vecBeta);

				if ( (vecBeta.array().abs()<= dMaxBeta).all() )
				{
					// Method 2 : use the maximum length margin - when there is straight line, there is more margin1
					P.L = 2*dMax;
					break;
				} else if ( (vecBeta.array().abs() > dMaxBeta).all() && (iTemp == (n-1) ) ) 	{
					funcReturn.flag = 0;
					funcReturn.newTree = tree;
					funcReturn.INDC = iINDC;
					funcReturn.sigma = sigma;
					funcReturn.maxIte = iMaxIte;

					return funcReturn;
				}
			}
		} //if (tree.rows() == 2)

		if ( kk == -1)
			idx = vecIdx(0);
		else
			idx = vecIdx(kk);

		double cost = tree(idx,4) + P.L;
		Vector3d vecNewPoint = vec3RandomPoint - tree.block(idx,0,1,3).transpose();
		vecNewPoint = tree.block(idx,0,1,3).transpose()+vecNewPoint/vecNewPoint.norm()*P.L;

		if ( kk == -1)
		{
			vecNewNode.resize(vecNewPoint.size()+5);
			vecNewNode << vecNewPoint, 0, cost, P.L, 0, idx+1;
		}
		else
		{
			vecNewNode.resize(vecNewPoint.size()+vecBeta.size()+4);
			vecNewNode << vecNewPoint, 0, cost, P.L, vecBeta.array().abs(), idx+1;
		}

		// check to see if obstacle or random point is reached
		if ( Collision(vecNewNode, tree.row(idx), world, P.SZR, P.SZH, 0) == 0 )
		{
			newTree.resize( tree.rows()+1,tree.cols() );
			newTree.block(0, 0, tree.rows(), tree.cols()) = tree;
			newTree.row(tree.rows()) = vecNewNode.transpose();

			iFlag = 1;
		}
	} // while (iFlag==0)

	// check to see if new node connects directly to end node
	VectorXd vecDiff(3), vecSeg1(3), vecSeg2(3);

	vecSeg1 = vecNewNode.segment(0,3);
	vecSeg2 = vecEndNode.segment(0,3);
	vecDiff = vecSeg1 - vecSeg2;
	double dTerm = vecDiff.norm();

	double psi;
	// 7*dMax
	if ( (dTerm <= 7*dMax) && (iINDC==0) )
	{
		iINDC = 1;

		psi = atan2( vecEndNode(1)-vecNewNode(1), vecEndNode(0)-vecNewNode(0) );
		psi = psi - 2*M_PI*(psi>M_PI);

		sigma.r = dTerm;
		sigma.theta0 = psi;
		sigma.theta = 1.0*M_PI;
		sigma.sp = vecNewNode.segment(0,3);

		iMaxIte = 5;
	}

	if ( iINDC==1 )
	{
		// Method 2
		int end;

		vecP2.resize(vecNewNode.rows());
		vecP2 = vecNewNode;

		end = vecP2.rows();

		if ( vecP2(end-1)==0 )
		{
			vecP1.resize(tree.rows());
			vecP1 = tree.row(0).transpose();
		}
		else
		{
			vecP1.resize(tree.rows());
			vecP1 = tree.row(vecP2(end-1)-1).transpose();
		}

		vecP3.resize(3);
		vecP3 = vecEndNode.segment(0,3);

		matWP.resize(3,3);
		matWP.row(0) = vecP1.segment(0,3);
		matWP.row(1) = vecP2.segment(0,3);
		matWP.row(2) = vecP3.segment(0,3);

		double d_rm;

		Kappa_Max_Calculation(matWP, P, matdKappa, vecBeta);

		d_rm = ( c4*sin(vecP2(end-2)) ) / ( dKmax*cos(vecP2(end-2))*cos(vecP2(end-2)) );

		if ( (2*dMax-d_rm >= 1.0*matdKappa(0,0)) &&
			(Collision(vecNewNode, vecEndNode, world, P.SZR, P.SZH, 0) == 0) &&
			(matdKappa(0,0) <= dTerm) )
		{
			flagRet = 1;
			newTree(newTree.rows()-1,3) = 1;
		}
		else
			flagRet = 0;
	} else
		flagRet = 0;
	// if ( iINDC==1 )

	funcReturn.flag = flagRet;
	funcReturn.newTree = newTree;
	funcReturn.INDC = iINDC;
	funcReturn.sigma = sigma;
	funcReturn.maxIte = iMaxIte;

	return funcReturn;
}
Beispiel #3
0
void CTRPBPInference::infer(CGraph &graph,
                          map<size_t,VectorXd> &nodeBeliefs,
                          map<size_t,MatrixXd> &edgeBeliefs,
                          double &logZ)
{
    //
    //  Algorithm workflow:
    //  1. Compute the messages passed
    //  2. Compute node beliefs
    //  3. Compute edge beliefs
    //  4. Compute logZ
    //

    nodeBeliefs.clear();
    edgeBeliefs.clear();

    const vector<CNodePtr> nodes = graph.getNodes();
    const vector<CEdgePtr> edges = graph.getEdges();
    multimap<size_t,CEdgePtr> edges_f = graph.getEdgesF();

    size_t N_nodes = nodes.size();
    size_t N_edges = edges.size();

    //
    // 1. Create spanning trees
    //

    bool allNodesAdded = false;
    vector<vector<size_t > > v_trees;
    vector<bool> v_addedNodes(N_nodes,false);
    map<size_t,size_t> addedNodesMap;

    for (size_t i = 0; i < N_nodes; i++)
        addedNodesMap[ nodes[i]->getID() ] = i;

    while (!allNodesAdded)
    {
        allNodesAdded = true;

        vector<size_t> tree;
        getSpanningTree( graph, tree );

        // Check that the tree is not empty
        if ( tree.size() )
            v_trees.push_back( tree );

        cout << "Tree: ";

        for ( size_t i_node = 0; i_node < tree.size(); i_node++ )
        {
            v_addedNodes[ addedNodesMap[tree[i_node]] ] = true;
            cout << tree[i_node] << " ";
        }

        cout << endl;

        for ( size_t i_node = 0; i_node < N_nodes; i_node++ )
            if ( !v_addedNodes[i_node] )
            {
                allNodesAdded = false;
                break;
            }

    }


    //
    // 1. Compute messages passed in each tree until convergence
    //

    vector<vector<VectorXd> >   messages;
    bool                        maximize = false;

    double totalSumOfMsgs = std::numeric_limits<double>::max();

    size_t iteration;

    for ( iteration = 0; iteration < m_options.maxIterations; iteration++ )
    {

        for ( size_t i_tree=0; i_tree < v_trees.size(); i_tree++ )
            messagesLBP( graph, m_options, messages, maximize, v_trees[i_tree] );

        double newTotalSumOfMsgs = 0;
        for ( size_t i = 0; i < N_edges; i++ )
        {
            newTotalSumOfMsgs += messages[i][0].sum() + messages[i][1].sum();
        }

        if ( std::abs( totalSumOfMsgs - newTotalSumOfMsgs ) <
             m_options.convergency )
            break;

        totalSumOfMsgs = newTotalSumOfMsgs;

    }

    //
    // 2. Compute node beliefs
    //

    for ( size_t nodeIndex = 0; nodeIndex < N_nodes; nodeIndex++ )
    {
        const CNodePtr nodePtr = graph.getNode( nodeIndex );
        size_t         nodeID  = nodePtr->getID();
        VectorXd       nodePotPlusIncMsg = nodePtr->getPotentials( m_options.considerNodeFixedValues );

        NEIGHBORS_IT neighbors = edges_f.equal_range(nodeID);

        //
        // Get the messages for all the neighbors, and multiply them with the node potential
        //
        for ( multimap<size_t,CEdgePtr>::iterator itNeigbhor = neighbors.first;
              itNeigbhor != neighbors.second;
              itNeigbhor++ )
        {
            CEdgePtr edgePtr( (*itNeigbhor).second );
            size_t edgeIndex = graph.getEdgeIndex( edgePtr->getID() );

            if ( !edgePtr->getNodePosition( nodeID ) ) // nodeID is the first node in the edge
                nodePotPlusIncMsg = nodePotPlusIncMsg.cwiseProduct(messages[ edgeIndex ][ 1 ]);
            else // nodeID is the second node in the dege
                nodePotPlusIncMsg = nodePotPlusIncMsg.cwiseProduct(messages[ edgeIndex ][ 0 ]);
        }

        // Normalize
        nodePotPlusIncMsg = nodePotPlusIncMsg / nodePotPlusIncMsg.sum();

        nodeBeliefs[ nodeID ] = nodePotPlusIncMsg;

        //cout << "Beliefs of node " << nodeIndex << endl << nodePotPlusIncMsg << endl;
    }

    //
    // 3. Compute edge beliefs
    //

    for ( size_t edgeIndex = 0; edgeIndex < N_edges; edgeIndex++ )
    {
        CEdgePtr edgePtr = edges[edgeIndex];
        size_t   edgeID  = edgePtr->getID();

        size_t ID1, ID2;
        edgePtr->getNodesID( ID1, ID2 );

        MatrixXd edgePotentials = edgePtr->getPotentials();
        MatrixXd edgeBelief = edgePotentials;

        VectorXd &message1To2 = messages[edgeIndex][0];
        VectorXd &message2To1 = messages[edgeIndex][1];

        //cout << "----------------------" << endl;
        //cout << nodeBeliefs[ ID1 ] << endl;
        //cout << "----------------------" << endl;
        //cout << message2To1 << endl;

        VectorXd node1Belief = nodeBeliefs[ ID1 ].cwiseQuotient( message2To1 );
        VectorXd node2Belief = nodeBeliefs[ ID2 ].cwiseQuotient( message1To2 );

        //cout << "----------------------" << endl;

        MatrixXd node1BeliefMatrix ( edgePotentials.rows(), edgePotentials.cols() );
        for ( size_t row = 0; row < edgePotentials.rows(); row++ )
            for ( size_t col = 0; col < edgePotentials.cols(); col++ )
                node1BeliefMatrix(row,col) = node1Belief(row);

        //cout << "Node 1 belief matrix: " << endl << node1BeliefMatrix << endl;

        edgeBelief = edgeBelief.cwiseProduct( node1BeliefMatrix );

        MatrixXd node2BeliefMatrix ( edgePotentials.rows(), edgePotentials.cols() );
        for ( size_t row = 0; row < edgePotentials.rows(); row++ )
            for ( size_t col = 0; col < edgePotentials.cols(); col++ )
                node2BeliefMatrix(row,col) = node2Belief(col);

        //cout << "Node 2 belief matrix: " << endl << node2BeliefMatrix << endl;

        edgeBelief = edgeBelief.cwiseProduct( node2BeliefMatrix );

        //cout << "Edge potentials" << endl << edgePotentials << endl;
        //cout << "Edge beliefs" << endl << edgeBelief << endl;

        // Normalize
        edgeBelief = edgeBelief / edgeBelief.sum();



        edgeBeliefs[ edgeID ] = edgeBelief;
    }

    //
    // 4. Compute logZ
    //

    double energyNodes  = 0;
    double energyEdges  = 0;
    double entropyNodes = 0;
    double entropyEdges = 0;

    // Compute energy and entropy from nodes

    for ( size_t nodeIndex = 0; nodeIndex < nodes.size(); nodeIndex++ )
    {
        CNodePtr nodePtr     = nodes[ nodeIndex ];
        size_t   nodeID      = nodePtr->getID();
        size_t   N_Neighbors = graph.getNumberOfNodeNeighbors( nodeID );

        // Useful computations and shorcuts
        VectorXd &nodeBelief        = nodeBeliefs[nodeID];
        VectorXd logNodeBelief      = nodeBeliefs[nodeID].array().log();
        VectorXd nodePotentials    = nodePtr->getPotentials( m_options.considerNodeFixedValues );
        VectorXd logNodePotentials = nodePotentials.array().log();

        // Entropy from the node
        energyNodes += N_Neighbors*( nodeBelief.cwiseProduct( logNodeBelief ).sum() );

        // Energy from the node
        entropyNodes += N_Neighbors*( nodeBelief.cwiseProduct( logNodePotentials ).sum() );
    }

    // Compute energy and entropy from nodes

    for ( size_t edgeIndex = 0; edgeIndex < N_edges; edgeIndex++ )
    {
        CEdgePtr edgePtr = edges[ edgeIndex ];
        size_t   edgeID  = edgePtr->getID();

        // Useful computations and shorcuts
        MatrixXd &edgeBelief       = edgeBeliefs[ edgeID ];
        MatrixXd logEdgeBelief     = edgeBelief.array().log();
        MatrixXd &edgePotentials   = edgePtr->getPotentials();
        MatrixXd logEdgePotentials = edgePotentials.array().log();

        // Entropy from the edge
        energyEdges += edgeBelief.cwiseProduct( logEdgeBelief ).sum();

        // Energy from the edge
        entropyEdges += edgeBelief.cwiseProduct( logEdgePotentials ).sum();

    }

    // Final Bethe free energy

    double BethefreeEnergy = ( energyNodes - energyEdges ) - ( entropyNodes - entropyEdges );

    // Compute logZ

    logZ = - BethefreeEnergy;

}
Beispiel #4
0
void CLBPInference::infer(CGraph &graph,
                          map<size_t,VectorXd> &nodeBeliefs,
                          map<size_t,MatrixXd> &edgeBeliefs,
                          double &logZ)
{
    //
    //  Algorithm workflow:
    //  1. Compute the messages passed
    //  2. Compute node beliefs
    //  3. Compute edge beliefs
    //  4. Compute logZ
    //

    nodeBeliefs.clear();
    edgeBeliefs.clear();

    const vector<CNodePtr> nodes = graph.getNodes();
    const vector<CEdgePtr> edges = graph.getEdges();
    multimap<size_t,CEdgePtr> edges_f = graph.getEdgesF();

    size_t N_nodes = nodes.size();
    size_t N_edges = edges.size();

    //
    // 1. Compute the messages passed
    //

    vector<vector<VectorXd> >   messages;
    bool                        maximize = false;

    messagesLBP( graph, m_options, messages, maximize );

    //
    // 2. Compute node beliefs
    //

    for ( size_t nodeIndex = 0; nodeIndex < N_nodes; nodeIndex++ )
    {
        const CNodePtr nodePtr = graph.getNode( nodeIndex );
        size_t         nodeID  = nodePtr->getID();
        VectorXd       nodePotPlusIncMsg = nodePtr->getPotentials( m_options.considerNodeFixedValues );

        NEIGHBORS_IT neighbors = edges_f.equal_range(nodeID);

        //
        // Get the messages for all the neighbors, and multiply them with the node potential
        //
        for ( multimap<size_t,CEdgePtr>::iterator itNeigbhor = neighbors.first;
              itNeigbhor != neighbors.second;
              itNeigbhor++ )
        {
            CEdgePtr edgePtr( (*itNeigbhor).second );
            size_t edgeIndex = graph.getEdgeIndex( edgePtr->getID() );

            if ( !edgePtr->getNodePosition( nodeID ) ) // nodeID is the first node in the edge
                nodePotPlusIncMsg = nodePotPlusIncMsg.cwiseProduct(messages[ edgeIndex ][ 1 ]);
            else // nodeID is the second node in the dege
                nodePotPlusIncMsg = nodePotPlusIncMsg.cwiseProduct(messages[ edgeIndex ][ 0 ]);
        }

        // Normalize
        nodePotPlusIncMsg = nodePotPlusIncMsg / nodePotPlusIncMsg.sum();

        nodeBeliefs[ nodeID ] = nodePotPlusIncMsg;

        //cout << "Beliefs of node " << nodeIndex << endl << nodePotPlusIncMsg << endl;
    }

    //
    // 3. Compute edge beliefs
    //

    for ( size_t edgeIndex = 0; edgeIndex < N_edges; edgeIndex++ )
    {
        CEdgePtr edgePtr = edges[edgeIndex];
        size_t   edgeID  = edgePtr->getID();

        size_t ID1, ID2;
        edgePtr->getNodesID( ID1, ID2 );

        MatrixXd edgePotentials = edgePtr->getPotentials();
        MatrixXd edgeBelief = edgePotentials;

        VectorXd &message1To2 = messages[edgeIndex][0];
        VectorXd &message2To1 = messages[edgeIndex][1];

        //cout << "----------------------" << endl;
        //cout << nodeBeliefs[ ID1 ] << endl;
        //cout << "----------------------" << endl;
        //cout << message2To1 << endl;

        VectorXd node1Belief = nodeBeliefs[ ID1 ].cwiseQuotient( message2To1 );
        VectorXd node2Belief = nodeBeliefs[ ID2 ].cwiseQuotient( message1To2 );

        //cout << "----------------------" << endl;

        MatrixXd node1BeliefMatrix ( edgePotentials.rows(), edgePotentials.cols() );
        for ( size_t row = 0; row < edgePotentials.rows(); row++ )
            for ( size_t col = 0; col < edgePotentials.cols(); col++ )
                node1BeliefMatrix(row,col) = node1Belief(row);

        //cout << "Node 1 belief matrix: " << endl << node1BeliefMatrix << endl;

        edgeBelief = edgeBelief.cwiseProduct( node1BeliefMatrix );

        MatrixXd node2BeliefMatrix ( edgePotentials.rows(), edgePotentials.cols() );
        for ( size_t row = 0; row < edgePotentials.rows(); row++ )
            for ( size_t col = 0; col < edgePotentials.cols(); col++ )
                node2BeliefMatrix(row,col) = node2Belief(col);

        //cout << "Node 2 belief matrix: " << endl << node2BeliefMatrix << endl;

        edgeBelief = edgeBelief.cwiseProduct( node2BeliefMatrix );

        //cout << "Edge potentials" << endl << edgePotentials << endl;
        //cout << "Edge beliefs" << endl << edgeBelief << endl;

        // Normalize
        edgeBelief = edgeBelief / edgeBelief.sum();



        edgeBeliefs[ edgeID ] = edgeBelief;
    }

    //
    // 4. Compute logZ
    //

    double energyNodes  = 0;
    double energyEdges  = 0;
    double entropyNodes = 0;
    double entropyEdges = 0;

    // Compute energy and entropy from nodes

    for ( size_t nodeIndex = 0; nodeIndex < nodes.size(); nodeIndex++ )
    {
        CNodePtr nodePtr     = nodes[ nodeIndex ];
        size_t   nodeID      = nodePtr->getID();
        size_t   N_Neighbors = graph.getNumberOfNodeNeighbors( nodeID );

        // Useful computations and shorcuts
        VectorXd &nodeBelief        = nodeBeliefs[nodeID];
        VectorXd logNodeBelief      = nodeBeliefs[nodeID].array().log();
        VectorXd nodePotentials    = nodePtr->getPotentials( m_options.considerNodeFixedValues );
        VectorXd logNodePotentials = nodePotentials.array().log();

        // Entropy from the node
        energyNodes += N_Neighbors*( nodeBelief.cwiseProduct( logNodeBelief ).sum() );

        // Energy from the node
        entropyNodes += N_Neighbors*( nodeBelief.cwiseProduct( logNodePotentials ).sum() );
    }

    // Compute energy and entropy from nodes

    for ( size_t edgeIndex = 0; edgeIndex < N_edges; edgeIndex++ )
    {
        CEdgePtr edgePtr = edges[ edgeIndex ];
        size_t   edgeID  = edgePtr->getID();

        // Useful computations and shorcuts
        MatrixXd &edgeBelief       = edgeBeliefs[ edgeID ];
        MatrixXd logEdgeBelief     = edgeBelief.array().log();
        MatrixXd &edgePotentials   = edgePtr->getPotentials();
        MatrixXd logEdgePotentials = edgePotentials.array().log();

        // Entropy from the edge
        energyEdges += edgeBelief.cwiseProduct( logEdgeBelief ).sum();

        // Energy from the edge
        entropyEdges += edgeBelief.cwiseProduct( logEdgePotentials ).sum();

    }

    // Final Bethe free energy

    double BethefreeEnergy = ( energyNodes - energyEdges ) - ( entropyNodes - entropyEdges );

    // Compute logZ

    logZ = - BethefreeEnergy;

}
lbfgsfloatval_t sparseAECost(
    void* netParam,
    const lbfgsfloatval_t *ptheta,
    lbfgsfloatval_t *grad,
    const int n,
    const lbfgsfloatval_t step)
{
    instanceSP* pStruct = (instanceSP*)(netParam);
    int hiddenSize = pStruct->hiddenSize;
    int visibleSize = pStruct->visibleSize;
    double lambda = pStruct->lambda;
    double beta = pStruct->beta;
    double sp = pStruct->sparsityParam;
    MatrixXd& data = pStruct->data;
    double cost = 0;

    MatrixXd w1(hiddenSize, visibleSize);
    MatrixXd w2(visibleSize, hiddenSize);
    VectorXd b1(hiddenSize);
    VectorXd b2(visibleSize);

    for (int i=0; i<hiddenSize*visibleSize; i++)
    {
        *(w1.data()+i) = *ptheta;
        ptheta++;
    }
    for (int i=0; i<visibleSize*hiddenSize; i++)
    {
        *(w2.data()+i) = *ptheta;
        ptheta++;
    }
    for (int i=0; i<hiddenSize; i++)
    {
        *(b1.data()+i) = *ptheta;
        ptheta++;
    }
    for (int i=0; i<visibleSize; i++)
    {
        *(b2.data()+i) = *ptheta;
        ptheta++;
    }

    int ndim = data.rows();
    int ndata = data.cols();

    MatrixXd z2 = w1 * data + b1.replicate(1, ndata);
    MatrixXd a2 = sigmoid(z2);
    MatrixXd z3 = w2 * a2 + b2.replicate(1, ndata);
    MatrixXd a3 = sigmoid(z3);

    VectorXd rho = a2.rowwise().sum() / ndata;
    VectorXd sparsityDelta = -sp / rho.array() + (1 - sp) / (1 - rho.array());

    MatrixXd delta3 = (a3 - data).array() * sigmoidGrad(z3).array();
    MatrixXd delta2 = (w2.transpose() * delta3 + beta * sparsityDelta.replicate(1, ndata)).array() 
                      * sigmoidGrad(z2).array();

    MatrixXd w1Grad = delta2 * data.transpose() / ndata + lambda * w1;
    VectorXd b1Grad = delta2.rowwise().sum() / ndata;
    MatrixXd w2Grad = delta3 * a2.transpose() / ndata + lambda * w2;
    VectorXd b2Grad = delta3.rowwise().sum() / ndata;

    cost = (0.5 * (a3 - data).array().pow(2)).matrix().sum() / ndata
            + 0.5 * lambda * ((w1.array().pow(2)).matrix().sum() 
            + (w2.array().pow(2)).matrix().sum())
            + beta * (sp * (sp / rho.array()).log() 
            + (1 - sp) * ((1 - sp) / (1 - rho.array())).log() ).matrix().sum();

    double* pgrad = grad;
    for (int i=0; i<hiddenSize*visibleSize; i++)
    {
        *pgrad = *(w1Grad.data()+i);
        pgrad++;
        
    }
    for (int i=0; i<visibleSize*hiddenSize; i++)
    {
        *pgrad = *(w2Grad.data()+i);
        pgrad++;
    }
    for (int i=0; i<hiddenSize; i++)
    {
        *pgrad = *(b1Grad.data()+i);
        pgrad++;
    }
    for (int i=0; i<visibleSize; i++)
    {
        *pgrad = *(b2Grad.data()+i);
        pgrad++;
    }

    return cost;
}