Пример #1
0
int GibbsMPEforScalarGaussianBNet( float eps)
{
    std::cout<<std::endl<<"Gibbs MPE for scalar gaussian BNet"<<std::endl;

    int ret =1;
    CBNet *pBnet = pnlExCreateScalarGaussianBNet();
    std::cout<<"BNet has been created \n";
    
    CGibbsSamplingInfEngine *pGibbsInf = CGibbsSamplingInfEngine::Create( pBnet );
    pGibbsInf->SetBurnIn( 100);
    pGibbsInf->SetMaxTime( 10000 );
    std::cout<<"burnIN and MaxTime have been defined \n";
    
    pEvidencesVector evidences;
    pBnet->GenerateSamples(&evidences, 1 );
    std::cout<<"evidence has been generated \n";
    
    const int ndsToToggle[] = { 0, 3 };
    evidences[0]->ToggleNodeState( 2, ndsToToggle );
    
    
    intVecVector queryes(1);
    queryes[0].push_back(0);
    pGibbsInf->SetQueries( queryes);
    std::cout<<"set queries"<<std::endl;
    
    pGibbsInf->EnterEvidence( evidences[0], 1 );
    std::cout<<"enter evidence"<<std::endl;
    
    
    intVector query(1,0);
    pGibbsInf->MarginalNodes( &query.front(),query.size() );
    std::cout<<"marginal nodes"<<std::endl;
    
    const CEvidence *pEvGibbs = pGibbsInf->GetMPE();

    CJtreeInfEngine *pJTreeInf = CJtreeInfEngine::Create(pBnet);
    pJTreeInf->EnterEvidence(evidences[0], 1);
    pJTreeInf->MarginalNodes(&query.front(), query.size());
    const CEvidence* pEvJTree = pJTreeInf->GetMPE();

    
   
    
    std::cout<<"result of gibbs"<<std::endl<<std::endl;
    pEvGibbs->Dump();
    pEvJTree->Dump();
    
    delete evidences[0];
   
    delete pGibbsInf;
    delete pJTreeInf;
    delete pBnet;
    return ret;
}
Пример #2
0
int CompareViterbyArHMM( CDBN* pDBN, int nTimeSlice, float eps )
{
    CBNet * pUnrolledDBN;
    pUnrolledDBN = static_cast<CBNet *>( pDBN->UnrollDynamicModel( nTimeSlice ) );


    /////////////////////////////////////////////////////////////////////////////
    //Create inference for unrolled DBN
    ////////////////////////////////////////////////////////////////////////////
    CEvidence *myEvidenceForUnrolledDBN;
    pEvidencesVector myEvidencesForDBN;
    CreateEvidencesArHMM( pDBN, nTimeSlice, &myEvidencesForDBN );
    myEvidenceForUnrolledDBN =
        CreateEvidenceForUnrolledArHMM(pUnrolledDBN , nTimeSlice, myEvidencesForDBN);
    CJtreeInfEngine *pUnrolJTree = CJtreeInfEngine::Create( pUnrolledDBN );

    //////////////////////////////////////////////////////////////////////////////

    /////////////////////////////////////////////////////////////////////////////
    //Create inference (smoothing) for DBN

    C1_5SliceJtreeInfEngine *pDynamicJTree;
    pDynamicJTree = C1_5SliceJtreeInfEngine::Create( pDBN );
    pUnrolJTree->EnterEvidence( myEvidenceForUnrolledDBN, 1 );
    pDynamicJTree->DefineProcedure( ptViterbi, nTimeSlice );
    pDynamicJTree->EnterEvidence( &myEvidencesForDBN.front(), nTimeSlice );
    pDynamicJTree->FindMPE();
    /////////////////////////////////////////////////////////////////////////////
    //
    //////////////////////////////////////////////////////////////////////////
    intVector queryForDBN, queryForDBNPrior;
    intVecVector queryForUnrollBnet;
    DefineQueryArHMM( pDBN, nTimeSlice,
	&queryForDBNPrior,	&queryForDBN, &queryForUnrollBnet );
    int itogResult = TRUE;
    int slice;
    for(slice = 0; slice < nTimeSlice; slice++)
    {

	pUnrolJTree->MarginalNodes( &(queryForUnrollBnet[slice]).front(),
	    (queryForUnrollBnet[slice]).size() );

	if( slice )
	{
	    pDynamicJTree->MarginalNodes( &queryForDBN.front(),
		queryForDBN.size(), slice );
	}
	else
	{
	    pDynamicJTree->MarginalNodes( &queryForDBNPrior.front(),
		queryForDBNPrior.size(), slice );
	}



	intVector pObsNodesOut1;
	intVector pObsNodesOut2;
	pnlVector<const unsigned char*> vals1;
	pnlVector<const unsigned char*> vals2;
	const CEvidence* pEv1 = pUnrolJTree->GetMPE();
	const CEvidence* pEv2 = pDynamicJTree->GetMPE();



	int nObsNodes = pEv1->GetNumberObsNodes();
	const CNodeType *const* nt = pEv1->GetNodeTypes();
	for(int i = 0; i < nObsNodes; i++)
	{
	    if(nt[i]->IsDiscrete())
	    {
		const int v1 = (pEv1->GetValueBySerialNumber(i)->GetInt());
		const int v2 = (pEv2->GetValueBySerialNumber(i)->GetInt());
		if (v2 != v1)
		{
		    itogResult = 0;
		    break;
		}
		
	    }
	    else
	    {
		int nodeSz = nt[i]->GetNodeSize();
		for( int j = 0; j < nodeSz; j++)
		{
		    const float v1 = pEv1->GetValueBySerialNumber(i)[j].GetFlt();
		    const float v2 = pEv2->GetValueBySerialNumber(i)[j].GetFlt();
		    if (v2 != v1)
		    {
			itogResult = 0;
			break;
		    }
		}
	    }
	    
	}
	if( !itogResult )
	{
	    if( pUnrolJTree->GetQueryMPE()->
		IsFactorsDistribFunEqual(pDynamicJTree->GetQueryMPE(), eps ) )
	    {
		itogResult = 1;
	    }
	    else
	    {
		pEv1->Dump();
		pEv2->Dump();
	    }
	}
	

    }
    for( slice = 0; slice < myEvidencesForDBN.size(); slice++ )
    {
	    delete myEvidencesForDBN[slice];
    }
    
    delete myEvidenceForUnrolledDBN;
    
    delete pUnrolJTree;
    
    delete pUnrolledDBN;

    delete pDynamicJTree;

    return itogResult;
}