// Executes a k-means step with a KmTree on the given point-set with the given set of centers, and // tests if it is correct by comparing results with the naive O(nkd) implementation. The cost is // returned, which is used to determine when k-means has finished. // // Generates an assertion failure if there is an error. Scalar TestKMeansStep(const KmTree &tree, int n, int k, int d, Scalar *points, Scalar *centers) { // Allocate memory int *assignment = (int*)malloc(n * sizeof(int)); Scalar *old_centers = (Scalar*)malloc(k * d * sizeof(Scalar)); Scalar *new_sums = (Scalar*)calloc(k * d, sizeof(Scalar)); int *new_counts = (int*)calloc(k, sizeof(int)); Scalar *bad_center = PointAllocate(d); if (assignment == 0 || old_centers == 0 || new_sums == 0 || new_counts == 0 || bad_center == 0) MemoryAssertFail("running unit test"); memset(bad_center, 0xff, d * sizeof(Scalar)); memcpy(old_centers, centers, k * d * sizeof(Scalar)); // Run fancy k-means Scalar fancy_cost = tree.DoKMeansStep(k, centers, assignment); // Test the assignments and build the correct aggregate data Scalar correct_cost = 0; for (int i = 0; i < n; i++) { Scalar fancy_dist_sq = PointDistSq(points + i*d, old_centers + assignment[i]*d, d); for (int j = 0; j < k; j++) if (memcmp(old_centers + j*d, bad_center, d*sizeof(Scalar)) != 0) { Scalar dist_sq = PointDistSq(points + i*d, old_centers + j*d, d); UnitTestAssert(TestScalarsGe(dist_sq, fancy_dist_sq), "k-means assigned point to the wrong cluster"); } // Note that the cost is measured from the OLD centers, not from the new centers correct_cost += fancy_dist_sq; PointAdd(new_sums + assignment[i]*d, points + i*d, d); new_counts[assignment[i]]++; } // Test the costs UnitTestAssert(TestScalarsEq(correct_cost, fancy_cost), "k-means calculated the cost function incorrectly"); // Test the centers for (int i = 0; i < k; i++) { bool fancy_is_void = (memcmp(centers + i*d, bad_center, d*sizeof(Scalar)) == 0); bool correct_is_void = (new_counts[i] == 0); UnitTestAssert(fancy_is_void == correct_is_void, "k-means failed to correctly mark whether a center was being used"); if (!fancy_is_void) { PointScale(new_sums + i*d, Scalar(1) / new_counts[i], d); for (int j = 0; j < d; j++) { UnitTestAssert(TestScalarsEq(new_sums[i*d + j], centers[i*d + j]), "k-means failed to set a center correctly"); } } } // Free memory PointFree(bad_center); free(new_counts); free(new_sums); free(old_centers); free(assignment); return fancy_cost; }
// See KMeans.h // Performs one full execution of k-means, logging any relevant information, and // tracking meta // statistics for the run. If min or max values are negative, they are treated // as unset. // best_centers and best_assignment can be 0, in which case they are not set. static void RunKMeansOnce( const KmTree& tree, int n, int k, int d, Scalar* points, Scalar* centers, Scalar* min_cost, Scalar* max_cost, Scalar* total_cost, double start_time, double* min_time, double* max_time, double* total_time, Scalar* best_centers, int* best_assignment) { MRPT_UNUSED_PARAM(n); MRPT_UNUSED_PARAM(points); const Scalar kEpsilon = Scalar(1e-8); // Used to determine when to terminate k-means // Do iterations of k-means until the cost stabilizes Scalar old_cost = 0; bool is_done = false; for (int iteration = 0; !is_done; iteration++) { Scalar new_cost = tree.DoKMeansStep(k, centers, 0); is_done = (iteration > 0 && new_cost >= (1 - kEpsilon) * old_cost); old_cost = new_cost; LOG(true, "Completed iteration #" << (iteration + 1) << ", cost=" << new_cost << "..." << endl); } double this_time = GetSeconds() - start_time; // Log the clustering we found LOG(false, "Completed run: cost=" << old_cost << " (" << this_time << " seconds)" << endl); // Handle a new min cost, updating best_centers and best_assignment as // appropriate if (*min_cost < 0 || old_cost < *min_cost) { *min_cost = old_cost; if (best_assignment != 0) tree.DoKMeansStep(k, centers, best_assignment); if (best_centers != 0) memcpy(best_centers, centers, sizeof(Scalar) * k * d); } // Update all other aggregate stats if (*max_cost < old_cost) *max_cost = old_cost; *total_cost += old_cost; if (*min_time < 0 || *min_time > this_time) *min_time = this_time; if (*max_time < this_time) *max_time = this_time; *total_time += this_time; }