bool TreeClassification::findBestSplit(size_t nodeID, std::vector<size_t>& possible_split_varIDs) { size_t num_samples_node = sampleIDs[nodeID].size(); size_t num_classes = class_values->size(); double best_decrease = -1; size_t best_varID = 0; double best_value = 0; size_t* class_counts = new size_t[num_classes](); // Compute overall class counts for (size_t i = 0; i < num_samples_node; ++i) { size_t sampleID = sampleIDs[nodeID][i]; uint sample_classID = (*response_classIDs)[sampleID]; ++class_counts[sample_classID]; } // For all possible split variables for (auto& varID : possible_split_varIDs) { // Find best split value, if ordered consider all values as split values, else all 2-partitions if ((*is_ordered_variable)[varID]) { // Use memory saving method if option set if (memory_saving_splitting) { findBestSplitValueSmallQ(nodeID, varID, num_classes, class_counts, num_samples_node, best_value, best_varID, best_decrease); } else { // Use faster method for both cases double q = (double) num_samples_node / (double) data->getNumUniqueDataValues(varID); if (q < Q_THRESHOLD) { findBestSplitValueSmallQ(nodeID, varID, num_classes, class_counts, num_samples_node, best_value, best_varID, best_decrease); } else { findBestSplitValueLargeQ(nodeID, varID, num_classes, class_counts, num_samples_node, best_value, best_varID, best_decrease); } } } else { findBestSplitValueUnordered(nodeID, varID, num_classes, class_counts, num_samples_node, best_value, best_varID, best_decrease); } } delete[] class_counts; // Stop if no good split found if (best_decrease < 0) { return true; } // Save best values split_varIDs[nodeID] = best_varID; split_values[nodeID] = best_value; // Compute gini index for this node and to variable importance if needed if (importance_mode == IMP_GINI) { addGiniImportance(nodeID, best_varID, best_decrease); } return false; }
bool TreeProbability::findBestSplit(size_t nodeID, std::vector<size_t>& possible_split_varIDs) { size_t num_samples_node = sampleIDs[nodeID].size(); double best_decrease = -1; size_t best_varID = 0; double best_value = 0; // Compute sum of responses in node double sum_node = 0; for (auto& sampleID : sampleIDs[nodeID]) { sum_node += data->get(sampleID, dependent_varID); } // For all possible split variables for (auto& varID : possible_split_varIDs) { // Find best split value, if ordered consider all values as split values, else all 2-partitions if ((*is_ordered_variable)[varID]) { // Use memory saving method if option set if (memory_saving_splitting) { findBestSplitValueSmallQ(nodeID, varID, sum_node, num_samples_node, best_value, best_varID, best_decrease); } else { // Use faster method for both cases double q = (double) num_samples_node / (double) data->getNumUniqueDataValues(varID); if (q < Q_THRESHOLD) { findBestSplitValueSmallQ(nodeID, varID, sum_node, num_samples_node, best_value, best_varID, best_decrease); } else { findBestSplitValueLargeQ(nodeID, varID, sum_node, num_samples_node, best_value, best_varID, best_decrease); } } } else { findBestSplitValueUnordered(nodeID, varID, sum_node, num_samples_node, best_value, best_varID, best_decrease); } } // Stop if no good split found if (best_decrease < 0) { return true; } // Save best values split_varIDs[nodeID] = best_varID; split_values[nodeID] = best_value; // Compute decrease of impurity for this node and add to variable importance if needed if (importance_mode == IMP_GINI) { addImpurityImportance(nodeID, best_varID, best_decrease); } return false; }