Exemplo n.º 1
0
void
VarElim::eliminate (VarId elimVar)
{
  Factor* result = 0;
  vector<size_t>& idxs = varFactors_.find (elimVar)->second;
  for (size_t i = 0; i < idxs.size(); i++) {
    size_t idx = idxs[i];
    if (factorList_[idx]) {
      if (result == 0) {
        result = new Factor (*factorList_[idx]);
      } else {
        result->multiply (*factorList_[idx]);
      }
      delete factorList_[idx];
      factorList_[idx] = 0;
    }
  }
  totalFactorSize_ += result->size();
  if (result->size() > largestFactorSize_) {
    largestFactorSize_ = result->size();
  }
  if (result != 0 && result->nrArguments() != 1) {
    result->sumOut (elimVar);
    factorList_.push_back (result);
    const VarIds& resultVarIds = result->arguments();
    for (size_t i = 0; i < resultVarIds.size(); i++) {
      vector<size_t>& idxs =
          varFactors_.find (resultVarIds[i])->second;
      idxs.push_back (factorList_.size() - 1);
    }
  }
}
Exemplo n.º 2
0
Params
BeliefProp::getFactorJoint (
    FacNode* fn,
    const VarIds& jointVarIds)
{
  if (runned_ == false) {
    runSolver();
  }
  Factor res (fn->factor());
  const BpLinks& links = getLinks( fn);
  for (size_t i = 0; i < links.size(); i++) {
    Factor msg ({links[i]->varNode()->varId()},
                {links[i]->varNode()->range()},
                getVarToFactorMsg (links[i]));
    res.multiply (msg);
  }
  res.sumOutAllExcept (jointVarIds);
  res.reorderArguments (jointVarIds);
  res.normalize();
  Params jointDist = res.params();
  if (Globals::logDomain) {
    Util::exp (jointDist);
  }
  return jointDist;
}
Exemplo n.º 3
0
void
VarElim::processFactorList (const VarIds& vids)
{
  totalFactorSize_   = 0;
  largestFactorSize_ = 0;
  for (size_t i = 0; i < elimOrder_.size(); i++) {
    if (Globals::verbosity >= 2) {
      if (Globals::verbosity >= 3) {
        Util::printDashedLine();
        printActiveFactors();
      }
      cout << "-> summing out " ;
      cout << fg.getVarNode (elimOrder_[i])->label() << endl;
    }
    eliminate (elimOrder_[i]);
  }

  Factor* finalFactor = new Factor();
  for (size_t i = 0; i < factorList_.size(); i++) {
    if (factorList_[i]) {
      finalFactor->multiply (*factorList_[i]);
      delete factorList_[i];
      factorList_[i] = 0;
    }
  }

  VarIds unobservedVids;
  for (size_t i = 0; i < vids.size(); i++) {
    if (fg.getVarNode (vids[i])->hasEvidence() == false) {
      unobservedVids.push_back (vids[i]);
    }
  }

  finalFactor->reorderArguments (unobservedVids);
  finalFactor->normalize();
  factorList_.push_back (finalFactor);
  if (Globals::verbosity > 0) {
    cout << "total factor size:   " << totalFactorSize_ << endl;
    cout << "largest factor size: " << largestFactorSize_ << endl;
    cout << endl;
  }
}
Exemplo n.º 4
0
void
BeliefProp::calcFactorToVarMsg (BpLink* link)
{
  FacNode* src = link->facNode();
  const VarNode* dst = link->varNode();
  const BpLinks& links = getLinks (src);
  // calculate the product of messages that were sent
  // to factor `src', except from var `dst'
  unsigned reps    = 1;
  unsigned msgSize = Util::sizeExpected (src->factor().ranges());
  Params msgProduct (msgSize, LogAware::multIdenty());
  if (Globals::logDomain) {
    for (size_t i = links.size(); i-- > 0; ) {
      if (links[i]->varNode() != dst) {
        if (Constants::showBpCalcs) {
          std::cout << "    message from " << links[i]->varNode()->label();
          std::cout << ": " ;
        }
        Util::apply_n_times (msgProduct, getVarToFactorMsg (links[i]),
            reps, std::plus<double>());
        if (Constants::showBpCalcs) {
          std::cout << std::endl;
        }
      }
      reps *= links[i]->varNode()->range();
    }
  } else {
    for (size_t i = links.size(); i-- > 0; ) {
      if (links[i]->varNode() != dst) {
        if (Constants::showBpCalcs) {
          std::cout << "    message from " << links[i]->varNode()->label();
          std::cout << ": " ;
        }
        Util::apply_n_times (msgProduct, getVarToFactorMsg (links[i]),
            reps, std::multiplies<double>());
        if (Constants::showBpCalcs) {
          std::cout << std::endl;
        }
      }
      reps *= links[i]->varNode()->range();
    }
  }
  Factor result (src->factor().arguments(),
      src->factor().ranges(), msgProduct);
  result.multiply (src->factor());
  if (Constants::showBpCalcs) {
    std::cout << "    message product:  " << msgProduct << std::endl;
    std::cout << "    original factor:  " << src->factor().params();
    std::cout << std::endl;
    std::cout << "    factor product:   " << result.params() << std::endl;
  }
  result.sumOutAllExcept (dst->varId());
  if (Constants::showBpCalcs) {
    std::cout << "    marginalized:     " << result.params() << std::endl;
  }
  link->nextMessage() = result.params();
  LogAware::normalize (link->nextMessage());
  if (Constants::showBpCalcs) {
    std::cout << "    curr msg:         " << link->message() << std::endl;
    std::cout << "    next msg:         " << link->nextMessage() << std::endl;
  }
}