コード例 #1
0
ファイル: allknn.cpp プロジェクト: zoq/ML-Libraries
int main(int argc, char** argv)
{
    // Contains query points.
    arma::mat query(2,3);
    query << 5 << 6 << 2.75 << arma::endr << 1.45 << 2 << 0.75 << arma::endr;

    // Load the data.
    arma::mat data;
    data::Load("../data/fisheriris_data.csv", data, true);
    // Choose the last two columns.
    data = data.rows(2,3);    
    
    std::vector<size_t> oldFromNewRefs;
    std::vector<size_t> oldFromNewQueries;
    arma::mat distancesOut;
    arma::Mat<size_t> neighborsOut;
    
    // Calculate the all 2-nearest-neighbors .
    BinarySpaceTree<bound::HRectBound<2>, QueryStat<NearestNeighborSort> >
    refTree(data, oldFromNewRefs, 20);
    
    BinarySpaceTree<bound::HRectBound<2>, QueryStat<NearestNeighborSort> >*
    queryTree = new BinarySpaceTree<bound::HRectBound<2>,
    QueryStat<NearestNeighborSort> >(query, oldFromNewQueries, 20);
    
    AllkNN* allknn = new AllkNN(&refTree, queryTree, data, query, false);    
    allknn->Search(2, neighborsOut, distancesOut);
    
    arma::Mat<size_t> neighbors;
    arma::mat distances;
    
    // Map the results back to the correct places.
    Unmap(neighborsOut, distancesOut, oldFromNewRefs, oldFromNewQueries,
          neighbors, distances);
    
    // Clean up.
    delete queryTree;
    delete allknn;
    
    // Show the results.
    Log::cout << distances.t() << std::endl;
    Log::cout << neighbors.t() << std::endl;

    return 0;
}
コード例 #2
0
ファイル: allknn_main.cpp プロジェクト: dblalock/mlpack-ios
int main(int argc, char *argv[])
{
  // Give CLI the command line parameters the user passed in.
  CLI::ParseCommandLine(argc, argv);

  if (CLI::GetParam<int>("seed") != 0)
    math::RandomSeed((size_t) CLI::GetParam<int>("seed"));
  else
    math::RandomSeed((size_t) std::time(NULL));

  // Get all the parameters.
  const string referenceFile = CLI::GetParam<string>("reference_file");
  const string queryFile = CLI::GetParam<string>("query_file");

  const string distancesFile = CLI::GetParam<string>("distances_file");
  const string neighborsFile = CLI::GetParam<string>("neighbors_file");

  int lsInt = CLI::GetParam<int>("leaf_size");

  size_t k = CLI::GetParam<int>("k");

  bool naive = CLI::HasParam("naive");
  bool singleMode = CLI::HasParam("single_mode");
  const bool randomBasis = CLI::HasParam("random_basis");

  arma::mat referenceData;
  arma::mat queryData; // So it doesn't go out of scope.
  data::Load(referenceFile, referenceData, true);

  Log::Info << "Loaded reference data from '" << referenceFile << "' ("
      << referenceData.n_rows << " x " << referenceData.n_cols << ")." << endl;

  if (queryFile != "")
  {
    data::Load(queryFile, queryData, true);
    Log::Info << "Loaded query data from '" << queryFile << "' ("
      << queryData.n_rows << " x " << queryData.n_cols << ")." << endl;
  }

  // Sanity check on k value: must be greater than 0, must be less than the
  // number of reference points.
  if (k > referenceData.n_cols)
  {
    Log::Fatal << "Invalid k: " << k << "; must be greater than 0 and less ";
    Log::Fatal << "than or equal to the number of reference points (";
    Log::Fatal << referenceData.n_cols << ")." << endl;
  }

  // Sanity check on leaf size.
  if (lsInt < 0)
  {
    Log::Fatal << "Invalid leaf size: " << lsInt << ".  Must be greater "
        "than or equal to 0." << endl;
  }
  size_t leafSize = lsInt;

  // Naive mode overrides single mode.
  if (singleMode && naive)
  {
    Log::Warn << "--single_mode ignored because --naive is present." << endl;
  }

  if (naive)
    leafSize = referenceData.n_cols;

  // See if we want to project onto a random basis.
  if (randomBasis)
  {
    // Generate the random basis.
    while (true)
    {
      // [Q, R] = qr(randn(d, d));
      // Q = Q * diag(sign(diag(R)));
      arma::mat q, r;
      if (arma::qr(q, r, arma::randn<arma::mat>(referenceData.n_rows,
          referenceData.n_rows)))
      {
        arma::vec rDiag(r.n_rows);
        for (size_t i = 0; i < rDiag.n_elem; ++i)
        {
          if (r(i, i) < 0)
            rDiag(i) = -1;
          else if (r(i, i) > 0)
            rDiag(i) = 1;
          else
            rDiag(i) = 0;
        }

        q *= arma::diagmat(rDiag);

        // Check if the determinant is positive.
        if (arma::det(q) >= 0)
        {
          referenceData = q * referenceData;
          if (queryFile != "")
            queryData = q * queryData;
          break;
        }
      }
    }
  }

  arma::Mat<size_t> neighbors;
  arma::mat distances;

  if (!CLI::HasParam("cover_tree"))
  {
    // Because we may construct it differently, we need a pointer.
    AllkNN* allknn = NULL;

    // Mappings for when we build the tree.
    std::vector<size_t> oldFromNewRefs;

    // Build trees by hand, so we can save memory: if we pass a tree to
    // NeighborSearch, it does not copy the matrix.
    Log::Info << "Building reference tree..." << endl;
    Timer::Start("tree_building");

    BinarySpaceTree<bound::HRectBound<2>,
        NeighborSearchStat<NearestNeighborSort> >
        refTree(referenceData, oldFromNewRefs, leafSize);
    BinarySpaceTree<bound::HRectBound<2>,
        NeighborSearchStat<NearestNeighborSort> >*
        queryTree = NULL; // Empty for now.

    Timer::Stop("tree_building");

    std::vector<size_t> oldFromNewQueries;

    if (CLI::GetParam<string>("query_file") != "")
    {
      if (naive && leafSize < queryData.n_cols)
        leafSize = queryData.n_cols;

      Log::Info << "Loaded query data from '" << queryFile << "' ("
          << queryData.n_rows << " x " << queryData.n_cols << ")." << endl;

      Log::Info << "Building query tree..." << endl;

      // Build trees by hand, so we can save memory: if we pass a tree to
      // NeighborSearch, it does not copy the matrix.
      if (!singleMode)
      {
        Timer::Start("tree_building");

        queryTree = new BinarySpaceTree<bound::HRectBound<2>,
            NeighborSearchStat<NearestNeighborSort> >(queryData,
            oldFromNewQueries, leafSize);

        Timer::Stop("tree_building");
      }

      allknn = new AllkNN(&refTree, queryTree, referenceData, queryData,
          singleMode);

      Log::Info << "Tree built." << endl;
    }
    else
    {
      allknn = new AllkNN(&refTree, referenceData, singleMode);

      Log::Info << "Trees built." << endl;
    }

    arma::mat distancesOut;
    arma::Mat<size_t> neighborsOut;

    Log::Info << "Computing " << k << " nearest neighbors..." << endl;
    allknn->Search(k, neighborsOut, distancesOut);

    Log::Info << "Neighbors computed." << endl;

    // We have to map back to the original indices from before the tree
    // construction.
    Log::Info << "Re-mapping indices..." << endl;

    // Map the results back to the correct places.
    if ((CLI::GetParam<string>("query_file") != "") && !singleMode)
      Unmap(neighborsOut, distancesOut, oldFromNewRefs, oldFromNewQueries,
          neighbors, distances);
    else if ((CLI::GetParam<string>("query_file") != "") && singleMode)
      Unmap(neighborsOut, distancesOut, oldFromNewRefs, neighbors, distances);
    else
      Unmap(neighborsOut, distancesOut, oldFromNewRefs, oldFromNewRefs,
          neighbors, distances);

    // Clean up.
    if (queryTree)
      delete queryTree;

    delete allknn;
  }
  else // Cover trees.
  {
    // Make sure to notify the user that they are using cover trees.
    Log::Info << "Using cover trees for nearest-neighbor calculation." << endl;

    // Build our reference tree.
    Log::Info << "Building reference tree..." << endl;
    Timer::Start("tree_building");
    CoverTree<metric::LMetric<2, true>, tree::FirstPointIsRoot,
        NeighborSearchStat<NearestNeighborSort> > referenceTree(referenceData,
        1.3);
    CoverTree<metric::LMetric<2, true>, tree::FirstPointIsRoot,
        NeighborSearchStat<NearestNeighborSort> >* queryTree = NULL;
    Timer::Stop("tree_building");

    NeighborSearch<NearestNeighborSort, metric::LMetric<2, true>,
        CoverTree<metric::LMetric<2, true>, tree::FirstPointIsRoot,
        NeighborSearchStat<NearestNeighborSort> > >* allknn = NULL;

    // See if we have query data.
    if (CLI::HasParam("query_file"))
    {
      // Build query tree.
      if (!singleMode)
      {
        Log::Info << "Building query tree..." << endl;
        Timer::Start("tree_building");
        queryTree = new CoverTree<metric::LMetric<2, true>,
            tree::FirstPointIsRoot, NeighborSearchStat<NearestNeighborSort> >(
            queryData, 1.3);
        Timer::Stop("tree_building");
      }

      allknn = new NeighborSearch<NearestNeighborSort, metric::LMetric<2, true>,
          CoverTree<metric::LMetric<2, true>, tree::FirstPointIsRoot,
          NeighborSearchStat<NearestNeighborSort> > >(&referenceTree, queryTree,
          referenceData, queryData, singleMode);
    }
    else
    {
      allknn = new NeighborSearch<NearestNeighborSort, metric::LMetric<2, true>,
          CoverTree<metric::LMetric<2, true>, tree::FirstPointIsRoot,
          NeighborSearchStat<NearestNeighborSort> > >(&referenceTree,
          referenceData, singleMode);
    }

    Log::Info << "Computing " << k << " nearest neighbors..." << endl;
    allknn->Search(k, neighbors, distances);

    Log::Info << "Neighbors computed." << endl;

    delete allknn;

    if (queryTree)
      delete queryTree;
  }

  // Save output.
  data::Save(distancesFile, distances);
  data::Save(neighborsFile, neighbors);
}