void CleanTree(TreeType& node) { node.Stat().LastDistance() = 0.0; for (size_t i = 0; i < node.NumChildren(); ++i) CleanTree(node.Child(i)); }
void CARTTrainer::pruneTree(TreeType & tree){ //Calculate g of all the nodes measureStrength(tree, 0, 0); //Find the lowest g of the internal nodes double g = std::numeric_limits<double>::max(); for(std::size_t i = 0; i != tree.size(); i++){ if(tree[i].leftNodeId > 0 && tree[i].g < g){ //Update g g = tree[i].g; } } //Prune the nodes with lowest g and make them terminal for(std::size_t i=0; i != tree.size(); i++){ //Make the internal nodes with the smallest g terminal nodes and prune their children! if( tree[i].leftNodeId > 0 && tree[i].g == g){ // pruneNode(tree, tree[i].leftNodeId); // pruneNode(tree, tree[i].rightNodeId); // //Make the node terminal tree[i].leftNodeId = 0; tree[i].rightNodeId = 0; } } }
DTBStat(const TreeType& node) : maxNeighborDistance(DBL_MAX), minNeighborDistance(DBL_MAX), bound(DBL_MAX), componentMembership( ((node.NumPoints() == 1) && (node.NumChildren() == 0)) ? node.Point(0) : -1) { }
void CheckHierarchy(const TreeType& tree) { for (size_t i = 0; i < tree.NumChildren(); i++) { BOOST_REQUIRE_EQUAL(&tree, tree.Child(i).Parent()); CheckHierarchy(tree.Child(i)); } }
double FastMKSRules<KernelType, TreeType>::Rescore(TreeType& queryNode, TreeType& /*referenceNode*/, const double oldScore) const { queryNode.Stat().Bound() = CalculateBound(queryNode); const double bestKernel = queryNode.Stat().Bound(); return ((1.0 / oldScore) > bestKernel) ? oldScore : DBL_MAX; }
FastMKSStat(const TreeType& node) : bound(-DBL_MAX), lastKernel(0.0), lastKernelNode(NULL) { // Do we have to calculate the centroid? if (tree::TreeTraits<TreeType>::FirstPointIsCentroid) { // If this type of tree has self-children, then maybe the evaluation is // already done. These statistics are built bottom-up, so the child stat // should already be done. if ((tree::TreeTraits<TreeType>::HasSelfChildren) && (node.NumChildren() > 0) && (node.Point(0) == node.Child(0).Point(0))) { selfKernel = node.Child(0).Stat().SelfKernel(); } else { selfKernel = sqrt(node.Metric().Kernel().Evaluate( node.Dataset().col(node.Point(0)), node.Dataset().col(node.Point(0)))); } } else { // Calculate the centroid. arma::vec center; node.Center(center); selfKernel = sqrt(node.Metric().Kernel().Evaluate(center, center)); } }
static int GetSplitPolicy(const TreeType& child, const size_t axis, const typename TreeType::ElemType cut) { if (child.Bound()[axis].Hi() <= cut) return AssignToFirstTree; else if (child.Bound()[axis].Lo() >= cut) return AssignToSecondTree; return SplitRequired; }
/* Removes branch with root node id nodeId, incl. the node itself */ void CARTTrainer::pruneNode(TreeType & tree, std::size_t nodeId){ std::size_t i = findNode(tree,nodeId); if(tree[i].leftNodeId>0){ //Prune left branch pruneNode(tree, tree[i].leftNodeId); //Prune right branch pruneNode(tree, tree[i].rightNodeId); } //Remove node tree.erase(tree.begin()+i); }
double RangeSearchRules<MetricType, TreeType>::Score(const size_t queryIndex, TreeType& referenceNode) { // We must get the minimum and maximum distances and store them in this // object. math::Range distances; if (tree::TreeTraits<TreeType>::FirstPointIsCentroid) { // In this situation, we calculate the base case. So we should check to be // sure we haven't already done that. double baseCase; if (tree::TreeTraits<TreeType>::HasSelfChildren && (referenceNode.Parent() != NULL) && (referenceNode.Point(0) == referenceNode.Parent()->Point(0))) { // If the tree has self-children and this is a self-child, the base case // was already calculated. baseCase = referenceNode.Parent()->Stat().LastDistance(); lastQueryIndex = queryIndex; lastReferenceIndex = referenceNode.Point(0); } else { // We must calculate the base case by hand. baseCase = BaseCase(queryIndex, referenceNode.Point(0)); } // This may be possibly loose for non-ball bound trees. distances.Lo() = baseCase - referenceNode.FurthestDescendantDistance(); distances.Hi() = baseCase + referenceNode.FurthestDescendantDistance(); // Update last distance calculation. referenceNode.Stat().LastDistance() = baseCase; } else { distances = referenceNode.RangeDistance(querySet.unsafe_col(queryIndex)); } // If the ranges do not overlap, prune this node. if (!distances.Contains(range)) return DBL_MAX; // In this case, all of the points in the reference node will be part of the // results. if ((distances.Lo() >= range.Lo()) && (distances.Hi() <= range.Hi())) { AddResult(queryIndex, referenceNode); return DBL_MAX; // We don't need to go any deeper. } // Otherwise the score doesn't matter. Recursion order is irrelevant in // range search. return 0.0; }
bool HilbertRTreeAuxiliaryInformation<TreeType, HilbertValueType>:: UpdateAuxiliaryInfo(TreeType* node) { if (node->IsLeaf()) // Should already be updated return true; TreeType* child = node->Children()[node->NumChildren() - 1]; if (hilbertValue.CompareWith(child->AuxiliaryInfo().hilbertValue()) < 0) { hilbertValue.Copy(node,child); return true; } return false; }
void SearchTree(TreeType<T>& tree) { T input; cout << "Enter item to search: "; cin >> input; tree.SearchItem(input); }
void DeleteFromTree(TreeType<T>& tree) { T input; cout << "Enter item to delete: "; cin >> input; tree.DeleteItem(input); }
void InsertItem(TreeType<T>& tree) { cout << "Enter number: "; T num; cin >> num; tree.InsertItem(num); }
int GetMaxLevel(const TreeType& tree) { int max = 1; if (!tree.IsLeaf()) { int m = 0; for (size_t i = 0; i < tree.NumChildren(); i++) { int n = GetMaxLevel(*tree.Children()[i]); if (n > m) m = n; } max += m; } return max; }
int GetMinLevel(const TreeType& tree) { int min = 1; if (!tree.IsLeaf()) { int m = INT_MAX; for (size_t i = 0; i < tree.NumChildren(); i++) { int n = GetMinLevel(*tree.Children()[i]); if (n < m) m = n; } min += m; } return min; }
void NeighborSearchRules< SortPolicy, MetricType, TreeType>:: UpdateAfterRecursion(TreeType& queryNode, TreeType& /* referenceNode */) { // Find the worst distance that the children found (including any points), and // update the bound accordingly. double worstDistance = SortPolicy::BestDistance(); // First look through children nodes. for (size_t i = 0; i < queryNode.NumChildren(); ++i) { if (SortPolicy::IsBetter(worstDistance, queryNode.Child(i).Stat().Bound())) worstDistance = queryNode.Child(i).Stat().Bound(); } // Now look through children points. for (size_t i = 0; i < queryNode.NumPoints(); ++i) { if (SortPolicy::IsBetter(worstDistance, distances(distances.n_rows - 1, queryNode.Point(i)))) worstDistance = distances(distances.n_rows - 1, queryNode.Point(i)); } // Take the worst distance from all of these, and update our bound to reflect // that. queryNode.Stat().Bound() = worstDistance; }
double DTBRules<MetricType, TreeType>::Score(TreeType& queryNode, TreeType& referenceNode) { // If all the queries belong to the same component as all the references // then we prune. if ((queryNode.Stat().ComponentMembership() >= 0) && (queryNode.Stat().ComponentMembership() == referenceNode.Stat().ComponentMembership())) return DBL_MAX; ++scores; const double distance = queryNode.MinDistance(&referenceNode); const double bound = CalculateBound(queryNode); // If all the points in the reference node are farther than the candidate // nearest neighbor for all queries in the node, we prune. return (bound < distance) ? DBL_MAX : distance; }
void CheckContainment(const TreeType& tree) { if (tree.NumChildren() == 0) { for (size_t i = 0; i < tree.Count(); i++) BOOST_REQUIRE(tree.Bound().Contains( tree.Dataset().unsafe_col(tree.Points()[i]))); } else { for (size_t i = 0; i < tree.NumChildren(); i++) { for (size_t j = 0; j < tree.Bound().Dim(); j++) BOOST_REQUIRE(tree.Bound()[j].Contains(tree.Children()[i]->Bound()[j])); CheckContainment(*(tree.Children()[i])); } } }
inline double KDERules<MetricType, KernelType, TreeType>:: Score(const size_t queryIndex, TreeType& referenceNode) { double score, maxKernel, minKernel, bound; const arma::vec& queryPoint = querySet.unsafe_col(queryIndex); const double minDistance = referenceNode.MinDistance(queryPoint); bool newCalculations = true; if (tree::TreeTraits<TreeType>::FirstPointIsCentroid && lastQueryIndex == queryIndex && traversalInfo.LastReferenceNode() != NULL && traversalInfo.LastReferenceNode()->Point(0) == referenceNode.Point(0)) { // Don't duplicate calculations. newCalculations = false; lastQueryIndex = queryIndex; lastReferenceIndex = referenceNode.Point(0); } else { // Calculations are new. maxKernel = kernel.Evaluate(minDistance); minKernel = kernel.Evaluate(referenceNode.MaxDistance(queryPoint)); bound = maxKernel - minKernel; } if (newCalculations && bound <= (absError + relError * minKernel) / referenceSet.n_cols) { // Estimate values. double kernelValue; // Calculate kernel value based on reference node centroid. if (tree::TreeTraits<TreeType>::FirstPointIsCentroid) { kernelValue = EvaluateKernel(queryIndex, referenceNode.Point(0)); } else { kde::KDEStat& referenceStat = referenceNode.Stat(); kernelValue = EvaluateKernel(queryPoint, referenceStat.Centroid()); } densities(queryIndex) += referenceNode.NumDescendants() * kernelValue; // Don't explore this tree branch. score = DBL_MAX; } else { score = minDistance; } ++scores; traversalInfo.LastReferenceNode() = &referenceNode; traversalInfo.LastScore() = score; return score; }
size_t nodeDepth(const TreeType& tree, const NodeType& node) { const auto* currentNode = &node; size_t depth = 0; while (!currentNode->isRoot()) { currentNode = &tree.nodeAt(currentNode->parentNodeIndex); depth++; } return depth; }
inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Score( TreeType& queryNode, TreeType& referenceNode) const { const double distance = SortPolicy::BestNodeToNodeDistance(&queryNode, &referenceNode); const double bestDistance = queryNode.Stat().Bound(); return (SortPolicy::IsBetter(distance, bestDistance)) ? distance : DBL_MAX; }
inline double DTBRules<MetricType, TreeType>::CalculateBound( TreeType& queryNode) const { double worstPointBound = -DBL_MAX; double bestPointBound = DBL_MAX; double worstChildBound = -DBL_MAX; double bestChildBound = DBL_MAX; // Now, find the best and worst point bounds. for (size_t i = 0; i < queryNode.NumPoints(); ++i) { const size_t pointComponent = connections.Find(queryNode.Point(i)); const double bound = neighborsDistances[pointComponent]; if (bound > worstPointBound) worstPointBound = bound; if (bound < bestPointBound) bestPointBound = bound; } // Find the best and worst child bounds. for (size_t i = 0; i < queryNode.NumChildren(); ++i) { const double maxBound = queryNode.Child(i).Stat().MaxNeighborDistance(); if (maxBound > worstChildBound) worstChildBound = maxBound; const double minBound = queryNode.Child(i).Stat().MinNeighborDistance(); if (minBound < bestChildBound) bestChildBound = minBound; } // Now calculate the actual bounds. const double worstBound = std::max(worstPointBound, worstChildBound); const double bestBound = std::min(bestPointBound, bestChildBound); // We must check that bestBound != DBL_MAX; otherwise, we risk overflow. const double bestAdjustedBound = (bestBound == DBL_MAX) ? DBL_MAX : bestBound + 2 * queryNode.FurthestDescendantDistance(); // Update the relevant quantities in the node. queryNode.Stat().MaxNeighborDistance() = worstBound; queryNode.Stat().MinNeighborDistance() = bestBound; queryNode.Stat().Bound() = std::min(worstBound, bestAdjustedBound); return queryNode.Stat().Bound(); }
double DTBRules<MetricType, TreeType>::Score(const size_t queryIndex, TreeType& referenceNode) { size_t queryComponentIndex = connections.Find(queryIndex); // If the query belongs to the same component as all of the references, // then prune. The cast is to stop a warning about comparing unsigned to // signed values. if (queryComponentIndex == (size_t) referenceNode.Stat().ComponentMembership()) return DBL_MAX; const arma::vec queryPoint = dataSet.unsafe_col(queryIndex); const double distance = referenceNode.MinDistance(queryPoint); // If all the points in the reference node are farther than the candidate // nearest neighbor for the query's component, we prune. return neighborsDistances[queryComponentIndex] < distance ? DBL_MAX : distance; }
/** * @brief Return a list clades, each itself containing a list of edge indices of that clade. * * The function takes a list of clades with their taxa as input, and a reference tree. * It then inspects all clades and findes the edges of the tree that belong into a clade. * Furthermore, a clade "basal_branches" is added for those edges of the tree that do not * belong to any clade. * * The edges of a clade are determined by finding the smalles subtree (split) of the tree that * contains all nodes of the clade. That means, the clades should be monophyletic in order for this * algorithm to work properly. */ CladeEdgeList get_clade_edges( CladeTaxaList const& clades, TreeType& tree ) { // Prepare the result map. CladeEdgeList clade_edges; // Make a set of all edges that do not belong to any clade (the basal branches of the tree). // We first fill it with all edge indices, then remove the clade-edges later, // so that only the wanted ones remain. std::unordered_set<size_t> basal_branches; for( auto it = tree.begin_edges(); it != tree.end_edges(); ++it ) { basal_branches.insert( it->index() ); } // Process all clades. for( auto const& clade : clades ) { // Find the edges that are part of the subtree of this clade. auto const subedges = get_clade_edges( tree, clade.second ); // TODO for now, we convert to an unordered map here by hand. // this can be clean up! auto const subedge_map = std::unordered_set<size_t>( subedges.begin(), subedges.end() ); // Add them to the clade edges list. clade_edges.push_back( std::make_pair( clade.first, subedge_map )); // Remove the edge indices of this clade from the basal branches (non-clade) edges list. for( auto const edge : subedges ) { basal_branches.erase( edge ); } } // Now that we have processed all clades, also add the non-clade edges (basal branches) // to the list as a special clade "basal_branches". This way, all edges of the reference tree // are used by exaclty one clade. clade_edges.push_back( std::make_pair( "basal_branches", basal_branches )); return clade_edges; }
inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Rescore( TreeType& queryNode, TreeType& /* referenceNode */, const double oldScore) const { if (oldScore == DBL_MAX) return oldScore; const double bestDistance = queryNode.Stat().Bound(); return (SortPolicy::IsBetter(oldScore, bestDistance)) ? oldScore : DBL_MAX; }
void TestRangeQuery(TreeType &a_tree, string &db_filename, string &key1, string &key2) { SequenceMap lower_bound(key1); SequenceMap upper_bound(key2); ifstream rebasefile(db_filename); string db_line; if (!rebasefile.is_open()) { //Program exists if file is not opened or cannot be opened exit(EXIT_FAILURE); } while(rebasefile.good()) { (getline(rebasefile, db_line)); size_t pos; string delimiter = "/"; string a_reco_seq; string an_enz_acro; //Enzyme acronym is copied and deleted from db_line pos = db_line.find(delimiter); an_enz_acro = db_line.substr(0, pos); db_line.erase(0, pos + delimiter.length()); while((pos = db_line.find(delimiter)) != string::npos) { a_reco_seq = db_line.substr(0, pos); if(a_reco_seq == "") { break; } else { SequenceMap new_sequence_map(a_reco_seq, an_enz_acro); a_tree.insert(new_sequence_map); db_line.erase(0, pos + delimiter.length()); } } } a_tree.print_range_elem(lower_bound, upper_bound); }
double DTBRules<MetricType, TreeType>::Score(const size_t queryIndex, TreeType& referenceNode, const double baseCaseResult) { // I don't really understand the last argument here // It just gets passed in the distance call, otherwise this function // is the same as the one above. size_t queryComponentIndex = connections.Find(queryIndex); // If the query belongs to the same component as all of the references, // then prune. if (queryComponentIndex == referenceNode.Stat().ComponentMembership()) return DBL_MAX; const arma::vec queryPoint = dataSet.unsafe_col(queryIndex); const double distance = referenceNode.MinDistance(queryPoint, baseCaseResult); // If all the points in the reference node are farther than the candidate // nearest neighbor for the query's component, we prune. return (neighborsDistances[queryComponentIndex] < distance) ? DBL_MAX : distance; }
void FillTree (std::string db_filename, TreeType &a_tree) { ifstream inStream(db_filename); std::string db_line, an_enz_acro, a_reco_seq, garbage_line; for (int i = 0; i < 10; i ++) //Skip over the header and begin reading on line 11. getline(inStream, garbage_line); while (std::getline (inStream, db_line)) { //inStream has reached line 11 and will begin to parse data. if (db_line.empty()) continue; size_t first_slash = db_line.find("/"); an_enz_acro = GetEnzymeAcronym(db_line, first_slash); while (GetNextRecognitionSequence(db_line, a_reco_seq, first_slash)) { SequenceMap new_sequence_map(a_reco_seq, an_enz_acro); a_tree.insert(new_sequence_map); } } }
void CheckBound(const TreeType& tree) { typedef typename TreeType::ElemType ElemType; for (size_t i = 0; i < tree.NumDescendants(); i++) { arma::Col<ElemType> point = tree.Dataset().col(tree.Descendant(i)); // Check that the point is contained in the bound. BOOST_REQUIRE_EQUAL(true, tree.Bound().Contains(point)); const arma::Mat<ElemType>& loBound = tree.Bound().LoBound(); const arma::Mat<ElemType>& hiBound = tree.Bound().HiBound(); // Ensure that there is a hyperrectangle that contains the point. bool success = false; for (size_t j = 0; j < tree.Bound().NumBounds(); j++) { success = true; for (size_t k = 0; k < loBound.n_rows; k++) { if (point[k] < loBound(k, j) - 1e-14 * std::fabs(loBound(k, j)) || point[k] > hiBound(k, j) + 1e-14 * std::fabs(hiBound(k, j))) { success = false; break; } } if (success) break; } BOOST_REQUIRE_EQUAL(success, true); } if (!tree.IsLeaf()) { CheckBound(*tree.Left()); CheckBound(*tree.Right()); } }
DualTreeKMeansStatistic(TreeType& node) : closestQueryNode(NULL), minQueryNodeDistance(DBL_MAX), maxQueryNodeDistance(DBL_MAX), clustersPruned(0), iteration(size_t() - 1) { // Empirically calculate the centroid. centroid.zeros(node.Dataset().n_rows); for (size_t i = 0; i < node.NumPoints(); ++i) centroid += node.Dataset().col(node.Point(i)); for (size_t i = 0; i < node.NumChildren(); ++i) centroid += node.Child(i).NumDescendants() * node.Child(i).Stat().Centroid(); centroid /= node.NumDescendants(); }