Exemplo n.º 1
0
Real BBPCostFunction::evaluate( const InfAlg &ia, const vector<size_t> *stateP ) const {
    Real cf = 0.0;
    const FactorGraph &fg = ia.fg();

    switch( (size_t)(*this) ) {
        case CFN_BETHE_ENT: // ignores state
            cf = -ia.logZ();
            break;
        case CFN_VAR_ENT: // ignores state
            for( size_t i = 0; i < fg.nrVars(); i++ )
                cf += -ia.beliefV(i).entropy();
            break;
        case CFN_FACTOR_ENT: // ignores state
            for( size_t I = 0; I < fg.nrFactors(); I++ )
                cf += -ia.beliefF(I).entropy();
            break;
        case CFN_GIBBS_B:
        case CFN_GIBBS_B2:
        case CFN_GIBBS_EXP: {
            DAI_ASSERT( stateP != NULL );
            vector<size_t> state = *stateP;
            DAI_ASSERT( state.size() == fg.nrVars() );
            for( size_t i = 0; i < fg.nrVars(); i++ ) {
                Real b = ia.beliefV(i)[state[i]];
                switch( (size_t)(*this) ) {
                    case CFN_GIBBS_B:
                        cf += b;
                        break;
                    case CFN_GIBBS_B2:
                        cf += b * b / 2.0;
                        break;
                    case CFN_GIBBS_EXP:
                        cf += exp( b );
                        break;
                    default:
                        DAI_THROW(UNKNOWN_ENUM_VALUE);
                }
            }
            break;
        } case CFN_GIBBS_B_FACTOR:
          case CFN_GIBBS_B2_FACTOR:
          case CFN_GIBBS_EXP_FACTOR: {
            DAI_ASSERT( stateP != NULL );
            vector<size_t> state = *stateP;
            DAI_ASSERT( state.size() == fg.nrVars() );
            for( size_t I = 0; I < fg.nrFactors(); I++ ) {
                size_t x_I = getFactorEntryForState( fg, I, state );
                Real b = ia.beliefF(I)[x_I];
                switch( (size_t)(*this) ) {
                    case CFN_GIBBS_B_FACTOR:
                        cf += b;
                        break;
                    case CFN_GIBBS_B2_FACTOR:
                        cf += b * b / 2.0;
                        break;
                    case CFN_GIBBS_EXP_FACTOR:
                        cf += exp( b );
                        break;
                    default:
                        DAI_THROW(UNKNOWN_ENUM_VALUE);
                }
            }
            break;
        } default:
            DAI_THROWE(UNKNOWN_ENUM_VALUE, "Unknown cost function " + std::string(*this));
    }
    return cf;
}
Exemplo n.º 2
0
        /// Run the algorithm and store its results
        void doDAI() {
            double tic = toc();
            if( obj != NULL ) {
                // Initialize
                obj->init();
                // Run
                obj->run();
                // Record the time
                time += toc() - tic;

                // Store logarithm of the partition sum (if supported)
                try {
                    logZ = obj->logZ();
                    has_logZ = true;
                } catch( Exception &e ) {
                    if( e.getCode() == Exception::NOT_IMPLEMENTED )
                        has_logZ = false;
                    else
                        throw;
                }

                // Store maximum difference encountered in last iteration (if supported)
                try {
                    maxdiff = obj->maxDiff();
                    has_maxdiff = true;
                } catch( Exception &e ) {
                    if( e.getCode() == Exception::NOT_IMPLEMENTED )
                        has_maxdiff = false;
                    else
                        throw;
                }

                // Store number of iterations needed (if supported)
                try {
                    iters = obj->Iterations();
                    has_iters = true;
                } catch( Exception &e ) {
                    if( e.getCode() == Exception::NOT_IMPLEMENTED )
                        has_iters = false;
                    else
                        throw;
                }

                // Store variable marginals
                varMarginals.clear();
                for( size_t i = 0; i < obj->fg().nrVars(); i++ )
                    varMarginals.push_back( obj->beliefV( i ) );

                // Store factor marginals
                facMarginals.clear();
                for( size_t I = 0; I < obj->fg().nrFactors(); I++ )
                    try {
                        facMarginals.push_back( obj->beliefF( I ) );
                    } catch( Exception &e ) {
                        if( e.getCode() == Exception::BELIEF_NOT_AVAILABLE )
                            facMarginals.push_back( Factor( obj->fg().factor(I).vars(), INFINITY ) );
                        else
                            throw;
                    }

                // Store all marginals calculated by the method
                allMarginals = obj->beliefs();
            };
        }