示例#1
0
int main(int argc, char *argv[])
{
  // Give CLI the command line parameters the user passed in.
  CLI::ParseCommandLine(argc, argv);

  if (CLI::GetParam<int>("seed") != 0)
    math::RandomSeed((size_t) CLI::GetParam<int>("seed"));
  else
    math::RandomSeed((size_t) time(NULL));

  // Get all the parameters.
  size_t k = CLI::GetParam<int>("k");
  size_t secondHashSize = CLI::GetParam<int>("second_hash_size");
  size_t bucketSize = CLI::GetParam<int>("bucket_size");

  if (CLI::HasParam("input_model") && CLI::HasParam("reference"))
  {
    Log::Fatal << "Cannot specify both --reference_file and --input_model_file!"
        << " Either create a new model with --reference_file or use an existing"
        << " model with --input_model_file." << endl;
  }

  if (!CLI::HasParam("input_model") && !CLI::HasParam("reference"))
  {
    Log::Fatal << "Must specify either --input_model_file or --reference_file!"
        << endl;
  }

  if (!CLI::HasParam("neighbors") && !CLI::HasParam("distances") &&
      !CLI::HasParam("output_model"))
  {
    Log::Warn << "Neither --neighbors_file, --distances_file, nor "
        << "--output_model_file are specified; no results will be saved."
        << endl;
  }

  if ((CLI::HasParam("query") && !CLI::HasParam("k")) ||
      (!CLI::HasParam("query") && !CLI::HasParam("reference") &&
       CLI::HasParam("k")))
  {
    Log::Fatal << "Both --query_file or --reference_file and --k must be "
        << "specified if search is to be done!" << endl;
  }

  if (CLI::HasParam("input_model") && CLI::HasParam("k") &&
      !CLI::HasParam("query"))
  {
    Log::Info << "Performing LSH-based approximate nearest neighbor search on "
        << "the reference dataset in the model stored in '"
        << CLI::GetUnmappedParam<LSHSearch<>>("input_model") << "'." << endl;
  }

  if (!CLI::HasParam("k") && CLI::HasParam("neighbors"))
    Log::Warn << "--neighbors_file ignored because --k is not specified."
        << endl;

  if (!CLI::HasParam("k") && CLI::HasParam("distances"))
    Log::Warn << "--distances_file ignored because --k is not specified."
        << endl;

  // These declarations are here so that the matrices don't go out of scope.
  arma::mat referenceData;
  arma::mat queryData;

  // Pick up the LSH-specific parameters.
  const size_t numProj = CLI::GetParam<int>("projections");
  const size_t numTables = CLI::GetParam<int>("tables");
  const double hashWidth = CLI::GetParam<double>("hash_width");
  const size_t numProbes = (size_t) CLI::GetParam<int>("num_probes");

  arma::Mat<size_t> neighbors;
  arma::mat distances;

  if (hashWidth == 0.0)
    Log::Info << "Using LSH with " << numProj << " projections (K) and " <<
        numTables << " tables (L) with default hash width." << endl;
  else
    Log::Info << "Using LSH with " << numProj << " projections (K) and " <<
        numTables << " tables (L) with hash width(r): " << hashWidth << endl;

  LSHSearch<> allkann;
  if (CLI::HasParam("reference"))
  {
    referenceData = std::move(CLI::GetParam<arma::mat>("reference"));
    Log::Info << "Loaded reference data from '"
        << CLI::GetUnmappedParam<arma::mat>("reference") << "' ("
        << referenceData.n_rows << " x " << referenceData.n_cols << ")."
        << endl;

    Timer::Start("hash_building");
    allkann.Train(referenceData, numProj, numTables, hashWidth, secondHashSize,
        bucketSize);
    Timer::Stop("hash_building");
  }
  else if (CLI::HasParam("input_model"))
  {
    allkann = std::move(CLI::GetParam<LSHSearch<>>("input_model"));
  }

  if (CLI::HasParam("k"))
  {
    Log::Info << "Computing " << k << " distance approximate nearest neighbors."
        << endl;
    if (CLI::HasParam("query"))
    {
      queryData = std::move(CLI::GetParam<arma::mat>("query"));
      Log::Info << "Loaded query data from '"
          << CLI::GetUnmappedParam<arma::mat>("query") << "' ("
          << queryData.n_rows << " x " << queryData.n_cols << ")." << endl;

      allkann.Search(queryData, k, neighbors, distances, 0, numProbes);
    }
    else
    {
      allkann.Search(k, neighbors, distances, 0, numProbes);
    }
  }

  Log::Info << "Neighbors computed." << endl;

  // Compute recall, if desired.
  if (CLI::HasParam("true_neighbors"))
  {
    // Load the true neighbors.
    arma::Mat<size_t> trueNeighbors =
        std::move(CLI::GetParam<arma::Mat<size_t>>("true_neighbors"));
    Log::Info << "Loaded true neighbor indices from '"
        << CLI::GetUnmappedParam<arma::Mat<size_t>>("true_neighbors") << "'."
        << endl;

    // Compute recall and print it.
    double recallPercentage = 100 * allkann.ComputeRecall(neighbors,
        trueNeighbors);

    Log::Info << "Recall: " << recallPercentage << endl;
  }

  // Save output, if desired.
  if (CLI::HasParam("distances"))
    CLI::GetParam<arma::mat>("distances") = std::move(distances);
  if (CLI::HasParam("neighbors"))
    CLI::GetParam<arma::Mat<size_t>>("neighbors") = std::move(neighbors);
  if (CLI::HasParam("output_model"))
    CLI::GetParam<LSHSearch<>>("output_model") = std::move(allkann);

  CLI::Destroy();
}
示例#2
0
int main(int argc, char *argv[])
{
  // Give CLI the command line parameters the user passed in.
  CLI::ParseCommandLine(argc, argv);

  if (CLI::GetParam<int>("seed") != 0)
    math::RandomSeed((size_t) CLI::GetParam<int>("seed"));
  else
    math::RandomSeed((size_t) time(NULL));

  // Get all the parameters.
  string referenceFile = CLI::GetParam<string>("reference_file");
  string distancesFile = CLI::GetParam<string>("distances_file");
  string neighborsFile = CLI::GetParam<string>("neighbors_file");

  size_t k = CLI::GetParam<int>("k");
  size_t secondHashSize = CLI::GetParam<int>("second_hash_size");
  size_t bucketSize = CLI::GetParam<int>("bucket_size");

  arma::mat referenceData;
  arma::mat queryData; // So it doesn't go out of scope.
  data::Load(referenceFile, referenceData, true);

  Log::Info << "Loaded reference data from '" << referenceFile << "' ("
      << referenceData.n_rows << " x " << referenceData.n_cols << ")." << endl;

  // Sanity check on k value: must be greater than 0, must be less than the
  // number of reference points.
  if (k > referenceData.n_cols)
  {
    Log::Fatal << "Invalid k: " << k << "; must be greater than 0 and less ";
    Log::Fatal << "than or equal to the number of reference points (";
    Log::Fatal << referenceData.n_cols << ")." << endl;
  }

  // Pick up the LSH-specific parameters.
  const size_t numProj = CLI::GetParam<int>("projections");
  const size_t numTables = CLI::GetParam<int>("tables");
  const double hashWidth = CLI::GetParam<double>("hash_width");

  arma::Mat<size_t> neighbors;
  arma::mat distances;

  if (CLI::GetParam<string>("query_file") != "")
  {
    string queryFile = CLI::GetParam<string>("query_file");

    data::Load(queryFile, queryData, true);
    Log::Info << "Loaded query data from '" << queryFile << "' ("
              << queryData.n_rows << " x " << queryData.n_cols << ")." << endl;
  }

  if (hashWidth == 0.0)
    Log::Info << "Using LSH with " << numProj << " projections (K) and " <<
        numTables << " tables (L) with default hash width." << endl;
  else
    Log::Info << "Using LSH with " << numProj << " projections (K) and " <<
        numTables << " tables (L) with hash width(r): " << hashWidth << endl;

  Timer::Start("hash_building");

  LSHSearch<>* allkann;

  if (CLI::GetParam<string>("query_file") != "")
    allkann = new LSHSearch<>(referenceData, queryData, numProj, numTables,
                              hashWidth, secondHashSize, bucketSize);
  else
    allkann = new LSHSearch<>(referenceData, numProj, numTables, hashWidth,
                              secondHashSize, bucketSize);

  Timer::Stop("hash_building");

  Log::Info << "Computing " << k << " distance approximate nearest neighbors "
      << endl;
  allkann->Search(k, neighbors, distances);

  Log::Info << "Neighbors computed." << endl;

  // Save output.
  if (distancesFile != "")
    data::Save(distancesFile, distances);

  if (neighborsFile != "")
    data::Save(neighborsFile, neighbors);

  delete allkann;
}