/************************************** * Function: getStrongError * ------------------------ * calculates error rates at each "level" of the strong classifier; i.e. at * each weak classifier * * td: Training data to check strong classifier against * strong: strong classifier (i.e. ordered set of weak classifiers) * * returns: a list of the errors at each level */ vector<vector <double> > AdaBooster::getStrongError(TrainingData &td, const WeakClassifierList &strong){ unsigned int true_pos, false_pos, true_neg, false_neg; float precision, recall; vector< vector<double> > strong_err; vector<double> stats; // clear false_indices false_indices.clear(); // set dimensions and number of features dimensions = td.dimensions(); num_features = td.size(); // initialize vector of num_ftrs to zero vector<double> classify; for (unsigned int i=0; i<num_features; i++) classify.push_back(0.0); int sign; // traverse all weak classifiers for (unsigned int i=0; i<strong.size(); i++){ true_pos = false_pos = true_neg = false_neg = precision = recall = 0; // traverse all features for (unsigned int j=0; j<num_features; j++){ // check what the classifier guessed. If weak classifier decided // the feature was POS, sign = 1, otherwise sign = -1 if ( (strong[i].threshold() > td.at(j,strong[i].dimension()) && !strong[i].isFlipped()) || (strong[i].threshold() < td.at(j,strong[i].dimension()) && strong[i].isFlipped()) ) sign = 1; else sign = -1; // calculate classify so far classify[j] += strong[i].weight() * sign; // check classification against reality if (classify.at(j) >= strong_err_threshold && td.val(j) == POS) true_pos++; else if (classify.at(j) >= strong_err_threshold && td.val(j) == NEG){ false_pos++; // if we're at the last weak classifier and we still can't classify this point if (i == strong.size()-1) false_indices.push_back(j); // add index to false indices vector } else if (classify.at(j) < strong_err_threshold && td.val(j) == POS){ false_neg++; // similarly, we can't classify the point if (i == strong.size()-1) false_indices.push_back(j); } else true_neg++; } // calculate some stats and push into strong_err stats.clear(); stats.push_back((double)(false_pos + false_neg)/num_features); // flat error percentage stats.push_back((double)(true_pos)/(true_pos+false_pos)); // precision stats.push_back((double)(true_pos)/(true_pos+false_neg)); // recall stats.push_back((double)true_pos); // true positives stats.push_back((double)true_neg); // true negatives stats.push_back((double)false_pos); // false positives stats.push_back((double)false_neg); // false negatives strong_err.push_back(stats); } return strong_err; }