void CJtreeInfEngine::MarginalNodes( const int *query, int querySz, int notExpandJPD ) { // bad-args check PNL_CHECK_IS_NULL_POINTER(query); PNL_CHECK_RANGES( querySz, 1, m_pGraphicalModel->GetNumberOfNodes() ); // bad-args check end /* // the following should be working differently for the case of doing the // whole EnterEvidence procedure or just CollectEvidence for the root node if( ( m_lastOpDone != opsDistribute ) && ( m_lastOpDone != opsMargNodes ) ) { if( m_lastOpDone != opsCollect ) { PNL_THROW( CInvalidOperation, " cannot perform marginalization, infEngine inconsistent " ); } int numOfClqsContQuery; const int *clqsContQuery; m_pOriginalJTree->GetClqNumsContainingSubset( querySz, query, &numOfClqsContQuery, &clqsContQuery ); PNL_CHECK_FOR_ZERO(numOfClqsContQuery); if( std::find( clqsContQuery, clqsContQuery + numOfClqsContQuery, m_JTreeRootNode ) == clqsContQuery + numOfClqsContQuery ) { PNL_THROW( CInvalidOperation, " cannot marginalize to the non-root-clq nodes set " ); } //////// this is to debug for( int i = 0; i < numOfClqsContQuery; ++i ) { CPotential *pJPot = m_pJTree->GetNodePotential(clqsContQuery[i]) ->Marginalize( query, querySz ); CPotential *pJPot1 = pJPot->GetNormalized(); pJPot1->Dump(); delete pJPot; delete pJPot1; } /////////////////////////////////////////////////////// MarginalizeCliqueToQuery( m_JTreeRootNode, querySz, query ); m_lastOpDone = opsMargNodes; } else { */ int numOfClqsContQuery; const int *clqsContQuery; m_pJTree->GetClqNumsContainingSubset( querySz, query, &numOfClqsContQuery, &clqsContQuery ); if(numOfClqsContQuery) { if( std::find( clqsContQuery, clqsContQuery + numOfClqsContQuery, m_JTreeRootNode ) != ( clqsContQuery + numOfClqsContQuery ) ) { MarginalizeCliqueToQuery( m_JTreeRootNode, querySz, query, notExpandJPD ); } else { MarginalizeCliqueToQuery( *clqsContQuery, querySz, query, notExpandJPD ); } } else { const int* clqDomain; int clqSize; CPotential *resPot = NULL; delete m_pQueryJPD; m_pQueryJPD = NULL; ShrinkJTreeCliques(querySz, const_cast<int*>(query)); resPot = MergeCliques(querySz, const_cast<int*>(query)); resPot->GetDomain(&clqSize, &clqDomain); if( !pnlIsIdentical(querySz, const_cast<int*>(query), clqSize, const_cast<int*>(clqDomain)) ) { m_pQueryJPD = resPot->Marginalize(const_cast<int*>(query), querySz); } else { m_pQueryJPD = static_cast<CPotential*>(resPot->Clone()); } m_pQueryJPD->Normalize(); delete resPot; } }
CPotential* CJtreeInfEngine::MergeCliques(int domSize, int* Domain) { int numNodes = m_pJTree->GetNumberOfNodes(); potsPVector vPots(numNodes, (CPotential*)0); int i; const int* clqDomain; int clqSize; const int* sepDomain; int sepSize; const int *nbr, *nbrs_end; int numOfNbrs; const int *nbrs; const ENeighborType *nbrsTypes; intVector::const_iterator sourceIt, source_end; intVecVector::const_iterator layerIt = m_collectSequence.begin(), collSeq_end = m_collectSequence.end(); const CGraph *pGraph = m_pJTree->GetGraph(); intVector nodesSentMessages; intVector tmpV; for( ; layerIt != collSeq_end; ++layerIt ) { for( sourceIt = layerIt->begin(), source_end = layerIt->end(); sourceIt != source_end; ++sourceIt ) { if( !m_NodesAfterShrink[*sourceIt] ) continue; pGraph->GetNeighbors( *sourceIt, &numOfNbrs, &nbrs, &nbrsTypes ); tmpV.assign(Domain, Domain+domSize); for( nbr = nbrs, nbrs_end = nbrs + numOfNbrs; nbr != nbrs_end; ++nbr ) { if( !m_NodesAfterShrink[*nbr] ) continue; m_pJTree->GetSeparatorDomain(*sourceIt, *nbr, &sepSize, &sepDomain); tmpV = pnlSetUnion(sepSize, const_cast<int*>(sepDomain), tmpV.size(), &tmpV.front()); } m_pJTree->GetNodeContent(*sourceIt, &clqSize, &clqDomain); tmpV = pnlIntersect(clqSize, const_cast<int*>(clqDomain), tmpV.size(), &tmpV.front()); if( !pnlIsIdentical(tmpV.size(), &tmpV.front(), clqSize, const_cast<int*>(clqDomain)) ) { vPots[*sourceIt] = m_pJTree->GetNodePotential(*sourceIt)->Marginalize(tmpV); } else { vPots[*sourceIt] = static_cast<CPotential*>(m_pJTree->GetNodePotential(*sourceIt)->Clone()); } } } intVector bigDomain; layerIt = m_collectSequence.begin(); nodesSentMessages.assign(numNodes, false); CPotential* tPot; for( ; layerIt != collSeq_end; ++layerIt ) { for( sourceIt = layerIt->begin(), source_end = layerIt->end(); sourceIt != source_end; ++sourceIt ) { if( !m_NodesAfterShrink[*sourceIt] )continue; pGraph->GetNeighbors( *sourceIt, &numOfNbrs, &nbrs, &nbrsTypes ); for( nbr = nbrs, nbrs_end = nbrs + numOfNbrs; nbr != nbrs_end; ++nbr ) { if( !nodesSentMessages[*nbr] && m_NodesAfterShrink[*nbr] ) { CPotential* pPot = vPots[*nbr]; CPotential* cPot = vPots[*sourceIt]; CPotential* bigPot = pnlMultiply(pPot, cPot, GetModel()->GetModelDomain()); *bigPot /= *(m_pJTree->GetSeparatorPotential(*sourceIt, *nbr)); m_NodesAfterShrink[*sourceIt] = false; int numOfNbrs1; const int *nbrs1, *nbr1, *nbrs1_end; const ENeighborType *nbrsTypes1; pGraph->GetNeighbors( *nbr, &numOfNbrs1, &nbrs1, &nbrsTypes1 ); tmpV.assign(Domain, Domain+domSize); for(nbr1 = nbrs1, nbrs1_end = nbrs1 + numOfNbrs1; nbr1 != nbrs1_end; ++nbr1 ) { if( !m_NodesAfterShrink[*nbr1] ) continue; m_pJTree->GetSeparatorDomain(*nbr, *nbr1, &sepSize, &sepDomain); tmpV = pnlSetUnion(sepSize, const_cast<int*>(sepDomain), tmpV.size(), &tmpV.front()); } bigPot->GetDomain(&bigDomain); tmpV = pnlIntersect(tmpV.size(), &tmpV.front(), bigDomain.size(), &bigDomain.front()); if( tmpV.size() < bigDomain.size() ) { tPot = bigPot->Marginalize(&tmpV.front(), tmpV.size()); delete bigPot; bigPot = tPot; } delete vPots[*nbr]; vPots[*nbr] = bigPot; bigPot->GetDomain(&bigDomain); if( pnlIsSubset(domSize, Domain, bigDomain.size(), &bigDomain.front()) ) { CPotential* retPot = static_cast<CPotential*>(bigPot->Clone()); for(i=0; i<numNodes; i++) { delete vPots[i]; } vPots.clear(); m_NodesAfterShrink.clear(); return retPot; } } nodesSentMessages[*sourceIt] = true; } } } PNL_THROW(CInternalError, "internal error"); }