int main(int argc, char *argv[]) { // Give CLI the command line parameters the user passed in. CLI::ParseCommandLine(argc, argv); if (CLI::GetParam<int>("seed") != 0) math::RandomSeed((size_t) CLI::GetParam<int>("seed")); else math::RandomSeed((size_t) time(NULL)); // Get all the parameters. size_t k = CLI::GetParam<int>("k"); size_t secondHashSize = CLI::GetParam<int>("second_hash_size"); size_t bucketSize = CLI::GetParam<int>("bucket_size"); if (CLI::HasParam("input_model") && CLI::HasParam("reference")) { Log::Fatal << "Cannot specify both --reference_file and --input_model_file!" << " Either create a new model with --reference_file or use an existing" << " model with --input_model_file." << endl; } if (!CLI::HasParam("input_model") && !CLI::HasParam("reference")) { Log::Fatal << "Must specify either --input_model_file or --reference_file!" << endl; } if (!CLI::HasParam("neighbors") && !CLI::HasParam("distances") && !CLI::HasParam("output_model")) { Log::Warn << "Neither --neighbors_file, --distances_file, nor " << "--output_model_file are specified; no results will be saved." << endl; } if ((CLI::HasParam("query") && !CLI::HasParam("k")) || (!CLI::HasParam("query") && !CLI::HasParam("reference") && CLI::HasParam("k"))) { Log::Fatal << "Both --query_file or --reference_file and --k must be " << "specified if search is to be done!" << endl; } if (CLI::HasParam("input_model") && CLI::HasParam("k") && !CLI::HasParam("query")) { Log::Info << "Performing LSH-based approximate nearest neighbor search on " << "the reference dataset in the model stored in '" << CLI::GetUnmappedParam<LSHSearch<>>("input_model") << "'." << endl; } if (!CLI::HasParam("k") && CLI::HasParam("neighbors")) Log::Warn << "--neighbors_file ignored because --k is not specified." << endl; if (!CLI::HasParam("k") && CLI::HasParam("distances")) Log::Warn << "--distances_file ignored because --k is not specified." << endl; // These declarations are here so that the matrices don't go out of scope. arma::mat referenceData; arma::mat queryData; // Pick up the LSH-specific parameters. const size_t numProj = CLI::GetParam<int>("projections"); const size_t numTables = CLI::GetParam<int>("tables"); const double hashWidth = CLI::GetParam<double>("hash_width"); const size_t numProbes = (size_t) CLI::GetParam<int>("num_probes"); arma::Mat<size_t> neighbors; arma::mat distances; if (hashWidth == 0.0) Log::Info << "Using LSH with " << numProj << " projections (K) and " << numTables << " tables (L) with default hash width." << endl; else Log::Info << "Using LSH with " << numProj << " projections (K) and " << numTables << " tables (L) with hash width(r): " << hashWidth << endl; LSHSearch<> allkann; if (CLI::HasParam("reference")) { referenceData = std::move(CLI::GetParam<arma::mat>("reference")); Log::Info << "Loaded reference data from '" << CLI::GetUnmappedParam<arma::mat>("reference") << "' (" << referenceData.n_rows << " x " << referenceData.n_cols << ")." << endl; Timer::Start("hash_building"); allkann.Train(referenceData, numProj, numTables, hashWidth, secondHashSize, bucketSize); Timer::Stop("hash_building"); } else if (CLI::HasParam("input_model")) { allkann = std::move(CLI::GetParam<LSHSearch<>>("input_model")); } if (CLI::HasParam("k")) { Log::Info << "Computing " << k << " distance approximate nearest neighbors." << endl; if (CLI::HasParam("query")) { queryData = std::move(CLI::GetParam<arma::mat>("query")); Log::Info << "Loaded query data from '" << CLI::GetUnmappedParam<arma::mat>("query") << "' (" << queryData.n_rows << " x " << queryData.n_cols << ")." << endl; allkann.Search(queryData, k, neighbors, distances, 0, numProbes); } else { allkann.Search(k, neighbors, distances, 0, numProbes); } } Log::Info << "Neighbors computed." << endl; // Compute recall, if desired. if (CLI::HasParam("true_neighbors")) { // Load the true neighbors. arma::Mat<size_t> trueNeighbors = std::move(CLI::GetParam<arma::Mat<size_t>>("true_neighbors")); Log::Info << "Loaded true neighbor indices from '" << CLI::GetUnmappedParam<arma::Mat<size_t>>("true_neighbors") << "'." << endl; // Compute recall and print it. double recallPercentage = 100 * allkann.ComputeRecall(neighbors, trueNeighbors); Log::Info << "Recall: " << recallPercentage << endl; } // Save output, if desired. if (CLI::HasParam("distances")) CLI::GetParam<arma::mat>("distances") = std::move(distances); if (CLI::HasParam("neighbors")) CLI::GetParam<arma::Mat<size_t>>("neighbors") = std::move(neighbors); if (CLI::HasParam("output_model")) CLI::GetParam<LSHSearch<>>("output_model") = std::move(allkann); CLI::Destroy(); }
int main(int argc, char *argv[]) { // Give CLI the command line parameters the user passed in. CLI::ParseCommandLine(argc, argv); if (CLI::GetParam<int>("seed") != 0) math::RandomSeed((size_t) CLI::GetParam<int>("seed")); else math::RandomSeed((size_t) time(NULL)); // Get all the parameters. string referenceFile = CLI::GetParam<string>("reference_file"); string distancesFile = CLI::GetParam<string>("distances_file"); string neighborsFile = CLI::GetParam<string>("neighbors_file"); size_t k = CLI::GetParam<int>("k"); size_t secondHashSize = CLI::GetParam<int>("second_hash_size"); size_t bucketSize = CLI::GetParam<int>("bucket_size"); arma::mat referenceData; arma::mat queryData; // So it doesn't go out of scope. data::Load(referenceFile, referenceData, true); Log::Info << "Loaded reference data from '" << referenceFile << "' (" << referenceData.n_rows << " x " << referenceData.n_cols << ")." << endl; // Sanity check on k value: must be greater than 0, must be less than the // number of reference points. if (k > referenceData.n_cols) { Log::Fatal << "Invalid k: " << k << "; must be greater than 0 and less "; Log::Fatal << "than or equal to the number of reference points ("; Log::Fatal << referenceData.n_cols << ")." << endl; } // Pick up the LSH-specific parameters. const size_t numProj = CLI::GetParam<int>("projections"); const size_t numTables = CLI::GetParam<int>("tables"); const double hashWidth = CLI::GetParam<double>("hash_width"); arma::Mat<size_t> neighbors; arma::mat distances; if (CLI::GetParam<string>("query_file") != "") { string queryFile = CLI::GetParam<string>("query_file"); data::Load(queryFile, queryData, true); Log::Info << "Loaded query data from '" << queryFile << "' (" << queryData.n_rows << " x " << queryData.n_cols << ")." << endl; } if (hashWidth == 0.0) Log::Info << "Using LSH with " << numProj << " projections (K) and " << numTables << " tables (L) with default hash width." << endl; else Log::Info << "Using LSH with " << numProj << " projections (K) and " << numTables << " tables (L) with hash width(r): " << hashWidth << endl; Timer::Start("hash_building"); LSHSearch<>* allkann; if (CLI::GetParam<string>("query_file") != "") allkann = new LSHSearch<>(referenceData, queryData, numProj, numTables, hashWidth, secondHashSize, bucketSize); else allkann = new LSHSearch<>(referenceData, numProj, numTables, hashWidth, secondHashSize, bucketSize); Timer::Stop("hash_building"); Log::Info << "Computing " << k << " distance approximate nearest neighbors " << endl; allkann->Search(k, neighbors, distances); Log::Info << "Neighbors computed." << endl; // Save output. if (distancesFile != "") data::Save(distancesFile, distances); if (neighborsFile != "") data::Save(neighborsFile, neighbors); delete allkann; }