Beispiel #1
0
/**************************************
 * Function: getStrongError
 * ------------------------
 * calculates error rates at each "level" of the strong classifier; i.e. at
 * each weak classifier
 *
 * td: Training data to check strong classifier against
 * strong: strong classifier (i.e. ordered set of weak classifiers)
 *
 * returns: a list of the errors at each level
 */
vector<vector <double> > AdaBooster::getStrongError(TrainingData &td, const WeakClassifierList &strong){
	unsigned int true_pos, false_pos, true_neg, false_neg;
	float precision, recall;
	vector< vector<double> > strong_err;
	vector<double> stats;

	// clear false_indices
	false_indices.clear();
	
	// set dimensions and number of features
	dimensions = td.dimensions();
	num_features = td.size();

	// initialize vector of num_ftrs to zero
	vector<double> classify;
	for (unsigned int i=0; i<num_features; i++)
		classify.push_back(0.0);
	
	int sign;
	// traverse all weak classifiers
	for (unsigned int i=0; i<strong.size(); i++){
		true_pos = false_pos = true_neg = false_neg = precision = recall = 0;
		// traverse all features
		for (unsigned int j=0; j<num_features; j++){
			// check what the classifier guessed. If weak classifier decided
			// the feature was POS, sign = 1, otherwise sign = -1
			if ( (strong[i].threshold() > td.at(j,strong[i].dimension()) && !strong[i].isFlipped()) ||
					(strong[i].threshold() < td.at(j,strong[i].dimension()) && strong[i].isFlipped()) )
				sign = 1;
			else
				sign = -1;

			// calculate classify so far
			classify[j] += strong[i].weight() * sign;

			// check classification against reality
			if (classify.at(j) >= strong_err_threshold && td.val(j) == POS)
				true_pos++;
			else if (classify.at(j) >= strong_err_threshold && td.val(j) == NEG){
				false_pos++;
				// if we're at the last weak classifier and we still can't classify this point
				if (i == strong.size()-1)
					false_indices.push_back(j); // add index to false indices vector
			}
			else if (classify.at(j) < strong_err_threshold && td.val(j) == POS){
				false_neg++;
				// similarly, we can't classify the point
				if (i == strong.size()-1)
					false_indices.push_back(j);
			}
			else
				true_neg++;
		}
		// calculate some stats and push into strong_err
		stats.clear();
		stats.push_back((double)(false_pos + false_neg)/num_features); // flat error percentage
		stats.push_back((double)(true_pos)/(true_pos+false_pos)); // precision
		stats.push_back((double)(true_pos)/(true_pos+false_neg)); // recall
		stats.push_back((double)true_pos); // true positives
		stats.push_back((double)true_neg); // true negatives
		stats.push_back((double)false_pos); // false positives
		stats.push_back((double)false_neg); // false negatives

		strong_err.push_back(stats);
	}
	return strong_err;
}