Example #1
0
int main(int argc, char** argv)
{
  CLI::ParseCommandLine(argc, argv);

  // Check input parameters for validity.
  const string numericSplitStrategy =
      CLI::GetParam<string>("numeric_split_strategy");

  if ((CLI::HasParam("predictions") || CLI::HasParam("probabilities")) &&
       !CLI::HasParam("test"))
    Log::Fatal << "--test_file must be specified if --predictions_file or "
        << "--probabilities_file is specified." << endl;

  if (!CLI::HasParam("training") && !CLI::HasParam("input_model"))
    Log::Fatal << "One of --training_file or --input_model_file must be "
        << "specified!" << endl;

  if (CLI::HasParam("training") && !CLI::HasParam("labels"))
    Log::Fatal << "If --training_file is specified, --labels_file must be "
        << "specified too!" << endl;

  if (!CLI::HasParam("training") && CLI::HasParam("batch_mode"))
    Log::Warn << "--batch_mode (-b) ignored; no training set provided." << endl;

  if (CLI::HasParam("passes") && CLI::HasParam("batch_mode"))
    Log::Warn << "--batch_mode (-b) ignored because --passes was specified."
        << endl;

  if (CLI::HasParam("test") && !CLI::HasParam("predictions") &&
      !CLI::HasParam("probabilities") && !CLI::HasParam("test_labels"))
    Log::Warn << "--test_file (-T) is specified, but none of "
        << "--predictions_file (-p), --probabilities_file (-P), or "
        << "--test_labels_file (-L) are specified, so no output will be given!"
        << endl;

  if ((numericSplitStrategy != "domingos") &&
      (numericSplitStrategy != "binary"))
  {
    Log::Fatal << "Unrecognized numeric split strategy ("
        << numericSplitStrategy << ")!  Must be 'domingos' or 'binary'."
        << endl;
  }

  // Do we need to load a model or do we already have one?
  HoeffdingTreeModel model;
  DatasetInfo datasetInfo;
  arma::mat trainingSet;
  arma::Mat<size_t> labels;
  if (CLI::HasParam("input_model"))
  {
    model = std::move(CLI::GetParam<HoeffdingTreeModel>("input_model"));
  }
  else
  {
    // Initialize a model.
    if (!CLI::HasParam("info_gain") && (numericSplitStrategy == "domingos"))
      model = HoeffdingTreeModel(HoeffdingTreeModel::GINI_HOEFFDING);
    else if (!CLI::HasParam("info_gain") && (numericSplitStrategy == "binary"))
      model = HoeffdingTreeModel(HoeffdingTreeModel::GINI_BINARY);
    else if (CLI::HasParam("info_gain") && (numericSplitStrategy == "domingos"))
      model = HoeffdingTreeModel(HoeffdingTreeModel::INFO_HOEFFDING);
    else if (CLI::HasParam("info_gain") && (numericSplitStrategy == "binary"))
      model = HoeffdingTreeModel(HoeffdingTreeModel::INFO_BINARY);
  }

  // Now, do we need to train?
  if (CLI::HasParam("training"))
  {
    // Load necessary parameters for training.
    const double confidence = CLI::GetParam<double>("confidence");
    const size_t maxSamples = (size_t) CLI::GetParam<int>("max_samples");
    const size_t minSamples = (size_t) CLI::GetParam<int>("min_samples");
    bool batchTraining = CLI::HasParam("batch_mode");
    const size_t bins = (size_t) CLI::GetParam<int>("bins");
    const size_t observationsBeforeBinning = (size_t)
        CLI::GetParam<int>("observations_before_binning");
    size_t passes = (size_t) CLI::GetParam<int>("passes");
    if (passes > 1)
      batchTraining = false; // We already warned about this earlier.

    // We need to train the model.  First, load the data.
    datasetInfo = std::move(std::get<0>(CLI::GetParam<TupleType>("training")));
    trainingSet = std::move(std::get<1>(CLI::GetParam<TupleType>("training")));
    for (size_t i = 0; i < trainingSet.n_rows; ++i)
      Log::Info << datasetInfo.NumMappings(i) << " mappings in dimension "
          << i << "." << endl;

    labels = CLI::GetParam<arma::Mat<size_t>>("labels");

    if (labels.n_rows > 1)
      labels = labels.t();
    if (labels.n_rows > 1)
      Log::Fatal << "Labels must be one-dimensional!" << endl;

    // Next, create the model with the right type.  Then build the tree with the
    // appropriate type of instantiated numeric split type.  This is a little
    // bit ugly.  Maybe there is a nicer way to get this numeric split
    // information to the trees, but this is ok for now.
    Timer::Start("tree_training");

    // Do we need to initialize a model?
    if (!CLI::HasParam("input_model"))
    {
      // Build the model.
      model.BuildModel(trainingSet, datasetInfo, labels.row(0),
          arma::max(labels.row(0)) + 1, batchTraining, confidence, maxSamples,
          100, minSamples, bins, observationsBeforeBinning);
      --passes; // This model-building takes one pass.
    }

    // Now pass over the trees as many times as we need to.
    if (batchTraining)
    {
      // We only need to do batch training if we've not already called
      // BuildModel.
      if (CLI::HasParam("input_model"))
        model.Train(trainingSet, labels.row(0), true);
    }
    else
    {
      for (size_t p = 0; p < passes; ++p)
        model.Train(trainingSet, labels.row(0), false);
    }

    Timer::Stop("tree_training");
  }

  // Do we need to evaluate the training set error?
  if (CLI::HasParam("training"))
  {
    // Get training error.
    arma::Row<size_t> predictions;
    model.Classify(trainingSet, predictions);

    size_t correct = 0;
    for (size_t i = 0; i < labels.n_elem; ++i)
      if (labels[i] == predictions[i])
        ++correct;

    Log::Info << correct << " out of " << labels.n_elem << " correct "
        << "on training set (" << double(correct) / double(labels.n_elem) *
        100.0 << ")." << endl;
  }

  // Get the number of nodes in the tree.
  Log::Info << model.NumNodes() << " nodes in the tree." << endl;

  // The tree is trained or loaded.  Now do any testing if we need.
  if (CLI::HasParam("test"))
  {
    // Before loading, pre-set the dataset info by getting the raw parameter
    // (that doesn't call data::Load()).
    std::get<0>(CLI::GetRawParam<TupleType>("test")) = datasetInfo;
    arma::mat testSet = std::get<1>(CLI::GetParam<TupleType>("test"));

    arma::Row<size_t> predictions;
    arma::rowvec probabilities;

    Timer::Start("tree_testing");
    model.Classify(testSet, predictions, probabilities);
    Timer::Stop("tree_testing");

    if (CLI::HasParam("test_labels"))
    {
      arma::Mat<size_t> testLabels =
          std::move(CLI::GetParam<arma::Mat<size_t>>("test_labels"));
      if (testLabels.n_rows > 1)
        testLabels = testLabels.t();
      if (testLabels.n_rows > 1)
        Log::Fatal << "Test labels must be one-dimensional!" << endl;

      size_t correct = 0;
      for (size_t i = 0; i < testLabels.n_elem; ++i)
      {
        if (predictions[i] == testLabels[i])
          ++correct;
      }
      Log::Info << correct << " out of " << testLabels.n_elem << " correct "
          << "on test set (" << double(correct) / double(testLabels.n_elem) *
          100.0 << ")." << endl;
    }

    if (CLI::HasParam("predictions"))
      CLI::GetParam<arma::Mat<size_t>>("predictions") = std::move(predictions);

    if (CLI::HasParam("probabilities"))
      CLI::GetParam<arma::mat>("probabilities") = std::move(probabilities);
  }

  // Check the accuracy on the training set.
  if (CLI::HasParam("output_model"))
    CLI::GetParam<HoeffdingTreeModel>("output_model") = std::move(model);

  CLI::Destroy();
}