CMNet* CMNet::CreateWithRandomMatrices( const intVecVector& clqs, CModelDomain* pMD) { CMNet* pMNet = CMNet::Create( clqs, pMD ); pMNet->AllocFactors(); int numFactors = pMNet->GetNumberOfFactors(); int i; for( i = 0; i < numFactors; i++ ) { pMNet->AllocFactor( i ); CFactor* ft = pMNet->GetFactor(i); ft->CreateAllNecessaryMatrices(); } return pMNet; }
CMNet* myCreateDiscreteMNet() { const int numOfNds = 4; const int numOfNodeTypes = 1; const int numOfClqs = 4; intVector clqSizes( numOfClqs, 2 ); int clq0[] = { 0, 1 }; int clq1[] = { 1, 2 }; int clq2[] = { 2, 3 }; int clq3[] = { 3, 0 }; const int *clqs[] = { clq0, clq1, clq2, clq3 }; CNodeType nodeType( 1, 2 ); intVector nodeAssociation( numOfNds, 0 ); CMNet *pMNet = CMNet::Create( numOfNds, numOfNodeTypes, &nodeType, &nodeAssociation.front(), numOfClqs, &clqSizes.front(), clqs ); pMNet->AllocFactors(); float data0[] = { 0.79f, 0.21f, 0.65f, 0.35f }; float data1[] = { 0.91f, 0.09f, 0.22f, 0.78f }; float data2[] = { 0.45f, 0.55f, 0.24f, 0.76f }; float data3[] = { 0.51f, 0.49f, 0.29f, 0.71f }; float *data[] = { data0, data1, data2, data3 }; int i = 0; for( ; i < numOfClqs; ++i ) { pMNet->AllocFactor(i); pMNet->GetFactor(i)->AllocMatrix( data[i], matTable ); } return pMNet; }
int GibbsMPEForMNet( float eps) { std::cout<<std::endl<<"Gibbs MPE for MNet"<< std::endl; pEvidencesVector evidences; CGibbsSamplingInfEngine *pGibbsInf; int ret = 1; CMNet* pMNet = myCreateDiscreteMNet(); pMNet->GenerateSamples( &evidences, 1 ); const int ndsToToggleMNet[] = { 0, 3 }; evidences[0]->ToggleNodeState( 2, ndsToToggleMNet ); pGibbsInf = CGibbsSamplingInfEngine::Create(pMNet); intVecVector queries(1); //pGibbsInf->SetParemeter(pMaxTime, 500); queries[0].clear(); queries[0].push_back(0); queries[0].push_back(3); pGibbsInf->SetQueries(queries); pGibbsInf->EnterEvidence( evidences[0], 1 ); CNaiveInfEngine* pInf = CNaiveInfEngine::Create(pMNet); pInf->EnterEvidence( evidences[0], 1 ); const int querySzMNet = 2; const int queryMNet[] = {0, 3}; pGibbsInf->MarginalNodes( queryMNet,querySzMNet ); pInf->MarginalNodes( queryMNet,querySzMNet ); const CEvidence *pEvGibbs = pGibbsInf->GetMPE(); const CEvidence *pEvInf = pInf->GetMPE(); int i; for( i = 0; i < querySzMNet; i++ ) { if( pEvGibbs->GetValueBySerialNumber(i)->GetInt() != pEvInf->GetValueBySerialNumber(i)->GetInt() ) { ret = 0; break; } } std::cout<<"result of gibbs"<<std::endl; pEvGibbs->Dump(); std::cout<<std::endl<<"result of naive"<<std::endl; pEvInf->Dump(); delete pInf; delete pGibbsInf; delete evidences[0]; delete pMNet; //////////////////////////////////////////////////////////////////////////////////////// return ret; }
int GibbsForMNet(float eps) { std::cout<<std::endl<<"Gibbs for discrete MNet"<< std::endl; pEvidencesVector evidences; CGibbsSamplingInfEngine *pGibbsInf; int ret; CMNet* pMNet = myCreateDiscreteMNet(); pMNet->GenerateSamples( &evidences, 1 ); const int ndsToToggleMNet[] = { 0, 3 }; evidences[0]->ToggleNodeState( 2, ndsToToggleMNet ); pGibbsInf = CGibbsSamplingInfEngine::Create(pMNet); intVecVector queries(1); //pGibbsInf->SetParemeter(pMaxTime, 500); queries[0].clear(); queries[0].push_back(0); queries[0].push_back(3); pGibbsInf->SetQueries(queries); pGibbsInf->EnterEvidence( evidences[0] ); CNaiveInfEngine* pInf = CNaiveInfEngine::Create(pMNet); pInf->EnterEvidence( evidences[0] ); const int querySzMNet = 2; const int queryMNet[] = {0, 3}; pGibbsInf->MarginalNodes( queryMNet,querySzMNet ); pInf->MarginalNodes( queryMNet,querySzMNet ); const CPotential *pQueryPot1MNet = pGibbsInf->GetQueryJPD(); const CPotential *pQueryPot2MNet = pInf->GetQueryJPD(); ret = pQueryPot1MNet-> IsFactorsDistribFunEqual( pQueryPot2MNet, eps, 0 ); std::cout<<"result of gibbs"<<std::endl; pQueryPot1MNet->Dump(); std::cout<<std::endl<<"result of naive"<<std::endl; pQueryPot2MNet->Dump(); delete pInf; delete pGibbsInf; delete evidences[0]; delete pMNet; //////////////////////////////////////////////////////////////////////////////////////// return ret; }
CMNet::CMNet(const CMNet& rMNet):CStaticGraphicalModel( rMNet.GetModelDomain() ) { // need to copy Graph and factors }