int main(int argc, char *argv[]) { // Give CLI the command line parameters the user passed in. CLI::ParseCommandLine(argc, argv); if (CLI::GetParam<int>("seed") != 0) math::RandomSeed((size_t) CLI::GetParam<int>("seed")); else math::RandomSeed((size_t) std::time(NULL)); // A user cannot specify both reference data and a model. if (CLI::HasParam("reference_file") && CLI::HasParam("input_model_file")) Log::Fatal << "Only one of --reference_file (-r) or --input_model_file (-m)" << " may be specified!" << endl; // A user must specify one of them... if (!CLI::HasParam("reference_file") && !CLI::HasParam("input_model_file")) Log::Fatal << "No model specified (--input_model_file) and no reference " << "data specified (--reference_file)! One must be provided." << endl; if (CLI::HasParam("input_model_file")) { // Notify the user of parameters that will be ignored. if (CLI::HasParam("tree_type")) Log::Warn << "--tree_type (-t) will be ignored because --input_model_file" << " is specified." << endl; if (CLI::HasParam("random_basis")) Log::Warn << "--random_basis (-R) will be ignored because " << "--input_model_file is specified." << endl; // Notify the user of parameters that will be only be considered for query // tree. if (CLI::HasParam("leaf_size")) Log::Warn << "--leaf_size (-l) will only be considered for the query " "tree, because --input_model_file is specified." << endl; } // The user should give something to do... if (!CLI::HasParam("k") && !CLI::HasParam("output_model_file")) Log::Warn << "Neither -k nor --output_model_file are specified, so no " << "results from this program will be saved!" << endl; // If the user specifies k but no output files, they should be warned. if (CLI::HasParam("k") && !(CLI::HasParam("neighbors_file") || CLI::HasParam("distances_file"))) Log::Warn << "Neither --neighbors_file nor --distances_file is specified, " << "so the furthest neighbor search results will not be saved!" << endl; // If the user specifies output files but no k, they should be warned. if ((CLI::HasParam("neighbors_file") || CLI::HasParam("distances_file")) && !CLI::HasParam("k")) Log::Warn << "An output file for furthest neighbor search is given (" << "--neighbors_file or --distances_file), but furthest neighbor search" << " is not being performed because k (--k) is not specified! No " << "results will be saved." << endl; if (!CLI::HasParam("k") && CLI::HasParam("true_neighbors_file")) Log::Warn << "--true_neighbors_file (-T) ignored because no search is being" << " performed (--k is not specified)." << endl; if (!CLI::HasParam("k") && CLI::HasParam("true_distances_file")) Log::Warn << "--true_distances_file (-D) ignored because no search is being" << " performed (--k is not specified)." << endl; // Sanity check on leaf size. const int lsInt = CLI::GetParam<int>("leaf_size"); if (lsInt < 1) Log::Fatal << "Invalid leaf size: " << lsInt << ". Must be greater than 0." << endl; // Sanity check on epsilon. double epsilon = CLI::GetParam<double>("epsilon"); if (epsilon < 0 || epsilon >= 1) Log::Fatal << "Invalid epsilon: " << epsilon << ". Must be in the range " << "[0,1)." << endl; // Sanity check on percentage. const double percentage = CLI::GetParam<double>("percentage"); if (percentage <= 0 || percentage > 1) Log::Fatal << "Invalid percentage: " << percentage << ". Must be in the " << "range (0,1] (decimal form)." << endl; if (CLI::HasParam("percentage") && CLI::HasParam("epsilon")) Log::Fatal << "Cannot provide both epsilon and percentage." << endl; if (CLI::HasParam("percentage")) epsilon = 1 - percentage; // We either have to load the reference data, or we have to load the model. NSModel<FurthestNeighborSort> kfn; const string algorithm = CLI::GetParam<string>("algorithm"); NeighborSearchMode searchMode = DUAL_TREE_MODE; if (algorithm == "naive") searchMode = NAIVE_MODE; else if (algorithm == "single_tree") searchMode = SINGLE_TREE_MODE; else if (algorithm == "dual_tree") searchMode = DUAL_TREE_MODE; else if (algorithm == "greedy") searchMode = GREEDY_SINGLE_TREE_MODE; else Log::Fatal << "Unknown neighbor search algorithm '" << algorithm << "'; " << "valid choices are 'naive', 'single_tree', 'dual_tree' and 'greedy'." << endl; if (CLI::HasParam("single_mode")) { searchMode = SINGLE_TREE_MODE; Log::Warn << "--single_mode is deprecated. Will be removed in mlpack " "3.0.0. Use '--algorithm single_tree' instead." << endl; if (CLI::HasParam("algorithm") && algorithm != "single_tree") Log::Fatal << "Contradiction between options --algorithm " << algorithm << " and --single_mode." << endl; } if (CLI::HasParam("naive")) { searchMode = NAIVE_MODE; Log::Warn << "--naive is deprecated. Will be removed in mlpack 3.0.0. Use " "'--algorithm naive' instead." << endl; if (CLI::HasParam("algorithm") && algorithm != "naive") Log::Fatal << "Contradiction between options --algorithm " << algorithm << " and --naive." << endl; if (CLI::HasParam("single_mode")) Log::Warn << "--single_mode ignored because --naive is present." << endl; } if (CLI::HasParam("reference_file")) { // Get all the parameters. const string referenceFile = CLI::GetParam<string>("reference_file"); const string treeType = CLI::GetParam<string>("tree_type"); const bool randomBasis = CLI::HasParam("random_basis"); KFNModel::TreeTypes tree = KFNModel::KD_TREE; if (treeType == "kd") tree = KFNModel::KD_TREE; else if (treeType == "cover") tree = KFNModel::COVER_TREE; else if (treeType == "r") tree = KFNModel::R_TREE; else if (treeType == "r-star") tree = KFNModel::R_STAR_TREE; else if (treeType == "ball") tree = KFNModel::BALL_TREE; else if (treeType == "x") tree = KFNModel::X_TREE; else if (treeType == "hilbert-r") tree = KFNModel::HILBERT_R_TREE; else if (treeType == "r-plus") tree = KFNModel::R_PLUS_TREE; else if (treeType == "r-plus-plus") tree = KFNModel::R_PLUS_PLUS_TREE; else if (treeType == "vp") tree = KFNModel::VP_TREE; else if (treeType == "rp") tree = KFNModel::RP_TREE; else if (treeType == "max-rp") tree = KFNModel::MAX_RP_TREE; else if (treeType == "ub") tree = KFNModel::UB_TREE; else if (treeType == "oct") tree = KFNModel::OCTREE; else Log::Fatal << "Unknown tree type '" << treeType << "'; valid choices are " << "'kd', 'vp', 'rp', 'max-rp', 'ub', 'cover', 'r', 'r-star', 'x', " << "'ball', 'hilbert-r', 'r-plus', 'r-plus-plus', and 'oct'." << endl; kfn.TreeType() = tree; kfn.RandomBasis() = randomBasis; arma::mat referenceSet; data::Load(referenceFile, referenceSet, true); Log::Info << "Loaded reference data from '" << referenceFile << "' (" << referenceSet.n_rows << "x" << referenceSet.n_cols << ")." << endl; kfn.BuildModel(std::move(referenceSet), size_t(lsInt), searchMode, epsilon); } else { // Load the model from file. const string inputModelFile = CLI::GetParam<string>("input_model_file"); data::Load(inputModelFile, "kfn_model", kfn, true); // Fatal on failure. // Adjust search mode. kfn.SearchMode() = searchMode; kfn.Epsilon() = epsilon; // If leaf_size wasn't provided, let's consider the current value in the // loaded model. Else, update it (only considered when building the query // tree). if (CLI::HasParam("leaf_size")) kfn.LeafSize() = size_t(lsInt); Log::Info << "Loaded kFN model from '" << inputModelFile << "' (trained on " << kfn.Dataset().n_rows << "x" << kfn.Dataset().n_cols << " dataset)." << endl; } // Perform search, if desired. if (CLI::HasParam("k")) { const string queryFile = CLI::GetParam<string>("query_file"); const size_t k = (size_t) CLI::GetParam<int>("k"); arma::mat queryData; if (queryFile != "") { data::Load(queryFile, queryData, true); Log::Info << "Loaded query data from '" << queryFile << "' (" << queryData.n_rows << "x" << queryData.n_cols << ")." << endl; } // Sanity check on k value: must be greater than 0, must be less than the // number of reference points. Since it is unsigned, we only test the upper // bound. if (k > kfn.Dataset().n_cols) { Log::Fatal << "Invalid k: " << k << "; must be greater than 0 and less " << "than or equal to the number of reference points (" << kfn.Dataset().n_cols << ")." << endl; } // Now run the search. arma::Mat<size_t> neighbors; arma::mat distances; if (CLI::HasParam("query_file")) kfn.Search(std::move(queryData), k, neighbors, distances); else kfn.Search(k, neighbors, distances); Log::Info << "Search complete." << endl; // Save output, if desired. if (CLI::HasParam("neighbors_file")) data::Save(CLI::GetParam<string>("neighbors_file"), neighbors); if (CLI::HasParam("distances_file")) data::Save(CLI::GetParam<string>("distances_file"), distances); // Calculate the effective error, if desired. if (CLI::HasParam("true_distances_file")) { if (kfn.Epsilon() == 0) Log::Warn << "--true_distances_file (-D) specified, but the search is " << "exact, so there is no need to calculate the error!" << endl; const string trueDistancesFile = CLI::GetParam<string>( "true_distances_file"); arma::mat trueDistances; data::Load(trueDistancesFile, trueDistances, true); if (trueDistances.n_rows != distances.n_rows || trueDistances.n_cols != distances.n_cols) Log::Fatal << "The true distances file must have the same number of " << "values than the set of distances being queried!" << endl; Log::Info << "Effective error: " << KFN::EffectiveError(distances, trueDistances) << endl; } // Calculate the recall, if desired. if (CLI::HasParam("true_neighbors_file")) { if (kfn.Epsilon() == 0) Log::Warn << "--true_neighbors_file (-T) specified, but the search is " << "exact, so there is no need to calculate the recall!" << endl; const string trueNeighborsFile = CLI::GetParam<string>( "true_neighbors_file"); arma::Mat<size_t> trueNeighbors; data::Load(trueNeighborsFile, trueNeighbors, true); if (trueNeighbors.n_rows != neighbors.n_rows || trueNeighbors.n_cols != neighbors.n_cols) Log::Fatal << "The true neighbors file must have the same number of " << "values than the set of neighbors being queried!" << endl; Log::Info << "Recall: " << KFN::Recall(neighbors, trueNeighbors) << endl; } } if (CLI::HasParam("output_model_file")) { const string outputModelFile = CLI::GetParam<string>("output_model_file"); data::Save(outputModelFile, "kfn_model", kfn); } }
static void mlpackMain() { if (CLI::GetParam<int>("seed") != 0) math::RandomSeed((size_t) CLI::GetParam<int>("seed")); else math::RandomSeed((size_t) std::time(NULL)); // A user cannot specify both reference data and a model. RequireOnlyOnePassed({ "reference", "input_model" }, true); ReportIgnoredParam({{ "input_model", true }}, "tree_type"); ReportIgnoredParam({{ "input_model", true }}, "random_basis"); // Notify the user of parameters that will be only be considered for query // tree. if (CLI::HasParam("input_model") && CLI::HasParam("leaf_size")) { Log::Warn << PRINT_PARAM_STRING("leaf_size") << " will only be considered" << " for the query tree, because " << PRINT_PARAM_STRING("input_model") << " is specified." << endl; } // The user should give something to do... RequireAtLeastOnePassed({ "k", "output_model" }, false, "no results will be saved"); // If the user specifies k but no output files, they should be warned. if (CLI::HasParam("k")) { RequireAtLeastOnePassed({ "neighbors", "distances" }, false, "furthest neighbor search results will not be saved"); } // If the user specifies output files but no k, they should be warned. ReportIgnoredParam({{ "k", false }}, "neighbors"); ReportIgnoredParam({{ "k", false }}, "distances"); ReportIgnoredParam({{ "k", false }}, "true_neighbors"); ReportIgnoredParam({{ "k", false }}, "true_distances"); ReportIgnoredParam({{ "k", false }}, "query"); // Sanity check on leaf size. RequireParamValue<int>("leaf_size", [](int x) { return x > 0; }, true, "leaf size must be positive"); const int lsInt = CLI::GetParam<int>("leaf_size"); // Sanity check on epsilon. double epsilon = CLI::GetParam<double>("epsilon"); RequireParamValue<double>("epsilon", [](double x) { return x >= 0.0; }, true, "epsilon must be positive"); // Sanity check on percentage. const double percentage = CLI::GetParam<double>("percentage"); RequireParamValue<double>("percentage", [](double x) { return x > 0.0 && x <= 1.0; }, true, "percentage must be in the range (0, 1]"); ReportIgnoredParam({{ "epsilon", true }}, "percentage"); if (CLI::HasParam("percentage")) epsilon = 1 - percentage; // We either have to load the reference data, or we have to load the model. NSModel<FurthestNeighborSort>* kfn; const string algorithm = CLI::GetParam<string>("algorithm"); RequireParamInSet<string>("algorithm", { "naive", "single_tree", "dual_tree", "greedy" }, true, "unknown neighbor search algorithm"); NeighborSearchMode searchMode = DUAL_TREE_MODE; if (algorithm == "naive") searchMode = NAIVE_MODE; else if (algorithm == "single_tree") searchMode = SINGLE_TREE_MODE; else if (algorithm == "dual_tree") searchMode = DUAL_TREE_MODE; else if (algorithm == "greedy") searchMode = GREEDY_SINGLE_TREE_MODE; if (CLI::HasParam("reference")) { kfn = new KFNModel(); // Get all the parameters. RequireParamInSet<string>("tree_type", { "kd", "cover", "r", "r-star", "ball", "x", "hilbert-r", "r-plus", "r-plus-plus", "vp", "rp", "max-rp", "ub", "oct" }, true, "unknown tree type"); const string treeType = CLI::GetParam<string>("tree_type"); const bool randomBasis = CLI::HasParam("random_basis"); KFNModel::TreeTypes tree = KFNModel::KD_TREE; if (treeType == "kd") tree = KFNModel::KD_TREE; else if (treeType == "cover") tree = KFNModel::COVER_TREE; else if (treeType == "r") tree = KFNModel::R_TREE; else if (treeType == "r-star") tree = KFNModel::R_STAR_TREE; else if (treeType == "ball") tree = KFNModel::BALL_TREE; else if (treeType == "x") tree = KFNModel::X_TREE; else if (treeType == "hilbert-r") tree = KFNModel::HILBERT_R_TREE; else if (treeType == "r-plus") tree = KFNModel::R_PLUS_TREE; else if (treeType == "r-plus-plus") tree = KFNModel::R_PLUS_PLUS_TREE; else if (treeType == "vp") tree = KFNModel::VP_TREE; else if (treeType == "rp") tree = KFNModel::RP_TREE; else if (treeType == "max-rp") tree = KFNModel::MAX_RP_TREE; else if (treeType == "ub") tree = KFNModel::UB_TREE; else if (treeType == "oct") tree = KFNModel::OCTREE; kfn->TreeType() = tree; kfn->RandomBasis() = randomBasis; arma::mat referenceSet = std::move(CLI::GetParam<arma::mat>("reference")); Log::Info << "Using reference data from '" << CLI::GetPrintableParam<arma::mat>("reference") << "' (" << referenceSet.n_rows << "x" << referenceSet.n_cols << ")." << endl; kfn->BuildModel(std::move(referenceSet), size_t(lsInt), searchMode, epsilon); } else { // Load the model from file. kfn = CLI::GetParam<KFNModel*>("input_model"); // Adjust search mode. kfn->SearchMode() = searchMode; kfn->Epsilon() = epsilon; // If leaf_size wasn't provided, let's consider the current value in the // loaded model. Else, update it (only considered when building the query // tree). if (CLI::HasParam("leaf_size")) kfn->LeafSize() = size_t(lsInt); Log::Info << "Using kFN model from '" << CLI::GetPrintableParam<KFNModel*>("input_model") << "' (trained on " << kfn->Dataset().n_rows << "x" << kfn->Dataset().n_cols << " dataset)." << endl; } // Perform search, if desired. if (CLI::HasParam("k")) { const size_t k = (size_t) CLI::GetParam<int>("k"); arma::mat queryData; if (CLI::HasParam("query")) { queryData = std::move(CLI::GetParam<arma::mat>("query")); Log::Info << "Using query data from '" << CLI::GetPrintableParam<arma::mat>("query") << "' (" << queryData.n_rows << "x" << queryData.n_cols << ")." << endl; } // Sanity check on k value: must be greater than 0, must be less than or // equal to the number of reference points. Since it is unsigned, // we only test the upper bound. if (k > kfn->Dataset().n_cols) { Log::Fatal << "Invalid k: " << k << "; must be greater than 0 and less " << "than or equal to the number of reference points (" << kfn->Dataset().n_cols << ")." << endl; } // Sanity check on k value: must not be equal to the number of reference // points when query data has not been provided. if (!CLI::HasParam("query") && k == kfn->Dataset().n_cols) { Log::Fatal << "Invalid k: " << k << "; must be less than the number of " << "reference points (" << kfn->Dataset().n_cols << ") " << "if query data has not been provided." << endl; } // Now run the search. arma::Mat<size_t> neighbors; arma::mat distances; if (CLI::HasParam("query")) kfn->Search(std::move(queryData), k, neighbors, distances); else kfn->Search(k, neighbors, distances); Log::Info << "Search complete." << endl; // Save output. CLI::GetParam<arma::Mat<size_t>>("neighbors") = std::move(neighbors); CLI::GetParam<arma::mat>("distances") = std::move(distances); // Calculate the effective error, if desired. if (CLI::HasParam("true_distances")) { if (kfn->Epsilon() == 0) Log::Warn << PRINT_PARAM_STRING("true_distances") << " specified, but " << "the search is exact, so there is no need to calculate the " << "error!" << endl; arma::mat trueDistances = std::move(CLI::GetParam<arma::mat>("true_distances")); if (trueDistances.n_rows != distances.n_rows || trueDistances.n_cols != distances.n_cols) Log::Fatal << "The true distances file must have the same number of " << "values than the set of distances being queried!" << endl; Log::Info << "Effective error: " << KFN::EffectiveError(distances, trueDistances) << endl; } // Calculate the recall, if desired. if (CLI::HasParam("true_neighbors")) { if (kfn->Epsilon() == 0) Log::Warn << PRINT_PARAM_STRING("true_neighbors") << " specified, but " << "the search is exact, so there is no need to calculate the " << "recall!" << endl; arma::Mat<size_t> trueNeighbors = std::move(CLI::GetParam<arma::Mat<size_t>>("true_neighbors")); if (trueNeighbors.n_rows != neighbors.n_rows || trueNeighbors.n_cols != neighbors.n_cols) Log::Fatal << "The true neighbors file must have the same number of " << "values than the set of neighbors being queried!" << endl; Log::Info << "Recall: " << KFN::Recall(neighbors, trueNeighbors) << endl; } } CLI::GetParam<KFNModel*>("output_model") = kfn; }