void Octree<MetricType, StatisticType, MatType>::SingleTreeTraverser<RuleType>:: Traverse(const size_t queryIndex, Octree& referenceNode) { // If we are a leaf, run the base cases. if (referenceNode.NumChildren() == 0) { const size_t refBegin = referenceNode.Point(0); const size_t refEnd = refBegin + referenceNode.NumPoints(); for (size_t r = refBegin; r < refEnd; ++r) rule.BaseCase(queryIndex, r); } else { // Do a prioritized recursion, by scoring all candidates and then sorting // them. arma::vec scores(referenceNode.NumChildren()); for (size_t i = 0; i < scores.n_elem; ++i) scores[i] = rule.Score(queryIndex, referenceNode.Child(i)); // Sort the scores. arma::uvec sortedIndices = arma::sort_index(scores); for (size_t i = 0; i < sortedIndices.n_elem; ++i) { // If the node is pruned, all subsequent nodes in sorted order will also // be pruned. if (scores[sortedIndices[i]] == DBL_MAX) { numPrunes += (sortedIndices.n_elem - i); break; } Traverse(queryIndex, referenceNode.Child(sortedIndices[i])); } } }
void Octree<MetricType, StatisticType, MatType>::DualTreeTraverser<RuleType>:: Traverse(Octree& queryNode, Octree& referenceNode) { // Increment the visit counter. ++numVisited; // Store the current traversal info. traversalInfo = rule.TraversalInfo(); if (queryNode.IsLeaf() && referenceNode.IsLeaf()) { const size_t begin = queryNode.Point(0); const size_t end = begin + queryNode.NumPoints(); for (size_t q = begin; q < end; ++q) { // First, see if we can prune the reference node for this query point. rule.TraversalInfo() = traversalInfo; const double score = rule.Score(q, referenceNode); if (score == DBL_MAX) { ++numPrunes; continue; } const size_t rBegin = referenceNode.Point(0); const size_t rEnd = rBegin + referenceNode.NumPoints(); for (size_t r = rBegin; r < rEnd; ++r) rule.BaseCase(q, r); numBaseCases += referenceNode.NumPoints(); } } else if (!queryNode.IsLeaf() && referenceNode.IsLeaf()) { // We have to recurse down the query node. Order does not matter. for (size_t i = 0; i < queryNode.NumChildren(); ++i) { rule.TraversalInfo() = traversalInfo; const double score = rule.Score(queryNode.Child(i), referenceNode); if (score == DBL_MAX) { ++numPrunes; continue; } Traverse(queryNode.Child(i), referenceNode); } } else if (queryNode.IsLeaf() && !referenceNode.IsLeaf()) { // We have to recurse down the reference node, so we need to do it in an // ordered manner. arma::vec scores(referenceNode.NumChildren()); std::vector<typename RuleType::TraversalInfoType> tis(referenceNode.NumChildren()); for (size_t i = 0; i < referenceNode.NumChildren(); ++i) { rule.TraversalInfo() = traversalInfo; scores[i] = rule.Score(queryNode, referenceNode.Child(i)); tis[i] = rule.TraversalInfo(); } // Sort the scores. arma::uvec scoreOrder = arma::sort_index(scores); for (size_t i = 0; i < scoreOrder.n_elem; ++i) { if (scores[scoreOrder[i]] == DBL_MAX) { // We don't need to check any more---all children past here are pruned. numPrunes += scoreOrder.n_elem - i; break; } rule.TraversalInfo() = tis[scoreOrder[i]]; Traverse(queryNode, referenceNode.Child(scoreOrder[i])); } } else { // We have to recurse down both the query and reference nodes. Query order // does not matter, so we will do that in sequence. However we will // allocate the arrays for recursion at this level. arma::vec scores(referenceNode.NumChildren()); std::vector<typename RuleType::TraversalInfoType> tis(referenceNode.NumChildren()); for (size_t j = 0; j < queryNode.NumChildren(); ++j) { // Now we have to recurse down the reference node, which we will do in a // prioritized manner. for (size_t i = 0; i < referenceNode.NumChildren(); ++i) { rule.TraversalInfo() = traversalInfo; scores[i] = rule.Score(queryNode.Child(j), referenceNode.Child(i)); tis[i] = rule.TraversalInfo(); } // Sort the scores. arma::uvec scoreOrder = arma::sort_index(scores); for (size_t i = 0; i < scoreOrder.n_elem; ++i) { if (scores[scoreOrder[i]] == DBL_MAX) { // We don't need to check any more // All children past here are pruned. numPrunes += scoreOrder.n_elem - i; break; } rule.TraversalInfo() = tis[scoreOrder[i]]; Traverse(queryNode.Child(j), referenceNode.Child(scoreOrder[i])); } } } }