Пример #1
0
/**
 * binary to txt
 */
int Format0(int argc, char** argv)
{
  assert(argv[1][0] == '0' && argv[1][1] == 0);
  if (argc != 6) {
    cout
        << "-- For type 0, read txt file and output binary file, [size] and [dimension] is essential"
        << endl;
    cout << "   " << argv[0] << " 0 txt_file_in binary_file_out size dimension"
        << endl;
    return -1;
  }

  int size = atoi(argv[4]);
  int dimension = atoi(argv[5]);

  Points<DefaultDataTypes> points;
  points.InitializeFromFile(argv[2], size, dimension);
  points.SavePoints(argv[3]);

  return 0;
}
Пример #2
0
int main(int argc, char** argv)
{
  typedef typename DefaultDataTypes::Value ValueType;
  typedef typename DefaultDataTypes::Dist DistType;
  typedef typename DefaultDataTypes::Index IndexType;
  typedef typename DefaultDataTypes::Dim DimType;
  const size_t BitLength = DefaultHashTypes::BitLength;

  if (argc < 2) {
    cout << "Usage: " << argv[0]
        << " config_file_name [config_key=config_value ...]" << endl;
    return -1;
  }
  //load config from config file
  Config config(argv[1]);

  //load config from argv, note that this may override the key_value of the original
  for (int i = 2; i < argc; i++) {
    string k_v(argv[i]);
    size_t pos = k_v.find("=");
    if (pos == string::npos) {
      cout << "Unrecognized arg:" << k_v << endl;
      cout << "Usage: " << argv[0]
          << " config_file_name [config_key=config_value ...]" << endl;
      return -1;
    } else {
      string key = k_v.substr(0, pos);
      string value = k_v.substr(pos + 1);
      config.Add(key, value);
    }
  }

  if (config.Read<bool>(kShowConfigKey)) {
    cout << "==============config content=============" << endl;
    cout << config;
    cout << "=========================================" << endl;
  }

#ifdef NNPLUS_DEBUG
  std::cout << "DEBUG mode on" << std::endl;
#else
  std::cout << "DEBUG mode off" << std::endl;
#endif

#ifdef USE_PARALLELIZATION
  //control threads used in the whole program
  int open_MPI_threads = config.Read<int>(kOpenMPIThreadsNumKey);
  if (open_MPI_threads != -1 && open_MPI_threads > 0) {
    omp_set_num_threads(open_MPI_threads);
    cout << "omp_set_num_threads:" << open_MPI_threads << endl;
  } else {
    cout << "use default omp_threads" << endl;
  }
#endif

  unsigned int random_seed = config.Read<unsigned int>(kRandomSeedKey);
  srand(random_seed);

  Stopwatch timer("");
  timer.Reset();
  timer.Start();

  bool data_format_binary_flag = config.Read<bool>(kDataFormatBinaryKey);
  Points<DefaultDataTypes> dps;
  string input_data_file_name = config.Read<string>(kDataFileNameKey);
  if (data_format_binary_flag) {
    dps.LoadPoints(input_data_file_name.c_str());
  } else {
    dps.InitializeFromFile(input_data_file_name.c_str(),
        config.Read<IndexType>(kTextDataSizeKey),
        config.Read<DimType>(kTextDataDimKey));
  }
  Points<DefaultDataTypes> qps;
  string input_query_file_name = config.Read<string>(kQueryFileNameKey);
  if (data_format_binary_flag) {
    qps.LoadPoints(input_query_file_name.c_str());
  } else {
    qps.InitializeFromFile(input_query_file_name.c_str(),
        config.Read<IndexType>(kTextQuerySizeKey),
        config.Read<DimType>(kTextQueryDimKey));
  }
  assert(dps.dim_ == qps.dim_);
  cout << "- Reading Data Finished (" << timer.GetTime() << " seconds)" << endl;

  // check expansion and save kg_load time
  SearchEngine<DefaultHashTypes, DefaultDataTypes>::BuildParams bp;
  string search_algorithm = config.Read<string>(kSearchAlgorithm);
  if (search_algorithm == "HashHKM") {
    bp.algorithm = SearchEngine<DefaultHashTypes, DefaultDataTypes>::HashHKM;
    bp.forest_size = config.Read<size_t>(kForestTreeCountKey);
    bp.tree_height = config.Read<size_t>(kTreeHeightKey);
    bp.branches = config.ReadVector<size_t>(kTreeBranchesKey);
  } else if (search_algorithm == "HGNNS") {
    bp.algorithm = SearchEngine<DefaultHashTypes, DefaultDataTypes>::HGNNS;
    bp.k_neighbor = config.Read<size_t>(kKnnGraphNeighborKey);
    bp.forest_size = config.Read<size_t>(kForestTreeCountKey);
    bp.tree_height = config.Read<size_t>(kTreeHeightKey);
    bp.branches = config.ReadVector<size_t>(kTreeBranchesKey);
    bp.keep_branches = config.ReadVector<size_t>(kKeepBranchesKey);
    bp.hamming_keep_branches = config.ReadVector<size_t>(
        kHammingKeepBranchesKey);
  } else if (search_algorithm == "MGNNS") {
    bp.algorithm = SearchEngine<DefaultHashTypes, DefaultDataTypes>::MGNNS;
    bp.k_neighbor = config.Read<size_t>(kKnnGraphNeighborKey);
    bp.forest_size = config.Read<size_t>(kForestTreeCountKey);
    bp.tree_height = config.Read<size_t>(kTreeHeightKey);
    bp.branches = config.ReadVector<size_t>(kTreeBranchesKey);
    bp.keep_branches = config.ReadVector<size_t>(kKeepBranchesKey);
    bp.hamming_keep_branches = config.ReadVector<size_t>(
        kHammingKeepBranchesKey);
  } else if (search_algorithm == "HashScan") {
    bp.algorithm = SearchEngine<DefaultHashTypes, DefaultDataTypes>::HashScan;
  } else {
    throw runtime_error("Undefined Search Algorithm");
  }
  SearchEngine<DefaultHashTypes, DefaultDataTypes>::SearchParams sp;
  if (bp.algorithm
      == SearchEngine<DefaultHashTypes, DefaultDataTypes>::HashHKM) {
    sp.keep_branches = config.ReadVector<size_t>(kKeepBranchesKey);
    sp.hamming_keep_branches = config.ReadVector<size_t>(
        kHammingKeepBranchesKey);
    sp.real_distance = ComputeEuclideanDistance<ValueType, DistType>;
    sp.hamming_distance = ComputeHammingDistance<BitLength>;
  } else if (bp.algorithm
      == SearchEngine<DefaultHashTypes, DefaultDataTypes>::HGNNS) {
    sp.restart = config.Read<size_t>(kGnnsRestartKey);
    sp.expansion = config.Read<size_t>(kGnnsExpansionKey);
    sp.hash_expansion = config.Read<size_t>(kGnnsHashExpansionKey);
    sp.max_expansion = config.Read<size_t>(kGnnsMaxExpansionKey);
    sp.greedy_step = config.Read<size_t>(kGnnsGreedyStepKey);
    sp.real_distance = ComputeEuclideanDistance<ValueType, DistType>;
    sp.hamming_distance = ComputeHammingDistance<BitLength>;
  } else if (bp.algorithm
      == SearchEngine<DefaultHashTypes, DefaultDataTypes>::MGNNS) {
    sp.restart = config.Read<size_t>(kGnnsRestartKey);
    sp.expansion = config.Read<size_t>(kGnnsExpansionKey);
    sp.max_expansion = config.Read<size_t>(kGnnsMaxExpansionKey);
    sp.greedy_step = config.Read<size_t>(kGnnsGreedyStepKey);
    sp.real_distance = ComputeEuclideanDistance<ValueType, DistType>;
  } else if (bp.algorithm
      == SearchEngine<DefaultHashTypes, DefaultDataTypes>::HashScan) {
    // do nothing
  } else {
    throw runtime_error("Undefined Search Algorithm");
  }

  size_t knn = config.Read<size_t>(kNearKey);

  timer.Reset();
  timer.Start();
  // load ground truth
  Groundtruth<DefaultDataTypes> groundtruth;
  groundtruth.Initialize(dps, qps, knn,
      &ComputeEuclideanDistance<ValueType, DistType>);

  string gt_file_name_prefix = config.Read<string>(
      kGroundtruthFileNamePrefixKey);
  const size_t kMaxFileNameLength = 256;
  char gt_file_name[kMaxFileNameLength];
  sprintf(gt_file_name, "%s_d%d_q%d_k%d", gt_file_name_prefix.c_str(),
      (int) dps.size_, (int) qps.size_, (int) knn);
  FILE *gt_file = fopen(gt_file_name, "rb");
  if (gt_file != NULL) {
    std::cout << "-- Groundtruth file exists, " << gt_file_name << std::endl;
    std::cout << "-- Loading Groundtruth ..." << std::endl;
    groundtruth.Load(gt_file);
    fclose(gt_file);
  } else {
    std::cout << "-- Groundtruth file not exists, " << gt_file_name
        << std::endl;
    std::cout << "-- Building Groundtruth ..." << std::endl;
    groundtruth.Build();
    std::cout << "-- Saving Groundtruth to disk..." << std::endl;
    gt_file = fopen(gt_file_name, "wb");
    groundtruth.Save(gt_file);
    fclose(gt_file);
  }

  cout << "- GroundTruth Finished (" << timer.GetTime() << " seconds)" << endl;
  SearchEngine<DefaultHashTypes, DefaultDataTypes> se(bp, sp);
  se.load_data(&dps);

  string se_filename = config.Read<string>(kSearchEngineFile);
  FILE *se_file = fopen(se_filename.c_str(), "rb");
  if (se_file == NULL) {
    cout << "- SearchEngine BuildIndex Start" << endl;
    timer.Reset();
    timer.Start();
    se.build_index();
    se.save_index(se_filename);
    cout << "- SearchEngine BuildIndex Finished (" << timer.GetTime()
        << " seconds)" << endl;
  } else {
    fclose(se_file);
    timer.Reset();
    timer.Start();
    se.load_index(se_filename);
    cout << "- SearchEngine StartEngine Finished (" << timer.GetTime()
        << " seconds)" << endl;
  }

  srand(random_seed);

  cout << "- SearchEngine Search Start" << endl;
  timer.Reset();
  timer.Start();

  std::vector<IndexType>* indices = new std::vector<IndexType>[qps.size_];
  std::vector<DistType>* dists = new std::vector<DistType>[qps.size_];

#ifdef USE_PARALLELIZATION
#pragma omp parallel for
#endif
  for (IndexType i = 0; i < qps.size_; i++) {
    //    se.Search(qps.d_[i], sp, results[i]);
    se.search(qps.d_[i], knn, indices[i], dists[i]);
  }
  timer.Stop();
  cout << "- SearchEngine Search Finished (" << timer.GetTime()
      << " seconds, with " << qps.size_ << " queries)" << endl;

  const double ZERO = 1e-6;
  double hit_rate = 0;
  for (size_t i = 0; i < qps.size_; i++) {
    size_t hit = 0;
    for (size_t j = 0; j < indices[i].size(); j++) {
      //for (size_t j = 0; j < sp.k; j++) {
      if (j > 0) {
        assert(dists[i][j] >= dists[i][j - 1]);
      }
      for (size_t k = 0; k < knn; k++) {
        if (indices[i][j] == groundtruth.GetIndexes()[i][k]) {
          hit++;
          break;
        } else if (abs(dists[i][j] - groundtruth.GetDists()[i][k]) <= ZERO) {
          hit++;
          break;
        }
      }
    }
    hit_rate += 1.0 * hit / knn;
  }
  hit_rate /= qps.size_;

  cout.setf(ios::fixed);
  cout << setprecision(3) << hit_rate << ",\t" << (timer.GetTime()) << endl;

  delete[] dists;
  delete[] indices;
}
Пример #3
0
int main(int argc, char** argv)
{
  typedef typename DefaultDataTypes::Value ValueType;
  typedef typename DefaultDataTypes::Dist DistType;
  typedef typename DefaultDataTypes::Index IndexType;
  typedef typename DefaultDataTypes::Dim DimType;

  if (argc < 2) {
    cout << "Usage: " << argv[0]
        << " config_file_name [config_key=config_value ...]" << endl;
    return -1;
  }
  //load config from config file
  Config config(argv[1]);

  //load config from argv, note that this may override the key_value of the original
  for (int i = 2; i < argc; i++) {
    string k_v(argv[i]);
    size_t pos = k_v.find("=");
    if (pos == string::npos) {
      cout << "Unrecognized arg:" << k_v << endl;
      cout << "Usage: " << argv[0]
          << " config_file_name [config_key=config_value ...]" << endl;
      return -1;
    } else {
      string key = k_v.substr(0, pos);
      string value = k_v.substr(pos + 1);
      config.Add(key, value);
    }
  }

  if (config.Read<bool>(kShowConfigKey)) {
    cout << "==============config content=============" << endl;
    cout << config;
    cout << "=========================================" << endl;
  }

  unsigned int random_seed = config.Read<unsigned int>(kRandomSeedKey);
  srand(random_seed);

  Stopwatch timer("");
  timer.Reset();
  timer.Start();

  Points<DefaultDataTypes> dps;
  string input_data_file_name = config.Read<string>(kDataFileNameKey);
  bool data_format_binary_flag = config.Read<bool>(kDataFormatBinaryKey);
  if (data_format_binary_flag) {
    dps.LoadPoints(input_data_file_name.c_str());
  } else {
    dps.InitializeFromFile(input_data_file_name.c_str(),
        config.Read<IndexType>(kTextDataSizeKey),
        config.Read<DimType>(kTextDataDimKey));
  }

  Points<DefaultDataTypes> qps;
  string input_query_file_name = config.Read<string>(kQueryFileNameKey);
  if (data_format_binary_flag) {
    qps.LoadPoints(input_query_file_name.c_str());
  } else {
    qps.InitializeFromFile(input_query_file_name.c_str(),
        config.Read<IndexType>(kTextQuerySizeKey),
        config.Read<DimType>(kTextQueryDimKey));
  }
  assert(dps.dim_ == qps.dim_);

  cout << "- Reading Data Finished (" << timer.GetTime() << " seconds)" << endl;
  if (config.Read<bool>(kSavePointsKey)) {
    cout << "Saving data points to "
        << config.Read<string>(kSaveDataPointsFileName) << endl;
    dps.SavePoints(config.Read<string>(kSaveDataPointsFileName).c_str());
    cout << "Saving query points to "
        << config.Read<string>(kSaveQueryPointsFileName) << endl;
    qps.SavePoints(config.Read<string>(kSaveQueryPointsFileName).c_str());
  }

  size_t knn = config.Read<size_t>(kNearKey);

  timer.Reset();
  timer.Start();
  // load ground truth
  Groundtruth<DefaultDataTypes> groundtruth;
  groundtruth.Initialize(dps, qps, knn,
      &ComputeEuclideanDistance<ValueType, DistType>);

  string gt_file_name_prefix = config.Read<string>(
      kGroundtruthFileNamePrefixKey);
  const size_t kMaxFileNameLength = 256;
  char gt_file_name[kMaxFileNameLength];
  sprintf(gt_file_name, "%s_d%d_q%d_k%d", gt_file_name_prefix.c_str(),
      (int) dps.size_, (int) qps.size_, (int) knn);
  FILE *gt_file = fopen(gt_file_name, "rb");
  if (gt_file != NULL) {
    std::cout << "-- Groundtruth file exists, " << gt_file_name << std::endl;
    std::cout << "-- Loading Groundtruth ..." << std::endl;
    groundtruth.Load(gt_file);
    fclose(gt_file);
  } else {
    std::cout << "-- Groundtruth file not exists, " << gt_file_name
        << std::endl;
    std::cout << "-- Building Groundtruth ..." << std::endl;
    groundtruth.Build();
    std::cout << "-- Saving Groundtruth to disk..." << std::endl;
    gt_file = fopen(gt_file_name, "wb");
    groundtruth.Save(gt_file);
    fclose(gt_file);
  }
  cout << "- GroundTruth Finished (" << timer.GetTime() << " seconds)" << endl;

  size_t repeat_count = config.Read<size_t>(kRepeatCountKey);

  size_t kg_max_expansion = 0;
  if (config.Read<bool>(kBatchTestKey)) { // Batch test
    vector<size_t> expansions = config.ReadVector<size_t>(kGnnsExpansionsKey);
    vector<size_t> max_expansions = config.ReadVector<size_t>(
        kGnnsMaxExpansionsKey);
    assert(expansions.size() == max_expansions.size());
    vector<size_t>::iterator me_it = max_expansions.begin();
    for (vector<size_t>::iterator e_it = expansions.begin();
        e_it != expansions.end(); ++me_it, ++e_it) {
      assert(me_it != max_expansions.end());
      size_t expansion = *e_it;
      size_t max_expansion = *me_it;
      assert(max_expansion >= expansion);
      if (max_expansion > kg_max_expansion) {
        kg_max_expansion = max_expansion;
      }
    }
  } else {
    kg_max_expansion = config.Read<size_t>(kGnnsMaxExpansionKey);
  }

  timer.Reset();
  timer.Start();
  string kg_file_name = config.Read<string>(kKnnGraphFileName);
  KnnGraph<DefaultDataTypes> kg(kg_file_name.c_str(), kg_max_expansion);
  cout << "KnnGraph Loaded (" << timer.GetTime() << " seconds)" << endl;
  Gnns<DefaultDataTypes> gnns(dps, kg);

  cout << "== Start Test ==" << endl;
  if (config.Read<bool>(kBatchTestKey)) { // Batch test
    vector<size_t> restarts = config.ReadVector<size_t>(kGnnsRestartsKey);
    vector<size_t> expansions = config.ReadVector<size_t>(kGnnsExpansionsKey);
    vector<size_t> max_expansions = config.ReadVector<size_t>(
        kGnnsMaxExpansionsKey);
    assert(expansions.size() == max_expansions.size());
    vector<size_t> greedy_steps = config.ReadVector<size_t>(
        kGnnsGreedyStepsKey);
    for (vector<size_t>::iterator r_it = restarts.begin();
        r_it != restarts.end(); ++r_it) {
      size_t restart = *r_it;
      vector<size_t>::iterator me_it = max_expansions.begin();
      for (vector<size_t>::iterator e_it = expansions.begin();
          e_it != expansions.end(); ++me_it, ++e_it) {
        assert(me_it != max_expansions.end());
        size_t expansion = *e_it;
        size_t max_expansion = *me_it;
        double last_hit_rate = -1;
        for (vector<size_t>::iterator gs_it = greedy_steps.begin();
            gs_it != greedy_steps.end(); ++gs_it) {
          srand(random_seed);
          size_t greedy_step = *gs_it;
          double hit_rate = test(knn, qps, groundtruth, gnns, restart,
              expansion, max_expansion, greedy_step, repeat_count);
          if (abs(last_hit_rate - hit_rate) < 0.003) {
//          if (abs(last_hit_rate - hit_rate) < 0.003 && hit_rate > 0.95) {
            break;
          }
          last_hit_rate = hit_rate;
        }
      }
    }
  } else {
    size_t restart = config.Read<size_t>(kGnnsRestartKey);
    size_t expansion = config.Read<size_t>(kGnnsExpansionKey);
    size_t max_expansion = config.Read<size_t>(kGnnsMaxExpansionKey);
    size_t greedy_step = config.Read<size_t>(kGnnsGreedyStepKey);
    srand(random_seed);
    test(knn, qps, groundtruth, gnns, restart, expansion, max_expansion,
        greedy_step, repeat_count);
  }
  cout << "== Finish Test ==" << endl;

  kg.FreeGraph();
}