template <class T> int MRFEnergy<T>::Minimize_TRW_S(Options& options, std::vector<REAL> &lowerBound_arr, std::vector<REAL> &energy_arr, std::vector<clock_t> &time_arr, REAL* min_marginals) { Node* i; Node* j; MRFEdge* e; REAL vMin; int iter; REAL lowerBoundPrev; clock_t tStart = clock(); if (!m_isEnergyConstructionCompleted) { CompleteGraphConstruction(); } printf("TRW_S algorithm\n"); SetMonotonicTrees(); Vector* Di = (Vector*) m_buf; void* buf = (void*) (m_buf + m_vectorMaxSizeInBytes); iter = 0; bool lastIter = false; // main loop for (iter=1; ; iter++) { if (iter >= options.m_iterMax) lastIter = true; //////////////////////////////////////////////// // forward pass // //////////////////////////////////////////////// REAL* min_marginals_ptr = min_marginals; for (i=m_nodeFirst; i; i=i->m_next) { Di->Copy(m_Kglobal, i->m_K, &i->m_D); for (e=i->m_firstForward; e; e=e->m_nextForward) { Di->Add(m_Kglobal, i->m_K, e->m_message.GetMessagePtr()); } for (e=i->m_firstBackward; e; e=e->m_nextBackward) { Di->Add(m_Kglobal, i->m_K, e->m_message.GetMessagePtr()); } // normalize Di, update lower bound // vMin = Di->ComputeAndSubtractMin(m_Kglobal, i->m_K); // do not compute lower bound // lowerBound += vMin; // during the forward pass // pass messages from i to nodes with higher m_ordering for (e=i->m_firstForward; e; e=e->m_nextForward) { assert(e->m_tail == i); j = e->m_head; vMin = e->m_message.UpdateMessage(m_Kglobal, i->m_K, j->m_K, Di, e->m_gammaForward, 0, buf); // lowerBound += vMin; // do not compute lower bound during the forward pass } if (lastIter && min_marginals) { min_marginals_ptr += Di->GetArraySize(m_Kglobal, i->m_K); } } //////////////////////////////////////////////// // backward pass // //////////////////////////////////////////////// REAL lowerBound = 0; for (i=m_nodeLast; i; i=i->m_prev) { Di->Copy(m_Kglobal, i->m_K, &i->m_D); for (e=i->m_firstBackward; e; e=e->m_nextBackward) { Di->Add(m_Kglobal, i->m_K, e->m_message.GetMessagePtr()); } for (e=i->m_firstForward; e; e=e->m_nextForward) { Di->Add(m_Kglobal, i->m_K, e->m_message.GetMessagePtr()); } // normalize Di, update lower bound vMin = Di->ComputeAndSubtractMin(m_Kglobal, i->m_K); lowerBound += vMin; // pass messages from i to nodes with smaller m_ordering for (e=i->m_firstBackward; e; e=e->m_nextBackward) { assert(e->m_head == i); j = e->m_tail; vMin = e->m_message.UpdateMessage(m_Kglobal, i->m_K, j->m_K, Di, e->m_gammaBackward, 1, buf); lowerBound += vMin; } if (lastIter && min_marginals) { min_marginals_ptr -= Di->GetArraySize(m_Kglobal, i->m_K); for (int k=0; k<Di->GetArraySize(m_Kglobal, i->m_K); k++) { min_marginals_ptr[k] = Di->GetArrayValue(m_Kglobal, i->m_K, k); } } } //////////////////////////////////////////////// // check stopping criterion // //////////////////////////////////////////////// // Add lower bound, energy and time to output array lowerBound_arr.push_back(lowerBound); energy_arr.push_back(ComputeSolutionAndEnergy()); time_arr.push_back((clock() - tStart) * 1.0 / CLOCKS_PER_SEC); // print lower bound and energy, if necessary if ( lastIter || ( iter>=options.m_printMinIter && (options.m_printIter<1 || iter%options.m_printIter==0) ) ) { REAL energy = ComputeSolutionAndEnergy(); printf("iter %d: lower bound = %f, energy = %f\n", iter, lowerBound, energy); } if (lastIter) break; // check convergence of lower bound if (options.m_eps >= 0) { if (iter > 1 && lowerBound - lowerBoundPrev <= options.m_eps) { lastIter = true; } lowerBoundPrev = lowerBound; } } return iter; }
template <class T> int MRFEnergy<T>::Minimize_TRW_S(Options& options, REAL& lowerBound, REAL& energy) { Node* i; Node* j; MRFEdge* e; REAL vMin; int iter; REAL lowerBoundPrev; if (!m_isEnergyConstructionCompleted) { CompleteGraphConstruction(); } SetMonotonicTrees(); Vector* Di = (Vector*) m_buf; void* buf = (void*) (m_buf + m_vectorMaxSizeInBytes); iter = 0; // main loop for (iter=1; ; iter++) { //////////////////////////////////////////////// // forward pass // //////////////////////////////////////////////// for (i=m_nodeFirst; i; i=i->m_next) { Di->Copy(m_Kglobal, i->m_K, &i->m_D); for (e=i->m_firstForward; e; e=e->m_nextForward) { Di->Add(m_Kglobal, i->m_K, e->m_message.GetMessagePtr()); } for (e=i->m_firstBackward; e; e=e->m_nextBackward) { Di->Add(m_Kglobal, i->m_K, e->m_message.GetMessagePtr()); } // normalize Di, update lower bound // vMin = Di->ComputeAndSubtractMin(m_Kglobal, i->m_K); // do not compute lower bound // lowerBound += vMin; // during the forward pass // pass messages from i to nodes with higher m_ordering for (e=i->m_firstForward; e; e=e->m_nextForward) { assert(e->m_tail == i); j = e->m_head; vMin = e->m_message.UpdateMessage(m_Kglobal, i->m_K, j->m_K, Di, e->m_gammaForward, 0, buf); // lowerBound += vMin; // do not compute lower bound during the forward pass } } //////////////////////////////////////////////// // backward pass // //////////////////////////////////////////////// lowerBound = 0; for (i=m_nodeLast; i; i=i->m_prev) { Di->Copy(m_Kglobal, i->m_K, &i->m_D); for (e=i->m_firstBackward; e; e=e->m_nextBackward) { Di->Add(m_Kglobal, i->m_K, e->m_message.GetMessagePtr()); } for (e=i->m_firstForward; e; e=e->m_nextForward) { Di->Add(m_Kglobal, i->m_K, e->m_message.GetMessagePtr()); } // normalize Di, update lower bound vMin = Di->ComputeAndSubtractMin(m_Kglobal, i->m_K); lowerBound += vMin; // pass messages from i to nodes with smaller m_ordering for (e=i->m_firstBackward; e; e=e->m_nextBackward) { assert(e->m_head == i); j = e->m_tail; vMin = e->m_message.UpdateMessage(m_Kglobal, i->m_K, j->m_K, Di, e->m_gammaBackward, 1, buf); lowerBound += vMin; } } //////////////////////////////////////////////// // check stopping criterion // //////////////////////////////////////////////// bool finishFlag = false; if (iter >= options.m_iterMax) finishFlag = true; energy = ComputeSolutionAndEnergy(); REAL rel_gap = (energy - lowerBound)/energy; if (options.m_printMinIter) { mexPrintf("iter: %d ", iter); if (std::isinf(energy)) mexPrintf("lower bound: %g, inconsistent solution. \n", lowerBound); else mexPrintf("energy: %g lower bound: %g rel_gap: %g \n", energy, lowerBound, rel_gap); mexEvalString("drawnow"); } if (rel_gap < options.m_relgapMax) finishFlag = true; // if finishFlag==true terminate if (finishFlag) break; } return iter; }