Beispiel #1
0
float mmf::OptSO3MMFvMF::computeAssignment(uint32_t& N)
{
  N = this->cld_.counts().sum();
  // Compute log of the cluster weights and push them to GPU
  Eigen::VectorXf pi = Eigen::VectorXf::Ones(K()*6)*1000;
  std::cout << this->t_ << std::endl;
  std::cout<<"counts: "<<this->cld_.counts().transpose()<<std::endl;
  if (this->t_ == 0) {
    pi.fill(1.);
  } else if (this->t_ >24) {
    std::cout << "truncating noisy MFs: " << std::endl;
    for (uint32_t k=0; k<K()*6; ++k) {
      float count = this->cld_.counts().middleRows((k/6)*6,6).sum();
      pi(k) =  count > 0.10*N ? count : 1.e-20;
      std::cout << count  << " < " << 0.1*N << std::endl;
    }
  } else {
    // estimate the axis and MF proportions
//    pi += this->cld_.counts();
    // estimate only the MF proportions
    for (uint32_t k=0; k<K()*6; ++k)
      pi(k) += this->cld_.counts().middleRows((k/6)*6,6).sum();
    if (estimateTau_) {
      for (uint32_t k=0; k<K()*6; ++k)
        if (this->cld_.counts()(k) == 0) {
          taus_(k) = 0.; // uniform
        } else {
          Eigen::Vector3f mu = Eigen::Vector3f::Zero();
          mu((k%6)/2) = (k%6)%2==0?-1.:1.;
          mu = Rs_[k/6]*mu;
          taus_(k) = jsc::vMF<3>::MLEstimateTau(
              this->cld_.xSums().col(k).cast<double>(),
              mu.cast<double>(), this->cld_.counts()(k));
        }
    } else {
      taus_.fill(60.);
    }
  }
  std::cout<<"pi: "<<pi.transpose()<<std::endl;
  pi = (pi.array() / pi.sum()).array().log();
  std::cout<<"pi: "<<pi.transpose()<<std::endl;
  if (estimateTau_) {
    std::cout << pi.transpose() << std::endl;
    std::cout << taus_.transpose() << std::endl;
    for (uint32_t k=0; k<K()*6; ++k) {
      pi(k) -= jsc::vMF<3>::log2SinhOverZ(taus_(k)) - log(2.*M_PI);
    }
  }
  pi_.set(pi);
  Rot2Device();
  Eigen::VectorXf residuals = Eigen::VectorXf::Zero(K()*6);
  MMFvMFCostFctAssignmentGPU((float*)residuals.data(), d_cost, &N,
      d_N_, cld_.d_x(), d_weights_, cld_.d_z(), d_mu_, pi_.data(),
      cld_.N(), K());
  return residuals.sum();
};