int main(int argc, char** argv) { // Contains query points. arma::mat query(2,3); query << 5 << 6 << 2.75 << arma::endr << 1.45 << 2 << 0.75 << arma::endr; // Load the data. arma::mat data; data::Load("../data/fisheriris_data.csv", data, true); // Choose the last two columns. data = data.rows(2,3); std::vector<size_t> oldFromNewRefs; std::vector<size_t> oldFromNewQueries; arma::mat distancesOut; arma::Mat<size_t> neighborsOut; // Calculate the all 2-nearest-neighbors . BinarySpaceTree<bound::HRectBound<2>, QueryStat<NearestNeighborSort> > refTree(data, oldFromNewRefs, 20); BinarySpaceTree<bound::HRectBound<2>, QueryStat<NearestNeighborSort> >* queryTree = new BinarySpaceTree<bound::HRectBound<2>, QueryStat<NearestNeighborSort> >(query, oldFromNewQueries, 20); AllkNN* allknn = new AllkNN(&refTree, queryTree, data, query, false); allknn->Search(2, neighborsOut, distancesOut); arma::Mat<size_t> neighbors; arma::mat distances; // Map the results back to the correct places. Unmap(neighborsOut, distancesOut, oldFromNewRefs, oldFromNewQueries, neighbors, distances); // Clean up. delete queryTree; delete allknn; // Show the results. Log::cout << distances.t() << std::endl; Log::cout << neighbors.t() << std::endl; return 0; }
int main(int argc, char *argv[]) { // Give CLI the command line parameters the user passed in. CLI::ParseCommandLine(argc, argv); if (CLI::GetParam<int>("seed") != 0) math::RandomSeed((size_t) CLI::GetParam<int>("seed")); else math::RandomSeed((size_t) std::time(NULL)); // Get all the parameters. const string referenceFile = CLI::GetParam<string>("reference_file"); const string queryFile = CLI::GetParam<string>("query_file"); const string distancesFile = CLI::GetParam<string>("distances_file"); const string neighborsFile = CLI::GetParam<string>("neighbors_file"); int lsInt = CLI::GetParam<int>("leaf_size"); size_t k = CLI::GetParam<int>("k"); bool naive = CLI::HasParam("naive"); bool singleMode = CLI::HasParam("single_mode"); const bool randomBasis = CLI::HasParam("random_basis"); arma::mat referenceData; arma::mat queryData; // So it doesn't go out of scope. data::Load(referenceFile, referenceData, true); Log::Info << "Loaded reference data from '" << referenceFile << "' (" << referenceData.n_rows << " x " << referenceData.n_cols << ")." << endl; if (queryFile != "") { data::Load(queryFile, queryData, true); Log::Info << "Loaded query data from '" << queryFile << "' (" << queryData.n_rows << " x " << queryData.n_cols << ")." << endl; } // Sanity check on k value: must be greater than 0, must be less than the // number of reference points. if (k > referenceData.n_cols) { Log::Fatal << "Invalid k: " << k << "; must be greater than 0 and less "; Log::Fatal << "than or equal to the number of reference points ("; Log::Fatal << referenceData.n_cols << ")." << endl; } // Sanity check on leaf size. if (lsInt < 0) { Log::Fatal << "Invalid leaf size: " << lsInt << ". Must be greater " "than or equal to 0." << endl; } size_t leafSize = lsInt; // Naive mode overrides single mode. if (singleMode && naive) { Log::Warn << "--single_mode ignored because --naive is present." << endl; } if (naive) leafSize = referenceData.n_cols; // See if we want to project onto a random basis. if (randomBasis) { // Generate the random basis. while (true) { // [Q, R] = qr(randn(d, d)); // Q = Q * diag(sign(diag(R))); arma::mat q, r; if (arma::qr(q, r, arma::randn<arma::mat>(referenceData.n_rows, referenceData.n_rows))) { arma::vec rDiag(r.n_rows); for (size_t i = 0; i < rDiag.n_elem; ++i) { if (r(i, i) < 0) rDiag(i) = -1; else if (r(i, i) > 0) rDiag(i) = 1; else rDiag(i) = 0; } q *= arma::diagmat(rDiag); // Check if the determinant is positive. if (arma::det(q) >= 0) { referenceData = q * referenceData; if (queryFile != "") queryData = q * queryData; break; } } } } arma::Mat<size_t> neighbors; arma::mat distances; if (!CLI::HasParam("cover_tree")) { // Because we may construct it differently, we need a pointer. AllkNN* allknn = NULL; // Mappings for when we build the tree. std::vector<size_t> oldFromNewRefs; // Build trees by hand, so we can save memory: if we pass a tree to // NeighborSearch, it does not copy the matrix. Log::Info << "Building reference tree..." << endl; Timer::Start("tree_building"); BinarySpaceTree<bound::HRectBound<2>, NeighborSearchStat<NearestNeighborSort> > refTree(referenceData, oldFromNewRefs, leafSize); BinarySpaceTree<bound::HRectBound<2>, NeighborSearchStat<NearestNeighborSort> >* queryTree = NULL; // Empty for now. Timer::Stop("tree_building"); std::vector<size_t> oldFromNewQueries; if (CLI::GetParam<string>("query_file") != "") { if (naive && leafSize < queryData.n_cols) leafSize = queryData.n_cols; Log::Info << "Loaded query data from '" << queryFile << "' (" << queryData.n_rows << " x " << queryData.n_cols << ")." << endl; Log::Info << "Building query tree..." << endl; // Build trees by hand, so we can save memory: if we pass a tree to // NeighborSearch, it does not copy the matrix. if (!singleMode) { Timer::Start("tree_building"); queryTree = new BinarySpaceTree<bound::HRectBound<2>, NeighborSearchStat<NearestNeighborSort> >(queryData, oldFromNewQueries, leafSize); Timer::Stop("tree_building"); } allknn = new AllkNN(&refTree, queryTree, referenceData, queryData, singleMode); Log::Info << "Tree built." << endl; } else { allknn = new AllkNN(&refTree, referenceData, singleMode); Log::Info << "Trees built." << endl; } arma::mat distancesOut; arma::Mat<size_t> neighborsOut; Log::Info << "Computing " << k << " nearest neighbors..." << endl; allknn->Search(k, neighborsOut, distancesOut); Log::Info << "Neighbors computed." << endl; // We have to map back to the original indices from before the tree // construction. Log::Info << "Re-mapping indices..." << endl; // Map the results back to the correct places. if ((CLI::GetParam<string>("query_file") != "") && !singleMode) Unmap(neighborsOut, distancesOut, oldFromNewRefs, oldFromNewQueries, neighbors, distances); else if ((CLI::GetParam<string>("query_file") != "") && singleMode) Unmap(neighborsOut, distancesOut, oldFromNewRefs, neighbors, distances); else Unmap(neighborsOut, distancesOut, oldFromNewRefs, oldFromNewRefs, neighbors, distances); // Clean up. if (queryTree) delete queryTree; delete allknn; } else // Cover trees. { // Make sure to notify the user that they are using cover trees. Log::Info << "Using cover trees for nearest-neighbor calculation." << endl; // Build our reference tree. Log::Info << "Building reference tree..." << endl; Timer::Start("tree_building"); CoverTree<metric::LMetric<2, true>, tree::FirstPointIsRoot, NeighborSearchStat<NearestNeighborSort> > referenceTree(referenceData, 1.3); CoverTree<metric::LMetric<2, true>, tree::FirstPointIsRoot, NeighborSearchStat<NearestNeighborSort> >* queryTree = NULL; Timer::Stop("tree_building"); NeighborSearch<NearestNeighborSort, metric::LMetric<2, true>, CoverTree<metric::LMetric<2, true>, tree::FirstPointIsRoot, NeighborSearchStat<NearestNeighborSort> > >* allknn = NULL; // See if we have query data. if (CLI::HasParam("query_file")) { // Build query tree. if (!singleMode) { Log::Info << "Building query tree..." << endl; Timer::Start("tree_building"); queryTree = new CoverTree<metric::LMetric<2, true>, tree::FirstPointIsRoot, NeighborSearchStat<NearestNeighborSort> >( queryData, 1.3); Timer::Stop("tree_building"); } allknn = new NeighborSearch<NearestNeighborSort, metric::LMetric<2, true>, CoverTree<metric::LMetric<2, true>, tree::FirstPointIsRoot, NeighborSearchStat<NearestNeighborSort> > >(&referenceTree, queryTree, referenceData, queryData, singleMode); } else { allknn = new NeighborSearch<NearestNeighborSort, metric::LMetric<2, true>, CoverTree<metric::LMetric<2, true>, tree::FirstPointIsRoot, NeighborSearchStat<NearestNeighborSort> > >(&referenceTree, referenceData, singleMode); } Log::Info << "Computing " << k << " nearest neighbors..." << endl; allknn->Search(k, neighbors, distances); Log::Info << "Neighbors computed." << endl; delete allknn; if (queryTree) delete queryTree; } // Save output. data::Save(distancesFile, distances); data::Save(neighborsFile, neighbors); }