void writePRCurves(const char *testName, const drwnClassifier *classifier, const drwnClassifierDataset& dataset, const string& filename) { drwnPRCurve pr; const int nClasses = dataset.maxTarget() + 1; if (nClasses == 2) { pr.accumulate(dataset, classifier); pr.writeCurve((filename + string(".txt")).c_str()); } else { // compute marginals (normalized score) vector<vector<double> > marginals; classifier->getClassMarginals(dataset.features, marginals); // compute pr curve for each class for (int c = 0; c < nClasses; c++) { pr.clear(); for (int i = 0; i < dataset.size(); i++) { if (dataset.targets[i] == c) { pr.accumulatePositives(marginals[i][c]); } else { pr.accumulateNegatives(marginals[i][c]); } } pr.writeCurve((filename + string("_") + toString(c) + string(".txt")).c_str()); } } }
// training double drwnMultiClassLogisticBase::train(const drwnClassifierDataset& dataset) { if (dataset.hasWeights()) { return train(dataset.features, dataset.targets, dataset.weights); } else { return train(dataset.features, dataset.targets); } }
// training double drwnDecisionTree::train(const drwnClassifierDataset& dataset) { if (dataset.hasWeights()) { return train(dataset.features, dataset.targets, dataset.weights); } else { return train(dataset.features, dataset.targets); } }
// training double drwnCompositeClassifier::train(const drwnClassifierDataset& dataset) { DRWN_FCN_TIC; // compute label weights vector<int> classCounts(_nClasses, 0); int totalCount = 0; for (int i = 0; i < dataset.size(); i++) { if (dataset.targets[i] >= 0) { classCounts[dataset.targets[i]] += 1; totalCount += 1; } } // learn binary classifiers vector<double> weights(dataset.size(), 1.0); vector<int> targets(dataset.size(), -1); for (unsigned i = 0; i < _binaryClassifiers.size(); i++) { delete _binaryClassifiers[i]; } _binaryClassifiers.clear(); switch (_method) { case DRWN_ONE_VS_ALL: for (int k = 0; k < _nClasses; k++) { if (classCounts[k] == 0) continue; DRWN_LOG_VERBOSE("...training <class " << (k + 1) << ">-vs-all"); const double w1 = 1.0 / (double)classCounts[k]; const double w0 = 1.0 / (double)(totalCount - classCounts[k]); for (int i = 0; i < dataset.size(); i++) { targets[i] = (dataset.targets[i] == k ? 1 : 0); weights[i] = (dataset.targets[i] == k ? w1 : w0); } drwnClassifier *c = drwnClassifierFactory::get().create(_baseClassifier.c_str()); c->initialize(_nFeatures, 2); c->train(dataset.features, targets, weights); _binaryClassifiers.push_back(c); } break; case DRWN_ONE_VS_ONE: for (int k = 0; k < _nClasses; k++) { if (classCounts[k] == 0) continue; for (int l = 0; l < k; l++) { if (classCounts[l] == 0) continue; DRWN_LOG_VERBOSE("...training <class " << (k + 1) << ">-vs-<" << " class " << (l + 1) << ">"); const double w1 = 1.0 / (double)classCounts[k]; const double w0 = 1.0 / (double)classCounts[l]; for (int i = 0; i < dataset.size(); i++) { if ((dataset.targets[i] != k) && (dataset.targets[i] != l)) { targets[i] = -1; } else { targets[i] = (dataset.targets[i] == k ? 1 : 0); weights[i] = (dataset.targets[i] == k ? w1 : w0); } } drwnClassifier *c = drwnClassifierFactory::get().create(_baseClassifier.c_str()); c->initialize(_nFeatures, 2); c->train(dataset.features, targets, weights); _binaryClassifiers.push_back(c); } } break; default: DRWN_LOG_FATAL("unknown method in drwnCompositeClassifier::train"); } // clear some memory targets.clear(); weights.clear(); // learn calibration weights vector<double> f(2); vector<vector<double> > features(dataset.size(), vector<double>(_binaryClassifiers.size(), 0)); for (int i = 0; i < dataset.size(); i++) { for (unsigned j = 0; j < _binaryClassifiers.size(); j++) { _binaryClassifiers[j]->getClassScores(dataset.features[i], f); features[i][j] = f[0] - f[1]; } } DRWN_LOG_VERBOSE("whitening feature vectors..."); _featureWhitener.train(features); _featureWhitener.transform(features); // train multi-class logistic model DRWN_LOG_VERBOSE("learning calibration weights..."); _calibrationWeights.initialize(_binaryClassifiers.size(), _nClasses); if (dataset.hasWeights()) { _calibrationWeights.train(features, dataset.targets, dataset.weights); } else { _calibrationWeights.train(features, dataset.targets); } _bValid = true; DRWN_FCN_TOC; return 0.0; }
// training double drwnBoostedClassifier::train(const drwnClassifierDataset& dataset) { DRWN_FCN_TIC; // pre-compute sorted feature index vector<vector<int> > sortIndex; drwnDecisionTree::computeSortedFeatureIndex(dataset.features, sortIndex); // allocate weights vector<double> weights; if (dataset.hasWeights()) { weights = dataset.weights; } else { weights.resize(dataset.size(), 1.0); } // mark unknown samples drwnBitArray sampleIndex(dataset.size()); sampleIndex.ones(); for (int i = 0; i < dataset.size(); i++) { if (dataset.targets[i] < 0) { sampleIndex.clear(i); weights[i] = 0.0; } } vector<int> predicted(dataset.size(), -1); // iterate over rounds DRWN_START_PROGRESS("training", _numRounds); for (int i = 0; i < _numRounds; i++) { DRWN_LOG_STATUS("training boosted classifier round " << i << " of " << _numRounds << "..."); DRWN_INC_PROGRESS; // normalize the sample weights Eigen::Map<VectorXd>(&weights[0], weights.size()) /= Eigen::Map<VectorXd>(&weights[0], weights.size()).sum(); // learn a weak-learner with current weights drwnDecisionTree *tree = new drwnDecisionTree(_nFeatures, _nClasses); tree->setProperty(tree->findProperty("maxDepth"), _maxDepth); tree->learnDecisionTree(dataset.features, dataset.targets, weights, sortIndex, sampleIndex); // predict classes and calculate eplison double epsilon = 0.0; for (int j = 0; j < dataset.size(); j++) { if (dataset.targets[j] < 0) continue; predicted[j] = tree->getClassification(dataset.features[j]); if (predicted[j] != dataset.targets[j]) { epsilon += weights[j]; } } if (epsilon >= 1.0 - 1.0 / _nClasses) { DRWN_LOG_WARNING("boosting terminated at round " << (i + 1) << " of " << _numRounds); delete tree; break; } // check for perfect classification if ((i == 0) && (epsilon == 0.0)) { DRWN_LOG_WARNING("boosting found a perfect classifier in first round"); _alphas.push_back(1.0); _weakLearners.push_back(tree); break; } // calculate boosting coefficient double alpha; if (_method == DRWN_BOOST_GENTLE) { alpha = 1.0; } else { alpha = log((1.0 - epsilon) / epsilon) + log(_nClasses - 1.0); if (!isfinite(alpha)) { DRWN_LOG_WARNING("boosting terminated at round " << (i + 1) << " of " << _numRounds << " (non-finite alpha)"); delete tree; break; } } _alphas.push_back(alpha); _weakLearners.push_back(tree); // update the sample weights const double nu = exp(alpha); for (unsigned j = 0; j < weights.size(); j++) { if (predicted[j] != dataset.targets[j]) { weights[j] *= nu; } } } DRWN_END_PROGRESS; #if 0 // normalize boosting coefficients Eigen::Map<VectorXd>(&_alphas[0], _alphas.size()) /= Eigen::Map<VectorXd>(&_alphas[0], _alphas.size()).sum(); #endif _bValid = true; // return classification accuracy double totalCorrect = 0.0; double totalWeight = 0.0; for (int j = 0; j < dataset.size(); j++) { if (dataset.targets[j] < 0) continue; predicted[j] = this->getClassification(dataset.features[j]); if (predicted[j] == dataset.targets[j]) { totalCorrect += dataset.hasWeights() ? dataset.weights[j] : 1.0; } totalWeight += dataset.hasWeights() ? dataset.weights[j] : 1.0; } DRWN_FCN_TOC; return (totalWeight > 0.0) ? totalCorrect / totalWeight : 1.0; }