Beispiel #1
0
int main(int argc, char *argv[])
{
  CLI::ParseCommandLine(argc, argv);

  // Validate input parameters.
  if (CLI::HasParam("training_file") && CLI::HasParam("input_model_file"))
    Log::Fatal << "Only one of --training_file (-t) or --input_model_file (-m) "
        << "may be specified!" << endl;

  if (!CLI::HasParam("training_file") && !CLI::HasParam("input_model_file"))
    Log::Fatal << "Neither --training_file (-t) nor --input_model_file (-m) "
        << "are specified!" << endl;

  if (!CLI::HasParam("training_file"))
  {
    if (CLI::HasParam("training_set_estimates_file"))
      Log::Warn << "--training_set_estimates_file (-e) ignored because "
          << "--training_file (-t) is not specified." << endl;
    if (CLI::HasParam("folds"))
      Log::Warn << "--folds (-f) ignored because --training_file (-t) is not "
          << "specified." << endl;
    if (CLI::HasParam("min_leaf_size"))
      Log::Warn << "--min_leaf_size (-l) ignored because --training_file (-t) "
          << "is not specified." << endl;
    if (CLI::HasParam("max_leaf_size"))
      Log::Warn << "--max_leaf_size (-L) ignored because --training_file (-t) "
          << "is not specified." << endl;
  }

  if (!CLI::HasParam("test_file") && CLI::HasParam("test_set_estimates_file"))
    Log::Warn << "--test_set_estimates_file (-E) ignored because --test_file "
        << "(-T) is not specified." << endl;

  // Are we training a DET or loading from file?
  DTree* tree;
  if (CLI::HasParam("training_file"))
  {
    const string trainSetFile = CLI::GetParam<string>("training_file");
    arma::mat trainingData;
    data::Load(trainSetFile, trainingData, true);

    // Cross-validation here.
    size_t folds = CLI::GetParam<int>("folds");
    if (folds == 0)
    {
      folds = trainingData.n_cols;
      Log::Info << "Performing leave-one-out cross validation." << endl;
    }
    else
    {
      Log::Info << "Performing " << folds << "-fold cross validation." << endl;
    }

    const bool regularization = false;
//    const bool regularization = CLI::HasParam("volume_regularization");
    const int maxLeafSize = CLI::GetParam<int>("max_leaf_size");
    const int minLeafSize = CLI::GetParam<int>("min_leaf_size");

    // Obtain the optimal tree.
    Timer::Start("det_training");
    tree = Trainer(trainingData, folds, regularization, maxLeafSize,
        minLeafSize, "");
    Timer::Stop("det_training");

    // Compute training set estimates, if desired.
    if (CLI::GetParam<string>("training_set_estimates_file") != "")
    {
      // Compute density estimates for each point in the training set.
      arma::rowvec trainingDensities(trainingData.n_cols);
      Timer::Start("det_estimation_time");
      for (size_t i = 0; i < trainingData.n_cols; i++)
        trainingDensities[i] = tree->ComputeValue(trainingData.unsafe_col(i));
      Timer::Stop("det_estimation_time");

      data::Save(CLI::GetParam<string>("training_set_estimates_file"),
          trainingDensities);
    }
  }
  else
  {
    data::Load(CLI::GetParam<string>("input_model_file"), "det_model", tree,
        true);
  }

  // Compute the density at the provided test points and output the density in
  // the given file.
  const string testFile = CLI::GetParam<string>("test_file");
  if (testFile != "")
  {
    arma::mat testData;
    data::Load(testFile, testData, true);

    // Compute test set densities.
    Timer::Start("det_test_set_estimation");
    arma::rowvec testDensities(testData.n_cols);
    for (size_t i = 0; i < testData.n_cols; i++)
      testDensities[i] = tree->ComputeValue(testData.unsafe_col(i));
    Timer::Stop("det_test_set_estimation");

    if (CLI::GetParam<string>("test_set_estimates_file") != "")
      data::Save(CLI::GetParam<string>("test_set_estimates_file"),
          testDensities);
  }

  // Print variable importance.
  if (CLI::HasParam("vi_file"))
    PrintVariableImportance(tree, CLI::GetParam<string>("vi_file"));

  // Save the model, if desired.
  if (CLI::HasParam("output_model_file"))
    data::Save(CLI::GetParam<string>("output_model_file"), "det_model", tree,
        false);

  delete tree;
}
Beispiel #2
0
static void mlpackMain()
{
  // Validate input parameters.
  RequireOnlyOnePassed({ "training", "input_model" }, true);

  ReportIgnoredParam({{ "training", false }}, "training_set_estimates");
  ReportIgnoredParam({{ "training", false }}, "folds");
  ReportIgnoredParam({{ "training", false }}, "min_leaf_size");
  ReportIgnoredParam({{ "training", false }}, "max_leaf_size");

  if (CLI::HasParam("tag_file"))
    RequireAtLeastOnePassed({ "training", "test" }, true);

  if (CLI::HasParam("training"))
  {
    RequireAtLeastOnePassed({ "output_model", "training_set_estimates", "vi",
        "tag_file", "tag_counters_file" }, false, "no output will be saved");
  }

  ReportIgnoredParam({{ "test", false }}, "test_set_estimates");

  RequireParamValue<int>("max_leaf_size", [](int x) { return x > 0; }, true,
      "maximum leaf size must be positive");
  RequireParamValue<int>("min_leaf_size", [](int x) { return x > 0; }, true,
      "minimum leaf size must be positive");

  // Are we training a DET or loading from file?
  DTree<arma::mat, int>* tree;
  arma::mat trainingData;
  arma::mat testData;

  if (CLI::HasParam("training"))
  {
    trainingData = std::move(CLI::GetParam<arma::mat>("training"));

    const bool regularization = false;
//    const bool regularization = CLI::HasParam("volume_regularization");
    const int maxLeafSize = CLI::GetParam<int>("max_leaf_size");
    const int minLeafSize = CLI::GetParam<int>("min_leaf_size");
    const bool skipPruning = CLI::HasParam("skip_pruning");
    size_t folds = CLI::GetParam<int>("folds");

    if (folds == 0)
      folds = trainingData.n_cols;

    // Obtain the optimal tree.
    Timer::Start("det_training");
    tree = Trainer<arma::mat, int>(trainingData, folds, regularization,
                                   maxLeafSize, minLeafSize,
                                   skipPruning);
    Timer::Stop("det_training");

    // Compute training set estimates, if desired.
    if (CLI::HasParam("training_set_estimates"))
    {
      // Compute density estimates for each point in the training set.
      arma::rowvec trainingDensities(trainingData.n_cols);
      Timer::Start("det_estimation_time");
      for (size_t i = 0; i < trainingData.n_cols; i++)
        trainingDensities[i] = tree->ComputeValue(trainingData.unsafe_col(i));
      Timer::Stop("det_estimation_time");

      CLI::GetParam<arma::mat>("training_set_estimates") =
          std::move(trainingDensities);
    }
  }
  else
  {
    tree = &CLI::GetParam<DTree<arma::mat>>("input_model");
  }

  // Compute the density at the provided test points and output the density in
  // the given file.
  if (CLI::HasParam("test"))
  {
    testData = std::move(CLI::GetParam<arma::mat>("test"));
    if (CLI::HasParam("test_set_estimates"))
    {
      // Compute test set densities.
      Timer::Start("det_test_set_estimation");
      arma::rowvec testDensities(testData.n_cols);

      for (size_t i = 0; i < testData.n_cols; i++)
        testDensities[i] = tree->ComputeValue(testData.unsafe_col(i));

      Timer::Stop("det_test_set_estimation");

      CLI::GetParam<arma::mat>("test_set_estimates") = std::move(testDensities);
    }

    // Print variable importance.
    if (CLI::HasParam("vi"))
    {
      arma::vec importances;
      tree->ComputeVariableImportance(importances);
      CLI::GetParam<arma::mat>("vi") = importances.t();
    }
  }

  if (CLI::HasParam("tag_file"))
  {
    const arma::mat& estimationData =
        CLI::HasParam("test") ? testData : trainingData;
    const string tagFile = CLI::GetParam<string>("tag_file");
    std::ofstream ofs;
    ofs.open(tagFile, std::ofstream::out);

    arma::Row<size_t> counters;

    Timer::Start("det_test_set_tagging");
    if (!ofs.is_open())
    {
      Log::Warn << "Unable to open file '" << tagFile
        << "' to save tag membership info."
        << std::endl;
    }
    else if (CLI::HasParam("path_format"))
    {
      const bool reqCounters = CLI::HasParam("tag_counters_file");
      const string pathFormat = CLI::GetParam<string>("path_format");

      PathCacher::PathFormat theFormat;
      if (pathFormat == "lr" || pathFormat == "LR")
        theFormat = PathCacher::FormatLR;
      else if (pathFormat == "lr-id" || pathFormat == "LR-ID")
        theFormat = PathCacher::FormatLR_ID;
      else if (pathFormat == "id-lr" || pathFormat == "ID-LR")
        theFormat = PathCacher::FormatID_LR;
      else
      {
        Log::Warn << "Unknown path format specified: '" << pathFormat
          << "'. Valid are: lr | lr-id | id-lr. Defaults to 'lr'." << endl;
        theFormat = PathCacher::FormatLR;
      }

      PathCacher path(theFormat, tree);
      counters.zeros(path.NumNodes());

      for (size_t i = 0; i < estimationData.n_cols; i++)
      {
        int tag = tree->FindBucket(estimationData.unsafe_col(i));

        ofs << tag << " " << path.PathFor(tag) << std::endl;
        for (; tag >= 0 && reqCounters; tag = path.ParentOf(tag))
          counters(tag) += 1;
      }

      ofs.close();

      if (reqCounters)
      {
        ofs.open(CLI::GetParam<string>("tag_counters_file"),
                 std::ofstream::out);

        for (size_t j = 0; j < counters.n_elem; ++j)
          ofs << j << " "
              << counters(j) << " "
              << path.PathFor(j) << endl;

        ofs.close();
      }
    }
    else
    {
      int numLeaves = tree->TagTree();
      counters.zeros(numLeaves);

      for (size_t i = 0; i < estimationData.n_cols; i++)
      {
        const int tag = tree->FindBucket(estimationData.unsafe_col(i));

        ofs << tag << std::endl;
        counters(tag) += 1;
      }

      if (CLI::HasParam("tag_counters_file"))
        data::Save(CLI::GetParam<string>("tag_counters_file"), counters);
    }

    Timer::Stop("det_test_set_tagging");
    ofs.close();
  }

  // Save the model, if desired.
  if (CLI::HasParam("output_model"))
    CLI::GetParam<DTree<arma::mat>>("output_model") = std::move(*tree);

  // Clean up memory, if we need to.
  if (!CLI::HasParam("input_model") && !CLI::HasParam("output_model"))
    delete tree;
}
Beispiel #3
0
// This function trains the optimal decision tree using the given number of
// folds.
DTree* mlpack::det::Trainer(arma::mat& dataset,
                            const size_t folds,
                            const bool useVolumeReg,
                            const size_t maxLeafSize,
                            const size_t minLeafSize,
                            const std::string unprunedTreeOutput)
{
  // Initialize the tree.
  DTree* dtree = new DTree(dataset);

  // Prepare to grow the tree...
  arma::Col<size_t> oldFromNew(dataset.n_cols);
  for (size_t i = 0; i < oldFromNew.n_elem; i++)
    oldFromNew[i] = i;

  // Save the dataset since it would be modified while growing the tree.
  arma::mat newDataset(dataset);

  // Growing the tree
  double oldAlpha = 0.0;
  double alpha = dtree->Grow(newDataset, oldFromNew, useVolumeReg, maxLeafSize,
      minLeafSize);

  Log::Info << dtree->SubtreeLeaves() << " leaf nodes in the tree using full "
      << "dataset; minimum alpha: " << alpha << "." << std::endl;

  // Compute densities for the training points in the full tree, if we were
  // asked for this.
  if (unprunedTreeOutput != "")
  {
    std::ofstream outfile(unprunedTreeOutput.c_str());
    if (outfile.good())
    {
      for (size_t i = 0; i < dataset.n_cols; ++i)
      {
        arma::vec testPoint = dataset.unsafe_col(i);
        outfile << dtree->ComputeValue(testPoint) << std::endl;
      }
    }
    else
    {
      Log::Warn << "Can't open '" << unprunedTreeOutput << "' to write computed"
          << " densities to." << std::endl;
    }

    outfile.close();
  }

  // Sequentially prune and save the alpha values and the values of c_t^2 * r_t.
  std::vector<std::pair<double, double> > prunedSequence;
  while (dtree->SubtreeLeaves() > 1)
  {
    std::pair<double, double> treeSeq(oldAlpha,
        dtree->SubtreeLeavesLogNegError());
    prunedSequence.push_back(treeSeq);
    oldAlpha = alpha;
    alpha = dtree->PruneAndUpdate(oldAlpha, dataset.n_cols, useVolumeReg);

    // Some sanity checks.
    Log::Assert((alpha < std::numeric_limits<double>::max()) ||
        (dtree->SubtreeLeaves() == 1));
    Log::Assert(alpha > oldAlpha);
    Log::Assert(dtree->SubtreeLeavesLogNegError() < treeSeq.second);
  }

  std::pair<double, double> treeSeq(oldAlpha,
      dtree->SubtreeLeavesLogNegError());
  prunedSequence.push_back(treeSeq);

  Log::Info << prunedSequence.size() << " trees in the sequence; maximum alpha:"
      << " " << oldAlpha << "." << std::endl;

  delete dtree;

  arma::mat cvData(dataset);
  size_t testSize = dataset.n_cols / folds;

  std::vector<double> regularizationConstants;
  regularizationConstants.resize(prunedSequence.size(), 0);

  // Go through each fold.
  for (size_t fold = 0; fold < folds; fold++)
  {
    // Break up data into train and test sets.
    size_t start = fold * testSize;
    size_t end = std::min((fold + 1) * testSize, (size_t) cvData.n_cols);

    arma::mat test = cvData.cols(start, end - 1);
    arma::mat train(cvData.n_rows, cvData.n_cols - test.n_cols);

    if (start == 0 && end < cvData.n_cols)
    {
      train.cols(0, train.n_cols - 1) = cvData.cols(end, cvData.n_cols - 1);
    }
    else if (start > 0 && end == cvData.n_cols)
    {
      train.cols(0, train.n_cols - 1) = cvData.cols(0, start - 1);
    }
    else
    {
      train.cols(0, start - 1) = cvData.cols(0, start - 1);
      train.cols(start, train.n_cols - 1) = cvData.cols(end, cvData.n_cols - 1);
    }

    // Initialize the tree.
    DTree* cvDTree = new DTree(train);

    // Getting ready to grow the tree...
    arma::Col<size_t> cvOldFromNew(train.n_cols);
    for (size_t i = 0; i < cvOldFromNew.n_elem; i++)
      cvOldFromNew[i] = i;

    // Grow the tree.
    oldAlpha = 0.0;
    alpha = cvDTree->Grow(train, cvOldFromNew, useVolumeReg, maxLeafSize,
        minLeafSize);

    // Sequentially prune with all the values of available alphas and adding
    // values for test values.  Don't enter this loop if there are less than two
    // trees in the pruned sequence.
    for (size_t i = 0;
         i < ((prunedSequence.size() < 2) ? 0 : prunedSequence.size() - 2); ++i)
    {
      // Compute test values for this state of the tree.
      double cvVal = 0.0;
      for (size_t j = 0; j < test.n_cols; j++)
      {
        arma::vec testPoint = test.unsafe_col(j);
        cvVal += cvDTree->ComputeValue(testPoint);
      }

      // Update the cv regularization constant.
      regularizationConstants[i] += 2.0 * cvVal / (double) dataset.n_cols;

      // Determine the new alpha value and prune accordingly.
      oldAlpha = 0.5 * (prunedSequence[i + 1].first +
          prunedSequence[i + 2].first);
      alpha = cvDTree->PruneAndUpdate(oldAlpha, train.n_cols, useVolumeReg);
    }

    // Compute test values for this state of the tree.
    double cvVal = 0.0;
    for (size_t i = 0; i < test.n_cols; ++i)
    {
      arma::vec testPoint = test.unsafe_col(i);
      cvVal += cvDTree->ComputeValue(testPoint);
    }

    if (prunedSequence.size() > 2)
      regularizationConstants[prunedSequence.size() - 2] += 2.0 * cvVal /
          (double) dataset.n_cols;

    test.reset();
    delete cvDTree;
  }

  double optimalAlpha = -1.0;
  long double cvBestError = -std::numeric_limits<long double>::max();

  for (size_t i = 0; i < prunedSequence.size() - 1; ++i)
  {
    // We can no longer work in the log-space for this because we have no
    // guarantee the quantity will be positive.
    long double thisError = -std::exp((long double) prunedSequence[i].second) +
        (long double) regularizationConstants[i];

    if (thisError > cvBestError)
    {
      cvBestError = thisError;
      optimalAlpha = prunedSequence[i].first;
    }
  }

  Log::Info << "Optimal alpha: " << optimalAlpha << "." << std::endl;

  // Initialize the tree.
  DTree* dtreeOpt = new DTree(dataset);

  // Getting ready to grow the tree...
  for (size_t i = 0; i < oldFromNew.n_elem; i++)
    oldFromNew[i] = i;

  // Save the dataset since it would be modified while growing the tree.
  newDataset = dataset;

  // Grow the tree.
  oldAlpha = -DBL_MAX;
  alpha = dtreeOpt->Grow(newDataset, oldFromNew, useVolumeReg, maxLeafSize,
      minLeafSize);

  // Prune with optimal alpha.
  while ((oldAlpha < optimalAlpha) && (dtreeOpt->SubtreeLeaves() > 1))
  {
    oldAlpha = alpha;
    alpha = dtreeOpt->PruneAndUpdate(oldAlpha, newDataset.n_cols, useVolumeReg);

    // Some sanity checks.
    Log::Assert((alpha < std::numeric_limits<double>::max()) ||
        (dtreeOpt->SubtreeLeaves() == 1));
    Log::Assert(alpha > oldAlpha);
  }

  Log::Info << dtreeOpt->SubtreeLeaves() << " leaf nodes in the optimally "
      << "pruned tree; optimal alpha: " << oldAlpha << "." << std::endl;

  return dtreeOpt;
}
Beispiel #4
0
int main(int argc, char *argv[])
{
  CLI::ParseCommandLine(argc, argv);

  string trainSetFile = CLI::GetParam<string>("train_file");
  arma::Mat<double> trainingData;

  data::Load(trainSetFile, trainingData, true);

  // Cross-validation here.
  size_t folds = CLI::GetParam<int>("folds");
  if (folds == 0)
  {
    folds = trainingData.n_cols;
    Log::Info << "Performing leave-one-out cross validation." << endl;
  }
  else
  {
    Log::Info << "Performing " << folds << "-fold cross validation." << endl;
  }

  const string unprunedTreeEstimateFile =
      CLI::GetParam<string>("unpruned_tree_estimates_file");
  const bool regularization = false;
//  const bool regularization = CLI::HasParam("volume_regularization");
  const int maxLeafSize = CLI::GetParam<int>("max_leaf_size");
  const int minLeafSize = CLI::GetParam<int>("min_leaf_size");

  // Obtain the optimal tree.
  Timer::Start("det_training");
  DTree *dtreeOpt = Trainer(trainingData, folds, regularization, maxLeafSize,
      minLeafSize, unprunedTreeEstimateFile);
  Timer::Stop("det_training");

  // Compute densities for the training points in the optimal tree.
  FILE *fp = NULL;

  if (CLI::GetParam<string>("training_set_estimate_file") != "")
  {
    fp = fopen(CLI::GetParam<string>("training_set_estimate_file").c_str(),
        "w");

    // Compute density estimates for each point in the training set.
    Timer::Start("det_estimation_time");
    for (size_t i = 0; i < trainingData.n_cols; i++)
      fprintf(fp, "%lg\n", dtreeOpt->ComputeValue(trainingData.unsafe_col(i)));
    Timer::Stop("det_estimation_time");

    fclose(fp);
  }

  // Compute the density at the provided test points and output the density in
  // the given file.
  const string testFile = CLI::GetParam<string>("test_file");
  if (testFile != "")
  {
    arma::mat testData;
    data::Load(testFile, testData, true);

    fp = NULL;

    if (CLI::GetParam<string>("test_set_estimates_file") != "")
    {
      fp = fopen(CLI::GetParam<string>("test_set_estimates_file").c_str(), "w");

      Timer::Start("det_test_set_estimation");
      for (size_t i = 0; i < testData.n_cols; i++)
        fprintf(fp, "%lg\n", dtreeOpt->ComputeValue(testData.unsafe_col(i)));
      Timer::Stop("det_test_set_estimation");

      fclose(fp);
    }
  }

  // Print the final tree.
  if (CLI::HasParam("print_tree"))
  {
    fp = NULL;
    if (CLI::GetParam<string>("tree_file") != "")
    {
      fp = fopen(CLI::GetParam<string>("tree_file").c_str(), "w");

      if (fp != NULL)
      {
        dtreeOpt->WriteTree(fp);
        fclose(fp);
      }
    }
    else
    {
      dtreeOpt->WriteTree(stdout);
      printf("\n");
    }
  }

  // Print the leaf memberships for the optimal tree.
  if (CLI::GetParam<string>("labels_file") != "")
  {
    std::string labelsFile = CLI::GetParam<string>("labels_file");
    arma::Mat<size_t> labels;

    data::Load(labelsFile, labels, true);

    size_t numClasses = 0;
    for (size_t i = 0; i < labels.n_elem; ++i)
    {
      if (labels[i] > numClasses)
        numClasses = labels[i];
    }

    Log::Info << numClasses << " found in labels file '" << labelsFile << "'."
        << std::endl;

    Log::Assert(trainingData.n_cols == labels.n_cols);
    Log::Assert(labels.n_rows == 1);

    PrintLeafMembership(dtreeOpt, trainingData, labels, numClasses,
       CLI::GetParam<string>("leaf_class_table_file"));
  }

  // Print variable importance.
  if (CLI::HasParam("print_vi"))
  {
    PrintVariableImportance(dtreeOpt, CLI::GetParam<string>("vi_file"));
  }

  delete dtreeOpt;
}