Example #1
// Executes a k-means step with a KmTree on the given point-set with the given set of centers, and
// tests if it is correct by comparing results with the naive O(nkd) implementation. The cost is
// returned, which is used to determine when k-means has finished.
// Generates an assertion failure if there is an error.
Scalar TestKMeansStep(const KmTree &tree, int n, int k, int d, Scalar *points, Scalar *centers) {
  // Allocate memory
  int *assignment = (int*)malloc(n * sizeof(int));
  Scalar *old_centers = (Scalar*)malloc(k * d * sizeof(Scalar));
  Scalar *new_sums = (Scalar*)calloc(k * d, sizeof(Scalar));
  int *new_counts = (int*)calloc(k, sizeof(int));
  Scalar *bad_center = PointAllocate(d);
  if (assignment == 0 || old_centers == 0 || new_sums == 0 || new_counts == 0 || bad_center == 0)
    MemoryAssertFail("running unit test");
  memset(bad_center, 0xff, d * sizeof(Scalar));
  memcpy(old_centers, centers, k * d * sizeof(Scalar));

  // Run fancy k-means
  Scalar fancy_cost = tree.DoKMeansStep(k, centers, assignment);

  // Test the assignments and build the correct aggregate data
  Scalar correct_cost = 0;
  for (int i = 0; i < n; i++) {
    Scalar fancy_dist_sq = PointDistSq(points + i*d, old_centers + assignment[i]*d, d);
    for (int j = 0; j < k; j++)
    if (memcmp(old_centers + j*d, bad_center, d*sizeof(Scalar)) != 0) {
      Scalar dist_sq = PointDistSq(points + i*d, old_centers + j*d, d);
      UnitTestAssert(TestScalarsGe(dist_sq, fancy_dist_sq),
                     "k-means assigned point to the wrong cluster");

    // Note that the cost is measured from the OLD centers, not from the new centers
    correct_cost += fancy_dist_sq;
    PointAdd(new_sums + assignment[i]*d, points + i*d, d);

  // Test the costs
  UnitTestAssert(TestScalarsEq(correct_cost, fancy_cost),
                 "k-means calculated the cost function incorrectly");

  // Test the centers
  for (int i = 0; i < k; i++) {
    bool fancy_is_void = (memcmp(centers + i*d, bad_center, d*sizeof(Scalar)) == 0);
    bool correct_is_void = (new_counts[i] == 0);
    UnitTestAssert(fancy_is_void == correct_is_void,
                   "k-means failed to correctly mark whether a center was being used");
    if (!fancy_is_void) {
      PointScale(new_sums + i*d, Scalar(1) / new_counts[i], d);
      for (int j = 0; j < d; j++) {
        UnitTestAssert(TestScalarsEq(new_sums[i*d + j], centers[i*d + j]),
                       "k-means failed to set a center correctly");

  // Free memory
  return fancy_cost;
Example #2
// See KMeans.h
// Performs one full execution of k-means, logging any relevant information, and
// tracking meta
// statistics for the run. If min or max values are negative, they are treated
// as unset.
// best_centers and best_assignment can be 0, in which case they are not set.
static void RunKMeansOnce(
	const KmTree& tree, int n, int k, int d, Scalar* points, Scalar* centers,
	Scalar* min_cost, Scalar* max_cost, Scalar* total_cost, double start_time,
	double* min_time, double* max_time, double* total_time,
	Scalar* best_centers, int* best_assignment)
	const Scalar kEpsilon =
		Scalar(1e-8);  // Used to determine when to terminate k-means

	// Do iterations of k-means until the cost stabilizes
	Scalar old_cost = 0;
	bool is_done = false;
	for (int iteration = 0; !is_done; iteration++)
		Scalar new_cost = tree.DoKMeansStep(k, centers, 0);
		is_done = (iteration > 0 && new_cost >= (1 - kEpsilon) * old_cost);
		old_cost = new_cost;
		LOG(true, "Completed iteration #" << (iteration + 1) << ", cost="
										  << new_cost << "..." << endl);
	double this_time = GetSeconds() - start_time;

	// Log the clustering we found
	LOG(false, "Completed run: cost=" << old_cost << " (" << this_time
									  << " seconds)" << endl);

	// Handle a new min cost, updating best_centers and best_assignment as
	// appropriate
	if (*min_cost < 0 || old_cost < *min_cost)
		*min_cost = old_cost;
		if (best_assignment != 0)
			tree.DoKMeansStep(k, centers, best_assignment);
		if (best_centers != 0)
			memcpy(best_centers, centers, sizeof(Scalar) * k * d);

	// Update all other aggregate stats
	if (*max_cost < old_cost) *max_cost = old_cost;
	*total_cost += old_cost;
	if (*min_time < 0 || *min_time > this_time) *min_time = this_time;
	if (*max_time < this_time) *max_time = this_time;
	*total_time += this_time;