double DecisionStump<MatType>::SetupSplitDimension( const arma::rowvec& dimension, const arma::Row<size_t>& labels, const arma::rowvec& weights) { size_t i, count, begin, end; double entropy = 0.0; // Sort the dimension in order to calculate splitting ranges. arma::rowvec sortedDim = arma::sort(dimension); // Store the indices of the sorted dimension to build a vector of sorted // labels. This sort is stable. arma::uvec sortedIndexDim = arma::stable_sort_index(dimension.t()); arma::Row<size_t> sortedLabels(dimension.n_elem); arma::rowvec sortedWeights(dimension.n_elem); for (i = 0; i < dimension.n_elem; i++) { sortedLabels(i) = labels(sortedIndexDim(i)); // Apply weights if necessary. if (UseWeights) sortedWeights(i) = weights(sortedIndexDim(i)); } i = 0; count = 0; // This splits the sorted data into buckets of size greater than or equal to // bucketSize. while (i < sortedLabels.n_elem) { count++; if (i == sortedLabels.n_elem - 1) { // If we're at the end, then don't worry about the bucket size; just take // this as the last bin. begin = i - count + 1; end = i; // Use ratioEl to calculate the ratio of elements in this split. const double ratioEl = ((double) (end - begin + 1) / sortedLabels.n_elem); entropy += ratioEl * CalculateEntropy<UseWeights>( sortedLabels.subvec(begin, end), sortedWeights.subvec(begin, end)); i++; } else if (sortedLabels(i) != sortedLabels(i + 1)) { // If we're not at the last element of sortedLabels, then check whether // count is less than the current bucket size. if (count < bucketSize) { // If it is, then take the minimum bucket size anyways. // This is where the inpBucketSize comes into use. // This makes sure there isn't a bucket for every change in labels. begin = i - count + 1; end = begin + bucketSize - 1; if (end > sortedLabels.n_elem - 1) end = sortedLabels.n_elem - 1; } else { // If it is not, then take the bucket size as the value of count. begin = i - count + 1; end = i; } const double ratioEl = ((double) (end - begin + 1) / sortedLabels.n_elem); entropy += ratioEl * CalculateEntropy<UseWeights>( sortedLabels.subvec(begin, end), sortedWeights.subvec(begin, end)); i = end + 1; count = 0; } else i++; } return entropy; }
double mahalanobis_chol(const arma::rowvec& x, const arma::rowvec& mu, const arma::mat& R) { const arma::rowvec err = x - mu; const arma::mat Rinv(inv(trimatl(R))); return arma::as_scalar(err * Rinv * Rinv.t() * err.t()); }
double mahalanobis(const arma::rowvec& x, const arma::rowvec& mu, const arma::mat& sigma) { const arma::rowvec err = x - mu; return arma::as_scalar(err * sigma.i() * err.t()); }
void DecisionStump<MatType>::TrainOnAtt(const arma::rowvec& attribute, const arma::Row<size_t>& labels) { size_t i, count, begin, end; arma::rowvec sortedSplitAtt = arma::sort(attribute); arma::uvec sortedSplitIndexAtt = arma::stable_sort_index(attribute.t()); arma::Row<size_t> sortedLabels(attribute.n_elem); sortedLabels.fill(0); arma::vec tempSplit; arma::Row<size_t> tempLabel; for (i = 0; i < attribute.n_elem; i++) sortedLabels(i) = labels(sortedSplitIndexAtt(i)); arma::rowvec subCols; rType mostFreq; i = 0; count = 0; while (i < sortedLabels.n_elem) { count++; if (i == sortedLabels.n_elem - 1) { begin = i - count + 1; end = i; arma::rowvec zSubCols((sortedLabels.cols(begin, end)).n_elem); zSubCols.fill(0.0); subCols = sortedLabels.cols(begin, end) + zSubCols; mostFreq = CountMostFreq<double>(subCols); split.resize(split.n_elem + 1); split(split.n_elem - 1) = sortedSplitAtt(begin); binLabels.resize(binLabels.n_elem + 1); binLabels(binLabels.n_elem - 1) = mostFreq; i++; } else if (sortedLabels(i) != sortedLabels(i + 1)) { if (count < bucketSize) { // Test for different values of bucketSize, especially extreme cases. begin = i - count + 1; end = begin + bucketSize - 1; if (end > sortedLabels.n_elem - 1) end = sortedLabels.n_elem - 1; } else { begin = i - count + 1; end = i; } arma::rowvec zSubCols((sortedLabels.cols(begin, end)).n_elem); zSubCols.fill(0.0); subCols = sortedLabels.cols(begin, end) + zSubCols; // Find the most frequent element in subCols so as to assign a label to // the bucket of subCols. mostFreq = CountMostFreq<double>(subCols); split.resize(split.n_elem + 1); split(split.n_elem - 1) = sortedSplitAtt(begin); binLabels.resize(binLabels.n_elem + 1); binLabels(binLabels.n_elem - 1) = mostFreq; i = end + 1; count = 0; } else i++; } // Now trim the split matrix so that buckets one after the after which point // to the same classLabel are merged as one big bucket. MergeRanges(); }