int main(int argc, char** argv) { CLI::ParseCommandLine(argc, argv); // Check input parameters for validity. const string numericSplitStrategy = CLI::GetParam<string>("numeric_split_strategy"); if ((CLI::HasParam("predictions") || CLI::HasParam("probabilities")) && !CLI::HasParam("test")) Log::Fatal << "--test_file must be specified if --predictions_file or " << "--probabilities_file is specified." << endl; if (!CLI::HasParam("training") && !CLI::HasParam("input_model")) Log::Fatal << "One of --training_file or --input_model_file must be " << "specified!" << endl; if (CLI::HasParam("training") && !CLI::HasParam("labels")) Log::Fatal << "If --training_file is specified, --labels_file must be " << "specified too!" << endl; if (!CLI::HasParam("training") && CLI::HasParam("batch_mode")) Log::Warn << "--batch_mode (-b) ignored; no training set provided." << endl; if (CLI::HasParam("passes") && CLI::HasParam("batch_mode")) Log::Warn << "--batch_mode (-b) ignored because --passes was specified." << endl; if (CLI::HasParam("test") && !CLI::HasParam("predictions") && !CLI::HasParam("probabilities") && !CLI::HasParam("test_labels")) Log::Warn << "--test_file (-T) is specified, but none of " << "--predictions_file (-p), --probabilities_file (-P), or " << "--test_labels_file (-L) are specified, so no output will be given!" << endl; if ((numericSplitStrategy != "domingos") && (numericSplitStrategy != "binary")) { Log::Fatal << "Unrecognized numeric split strategy (" << numericSplitStrategy << ")! Must be 'domingos' or 'binary'." << endl; } // Do we need to load a model or do we already have one? HoeffdingTreeModel model; DatasetInfo datasetInfo; arma::mat trainingSet; arma::Mat<size_t> labels; if (CLI::HasParam("input_model")) { model = std::move(CLI::GetParam<HoeffdingTreeModel>("input_model")); } else { // Initialize a model. if (!CLI::HasParam("info_gain") && (numericSplitStrategy == "domingos")) model = HoeffdingTreeModel(HoeffdingTreeModel::GINI_HOEFFDING); else if (!CLI::HasParam("info_gain") && (numericSplitStrategy == "binary")) model = HoeffdingTreeModel(HoeffdingTreeModel::GINI_BINARY); else if (CLI::HasParam("info_gain") && (numericSplitStrategy == "domingos")) model = HoeffdingTreeModel(HoeffdingTreeModel::INFO_HOEFFDING); else if (CLI::HasParam("info_gain") && (numericSplitStrategy == "binary")) model = HoeffdingTreeModel(HoeffdingTreeModel::INFO_BINARY); } // Now, do we need to train? if (CLI::HasParam("training")) { // Load necessary parameters for training. const double confidence = CLI::GetParam<double>("confidence"); const size_t maxSamples = (size_t) CLI::GetParam<int>("max_samples"); const size_t minSamples = (size_t) CLI::GetParam<int>("min_samples"); bool batchTraining = CLI::HasParam("batch_mode"); const size_t bins = (size_t) CLI::GetParam<int>("bins"); const size_t observationsBeforeBinning = (size_t) CLI::GetParam<int>("observations_before_binning"); size_t passes = (size_t) CLI::GetParam<int>("passes"); if (passes > 1) batchTraining = false; // We already warned about this earlier. // We need to train the model. First, load the data. datasetInfo = std::move(std::get<0>(CLI::GetParam<TupleType>("training"))); trainingSet = std::move(std::get<1>(CLI::GetParam<TupleType>("training"))); for (size_t i = 0; i < trainingSet.n_rows; ++i) Log::Info << datasetInfo.NumMappings(i) << " mappings in dimension " << i << "." << endl; labels = CLI::GetParam<arma::Mat<size_t>>("labels"); if (labels.n_rows > 1) labels = labels.t(); if (labels.n_rows > 1) Log::Fatal << "Labels must be one-dimensional!" << endl; // Next, create the model with the right type. Then build the tree with the // appropriate type of instantiated numeric split type. This is a little // bit ugly. Maybe there is a nicer way to get this numeric split // information to the trees, but this is ok for now. Timer::Start("tree_training"); // Do we need to initialize a model? if (!CLI::HasParam("input_model")) { // Build the model. model.BuildModel(trainingSet, datasetInfo, labels.row(0), arma::max(labels.row(0)) + 1, batchTraining, confidence, maxSamples, 100, minSamples, bins, observationsBeforeBinning); --passes; // This model-building takes one pass. } // Now pass over the trees as many times as we need to. if (batchTraining) { // We only need to do batch training if we've not already called // BuildModel. if (CLI::HasParam("input_model")) model.Train(trainingSet, labels.row(0), true); } else { for (size_t p = 0; p < passes; ++p) model.Train(trainingSet, labels.row(0), false); } Timer::Stop("tree_training"); } // Do we need to evaluate the training set error? if (CLI::HasParam("training")) { // Get training error. arma::Row<size_t> predictions; model.Classify(trainingSet, predictions); size_t correct = 0; for (size_t i = 0; i < labels.n_elem; ++i) if (labels[i] == predictions[i]) ++correct; Log::Info << correct << " out of " << labels.n_elem << " correct " << "on training set (" << double(correct) / double(labels.n_elem) * 100.0 << ")." << endl; } // Get the number of nodes in the tree. Log::Info << model.NumNodes() << " nodes in the tree." << endl; // The tree is trained or loaded. Now do any testing if we need. if (CLI::HasParam("test")) { // Before loading, pre-set the dataset info by getting the raw parameter // (that doesn't call data::Load()). std::get<0>(CLI::GetRawParam<TupleType>("test")) = datasetInfo; arma::mat testSet = std::get<1>(CLI::GetParam<TupleType>("test")); arma::Row<size_t> predictions; arma::rowvec probabilities; Timer::Start("tree_testing"); model.Classify(testSet, predictions, probabilities); Timer::Stop("tree_testing"); if (CLI::HasParam("test_labels")) { arma::Mat<size_t> testLabels = std::move(CLI::GetParam<arma::Mat<size_t>>("test_labels")); if (testLabels.n_rows > 1) testLabels = testLabels.t(); if (testLabels.n_rows > 1) Log::Fatal << "Test labels must be one-dimensional!" << endl; size_t correct = 0; for (size_t i = 0; i < testLabels.n_elem; ++i) { if (predictions[i] == testLabels[i]) ++correct; } Log::Info << correct << " out of " << testLabels.n_elem << " correct " << "on test set (" << double(correct) / double(testLabels.n_elem) * 100.0 << ")." << endl; } if (CLI::HasParam("predictions")) CLI::GetParam<arma::Mat<size_t>>("predictions") = std::move(predictions); if (CLI::HasParam("probabilities")) CLI::GetParam<arma::mat>("probabilities") = std::move(probabilities); } // Check the accuracy on the training set. if (CLI::HasParam("output_model")) CLI::GetParam<HoeffdingTreeModel>("output_model") = std::move(model); CLI::Destroy(); }