int main(int argc, char *argv[]) { CLI::ParseCommandLine(argc, argv); // Validate input parameters. if (CLI::HasParam("training_file") && CLI::HasParam("input_model_file")) Log::Fatal << "Only one of --training_file (-t) or --input_model_file (-m) " << "may be specified!" << endl; if (!CLI::HasParam("training_file") && !CLI::HasParam("input_model_file")) Log::Fatal << "Neither --training_file (-t) nor --input_model_file (-m) " << "are specified!" << endl; if (!CLI::HasParam("training_file")) { if (CLI::HasParam("training_set_estimates_file")) Log::Warn << "--training_set_estimates_file (-e) ignored because " << "--training_file (-t) is not specified." << endl; if (CLI::HasParam("folds")) Log::Warn << "--folds (-f) ignored because --training_file (-t) is not " << "specified." << endl; if (CLI::HasParam("min_leaf_size")) Log::Warn << "--min_leaf_size (-l) ignored because --training_file (-t) " << "is not specified." << endl; if (CLI::HasParam("max_leaf_size")) Log::Warn << "--max_leaf_size (-L) ignored because --training_file (-t) " << "is not specified." << endl; } if (!CLI::HasParam("test_file") && CLI::HasParam("test_set_estimates_file")) Log::Warn << "--test_set_estimates_file (-E) ignored because --test_file " << "(-T) is not specified." << endl; // Are we training a DET or loading from file? DTree* tree; if (CLI::HasParam("training_file")) { const string trainSetFile = CLI::GetParam<string>("training_file"); arma::mat trainingData; data::Load(trainSetFile, trainingData, true); // Cross-validation here. size_t folds = CLI::GetParam<int>("folds"); if (folds == 0) { folds = trainingData.n_cols; Log::Info << "Performing leave-one-out cross validation." << endl; } else { Log::Info << "Performing " << folds << "-fold cross validation." << endl; } const bool regularization = false; // const bool regularization = CLI::HasParam("volume_regularization"); const int maxLeafSize = CLI::GetParam<int>("max_leaf_size"); const int minLeafSize = CLI::GetParam<int>("min_leaf_size"); // Obtain the optimal tree. Timer::Start("det_training"); tree = Trainer(trainingData, folds, regularization, maxLeafSize, minLeafSize, ""); Timer::Stop("det_training"); // Compute training set estimates, if desired. if (CLI::GetParam<string>("training_set_estimates_file") != "") { // Compute density estimates for each point in the training set. arma::rowvec trainingDensities(trainingData.n_cols); Timer::Start("det_estimation_time"); for (size_t i = 0; i < trainingData.n_cols; i++) trainingDensities[i] = tree->ComputeValue(trainingData.unsafe_col(i)); Timer::Stop("det_estimation_time"); data::Save(CLI::GetParam<string>("training_set_estimates_file"), trainingDensities); } } else { data::Load(CLI::GetParam<string>("input_model_file"), "det_model", tree, true); } // Compute the density at the provided test points and output the density in // the given file. const string testFile = CLI::GetParam<string>("test_file"); if (testFile != "") { arma::mat testData; data::Load(testFile, testData, true); // Compute test set densities. Timer::Start("det_test_set_estimation"); arma::rowvec testDensities(testData.n_cols); for (size_t i = 0; i < testData.n_cols; i++) testDensities[i] = tree->ComputeValue(testData.unsafe_col(i)); Timer::Stop("det_test_set_estimation"); if (CLI::GetParam<string>("test_set_estimates_file") != "") data::Save(CLI::GetParam<string>("test_set_estimates_file"), testDensities); } // Print variable importance. if (CLI::HasParam("vi_file")) PrintVariableImportance(tree, CLI::GetParam<string>("vi_file")); // Save the model, if desired. if (CLI::HasParam("output_model_file")) data::Save(CLI::GetParam<string>("output_model_file"), "det_model", tree, false); delete tree; }
static void mlpackMain() { // Validate input parameters. RequireOnlyOnePassed({ "training", "input_model" }, true); ReportIgnoredParam({{ "training", false }}, "training_set_estimates"); ReportIgnoredParam({{ "training", false }}, "folds"); ReportIgnoredParam({{ "training", false }}, "min_leaf_size"); ReportIgnoredParam({{ "training", false }}, "max_leaf_size"); if (CLI::HasParam("tag_file")) RequireAtLeastOnePassed({ "training", "test" }, true); if (CLI::HasParam("training")) { RequireAtLeastOnePassed({ "output_model", "training_set_estimates", "vi", "tag_file", "tag_counters_file" }, false, "no output will be saved"); } ReportIgnoredParam({{ "test", false }}, "test_set_estimates"); RequireParamValue<int>("max_leaf_size", [](int x) { return x > 0; }, true, "maximum leaf size must be positive"); RequireParamValue<int>("min_leaf_size", [](int x) { return x > 0; }, true, "minimum leaf size must be positive"); // Are we training a DET or loading from file? DTree<arma::mat, int>* tree; arma::mat trainingData; arma::mat testData; if (CLI::HasParam("training")) { trainingData = std::move(CLI::GetParam<arma::mat>("training")); const bool regularization = false; // const bool regularization = CLI::HasParam("volume_regularization"); const int maxLeafSize = CLI::GetParam<int>("max_leaf_size"); const int minLeafSize = CLI::GetParam<int>("min_leaf_size"); const bool skipPruning = CLI::HasParam("skip_pruning"); size_t folds = CLI::GetParam<int>("folds"); if (folds == 0) folds = trainingData.n_cols; // Obtain the optimal tree. Timer::Start("det_training"); tree = Trainer<arma::mat, int>(trainingData, folds, regularization, maxLeafSize, minLeafSize, skipPruning); Timer::Stop("det_training"); // Compute training set estimates, if desired. if (CLI::HasParam("training_set_estimates")) { // Compute density estimates for each point in the training set. arma::rowvec trainingDensities(trainingData.n_cols); Timer::Start("det_estimation_time"); for (size_t i = 0; i < trainingData.n_cols; i++) trainingDensities[i] = tree->ComputeValue(trainingData.unsafe_col(i)); Timer::Stop("det_estimation_time"); CLI::GetParam<arma::mat>("training_set_estimates") = std::move(trainingDensities); } } else { tree = &CLI::GetParam<DTree<arma::mat>>("input_model"); } // Compute the density at the provided test points and output the density in // the given file. if (CLI::HasParam("test")) { testData = std::move(CLI::GetParam<arma::mat>("test")); if (CLI::HasParam("test_set_estimates")) { // Compute test set densities. Timer::Start("det_test_set_estimation"); arma::rowvec testDensities(testData.n_cols); for (size_t i = 0; i < testData.n_cols; i++) testDensities[i] = tree->ComputeValue(testData.unsafe_col(i)); Timer::Stop("det_test_set_estimation"); CLI::GetParam<arma::mat>("test_set_estimates") = std::move(testDensities); } // Print variable importance. if (CLI::HasParam("vi")) { arma::vec importances; tree->ComputeVariableImportance(importances); CLI::GetParam<arma::mat>("vi") = importances.t(); } } if (CLI::HasParam("tag_file")) { const arma::mat& estimationData = CLI::HasParam("test") ? testData : trainingData; const string tagFile = CLI::GetParam<string>("tag_file"); std::ofstream ofs; ofs.open(tagFile, std::ofstream::out); arma::Row<size_t> counters; Timer::Start("det_test_set_tagging"); if (!ofs.is_open()) { Log::Warn << "Unable to open file '" << tagFile << "' to save tag membership info." << std::endl; } else if (CLI::HasParam("path_format")) { const bool reqCounters = CLI::HasParam("tag_counters_file"); const string pathFormat = CLI::GetParam<string>("path_format"); PathCacher::PathFormat theFormat; if (pathFormat == "lr" || pathFormat == "LR") theFormat = PathCacher::FormatLR; else if (pathFormat == "lr-id" || pathFormat == "LR-ID") theFormat = PathCacher::FormatLR_ID; else if (pathFormat == "id-lr" || pathFormat == "ID-LR") theFormat = PathCacher::FormatID_LR; else { Log::Warn << "Unknown path format specified: '" << pathFormat << "'. Valid are: lr | lr-id | id-lr. Defaults to 'lr'." << endl; theFormat = PathCacher::FormatLR; } PathCacher path(theFormat, tree); counters.zeros(path.NumNodes()); for (size_t i = 0; i < estimationData.n_cols; i++) { int tag = tree->FindBucket(estimationData.unsafe_col(i)); ofs << tag << " " << path.PathFor(tag) << std::endl; for (; tag >= 0 && reqCounters; tag = path.ParentOf(tag)) counters(tag) += 1; } ofs.close(); if (reqCounters) { ofs.open(CLI::GetParam<string>("tag_counters_file"), std::ofstream::out); for (size_t j = 0; j < counters.n_elem; ++j) ofs << j << " " << counters(j) << " " << path.PathFor(j) << endl; ofs.close(); } } else { int numLeaves = tree->TagTree(); counters.zeros(numLeaves); for (size_t i = 0; i < estimationData.n_cols; i++) { const int tag = tree->FindBucket(estimationData.unsafe_col(i)); ofs << tag << std::endl; counters(tag) += 1; } if (CLI::HasParam("tag_counters_file")) data::Save(CLI::GetParam<string>("tag_counters_file"), counters); } Timer::Stop("det_test_set_tagging"); ofs.close(); } // Save the model, if desired. if (CLI::HasParam("output_model")) CLI::GetParam<DTree<arma::mat>>("output_model") = std::move(*tree); // Clean up memory, if we need to. if (!CLI::HasParam("input_model") && !CLI::HasParam("output_model")) delete tree; }
// This function trains the optimal decision tree using the given number of // folds. DTree* mlpack::det::Trainer(arma::mat& dataset, const size_t folds, const bool useVolumeReg, const size_t maxLeafSize, const size_t minLeafSize, const std::string unprunedTreeOutput) { // Initialize the tree. DTree* dtree = new DTree(dataset); // Prepare to grow the tree... arma::Col<size_t> oldFromNew(dataset.n_cols); for (size_t i = 0; i < oldFromNew.n_elem; i++) oldFromNew[i] = i; // Save the dataset since it would be modified while growing the tree. arma::mat newDataset(dataset); // Growing the tree double oldAlpha = 0.0; double alpha = dtree->Grow(newDataset, oldFromNew, useVolumeReg, maxLeafSize, minLeafSize); Log::Info << dtree->SubtreeLeaves() << " leaf nodes in the tree using full " << "dataset; minimum alpha: " << alpha << "." << std::endl; // Compute densities for the training points in the full tree, if we were // asked for this. if (unprunedTreeOutput != "") { std::ofstream outfile(unprunedTreeOutput.c_str()); if (outfile.good()) { for (size_t i = 0; i < dataset.n_cols; ++i) { arma::vec testPoint = dataset.unsafe_col(i); outfile << dtree->ComputeValue(testPoint) << std::endl; } } else { Log::Warn << "Can't open '" << unprunedTreeOutput << "' to write computed" << " densities to." << std::endl; } outfile.close(); } // Sequentially prune and save the alpha values and the values of c_t^2 * r_t. std::vector<std::pair<double, double> > prunedSequence; while (dtree->SubtreeLeaves() > 1) { std::pair<double, double> treeSeq(oldAlpha, dtree->SubtreeLeavesLogNegError()); prunedSequence.push_back(treeSeq); oldAlpha = alpha; alpha = dtree->PruneAndUpdate(oldAlpha, dataset.n_cols, useVolumeReg); // Some sanity checks. Log::Assert((alpha < std::numeric_limits<double>::max()) || (dtree->SubtreeLeaves() == 1)); Log::Assert(alpha > oldAlpha); Log::Assert(dtree->SubtreeLeavesLogNegError() < treeSeq.second); } std::pair<double, double> treeSeq(oldAlpha, dtree->SubtreeLeavesLogNegError()); prunedSequence.push_back(treeSeq); Log::Info << prunedSequence.size() << " trees in the sequence; maximum alpha:" << " " << oldAlpha << "." << std::endl; delete dtree; arma::mat cvData(dataset); size_t testSize = dataset.n_cols / folds; std::vector<double> regularizationConstants; regularizationConstants.resize(prunedSequence.size(), 0); // Go through each fold. for (size_t fold = 0; fold < folds; fold++) { // Break up data into train and test sets. size_t start = fold * testSize; size_t end = std::min((fold + 1) * testSize, (size_t) cvData.n_cols); arma::mat test = cvData.cols(start, end - 1); arma::mat train(cvData.n_rows, cvData.n_cols - test.n_cols); if (start == 0 && end < cvData.n_cols) { train.cols(0, train.n_cols - 1) = cvData.cols(end, cvData.n_cols - 1); } else if (start > 0 && end == cvData.n_cols) { train.cols(0, train.n_cols - 1) = cvData.cols(0, start - 1); } else { train.cols(0, start - 1) = cvData.cols(0, start - 1); train.cols(start, train.n_cols - 1) = cvData.cols(end, cvData.n_cols - 1); } // Initialize the tree. DTree* cvDTree = new DTree(train); // Getting ready to grow the tree... arma::Col<size_t> cvOldFromNew(train.n_cols); for (size_t i = 0; i < cvOldFromNew.n_elem; i++) cvOldFromNew[i] = i; // Grow the tree. oldAlpha = 0.0; alpha = cvDTree->Grow(train, cvOldFromNew, useVolumeReg, maxLeafSize, minLeafSize); // Sequentially prune with all the values of available alphas and adding // values for test values. Don't enter this loop if there are less than two // trees in the pruned sequence. for (size_t i = 0; i < ((prunedSequence.size() < 2) ? 0 : prunedSequence.size() - 2); ++i) { // Compute test values for this state of the tree. double cvVal = 0.0; for (size_t j = 0; j < test.n_cols; j++) { arma::vec testPoint = test.unsafe_col(j); cvVal += cvDTree->ComputeValue(testPoint); } // Update the cv regularization constant. regularizationConstants[i] += 2.0 * cvVal / (double) dataset.n_cols; // Determine the new alpha value and prune accordingly. oldAlpha = 0.5 * (prunedSequence[i + 1].first + prunedSequence[i + 2].first); alpha = cvDTree->PruneAndUpdate(oldAlpha, train.n_cols, useVolumeReg); } // Compute test values for this state of the tree. double cvVal = 0.0; for (size_t i = 0; i < test.n_cols; ++i) { arma::vec testPoint = test.unsafe_col(i); cvVal += cvDTree->ComputeValue(testPoint); } if (prunedSequence.size() > 2) regularizationConstants[prunedSequence.size() - 2] += 2.0 * cvVal / (double) dataset.n_cols; test.reset(); delete cvDTree; } double optimalAlpha = -1.0; long double cvBestError = -std::numeric_limits<long double>::max(); for (size_t i = 0; i < prunedSequence.size() - 1; ++i) { // We can no longer work in the log-space for this because we have no // guarantee the quantity will be positive. long double thisError = -std::exp((long double) prunedSequence[i].second) + (long double) regularizationConstants[i]; if (thisError > cvBestError) { cvBestError = thisError; optimalAlpha = prunedSequence[i].first; } } Log::Info << "Optimal alpha: " << optimalAlpha << "." << std::endl; // Initialize the tree. DTree* dtreeOpt = new DTree(dataset); // Getting ready to grow the tree... for (size_t i = 0; i < oldFromNew.n_elem; i++) oldFromNew[i] = i; // Save the dataset since it would be modified while growing the tree. newDataset = dataset; // Grow the tree. oldAlpha = -DBL_MAX; alpha = dtreeOpt->Grow(newDataset, oldFromNew, useVolumeReg, maxLeafSize, minLeafSize); // Prune with optimal alpha. while ((oldAlpha < optimalAlpha) && (dtreeOpt->SubtreeLeaves() > 1)) { oldAlpha = alpha; alpha = dtreeOpt->PruneAndUpdate(oldAlpha, newDataset.n_cols, useVolumeReg); // Some sanity checks. Log::Assert((alpha < std::numeric_limits<double>::max()) || (dtreeOpt->SubtreeLeaves() == 1)); Log::Assert(alpha > oldAlpha); } Log::Info << dtreeOpt->SubtreeLeaves() << " leaf nodes in the optimally " << "pruned tree; optimal alpha: " << oldAlpha << "." << std::endl; return dtreeOpt; }
int main(int argc, char *argv[]) { CLI::ParseCommandLine(argc, argv); string trainSetFile = CLI::GetParam<string>("train_file"); arma::Mat<double> trainingData; data::Load(trainSetFile, trainingData, true); // Cross-validation here. size_t folds = CLI::GetParam<int>("folds"); if (folds == 0) { folds = trainingData.n_cols; Log::Info << "Performing leave-one-out cross validation." << endl; } else { Log::Info << "Performing " << folds << "-fold cross validation." << endl; } const string unprunedTreeEstimateFile = CLI::GetParam<string>("unpruned_tree_estimates_file"); const bool regularization = false; // const bool regularization = CLI::HasParam("volume_regularization"); const int maxLeafSize = CLI::GetParam<int>("max_leaf_size"); const int minLeafSize = CLI::GetParam<int>("min_leaf_size"); // Obtain the optimal tree. Timer::Start("det_training"); DTree *dtreeOpt = Trainer(trainingData, folds, regularization, maxLeafSize, minLeafSize, unprunedTreeEstimateFile); Timer::Stop("det_training"); // Compute densities for the training points in the optimal tree. FILE *fp = NULL; if (CLI::GetParam<string>("training_set_estimate_file") != "") { fp = fopen(CLI::GetParam<string>("training_set_estimate_file").c_str(), "w"); // Compute density estimates for each point in the training set. Timer::Start("det_estimation_time"); for (size_t i = 0; i < trainingData.n_cols; i++) fprintf(fp, "%lg\n", dtreeOpt->ComputeValue(trainingData.unsafe_col(i))); Timer::Stop("det_estimation_time"); fclose(fp); } // Compute the density at the provided test points and output the density in // the given file. const string testFile = CLI::GetParam<string>("test_file"); if (testFile != "") { arma::mat testData; data::Load(testFile, testData, true); fp = NULL; if (CLI::GetParam<string>("test_set_estimates_file") != "") { fp = fopen(CLI::GetParam<string>("test_set_estimates_file").c_str(), "w"); Timer::Start("det_test_set_estimation"); for (size_t i = 0; i < testData.n_cols; i++) fprintf(fp, "%lg\n", dtreeOpt->ComputeValue(testData.unsafe_col(i))); Timer::Stop("det_test_set_estimation"); fclose(fp); } } // Print the final tree. if (CLI::HasParam("print_tree")) { fp = NULL; if (CLI::GetParam<string>("tree_file") != "") { fp = fopen(CLI::GetParam<string>("tree_file").c_str(), "w"); if (fp != NULL) { dtreeOpt->WriteTree(fp); fclose(fp); } } else { dtreeOpt->WriteTree(stdout); printf("\n"); } } // Print the leaf memberships for the optimal tree. if (CLI::GetParam<string>("labels_file") != "") { std::string labelsFile = CLI::GetParam<string>("labels_file"); arma::Mat<size_t> labels; data::Load(labelsFile, labels, true); size_t numClasses = 0; for (size_t i = 0; i < labels.n_elem; ++i) { if (labels[i] > numClasses) numClasses = labels[i]; } Log::Info << numClasses << " found in labels file '" << labelsFile << "'." << std::endl; Log::Assert(trainingData.n_cols == labels.n_cols); Log::Assert(labels.n_rows == 1); PrintLeafMembership(dtreeOpt, trainingData, labels, numClasses, CLI::GetParam<string>("leaf_class_table_file")); } // Print variable importance. if (CLI::HasParam("print_vi")) { PrintVariableImportance(dtreeOpt, CLI::GetParam<string>("vi_file")); } delete dtreeOpt; }