//Train model with a regression dataset void CARTTrainer::train(ModelType& model, RegressionDataset const& dataset) { //Store the number of input dimensions m_inputDimension = inputDimension(dataset); //Pass input dimension (i.e., number of attributes) to tree model model.setInputDimension(m_inputDimension); //Store the size of the labels m_labelDimension = labelDimension(dataset); // create cross-validation folds RegressionDataset set=dataset; CVFolds<RegressionDataset > folds = createCVSameSize(set, m_numberOfFolds); double bestErrorRate = std::numeric_limits<double>::max(); CARTClassifier<RealVector>::TreeType bestTree; for (unsigned fold = 0; fold < m_numberOfFolds; ++fold){ //Run through all the cross validation sets RegressionDataset dataTrain = folds.training(fold); RegressionDataset dataTest = folds.validation(fold); std::size_t numTrainElements = dataTrain.numberOfElements(); AttributeTables tables = createAttributeTables(dataTrain.inputs()); std::vector < RealVector > labels(numTrainElements); boost::copy(dataTrain.labels().elements(),labels.begin()); //Build tree form this fold CARTClassifier<RealVector>::TreeType tree = buildTree(tables, dataTrain, labels, 0, dataTrain.numberOfElements()); //Add the tree to the model and prune model.setTree(tree); while(true){ //evaluate the error of current tree SquaredLoss<> loss; double error = loss.eval(dataTest.labels(), model(dataTest.inputs())); if(error < bestErrorRate){ //We have found a subtree that has a smaller error rate when tested! bestErrorRate = error; bestTree = tree; } if(tree.size() == 1) break; pruneTree(tree); model.setTree(tree); } } SHARK_CHECK(bestTree.size() > 0, "We should never set a tree that is empty."); model.setTree(bestTree); }
int main(int argc, char **argv) { RegressionDataset data; importCSV(data, "blogData_train.csv", LAST_COLUMN,1,',','#', 2<<16); LinearRegression trainer(100); LinearModel<> model; Timer time; trainer.train(model, data); double time_taken = time.stop(); SquaredLoss<> loss; cout << "Residual sum of squares:" << loss(data.labels(),model(data.inputs()))<<std::endl; cout << "Time:\n" << time_taken << endl; cout << time_taken << endl; }