void makeBetheMinSum(size_t nNodes, const double *theta, const cscMatrix &W, const double *A, const double *B, const double *alpha, double intervalSz, MinSum &m) { mxAssert(m.nNodes == nNodes, "nNodes was a lie."); cscMatrix intervals = calcIntervals(nNodes, A, B, intervalSz); // Unaries: Term two of (Eq 4) for (size_t j = 0; j < nNodes; j++) { int nStates = intervals.jc[j+1] - intervals.jc[j]; mxAssert(nStates >= 2, "states did not exceed two!"); Node &nj = m.addNode(j, nStates); int degMinusOne = degree(W, j) - 1; for (size_t k = 0; k < nStates; k++) { double q = intervals.pr[intervals.jc[j] + k]; nj(k) = -theta[j] * q + degMinusOne * binaryEntropy(q); //mexPrintf("%s:%d -- Node %d[%d] is %g\n", __FILE__, __LINE__, j, k, nj(k)); } } // We know ahead of time there will be nnz distinct potentials. m.potentials.reserve(W.nzMax / 2); // Pairwise: (Eq 5) for (size_t hi = 0; hi < nNodes; hi++) { int nHiStates = m.nodes[hi].nStates; mxAssert(nHiStates >= 0, "nHiStates cannot be negative."); for (size_t nodeIdx = W.jc[hi]; nodeIdx < W.jc[hi+1]; nodeIdx++) { int lo = W.ir[nodeIdx]; if (hi > lo) { int nLoStates = m.nodes[lo].nStates; mxAssert(nLoStates >= 0, "nLoStates cannot be negative."); // Only look at the upper triangular (minus diagonal) double w = W.pr[nodeIdx]; double aij = alpha[nodeIdx]; mxAssert(aij != 0, "alpha cannot be zero."); Potential &potential = m.addPotential(nLoStates, nHiStates); m.addEdge(lo, hi, 1.0, &potential); for (size_t iqLo = 0; iqLo < nLoStates; iqLo++) { double qLo = intervals.pr[intervals.jc[lo] + iqLo]; for (size_t iqHi = 0; iqHi < nHiStates; iqHi++) { double qHi = intervals.pr[intervals.jc[hi] + iqHi]; double marginals[4]; double xi = marginalize(aij, qLo, qHi, marginals); potential(iqLo, iqHi) = -w*xi - entropy<4>(marginals); //mexPrintf("%s:%d -- Potential at %lx entry (%d, %d) is %g\n", // __FILE__, __LINE__, &potential, iqLo, iqHi, potential(iqLo, iqHi)); } } } } } deleteCscMatrix(intervals); }
PySparseTensor marginalize(const TIV &dims) const { return marginalize(PyTensorIndex(dims)); }