RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>:: RectangleTree( const RectangleTree& other, const bool deepCopy) : maxNumChildren(other.MaxNumChildren()), minNumChildren(other.MinNumChildren()), numChildren(other.NumChildren()), children(maxNumChildren + 1), parent(other.Parent()), begin(other.Begin()), count(other.Count()), maxLeafSize(other.MaxLeafSize()), minLeafSize(other.MinLeafSize()), bound(other.bound), splitHistory(other.SplitHistory()), parentDistance(other.ParentDistance()), dataset(new MatType(*other.dataset)), ownsDataset(true), points(other.Points()), localDataset(NULL) { if (deepCopy) { if (numChildren > 0) { for (size_t i = 0; i < numChildren; i++) { children[i] = new RectangleTree(*(other.Children()[i])); } } else { localDataset = new MatType(other.LocalDataset()); } } else { children = other.Children(); arma::mat& otherData = const_cast<arma::mat&>(other.LocalDataset()); localDataset = &otherData; } }
RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType, AuxiliaryInformationType>:: RectangleTree( const RectangleTree& other, const bool deepCopy, RectangleTree* newParent) : maxNumChildren(other.MaxNumChildren()), minNumChildren(other.MinNumChildren()), numChildren(other.NumChildren()), children(maxNumChildren + 1, NULL), parent(deepCopy ? newParent : other.Parent()), begin(other.Begin()), count(other.Count()), numDescendants(other.numDescendants), maxLeafSize(other.MaxLeafSize()), minLeafSize(other.MinLeafSize()), bound(other.bound), parentDistance(other.ParentDistance()), dataset(deepCopy ? (parent ? parent->dataset : new MatType(*other.dataset)) : &other.Dataset()), ownsDataset(deepCopy && (!parent)), points(other.points), auxiliaryInfo(other.auxiliaryInfo, this, deepCopy) { if (deepCopy) { if (numChildren > 0) { for (size_t i = 0; i < numChildren; i++) children[i] = new RectangleTree(other.Child(i), true, this); } } else children = other.children; }
void RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType, AuxiliaryInformationType>:: CondenseTree(const arma::vec& point, std::vector<bool>& relevels, const bool usePoint) { // First delete the node if we need to. There's no point in shrinking the // bound first. if (IsLeaf() && count < minLeafSize && parent != NULL) { // We can't delete the root node. for (size_t i = 0; i < parent->NumChildren(); i++) { if (parent->children[i] == this) { // Decrement numChildren. if (!auxiliaryInfo.HandleNodeRemoval(parent, i)) { parent->children[i] = parent->children[--parent->NumChildren()]; } // We find the root and shrink bounds at the same time. bool stillShrinking = true; RectangleTree* root = parent; while (root->Parent() != NULL) { if (stillShrinking) stillShrinking = root->ShrinkBoundForBound(bound); root = root->Parent(); } if (stillShrinking) stillShrinking = root->ShrinkBoundForBound(bound); root = parent; while (root != NULL) { root->numDescendants -= numDescendants; root = root->Parent(); } stillShrinking = true; root = parent; while (root->Parent() != NULL) { if (stillShrinking) stillShrinking = root->AuxiliaryInfo().UpdateAuxiliaryInfo(root); root = root->Parent(); } if (stillShrinking) stillShrinking = root->AuxiliaryInfo().UpdateAuxiliaryInfo(root); // Reinsert the points at the root node. for (size_t j = 0; j < count; j++) root->InsertPoint(points[j], relevels); // This will check the minFill of the parent. parent->CondenseTree(point, relevels, usePoint); // Now it should be safe to delete this node. SoftDelete(); return; } } // Control should never reach here. assert(false); } else if (!IsLeaf() && numChildren < minNumChildren) { if (parent != NULL) { // The normal case. We need to be careful with the root. for (size_t j = 0; j < parent->NumChildren(); j++) { if (parent->children[j] == this) { // Decrement numChildren. if (!auxiliaryInfo.HandleNodeRemoval(parent,j)) { parent->children[j] = parent->children[--parent->NumChildren()]; } size_t level = TreeDepth(); // We find the root and shrink bounds at the same time. bool stillShrinking = true; RectangleTree* root = parent; while (root->Parent() != NULL) { if (stillShrinking) stillShrinking = root->ShrinkBoundForBound(bound); root = root->Parent(); } if (stillShrinking) stillShrinking = root->ShrinkBoundForBound(bound); root = parent; while (root != NULL) { root->numDescendants -= numDescendants; root = root->Parent(); } stillShrinking = true; root = parent; while (root->Parent() != NULL) { if (stillShrinking) stillShrinking = root->AuxiliaryInfo().UpdateAuxiliaryInfo(root); root = root->Parent(); } if (stillShrinking) stillShrinking = root->AuxiliaryInfo().UpdateAuxiliaryInfo(root); // Reinsert the nodes at the root node. for (size_t i = 0; i < numChildren; i++) root->InsertNode(children[i], level, relevels); // This will check the minFill of the point. parent->CondenseTree(point, relevels, usePoint); // Now it should be safe to delete this node. SoftDelete(); return; } } } else if (numChildren == 1) { // If there are multiple children, we can't do anything to the root. RectangleTree* child = children[0]; // Required for the X tree. if (child->NumChildren() > maxNumChildren) { maxNumChildren = child->MaxNumChildren(); children.resize(maxNumChildren+1); } for (size_t i = 0; i < child->NumChildren(); i++) { children[i] = child->children[i]; children[i]->Parent() = this; } numChildren = child->NumChildren(); for (size_t i = 0; i < child->Count(); i++) { // In case the tree has a height of two. points[i] = child->Point(i); } auxiliaryInfo = child->AuxiliaryInfo(); count = child->Count(); child->SoftDelete(); return; } } // If we didn't delete it, shrink the bound if we need to. if (usePoint && (ShrinkBoundForPoint(point) || auxiliaryInfo.UpdateAuxiliaryInfo(this)) && parent != NULL) parent->CondenseTree(point, relevels, usePoint); else if (!usePoint && (ShrinkBoundForBound(bound) || auxiliaryInfo.UpdateAuxiliaryInfo(this)) && parent != NULL) parent->CondenseTree(point, relevels, usePoint); }
void RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType, AuxiliaryInformationType>:: DualTreeTraverser<RuleType>::Traverse(RectangleTree& queryNode, RectangleTree& referenceNode) { // Increment the visit counter. ++numVisited; // Store the current traversal info. traversalInfo = rule.TraversalInfo(); // We now have four options. // 1) Both nodes are leaf nodes. // 2) Only the reference node is a leaf node. // 3) Only the query node is a leaf node. // 4) Niether node is a leaf node. // We go through those options in that order. if (queryNode.IsLeaf() && referenceNode.IsLeaf()) { // Evaluate the base case. Do the query points on the outside so we can // possibly prune the reference node for that particular point. for (size_t query = 0; query < queryNode.Count(); ++query) { // Restore the traversal information. rule.TraversalInfo() = traversalInfo; const double childScore = rule.Score(queryNode.Point(query), referenceNode); if (childScore == DBL_MAX) continue; // We don't require a search in this reference node. for(size_t ref = 0; ref < referenceNode.Count(); ++ref) rule.BaseCase(queryNode.Point(query), referenceNode.Point(ref)); numBaseCases += referenceNode.Count(); } } else if (!queryNode.IsLeaf() && referenceNode.IsLeaf()) { // We only need to traverse down the query node. Order doesn't matter here. for (size_t i = 0; i < queryNode.NumChildren(); ++i) { // Before recursing, we have to set the traversal information correctly. rule.TraversalInfo() = traversalInfo; ++numScores; if (rule.Score(queryNode.Child(i), referenceNode) < DBL_MAX) Traverse(queryNode.Child(i), referenceNode); else numPrunes++; } } else if (queryNode.IsLeaf() && !referenceNode.IsLeaf()) { // We only need to traverse down the reference node. Order does matter // here. // We sort the children of the reference node by their scores. std::vector<NodeAndScore> nodesAndScores(referenceNode.NumChildren()); for (size_t i = 0; i < referenceNode.NumChildren(); i++) { rule.TraversalInfo() = traversalInfo; nodesAndScores[i].node = &(referenceNode.Child(i)); nodesAndScores[i].score = rule.Score(queryNode, *(nodesAndScores[i].node)); nodesAndScores[i].travInfo = rule.TraversalInfo(); } std::sort(nodesAndScores.begin(), nodesAndScores.end(), nodeComparator); numScores += nodesAndScores.size(); for (size_t i = 0; i < nodesAndScores.size(); i++) { rule.TraversalInfo() = nodesAndScores[i].travInfo; if (rule.Rescore(queryNode, *(nodesAndScores[i].node), nodesAndScores[i].score) < DBL_MAX) { Traverse(queryNode, *(nodesAndScores[i].node)); } else { numPrunes += nodesAndScores.size() - i; break; } } } else { // We need to traverse down both the reference and the query trees. // We loop through all of the query nodes, and for each of them, we // loop through the reference nodes to see where we need to descend. for (size_t j = 0; j < queryNode.NumChildren(); j++) { // We sort the children of the reference node by their scores. std::vector<NodeAndScore> nodesAndScores(referenceNode.NumChildren()); for (size_t i = 0; i < referenceNode.NumChildren(); i++) { rule.TraversalInfo() = traversalInfo; nodesAndScores[i].node = &(referenceNode.Child(i)); nodesAndScores[i].score = rule.Score(queryNode.Child(j), *nodesAndScores[i].node); nodesAndScores[i].travInfo = rule.TraversalInfo(); } std::sort(nodesAndScores.begin(), nodesAndScores.end(), nodeComparator); numScores += nodesAndScores.size(); for (size_t i = 0; i < nodesAndScores.size(); i++) { rule.TraversalInfo() = nodesAndScores[i].travInfo; if (rule.Rescore(queryNode.Child(j), *(nodesAndScores[i].node), nodesAndScores[i].score) < DBL_MAX) { Traverse(queryNode.Child(j), *(nodesAndScores[i].node)); } else { numPrunes += nodesAndScores.size() - i; break; } } } } }