Exemplo n.º 1
0
Params
BeliefProp::solveQuery (VarIds queryVids)
{
  assert (queryVids.empty() == false);
  return queryVids.size() == 1
      ? getPosterioriOf (queryVids[0])
      : getJointDistributionOf (queryVids);
}
Exemplo n.º 2
0
void
runSolver (const FactorGraph& fg, const VarIds& queryIds)
{
  GroundSolver* solver = 0;
  switch (Globals::groundSolver) {
    case GroundSolverType::VE:
      solver = new VarElim (fg);
      break;
    case GroundSolverType::BP:
      solver = new BeliefProp (fg);
      break;
    case GroundSolverType::CBP:
      solver = new CountingBp (fg);
      break;
    default:
      assert (false);
  }
  if (Globals::verbosity > 0) {
    solver->printSolverFlags();
    cout << endl;
  }
  if (queryIds.empty()) {
    solver->printAllPosterioris();
  } else {
    solver->printAnswer (queryIds);
  }
  delete solver;
}
Exemplo n.º 3
0
void
VarElim::processFactorList (const VarIds& vids)
{
  totalFactorSize_   = 0;
  largestFactorSize_ = 0;
  for (size_t i = 0; i < elimOrder_.size(); i++) {
    if (Globals::verbosity >= 2) {
      if (Globals::verbosity >= 3) {
        Util::printDashedLine();
        printActiveFactors();
      }
      cout << "-> summing out " ;
      cout << fg.getVarNode (elimOrder_[i])->label() << endl;
    }
    eliminate (elimOrder_[i]);
  }

  Factor* finalFactor = new Factor();
  for (size_t i = 0; i < factorList_.size(); i++) {
    if (factorList_[i]) {
      finalFactor->multiply (*factorList_[i]);
      delete factorList_[i];
      factorList_[i] = 0;
    }
  }

  VarIds unobservedVids;
  for (size_t i = 0; i < vids.size(); i++) {
    if (fg.getVarNode (vids[i])->hasEvidence() == false) {
      unobservedVids.push_back (vids[i]);
    }
  }

  finalFactor->reorderArguments (unobservedVids);
  finalFactor->normalize();
  factorList_.push_back (finalFactor);
  if (Globals::verbosity > 0) {
    cout << "total factor size:   " << totalFactorSize_ << endl;
    cout << "largest factor size: " << largestFactorSize_ << endl;
    cout << endl;
  }
}
Exemplo n.º 4
0
VarIds
ElimGraph::getEliminatingOrder (const VarIds& exclude)
{
  VarIds elimOrder;
  unmarked_.reserve (nodes_.size());
  for (size_t i = 0; i < nodes_.size(); i++) {
    if (Util::contains (exclude, nodes_[i]->varId()) == false) {
      unmarked_.insert (nodes_[i]);
    }
  }
  size_t nrVarsToEliminate = nodes_.size() - exclude.size();
  for (size_t i = 0; i < nrVarsToEliminate; i++) {
    EgNode* node = getLowestCostNode();
    unmarked_.remove (node);
    const EGNeighs& neighs = node->neighbors();
    for (size_t j = 0; j < neighs.size(); j++) {
      neighs[j]->removeNeighbor (node);
    }
    elimOrder.push_back (node->varId());
    connectAllNeighbors (node);
  }
  return elimOrder;
}
Exemplo n.º 5
0
void
ElimGraph::exportToGraphViz (
    const char* fileName,
    bool showNeighborless,
    const VarIds& highlightVarIds) const
{
  ofstream out (fileName);
  if (!out.is_open()) {
    cerr << "error: cannot open file to write at " ;
    cerr << "Markov::exportToDotFile()" << endl;
    abort();
  }

  out << "strict graph {" << endl;

  for (size_t i = 0; i < nodes_.size(); i++) {
    if (showNeighborless || nodes_[i]->neighbors().size() != 0) {
      out << '"' << nodes_[i]->label() << '"' << endl;
    }
  }

  for (size_t i = 0; i < highlightVarIds.size(); i++) {
    EgNode* node =getEgNode (highlightVarIds[i]);
    if (node) {
      out << '"' << node->label() << '"' ;
      out << " [shape=box3d]" << endl;
    } else {
      cout << "error: invalid variable id: " << highlightVarIds[i] << endl;
      abort();
    }
  }

  for (size_t i = 0; i < nodes_.size(); i++) {
    EGNeighs neighs = nodes_[i]->neighbors();
    for (size_t j = 0; j < neighs.size(); j++) {
      out << '"' << nodes_[i]->label() << '"' << " -- " ;
      out << '"' << neighs[j]->label() << '"' << endl;
    }
  }

  out << "}" << endl;
  out.close();
}
Exemplo n.º 6
0
Params
VarElim::solveQuery (VarIds queryVids)
{
  if (Globals::verbosity > 1) {
    cout << "Solving query on " ;
    for (size_t i = 0; i < queryVids.size(); i++) {
      if (i != 0) cout << ", " ;
      cout << fg.getVarNode (queryVids[i])->label();
    }
    cout << endl;
  }
  factorList_.clear();
  varFactors_.clear();
  elimOrder_.clear();
  createFactorList();
  absorveEvidence();
  findEliminationOrder (queryVids);
  processFactorList (queryVids);
  Params params = factorList_.back()->params();
  if (Globals::logDomain) {
    Util::exp (params);
  }
  return params;
}
Exemplo n.º 7
0
VarIds
readQueryAndEvidence (
    FactorGraph& fg,
    int argc,
    const char* argv[],
    int start)
{
  VarIds queryIds;
  for (int i = start; i < argc; i++) {
    const string& arg = argv[i];
    if (arg.find ('=') == std::string::npos) {
      if (Util::isInteger (arg) == false) {
        cerr << "error: `" << arg << "' " ;
        cerr << "is not a variable id" ;
        cerr << endl;
        exit (0);
      }
      VarId vid = Util::stringToUnsigned (arg);
      VarNode* queryVar = fg.getVarNode (vid);
      if (queryVar == false) {
        cerr << "error: unknow variable with id " ;
        cerr << "`" << vid << "'"  << endl;
        exit (0);
      }
      queryIds.push_back (vid);
    } else {
      size_t pos = arg.find ('=');
      string leftArg  = arg.substr (0, pos);
      string rightArg = arg.substr (pos + 1);
      if (leftArg.empty()) {
        cerr << "error: missing left argument" << endl;
        cerr << USAGE << endl;
        exit (0);
      }
      if (Util::isInteger (leftArg) == false) {
        cerr << "error: `" << leftArg << "' " ;
        cerr << "is not a variable id" << endl ;
        exit (0);
        continue;
      }
      VarId vid = Util::stringToUnsigned (leftArg);
      VarNode* observedVar = fg.getVarNode (vid);
      if (observedVar == false) {
        cerr << "error: unknow variable with id " ;
        cerr << "`" << vid << "'"  << endl;
        exit (0);
      }
      if (rightArg.empty()) {
        cerr << "error: missing right argument" << endl;
        cerr << USAGE << endl;
        exit (0);
      }
      if (Util::isInteger (rightArg) == false) {
        cerr << "error: `" << rightArg << "' " ;
        cerr << "is not a state index" << endl ;
        exit (0);
      }
      unsigned stateIdx = Util::stringToUnsigned (rightArg);
      if (observedVar->isValidState (stateIdx) == false) {
        cerr << "error: `" << stateIdx << "' " ;
        cerr << "is not a valid state index for variable with id " ;
        cerr << "`" << vid << "'"  << endl;
        exit (0);
      }
      observedVar->setEvidence (stateIdx);
    }
  }
  return queryIds;
}