Esempio n. 1
0
// m_ij(x_j) = sum_xi {phi(i)*phi(i,j)*prod_{u \in N(i)\j} {m_uj(xi)}}
// m_ij(x_j) = max_xi {phi(i)*phi(i,j)*prod_{u \in N(i)\j} {m_uj(xi)}}
void InferenceEngineBP::sendMessage(BPNode* xi, BPNode* xj, dVector* phi_i, dMatrix** phi_ij, dVector** msg)
{    
	// potential(i) -> Vi
	dVector Vi(phi_i[xi->id]);
	
	// potential(i,j) -> Mij
	dMatrix Mij;
	if( xi->id < xj->id )
		Mij.set( phi_ij[xi->id][xj->id] );
	else {
		Mij.set( phi_ij[xj->id][xi->id] );
		Mij.transpose();
	}

	// prod_{u \in N(i)\j} {m_ui(xi)}	-> Vi
	std::list<BPNode*>::iterator it;
	for(it=xi->neighbors.begin(); it!=xi->neighbors.end(); it++) {
		if( xj->equal(*it) ) continue;
		Vi.add( msg[(*it)->id][xi->id] );
	}

	if( isSumProduct )
		logMultiply( Vi, Mij, msg[xi->id][xj->id] ); 
	else
		logMultiplyMaxProd( Vi, Mij, msg[xi->id][xj->id] ); 
}  
// m_ij(x_j) = sum_xi {potential(i)*potential(i,j)*prod_{u \in N(i)\j} {m_ui(xi)}}
void InferenceEngineLoopyBP::sendMessage(int xi, int xj, int nbNodes, const Beliefs potentials,
                                         std::vector<dVector>& messages, iMatrix adjMat, int adjMatMax, bool bMaxProd)
{
  int max_hi=-1; // for Viterbi decoding
  
  // potential(i)
  dVector Vi(potentials.belStates[xi]);
  
  // potential(i,j)
  dMatrix Mij(potentials.belEdges[adjMat(xi,xj)-1]);
  if( xi>xj ) Mij.transpose();
  
  // prod_{u \in N(i)\j} {m_ui(xi)}}
  int msg_idx;
  for( int xu=0; xu<nbNodes; xu++ ) {
    if( !adjMat(xu,xi) || xu==xj ) continue;
    msg_idx = (xu>xi) ? adjMatMax+adjMat(xu,xi)-1 : adjMat(xu,xi)-1;
    Vi.add(messages[msg_idx]);
  }
  
  // m_ij(xj) = Vi \dot Mij
  msg_idx = (xi>xj) ? adjMatMax+adjMat(xi,xj)-1 : adjMat(xi,xj)-1;
  if( bMaxProd )
    max_hi = logMultiplyMaxProd(Vi, Mij, messages[msg_idx]);
  else
    logMultiply(Vi, Mij, messages[msg_idx]);
  
  // Normalize messages to avoid numerical over/under-flow
  // Make \sum_{xj} m_ij(xj)=1. Other methods could also be used.
  double min = messages[msg_idx].min();
  if( min < 0 ) messages[msg_idx].add(-min);
  messages[msg_idx].multiply(1/messages[msg_idx].sum());
}