Beispiel #1
0
void makeBetheMinSum(size_t nNodes,
                     const double *theta,
                     const cscMatrix &W,
                     const double *A,
                     const double *B,
                     const double *alpha,
                     double intervalSz,
                     MinSum &m) {
  mxAssert(m.nNodes == nNodes, "nNodes was a lie.");

  cscMatrix intervals = calcIntervals(nNodes, A, B, intervalSz);

  // Unaries: Term two of (Eq 4)
  for (size_t j = 0; j < nNodes; j++) {
    int nStates = intervals.jc[j+1] - intervals.jc[j];
    mxAssert(nStates >= 2, "states did not exceed two!");
    Node &nj = m.addNode(j, nStates);

    int degMinusOne = degree(W, j) - 1;

    for (size_t k = 0; k < nStates; k++) {
      double q = intervals.pr[intervals.jc[j] + k];
      nj(k) = -theta[j] * q + degMinusOne * binaryEntropy(q);
      //mexPrintf("%s:%d -- Node %d[%d] is %g\n", __FILE__, __LINE__, j, k, nj(k));
    }
  }

  // We know ahead of time there will be nnz distinct potentials.
  m.potentials.reserve(W.nzMax / 2);

  // Pairwise: (Eq 5)
  for (size_t hi = 0; hi < nNodes; hi++) {
    int nHiStates = m.nodes[hi].nStates;
    mxAssert(nHiStates >= 0, "nHiStates cannot be negative.");

    for (size_t nodeIdx = W.jc[hi]; nodeIdx < W.jc[hi+1]; nodeIdx++) {
      int lo = W.ir[nodeIdx];

      if (hi > lo) {
        int nLoStates = m.nodes[lo].nStates;
        mxAssert(nLoStates >= 0, "nLoStates cannot be negative.");

        // Only look at the upper triangular (minus diagonal)
        double w = W.pr[nodeIdx];
        double aij = alpha[nodeIdx];

        mxAssert(aij != 0, "alpha cannot be zero.");

        Potential &potential = m.addPotential(nLoStates, nHiStates);
        m.addEdge(lo, hi, 1.0, &potential);

        for (size_t iqLo = 0; iqLo < nLoStates; iqLo++) {
          double qLo = intervals.pr[intervals.jc[lo] + iqLo];

          for (size_t iqHi = 0; iqHi < nHiStates; iqHi++) {
            double qHi = intervals.pr[intervals.jc[hi] + iqHi];
            double marginals[4];
            double xi = marginalize(aij, qLo, qHi, marginals);
            potential(iqLo, iqHi) = -w*xi - entropy<4>(marginals);

            //mexPrintf("%s:%d -- Potential at %lx entry (%d, %d) is %g\n",
            //          __FILE__, __LINE__, &potential, iqLo, iqHi, potential(iqLo, iqHi));
          }
        }

      }
    }
  }

  deleteCscMatrix(intervals);
}
Beispiel #2
0
 PySparseTensor marginalize(const TIV &dims) const
 { return marginalize(PyTensorIndex(dims)); }