//Classification void CARTTrainer::train(ModelType& model, ClassificationDataset const& dataset){ //Store the number of input dimensions m_inputDimension = inputDimension(dataset); //Pass input dimension (i.e., number of attributes) to tree model model.setInputDimension(m_inputDimension); //Find the largest label, so we know how big the histogram should be m_maxLabel = static_cast<unsigned int>(numberOfClasses(dataset))-1; // create cross-validation folds ClassificationDataset set=dataset; CVFolds<ClassificationDataset> folds = createCVSameSizeBalanced(set, m_numberOfFolds); //find the best tree for the cv folds double bestErrorRate = std::numeric_limits<double>::max(); CARTClassifier<RealVector>::TreeType bestTree; //Run through all the cross validation sets for (unsigned fold = 0; fold < m_numberOfFolds; ++fold) { ClassificationDataset dataTrain = folds.training(fold); ClassificationDataset dataTest = folds.validation(fold); //Create attribute tables //O.K. stores how often label(i) can be found in the dataset //O.K. TODO: std::vector<unsigned int> is sufficient boost::unordered_map<std::size_t, std::size_t> cAbove = createCountMatrix(dataTrain); AttributeTables tables = createAttributeTables(dataTrain.inputs()); //create initial tree for the fold CARTClassifier<RealVector>::TreeType tree = buildTree(tables, dataTrain, cAbove, 0); model.setTree(tree); while(true){ ZeroOneLoss<unsigned int, RealVector> loss; double errorRate = loss.eval(dataTest.labels(), model(dataTest.inputs())); if(errorRate < bestErrorRate){ //We have found a subtree that has a smaller error rate when tested! bestErrorRate = errorRate; bestTree = tree; } if(tree.size()!=1) break; pruneTree(tree); model.setTree(tree); } } SHARK_CHECK(bestTree.size() > 0, "We should never set a tree that is empty."); model.setTree(bestTree); }
//Train model with a regression dataset void CARTTrainer::train(ModelType& model, RegressionDataset const& dataset) { //Store the number of input dimensions m_inputDimension = inputDimension(dataset); //Pass input dimension (i.e., number of attributes) to tree model model.setInputDimension(m_inputDimension); //Store the size of the labels m_labelDimension = labelDimension(dataset); // create cross-validation folds RegressionDataset set=dataset; CVFolds<RegressionDataset > folds = createCVSameSize(set, m_numberOfFolds); double bestErrorRate = std::numeric_limits<double>::max(); CARTClassifier<RealVector>::TreeType bestTree; for (unsigned fold = 0; fold < m_numberOfFolds; ++fold){ //Run through all the cross validation sets RegressionDataset dataTrain = folds.training(fold); RegressionDataset dataTest = folds.validation(fold); std::size_t numTrainElements = dataTrain.numberOfElements(); AttributeTables tables = createAttributeTables(dataTrain.inputs()); std::vector < RealVector > labels(numTrainElements); boost::copy(dataTrain.labels().elements(),labels.begin()); //Build tree form this fold CARTClassifier<RealVector>::TreeType tree = buildTree(tables, dataTrain, labels, 0, dataTrain.numberOfElements()); //Add the tree to the model and prune model.setTree(tree); while(true){ //evaluate the error of current tree SquaredLoss<> loss; double error = loss.eval(dataTest.labels(), model(dataTest.inputs())); if(error < bestErrorRate){ //We have found a subtree that has a smaller error rate when tested! bestErrorRate = error; bestTree = tree; } if(tree.size() == 1) break; pruneTree(tree); model.setTree(tree); } } SHARK_CHECK(bestTree.size() > 0, "We should never set a tree that is empty."); model.setTree(bestTree); }
/** * @brief Cross Validation * * @return */ std::pair<TrainingSet, Prediction> ANNPredictor::predictionCV(const size_t nFolds) { using namespace std; using namespace shark; CVFolds<ClassificationDataset> folds = createCVSameSizeBalanced(*m_data, nFolds); std::pair<TrainingSet, Prediction> data; for (size_t fold = 0; fold < folds.size(); ++fold) { ClassificationDataset training = folds.training(fold); ClassificationDataset validation = folds.validation(fold); auto model = createFFNetModel(training); auto elements = validation.elements(); for (auto iter = elements.begin(); iter != elements.end(); iter++) { append(data, iter->input, iter->label, model(iter->input)); } } return std::move(data); }