Ejemplo n.º 1
0
inline void MeanShift<UseKernel, KernelType, MatType>::Cluster(
    const MatType& data,
    arma::Col<size_t>& assignments,
    arma::mat& centroids,
    bool useSeeds)
{
  if (radius <= 0)
  {
    // An invalid radius is given; an estimation is needed.
    Radius(EstimateRadius(data));
  }

  MatType seeds;
  const MatType* pSeeds = &data;
  if (useSeeds)
  {
    GenSeeds(data, radius, 1, seeds);
    pSeeds = &seeds;
  }

  // Holds all centroids before removing duplicate ones.
  arma::mat allCentroids(pSeeds->n_rows, pSeeds->n_cols);

  assignments.set_size(data.n_cols);

  range::RangeSearch<> rangeSearcher(data);
  math::Range validRadius(0, radius);
  std::vector<std::vector<size_t> > neighbors;
  std::vector<std::vector<double> > distances;

  // For each seed, perform mean shift algorithm.
  for (size_t i = 0; i < pSeeds->n_cols; ++i)
  {
    // Initial centroid is the seed itself.
    allCentroids.col(i) = pSeeds->unsafe_col(i);
    for (size_t completedIterations = 0; completedIterations < maxIterations;
         completedIterations++)
    {
      // Store new centroid in this.
      arma::colvec newCentroid = arma::zeros<arma::colvec>(pSeeds->n_rows);

      rangeSearcher.Search(allCentroids.unsafe_col(i), validRadius,
          neighbors, distances);
      if (neighbors[0].size() <= 1)
        break;

      // Calculate new centroid.
      if (!CalculateCentroid(data, neighbors[0], distances[0], newCentroid))
        newCentroid = allCentroids.unsafe_col(i);

      // If the mean shift vector is small enough, it has converged.
      if (metric::EuclideanDistance::Evaluate(newCentroid,
          allCentroids.unsafe_col(i)) < 1e-3 * radius)
      {
        // Determine if the new centroid is duplicate with old ones.
        bool isDuplicated = false;
        for (size_t k = 0; k < centroids.n_cols; ++k)
        {
          const double distance = metric::EuclideanDistance::Evaluate(
              allCentroids.unsafe_col(i), centroids.unsafe_col(k));
          if (distance < radius)
          {
            isDuplicated = true;
            break;
          }
        }

        if (!isDuplicated)
          centroids.insert_cols(centroids.n_cols, allCentroids.unsafe_col(i));

        // Get out of the loop.
        break;
      }

      // Update the centroid.
      allCentroids.col(i) = newCentroid;
    }
  }

  // Assign centroids to each point.
  neighbor::KNN neighborSearcher(centroids);
  arma::mat neighborDistances;
  arma::Mat<size_t> resultingNeighbors;
  neighborSearcher.Search(data, 1, resultingNeighbors, neighborDistances);
  assignments = resultingNeighbors.t();
}