void NeighborSearchRules< SortPolicy, MetricType, TreeType>:: UpdateAfterRecursion(TreeType& queryNode, TreeType& /* referenceNode */) { // Find the worst distance that the children found (including any points), and // update the bound accordingly. double worstDistance = SortPolicy::BestDistance(); // First look through children nodes. for (size_t i = 0; i < queryNode.NumChildren(); ++i) { if (SortPolicy::IsBetter(worstDistance, queryNode.Child(i).Stat().Bound())) worstDistance = queryNode.Child(i).Stat().Bound(); } // Now look through children points. for (size_t i = 0; i < queryNode.NumPoints(); ++i) { if (SortPolicy::IsBetter(worstDistance, distances(distances.n_rows - 1, queryNode.Point(i)))) worstDistance = distances(distances.n_rows - 1, queryNode.Point(i)); } // Take the worst distance from all of these, and update our bound to reflect // that. queryNode.Stat().Bound() = worstDistance; }
FastMKSStat(const TreeType& node) : bound(-DBL_MAX), lastKernel(0.0), lastKernelNode(NULL) { // Do we have to calculate the centroid? if (tree::TreeTraits<TreeType>::FirstPointIsCentroid) { // If this type of tree has self-children, then maybe the evaluation is // already done. These statistics are built bottom-up, so the child stat // should already be done. if ((tree::TreeTraits<TreeType>::HasSelfChildren) && (node.NumChildren() > 0) && (node.Point(0) == node.Child(0).Point(0))) { selfKernel = node.Child(0).Stat().SelfKernel(); } else { selfKernel = sqrt(node.Metric().Kernel().Evaluate( node.Dataset().col(node.Point(0)), node.Dataset().col(node.Point(0)))); } } else { // Calculate the centroid. arma::vec center; node.Center(center); selfKernel = sqrt(node.Metric().Kernel().Evaluate(center, center)); } }
inline double KDERules<MetricType, KernelType, TreeType>:: Score(const size_t queryIndex, TreeType& referenceNode) { double score, maxKernel, minKernel, bound; const arma::vec& queryPoint = querySet.unsafe_col(queryIndex); const double minDistance = referenceNode.MinDistance(queryPoint); bool newCalculations = true; if (tree::TreeTraits<TreeType>::FirstPointIsCentroid && lastQueryIndex == queryIndex && traversalInfo.LastReferenceNode() != NULL && traversalInfo.LastReferenceNode()->Point(0) == referenceNode.Point(0)) { // Don't duplicate calculations. newCalculations = false; lastQueryIndex = queryIndex; lastReferenceIndex = referenceNode.Point(0); } else { // Calculations are new. maxKernel = kernel.Evaluate(minDistance); minKernel = kernel.Evaluate(referenceNode.MaxDistance(queryPoint)); bound = maxKernel - minKernel; } if (newCalculations && bound <= (absError + relError * minKernel) / referenceSet.n_cols) { // Estimate values. double kernelValue; // Calculate kernel value based on reference node centroid. if (tree::TreeTraits<TreeType>::FirstPointIsCentroid) { kernelValue = EvaluateKernel(queryIndex, referenceNode.Point(0)); } else { kde::KDEStat& referenceStat = referenceNode.Stat(); kernelValue = EvaluateKernel(queryPoint, referenceStat.Centroid()); } densities(queryIndex) += referenceNode.NumDescendants() * kernelValue; // Don't explore this tree branch. score = DBL_MAX; } else { score = minDistance; } ++scores; traversalInfo.LastReferenceNode() = &referenceNode; traversalInfo.LastScore() = score; return score; }
double RangeSearchRules<MetricType, TreeType>::Score(const size_t queryIndex, TreeType& referenceNode) { // We must get the minimum and maximum distances and store them in this // object. math::Range distances; if (tree::TreeTraits<TreeType>::FirstPointIsCentroid) { // In this situation, we calculate the base case. So we should check to be // sure we haven't already done that. double baseCase; if (tree::TreeTraits<TreeType>::HasSelfChildren && (referenceNode.Parent() != NULL) && (referenceNode.Point(0) == referenceNode.Parent()->Point(0))) { // If the tree has self-children and this is a self-child, the base case // was already calculated. baseCase = referenceNode.Parent()->Stat().LastDistance(); lastQueryIndex = queryIndex; lastReferenceIndex = referenceNode.Point(0); } else { // We must calculate the base case by hand. baseCase = BaseCase(queryIndex, referenceNode.Point(0)); } // This may be possibly loose for non-ball bound trees. distances.Lo() = baseCase - referenceNode.FurthestDescendantDistance(); distances.Hi() = baseCase + referenceNode.FurthestDescendantDistance(); // Update last distance calculation. referenceNode.Stat().LastDistance() = baseCase; } else { distances = referenceNode.RangeDistance(querySet.unsafe_col(queryIndex)); } // If the ranges do not overlap, prune this node. if (!distances.Contains(range)) return DBL_MAX; // In this case, all of the points in the reference node will be part of the // results. if ((distances.Lo() >= range.Lo()) && (distances.Hi() <= range.Hi())) { AddResult(queryIndex, referenceNode); return DBL_MAX; // We don't need to go any deeper. } // Otherwise the score doesn't matter. Recursion order is irrelevant in // range search. return 0.0; }
DTBStat(const TreeType& node) : maxNeighborDistance(DBL_MAX), minNeighborDistance(DBL_MAX), bound(DBL_MAX), componentMembership( ((node.NumPoints() == 1) && (node.NumChildren() == 0)) ? node.Point(0) : -1) { }
DualTreeKMeansStatistic(TreeType& node) : neighbor::NeighborSearchStat<neighbor::NearestNeighborSort>(), upperBound(DBL_MAX), lowerBound(DBL_MAX), owner(size_t(-1)), pruned(size_t(-1)), staticPruned(false), staticUpperBoundMovement(0.0), staticLowerBoundMovement(0.0), trueParent(node.Parent()) { // Empirically calculate the centroid. centroid.zeros(node.Dataset().n_rows); for (size_t i = 0; i < node.NumPoints(); ++i) { // Correct handling of cover tree: don't double-count the point which // appears in the children. if (tree::TreeTraits<TreeType>::HasSelfChildren && i == 0 && node.NumChildren() > 0) continue; centroid += node.Dataset().col(node.Point(i)); } for (size_t i = 0; i < node.NumChildren(); ++i) centroid += node.Child(i).NumDescendants() * node.Child(i).Stat().Centroid(); centroid /= node.NumDescendants(); // Set the true children correctly. trueChildren.resize(node.NumChildren()); for (size_t i = 0; i < node.NumChildren(); ++i) trueChildren[i] = &node.Child(i); }
inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Score( const size_t queryIndex, TreeType& referenceNode) { ++scores; // Count number of Score() calls. double distance; if (tree::TreeTraits<TreeType>::FirstPointIsCentroid) { // The first point in the tree is the centroid. So we can then calculate // the base case between that and the query point. double baseCase = -1.0; if (tree::TreeTraits<TreeType>::HasSelfChildren) { // If the parent node is the same, then we have already calculated the // base case. if ((referenceNode.Parent() != NULL) && (referenceNode.Point(0) == referenceNode.Parent()->Point(0))) baseCase = referenceNode.Parent()->Stat().LastDistance(); else baseCase = BaseCase(queryIndex, referenceNode.Point(0)); // Save this evaluation. referenceNode.Stat().LastDistance() = baseCase; } distance = SortPolicy::CombineBest(baseCase, referenceNode.FurthestDescendantDistance()); } else { distance = SortPolicy::BestPointToNodeDistance(querySet.col(queryIndex), &referenceNode); } // Compare against the best k'th distance for this query point so far. const double bestDistance = distances(distances.n_rows - 1, queryIndex); return (SortPolicy::IsBetter(distance, bestDistance)) ? distance : DBL_MAX; }
inline double DTBRules<MetricType, TreeType>::CalculateBound( TreeType& queryNode) const { double worstPointBound = -DBL_MAX; double bestPointBound = DBL_MAX; double worstChildBound = -DBL_MAX; double bestChildBound = DBL_MAX; // Now, find the best and worst point bounds. for (size_t i = 0; i < queryNode.NumPoints(); ++i) { const size_t pointComponent = connections.Find(queryNode.Point(i)); const double bound = neighborsDistances[pointComponent]; if (bound > worstPointBound) worstPointBound = bound; if (bound < bestPointBound) bestPointBound = bound; } // Find the best and worst child bounds. for (size_t i = 0; i < queryNode.NumChildren(); ++i) { const double maxBound = queryNode.Child(i).Stat().MaxNeighborDistance(); if (maxBound > worstChildBound) worstChildBound = maxBound; const double minBound = queryNode.Child(i).Stat().MinNeighborDistance(); if (minBound < bestChildBound) bestChildBound = minBound; } // Now calculate the actual bounds. const double worstBound = std::max(worstPointBound, worstChildBound); const double bestBound = std::min(bestPointBound, bestChildBound); // We must check that bestBound != DBL_MAX; otherwise, we risk overflow. const double bestAdjustedBound = (bestBound == DBL_MAX) ? DBL_MAX : bestBound + 2 * queryNode.FurthestDescendantDistance(); // Update the relevant quantities in the node. queryNode.Stat().MaxNeighborDistance() = worstBound; queryNode.Stat().MinNeighborDistance() = bestBound; queryNode.Stat().Bound() = std::min(worstBound, bestAdjustedBound); return queryNode.Stat().Bound(); }
DualTreeKMeansStatistic(TreeType& node) : closestQueryNode(NULL), minQueryNodeDistance(DBL_MAX), maxQueryNodeDistance(DBL_MAX), clustersPruned(0), iteration(size_t() - 1) { // Empirically calculate the centroid. centroid.zeros(node.Dataset().n_rows); for (size_t i = 0; i < node.NumPoints(); ++i) centroid += node.Dataset().col(node.Point(i)); for (size_t i = 0; i < node.NumChildren(); ++i) centroid += node.Child(i).NumDescendants() * node.Child(i).Stat().Centroid(); centroid /= node.NumDescendants(); }
PellegMooreKMeansStatistic(TreeType& node) { centroid.zeros(node.Dataset().n_rows); // Hope it's a depth-first build procedure. Also, this won't work right for // trees that have self-children or stuff like that. for (size_t i = 0; i < node.NumChildren(); ++i) { centroid += node.Child(i).NumDescendants() * node.Child(i).Stat().Centroid(); } for (size_t i = 0; i < node.NumPoints(); ++i) { centroid += node.Dataset().col(node.Point(i)); } if (node.NumDescendants() > 0) centroid /= node.NumDescendants(); else centroid.fill(DBL_MAX); // Invalid centroid. What else can we do? }
void GreedySingleTreeTraverser<TreeType, RuleType>::Traverse( const size_t queryIndex, TreeType& referenceNode) { // Run the base case as necessary for all the points in the reference node. for (size_t i = 0; i < referenceNode.NumPoints(); ++i) rule.BaseCase(queryIndex, referenceNode.Point(i)); size_t bestChild = rule.GetBestChild(queryIndex, referenceNode); size_t numDescendants; // Check that referencenode is not a leaf node while calculating number of // descendants of it's best child. if (!referenceNode.IsLeaf()) numDescendants = referenceNode.Child(bestChild).NumDescendants(); else numDescendants = referenceNode.NumPoints(); // If number of descendants are more than minBaseCases than we can go along // with best child otherwise we need to traverse for each descendant to // ensure that we calculate at least minBaseCases number of base cases. if (!referenceNode.IsLeaf()) { if (numDescendants > minBaseCases) { // We are prunning all but one child. numPrunes += referenceNode.NumChildren() - 1; // Recurse the best child. Traverse(queryIndex, referenceNode.Child(bestChild)); } else { // Run the base case over first minBaseCases number of descendants. for (size_t i = 0; i <= minBaseCases; ++i) rule.BaseCase(queryIndex, referenceNode.Descendant(i)); } } }
void RangeSearchRules<MetricType, TreeType>::AddResult(const size_t queryIndex, TreeType& referenceNode) { // Some types of trees calculate the base case evaluation before Score() is // called, so if the base case has already been calculated, then we must avoid // adding that point to the results again. size_t baseCaseMod = 0; if (tree::TreeTraits<TreeType>::FirstPointIsCentroid && (queryIndex == lastQueryIndex) && (referenceNode.Point(0) == lastReferenceIndex)) { baseCaseMod = 1; } // Resize distances and neighbors vectors appropriately. We have to use // reserve() and not resize(), because we don't know if we will encounter the // case where the datasets and points are the same (and we skip in that case). const size_t oldSize = neighbors[queryIndex].size(); neighbors[queryIndex].reserve(oldSize + referenceNode.NumDescendants() - baseCaseMod); distances[queryIndex].reserve(oldSize + referenceNode.NumDescendants() - baseCaseMod); for (size_t i = baseCaseMod; i < referenceNode.NumDescendants(); ++i) { if ((&referenceSet == &querySet) && (queryIndex == referenceNode.Descendant(i))) continue; const double distance = metric.Evaluate(querySet.unsafe_col(queryIndex), referenceNode.Dataset().unsafe_col(referenceNode.Descendant(i))); neighbors[queryIndex].push_back(referenceNode.Descendant(i)); distances[queryIndex].push_back(distance); } }
void CheckTrees(TreeType& tree, TreeType& xmlTree, TreeType& textTree, TreeType& binaryTree) { const typename TreeType::Mat* dataset = &tree.Dataset(); // Make sure that the data matrices are the same. if (tree.Parent() == NULL) { CheckMatrices(*dataset, xmlTree.Dataset(), textTree.Dataset(), binaryTree.Dataset()); // Also ensure that the other parents are null too. BOOST_REQUIRE_EQUAL(xmlTree.Parent(), (TreeType*) NULL); BOOST_REQUIRE_EQUAL(textTree.Parent(), (TreeType*) NULL); BOOST_REQUIRE_EQUAL(binaryTree.Parent(), (TreeType*) NULL); } // Make sure the number of children is the same. BOOST_REQUIRE_EQUAL(tree.NumChildren(), xmlTree.NumChildren()); BOOST_REQUIRE_EQUAL(tree.NumChildren(), textTree.NumChildren()); BOOST_REQUIRE_EQUAL(tree.NumChildren(), binaryTree.NumChildren()); // Make sure the number of descendants is the same. BOOST_REQUIRE_EQUAL(tree.NumDescendants(), xmlTree.NumDescendants()); BOOST_REQUIRE_EQUAL(tree.NumDescendants(), textTree.NumDescendants()); BOOST_REQUIRE_EQUAL(tree.NumDescendants(), binaryTree.NumDescendants()); // Make sure the number of points is the same. BOOST_REQUIRE_EQUAL(tree.NumPoints(), xmlTree.NumPoints()); BOOST_REQUIRE_EQUAL(tree.NumPoints(), textTree.NumPoints()); BOOST_REQUIRE_EQUAL(tree.NumPoints(), binaryTree.NumPoints()); // Check that each point is the same. for (size_t i = 0; i < tree.NumPoints(); ++i) { BOOST_REQUIRE_EQUAL(tree.Point(i), xmlTree.Point(i)); BOOST_REQUIRE_EQUAL(tree.Point(i), textTree.Point(i)); BOOST_REQUIRE_EQUAL(tree.Point(i), binaryTree.Point(i)); } // Check that the parent distance is the same. BOOST_REQUIRE_CLOSE(tree.ParentDistance(), xmlTree.ParentDistance(), 1e-8); BOOST_REQUIRE_CLOSE(tree.ParentDistance(), textTree.ParentDistance(), 1e-8); BOOST_REQUIRE_CLOSE(tree.ParentDistance(), binaryTree.ParentDistance(), 1e-8); // Check that the furthest descendant distance is the same. BOOST_REQUIRE_CLOSE(tree.FurthestDescendantDistance(), xmlTree.FurthestDescendantDistance(), 1e-8); BOOST_REQUIRE_CLOSE(tree.FurthestDescendantDistance(), textTree.FurthestDescendantDistance(), 1e-8); BOOST_REQUIRE_CLOSE(tree.FurthestDescendantDistance(), binaryTree.FurthestDescendantDistance(), 1e-8); // Check that the minimum bound distance is the same. BOOST_REQUIRE_CLOSE(tree.MinimumBoundDistance(), xmlTree.MinimumBoundDistance(), 1e-8); BOOST_REQUIRE_CLOSE(tree.MinimumBoundDistance(), textTree.MinimumBoundDistance(), 1e-8); BOOST_REQUIRE_CLOSE(tree.MinimumBoundDistance(), binaryTree.MinimumBoundDistance(), 1e-8); // Recurse into the children. for (size_t i = 0; i < tree.NumChildren(); ++i) { // Check that the child dataset is the same. BOOST_REQUIRE_EQUAL(&xmlTree.Dataset(), &xmlTree.Child(i).Dataset()); BOOST_REQUIRE_EQUAL(&textTree.Dataset(), &textTree.Child(i).Dataset()); BOOST_REQUIRE_EQUAL(&binaryTree.Dataset(), &binaryTree.Child(i).Dataset()); // Make sure the parent link is right. BOOST_REQUIRE_EQUAL(xmlTree.Child(i).Parent(), &xmlTree); BOOST_REQUIRE_EQUAL(textTree.Child(i).Parent(), &textTree); BOOST_REQUIRE_EQUAL(binaryTree.Child(i).Parent(), &binaryTree); CheckTrees(tree.Child(i), xmlTree.Child(i), textTree.Child(i), binaryTree.Child(i)); } }
double FastMKSRules<KernelType, TreeType>::Score(TreeType& queryNode, TreeType& referenceNode) { // Update and get the query node's bound. queryNode.Stat().Bound() = CalculateBound(queryNode); const double bestKernel = queryNode.Stat().Bound(); // First, see if we can make a parent-child or parent-parent prune. These // four bounds on the maximum kernel value are looser than the bound normally // used, but they can prevent a base case from needing to be calculated. // Convenience caching so lines are shorter. const double queryParentDist = queryNode.ParentDistance(); const double queryDescDist = queryNode.FurthestDescendantDistance(); const double refParentDist = referenceNode.ParentDistance(); const double refDescDist = referenceNode.FurthestDescendantDistance(); double adjustedScore = traversalInfo.LastBaseCase(); const double queryDistBound = (queryParentDist + queryDescDist); const double refDistBound = (refParentDist + refDescDist); double dualQueryTerm; double dualRefTerm; // The parent-child and parent-parent prunes work by applying the same pruning // condition as when the parent node was used, except they are tighter because // queryDistBound < queryNode.Parent()->FurthestDescendantDistance() // and // refDistBound < referenceNode.Parent()->FurthestDescendantDistance() // so we construct the same bounds that were used when Score() was called with // the parents, except with the tighter distance bounds. Sometimes this // allows us to prune nodes without evaluating the base cases between them. if (traversalInfo.LastQueryNode() == queryNode.Parent()) { // We can assume that queryNode.Parent() != NULL, because at the root node // combination, the traversalInfo.LastQueryNode() pointer will _not_ be // NULL. We also should be guaranteed that // traversalInfo.LastReferenceNode() is either the reference node or the // parent of the reference node. adjustedScore += queryDistBound * traversalInfo.LastReferenceNode()->Stat().SelfKernel(); dualQueryTerm = queryDistBound; } else { // The query parent could be NULL, which does weird things and we have to // consider. if (traversalInfo.LastReferenceNode() != NULL) { adjustedScore += queryDescDist * traversalInfo.LastReferenceNode()->Stat().SelfKernel(); dualQueryTerm = queryDescDist; } else { // This makes it so a child-parent (or parent-parent) prune is not // possible. dualQueryTerm = 0.0; adjustedScore = bestKernel; } } if (traversalInfo.LastReferenceNode() == referenceNode.Parent()) { // We can assume that referenceNode.Parent() != NULL, because at the root // node combination, the traversalInfo.LastReferenceNode() pointer will // _not_ be NULL. adjustedScore += refDistBound * traversalInfo.LastQueryNode()->Stat().SelfKernel(); dualRefTerm = refDistBound; } else { // The reference parent could be NULL, which does weird things and we have // to consider. if (traversalInfo.LastQueryNode() != NULL) { adjustedScore += refDescDist * traversalInfo.LastQueryNode()->Stat().SelfKernel(); dualRefTerm = refDescDist; } else { // This makes it so a child-parent (or parent-parent) prune is not // possible. dualRefTerm = 0.0; adjustedScore = bestKernel; } } // Now add the dual term. adjustedScore += (dualQueryTerm * dualRefTerm); if (adjustedScore < bestKernel) { // It is not possible that this node combination can contain a point // combination with kernel value better than the minimum kernel value to // improve any of the results, so we can prune it. return DBL_MAX; } // We were unable to perform a parent-child or parent-parent prune, so now we // must calculate kernel evaluation, if necessary. double kernelEval = 0.0; if (tree::TreeTraits<TreeType>::FirstPointIsCentroid) { // For this type of tree, we may have already calculated the base case in // the parents. if ((traversalInfo.LastQueryNode() != NULL) && (traversalInfo.LastReferenceNode() != NULL) && (traversalInfo.LastQueryNode()->Point(0) == queryNode.Point(0)) && (traversalInfo.LastReferenceNode()->Point(0) == referenceNode.Point(0))) { // Base case already done. kernelEval = traversalInfo.LastBaseCase(); // When BaseCase() is called after Score(), these must be correct so that // another kernel evaluation is not performed. lastQueryIndex = queryNode.Point(0); lastReferenceIndex = referenceNode.Point(0); } else { // The kernel must be evaluated, but it is between points in the dataset, // so we can call BaseCase(). BaseCase() will set lastQueryIndex and // lastReferenceIndex correctly. kernelEval = BaseCase(queryNode.Point(0), referenceNode.Point(0)); } traversalInfo.LastBaseCase() = kernelEval; } else { // Calculate the maximum possible kernel value. arma::vec queryCentroid; arma::vec refCentroid; queryNode.Centroid(queryCentroid); referenceNode.Centroid(refCentroid); kernelEval = kernel.Evaluate(queryCentroid, refCentroid); traversalInfo.LastBaseCase() = kernelEval; } ++scores; double maxKernel; if (kernel::KernelTraits<KernelType>::IsNormalized) { // We have a tighter bound for normalized kernels. const double querySqDist = std::pow(queryDescDist, 2.0); const double refSqDist = std::pow(refDescDist, 2.0); const double bothSqDist = std::pow((queryDescDist + refDescDist), 2.0); if (kernelEval <= (1 - 0.5 * bothSqDist)) { const double queryDelta = (1 - 0.5 * querySqDist); const double queryGamma = queryDescDist * sqrt(1 - 0.25 * querySqDist); const double refDelta = (1 - 0.5 * refSqDist); const double refGamma = refDescDist * sqrt(1 - 0.25 * refSqDist); maxKernel = kernelEval * (queryDelta * refDelta - queryGamma * refGamma) + sqrt(1 - std::pow(kernelEval, 2.0)) * (queryGamma * refDelta + queryDelta * refGamma); } else { maxKernel = 1.0; } } else { // Use standard bound; kernel is not normalized. const double refKernelTerm = queryDescDist * referenceNode.Stat().SelfKernel(); const double queryKernelTerm = refDescDist * queryNode.Stat().SelfKernel(); maxKernel = kernelEval + refKernelTerm + queryKernelTerm + (queryDescDist * refDescDist); } // Store relevant information for parent-child pruning. traversalInfo.LastQueryNode() = &queryNode; traversalInfo.LastReferenceNode() = &referenceNode; // We return the inverse of the maximum kernel so that larger kernels are // recursed into first. return (maxKernel > bestKernel) ? (1.0 / maxKernel) : DBL_MAX; }
inline double KDERules<MetricType, KernelType, TreeType>:: Score(TreeType& queryNode, TreeType& referenceNode) { double score, maxKernel, minKernel, bound; const double minDistance = queryNode.MinDistance(referenceNode); // Calculations are not duplicated. bool newCalculations = true; if (tree::TreeTraits<TreeType>::FirstPointIsCentroid && (traversalInfo.LastQueryNode() != NULL) && (traversalInfo.LastReferenceNode() != NULL) && (traversalInfo.LastQueryNode()->Point(0) == queryNode.Point(0)) && (traversalInfo.LastReferenceNode()->Point(0) == referenceNode.Point(0))) { // Don't duplicate calculations. newCalculations = false; lastQueryIndex = queryNode.Point(0); lastReferenceIndex = referenceNode.Point(0); } else { // Calculations are new. maxKernel = kernel.Evaluate(minDistance); minKernel = kernel.Evaluate(queryNode.MaxDistance(referenceNode)); bound = maxKernel - minKernel; } // If possible, avoid some calculations because of the error tolerance. if (newCalculations && bound <= (absError + relError * minKernel) / referenceSet.n_cols) { // Auxiliary variables. double kernelValue; kde::KDEStat& referenceStat = referenceNode.Stat(); kde::KDEStat& queryStat = queryNode.Stat(); // If calculating a center is not required. if (tree::TreeTraits<TreeType>::FirstPointIsCentroid) { kernelValue = EvaluateKernel(queryNode.Point(0), referenceNode.Point(0)); } // Sadly, we have no choice but to calculate the center. else { kernelValue = EvaluateKernel(queryStat.Centroid(), referenceStat.Centroid()); } // Sum up estimations. for (size_t i = 0; i < queryNode.NumDescendants(); ++i) { densities(queryNode.Descendant(i)) += referenceNode.NumDescendants() * kernelValue; } score = DBL_MAX; } else { score = minDistance; } ++scores; traversalInfo.LastQueryNode() = &queryNode; traversalInfo.LastReferenceNode() = &referenceNode; traversalInfo.LastScore() = score; return score; }
inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>:: CalculateBound(TreeType& queryNode) const { // We have five possible bounds, and we must take the best of them all. We // don't use min/max here, but instead "best/worst", because this is general // to the nearest-neighbors/furthest-neighbors cases. For nearest neighbors, // min = best, max = worst. // // (1) worst ( worst_{all points p in queryNode} D_p[k], // worst_{all children c in queryNode} B(c) ); // (2) best_{all points p in queryNode} D_p[k] + worst child distance + // worst descendant distance; // (3) best_{all children c in queryNode} B(c) + // 2 ( worst descendant distance of queryNode - // worst descendant distance of c ); // (4) B_1(parent of queryNode) // (5) B_2(parent of queryNode); // // D_p[k] is the current k'th candidate distance for point p. // So we will loop over the points in queryNode and the children in queryNode // to calculate all five of these quantities. // Hm, can we populate our distances vector with estimates from the parent? // This is written specifically for the cover tree and assumes only one point // in a node. // if (queryNode.Parent() != NULL && queryNode.NumPoints() > 0) // { // size_t parentIndexStart = 0; // for (size_t i = 0; i < neighbors.n_rows; ++i) // { // const double pointDistance = distances(i, queryNode.Point(0)); // if (pointDistance == DBL_MAX) // { // // Cool, can we take an estimate from the parent? // const double parentWorstBound = distances(distances.n_rows - 1, // queryNode.Parent()->Point(0)); // if (parentWorstBound != DBL_MAX) // { // const double parentAdjustedDistance = parentWorstBound + // queryNode.ParentDistance(); // distances(i, queryNode.Point(0)) = parentAdjustedDistance; // } // } // } // } double worstPointDistance = SortPolicy::BestDistance(); double bestPointDistance = SortPolicy::WorstDistance(); // Loop over all points in this node to find the best and worst distance // candidates (for (1) and (2)). for (size_t i = 0; i < queryNode.NumPoints(); ++i) { const double distance = distances(distances.n_rows - 1, queryNode.Point(i)); if (SortPolicy::IsBetter(distance, bestPointDistance)) bestPointDistance = distance; if (SortPolicy::IsBetter(worstPointDistance, distance)) worstPointDistance = distance; } // Loop over all the children in this node to find the worst bound (for (1)) // and the best bound with the correcting factor for descendant distances (for // (3)). double worstChildBound = SortPolicy::BestDistance(); double bestAdjustedChildBound = SortPolicy::WorstDistance(); const double queryMaxDescendantDistance = queryNode.FurthestDescendantDistance(); for (size_t i = 0; i < queryNode.NumChildren(); ++i) { const double firstBound = queryNode.Child(i).Stat().FirstBound(); const double secondBound = queryNode.Child(i).Stat().SecondBound(); const double childMaxDescendantDistance = queryNode.Child(i).FurthestDescendantDistance(); if (SortPolicy::IsBetter(worstChildBound, firstBound)) worstChildBound = firstBound; // Now calculate adjustment for maximum descendant distances. const double adjustedBound = SortPolicy::CombineWorst(secondBound, 2 * (queryMaxDescendantDistance - childMaxDescendantDistance)); if (SortPolicy::IsBetter(adjustedBound, bestAdjustedChildBound)) bestAdjustedChildBound = adjustedBound; } // This is bound (1). const double firstBound = (SortPolicy::IsBetter(worstPointDistance, worstChildBound)) ? worstChildBound : worstPointDistance; // This is bound (2). const double secondBound = SortPolicy::CombineWorst( SortPolicy::CombineWorst(bestPointDistance, queryMaxDescendantDistance), queryNode.FurthestPointDistance()); // Bound (3) is bestAdjustedChildBound. // Bounds (4) and (5) are the parent bounds. const double fourthBound = (queryNode.Parent() != NULL) ? queryNode.Parent()->Stat().FirstBound() : SortPolicy::WorstDistance(); // const double fifthBound = (queryNode.Parent() != NULL) ? // queryNode.Parent()->Stat().SecondBound() - // queryNode.Parent()->FurthestDescendantDistance() - // queryNode.Parent()->FurthestPointDistance() + queryMaxDescendantDistance + // queryNode.FurthestPointDistance() + queryNode.ParentDistance() : // SortPolicy::WorstDistance(); // Now, we will take the best of these. Unfortunately due to the way // IsBetter() is defined, this sort of has to be a little ugly. // The variable interA represents the first bound (B_1), which is the worst // candidate distance of any descendants of this node. // The variable interC represents the second bound (B_2), which is a bound on // the worst distance of any descendants of this node assembled using the best // descendant candidate distance modified using the furthest descendant // distance. const double interA = (SortPolicy::IsBetter(firstBound, fourthBound)) ? firstBound : fourthBound; const double interB = (SortPolicy::IsBetter(bestAdjustedChildBound, secondBound)) ? bestAdjustedChildBound : secondBound; // const double interC = (SortPolicy::IsBetter(interB, fifthBound)) ? interB : // fifthBound; // Update the first and second bounds of the node. queryNode.Stat().FirstBound() = interA; queryNode.Stat().SecondBound() = interB; // Update the actual bound of the node. queryNode.Stat().Bound() = (SortPolicy::IsBetter(interA, interB)) ? interB : interB; return queryNode.Stat().Bound(); }
double FastMKSRules<KernelType, TreeType>::Score(const size_t queryIndex, TreeType& referenceNode) { // Compare with the current best. const double bestKernel = products(products.n_rows - 1, queryIndex); // See if we can perform a parent-child prune. const double furthestDist = referenceNode.FurthestDescendantDistance(); if (referenceNode.Parent() != NULL) { double maxKernelBound; const double parentDist = referenceNode.ParentDistance(); const double combinedDistBound = parentDist + furthestDist; const double lastKernel = referenceNode.Parent()->Stat().LastKernel(); if (kernel::KernelTraits<KernelType>::IsNormalized) { const double squaredDist = std::pow(combinedDistBound, 2.0); const double delta = (1 - 0.5 * squaredDist); if (lastKernel <= delta) { const double gamma = combinedDistBound * sqrt(1 - 0.25 * squaredDist); maxKernelBound = lastKernel * delta + gamma * sqrt(1 - std::pow(lastKernel, 2.0)); } else { maxKernelBound = 1.0; } } else { maxKernelBound = lastKernel + combinedDistBound * queryKernels[queryIndex]; } if (maxKernelBound < bestKernel) return DBL_MAX; } // Calculate the maximum possible kernel value, either by calculating the // centroid or, if the centroid is a point, use that. ++scores; double kernelEval; if (tree::TreeTraits<TreeType>::FirstPointIsCentroid) { // Could it be that this kernel evaluation has already been calculated? if (tree::TreeTraits<TreeType>::HasSelfChildren && referenceNode.Parent() != NULL && referenceNode.Point(0) == referenceNode.Parent()->Point(0)) { kernelEval = referenceNode.Parent()->Stat().LastKernel(); } else { kernelEval = BaseCase(queryIndex, referenceNode.Point(0)); } } else { const arma::vec queryPoint = querySet.unsafe_col(queryIndex); arma::vec refCentroid; referenceNode.Centroid(refCentroid); kernelEval = kernel.Evaluate(queryPoint, refCentroid); } referenceNode.Stat().LastKernel() = kernelEval; double maxKernel; if (kernel::KernelTraits<KernelType>::IsNormalized) { const double squaredDist = std::pow(furthestDist, 2.0); const double delta = (1 - 0.5 * squaredDist); if (kernelEval <= delta) { const double gamma = furthestDist * sqrt(1 - 0.25 * squaredDist); maxKernel = kernelEval * delta + gamma * sqrt(1 - std::pow(kernelEval, 2.0)); } else { maxKernel = 1.0; } } else { maxKernel = kernelEval + furthestDist * queryKernels[queryIndex]; } // We return the inverse of the maximum kernel so that larger kernels are // recursed into first. return (maxKernel > bestKernel) ? (1.0 / maxKernel) : DBL_MAX; }
double FastMKSRules<KernelType, TreeType>::CalculateBound(TreeType& queryNode) const { // We have four possible bounds -- just like NeighborSearchRules, but they are // slightly different in this context. // // (1) min ( min_{all points p in queryNode} P_p[k], // min_{all children c in queryNode} B(c) ); // (2) max_{all points p in queryNode} P_p[k] + (worst child distance + worst // descendant distance) sqrt(K(I_p[k], I_p[k])); // (3) max_{all children c in queryNode} B(c) + <-- not done yet. ignored. // (4) B(parent of queryNode); double worstPointKernel = DBL_MAX; double bestAdjustedPointKernel = -DBL_MAX; const double queryDescendantDistance = queryNode.FurthestDescendantDistance(); // Loop over all points in this node to find the best and worst. for (size_t i = 0; i < queryNode.NumPoints(); ++i) { const size_t point = queryNode.Point(i); if (products(products.n_rows - 1, point) < worstPointKernel) worstPointKernel = products(products.n_rows - 1, point); if (products(products.n_rows - 1, point) == -DBL_MAX) continue; // Avoid underflow. // This should be (queryDescendantDistance + centroidDistance) for any tree // but it works for cover trees since centroidDistance = 0 for cover trees. const double candidateKernel = products(products.n_rows - 1, point) - queryDescendantDistance * referenceKernels[indices(indices.n_rows - 1, point)]; if (candidateKernel > bestAdjustedPointKernel) bestAdjustedPointKernel = candidateKernel; } // Loop over all the children in the node. double worstChildKernel = DBL_MAX; for (size_t i = 0; i < queryNode.NumChildren(); ++i) { if (queryNode.Child(i).Stat().Bound() < worstChildKernel) worstChildKernel = queryNode.Child(i).Stat().Bound(); } // Now assemble bound (1). const double firstBound = (worstPointKernel < worstChildKernel) ? worstPointKernel : worstChildKernel; // Bound (2) is bestAdjustedPointKernel. const double fourthBound = (queryNode.Parent() == NULL) ? -DBL_MAX : queryNode.Parent()->Stat().Bound(); // Pick the best of these bounds. const double interA = (firstBound > bestAdjustedPointKernel) ? firstBound : bestAdjustedPointKernel; // const double interA = 0.0; const double interB = fourthBound; return (interA > interB) ? interA : interB; }
double RangeSearchRules<MetricType, TreeType>::Score(TreeType& queryNode, TreeType& referenceNode) { math::Range distances; if (tree::TreeTraits<TreeType>::FirstPointIsCentroid) { // It is possible that the base case has already been calculated. double baseCase = 0.0; bool alreadyDone = false; if (tree::TreeTraits<TreeType>::HasSelfChildren) { TreeType* lastQuery = (TreeType*) referenceNode.Stat().LastDistanceNode(); TreeType* lastRef = (TreeType*) queryNode.Stat().LastDistanceNode(); // Did the query node's last combination do the base case? if ((lastRef != NULL) && (referenceNode.Point(0) == lastRef->Point(0))) { baseCase = queryNode.Stat().LastDistance(); alreadyDone = true; } // Did the reference node's last combination do the base case? if ((lastQuery != NULL) && (queryNode.Point(0) == lastQuery->Point(0))) { baseCase = referenceNode.Stat().LastDistance(); alreadyDone = true; } // If the query node is a self-child, did the query parent's last // combination do the base case? if ((queryNode.Parent() != NULL) && (queryNode.Point(0) == queryNode.Parent()->Point(0))) { TreeType* lastParentRef = (TreeType*) queryNode.Parent()->Stat().LastDistanceNode(); if ((lastParentRef != NULL) && (referenceNode.Point(0) == lastParentRef->Point(0))) { baseCase = queryNode.Parent()->Stat().LastDistance(); alreadyDone = true; } } // If the reference node is a self-child, did the reference parent's last // combination do the base case? if ((referenceNode.Parent() != NULL) && (referenceNode.Point(0) == referenceNode.Parent()->Point(0))) { TreeType* lastQueryRef = (TreeType*) referenceNode.Parent()->Stat().LastDistanceNode(); if ((lastQueryRef != NULL) && (queryNode.Point(0) == lastQueryRef->Point(0))) { baseCase = referenceNode.Parent()->Stat().LastDistance(); alreadyDone = true; } } } if (!alreadyDone) { // We must calculate the base case. baseCase = BaseCase(queryNode.Point(0), referenceNode.Point(0)); } else { // Make sure that if BaseCase() is called, we don't duplicate results. lastQueryIndex = queryNode.Point(0); lastReferenceIndex = referenceNode.Point(0); } distances.Lo() = baseCase - queryNode.FurthestDescendantDistance() - referenceNode.FurthestDescendantDistance(); distances.Hi() = baseCase + queryNode.FurthestDescendantDistance() + referenceNode.FurthestDescendantDistance(); // Update the last distances performed for the query and reference node. queryNode.Stat().LastDistanceNode() = (void*) &referenceNode; queryNode.Stat().LastDistance() = baseCase; referenceNode.Stat().LastDistanceNode() = (void*) &queryNode; referenceNode.Stat().LastDistance() = baseCase; } else { // Just perform the calculation. distances = referenceNode.RangeDistance(&queryNode); } // If the ranges do not overlap, prune this node. if (!distances.Contains(range)) return DBL_MAX; // In this case, all of the points in the reference node will be part of all // the results for each point in the query node. if ((distances.Lo() >= range.Lo()) && (distances.Hi() <= range.Hi())) { for (size_t i = 0; i < queryNode.NumDescendants(); ++i) AddResult(queryNode.Descendant(i), referenceNode); return DBL_MAX; // We don't need to go any deeper. } // Otherwise the score doesn't matter. Recursion order is irrelevant in range // search. return 0.0; }
double PellegMooreKMeansRules<MetricType, TreeType>::Score( const size_t /* queryIndex */, TreeType& referenceNode) { // Obtain the parent's blacklist. If this is the root node, we'll start with // an empty blacklist. This means that after each iteration, we don't need to // reset any statistics. if (referenceNode.Parent() == NULL || referenceNode.Parent()->Stat().Blacklist().n_elem == 0) referenceNode.Stat().Blacklist().zeros(centroids.n_cols); else referenceNode.Stat().Blacklist() = referenceNode.Parent()->Stat().Blacklist(); // The query index is a fake index that we won't use, and the reference node // holds all of the points in the dataset. Our goal is to determine whether // or not this node is dominated by a single cluster. const size_t whitelisted = centroids.n_cols - arma::accu(referenceNode.Stat().Blacklist()); distanceCalculations += whitelisted; // Which cluster has minimum distance to the node? size_t closestCluster = centroids.n_cols; double minMinDistance = DBL_MAX; for (size_t i = 0; i < centroids.n_cols; ++i) { if (referenceNode.Stat().Blacklist()[i] == 0) { const double minDistance = referenceNode.MinDistance(centroids.col(i)); if (minDistance < minMinDistance) { minMinDistance = minDistance; closestCluster = i; } } } // Now, for every other whitelisted cluster, determine if the closest cluster // owns the point. This calculation is specific to hyperrectangle trees (but, // this implementation is specific to kd-trees, so that's okay). For // circular-bound trees, the condition should be simpler and can probably be // expressed as a comparison between minimum and maximum distances. size_t newBlacklisted = 0; for (size_t c = 0; c < centroids.n_cols; ++c) { if (referenceNode.Stat().Blacklist()[c] == 1 || c == closestCluster) continue; // This algorithm comes from the proof of Lemma 4 in the extended version // of the Pelleg-Moore paper (the CMU tech report, that is). It has been // adapted for speed. arma::vec cornerPoint(centroids.n_rows); for (size_t d = 0; d < referenceNode.Bound().Dim(); ++d) { if (centroids(d, c) > centroids(d, closestCluster)) cornerPoint(d) = referenceNode.Bound()[d].Hi(); else cornerPoint(d) = referenceNode.Bound()[d].Lo(); } const double closestDist = metric.Evaluate(cornerPoint, centroids.col(closestCluster)); const double otherDist = metric.Evaluate(cornerPoint, centroids.col(c)); distanceCalculations += 3; // One for cornerPoint, then two distances. if (closestDist < otherDist) { // The closest cluster dominates the node with respect to the cluster c. // So we can blacklist c. referenceNode.Stat().Blacklist()[c] = 1; ++newBlacklisted; } } if (whitelisted - newBlacklisted == 1) { // This node is dominated by the closest cluster. counts[closestCluster] += referenceNode.NumDescendants(); newCentroids.col(closestCluster) += referenceNode.NumDescendants() * referenceNode.Stat().Centroid(); return DBL_MAX; } // Perform the base case here. for (size_t i = 0; i < referenceNode.NumPoints(); ++i) { size_t bestCluster = centroids.n_cols; double bestDistance = DBL_MAX; for (size_t c = 0; c < centroids.n_cols; ++c) { if (referenceNode.Stat().Blacklist()[c] == 1) continue; ++distanceCalculations; // The reference index is the index of the data point. const double distance = metric.Evaluate(centroids.col(c), dataset.col(referenceNode.Point(i))); if (distance < bestDistance) { bestDistance = distance; bestCluster = c; } } // Add to resulting centroid. newCentroids.col(bestCluster) += dataset.col(referenceNode.Point(i)); ++counts(bestCluster); } // Otherwise, we're not sure, so we can't prune. Recursion order doesn't make // a difference, so we'll just return a score of 0. return 0.0; }
inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Score( TreeType& queryNode, TreeType& referenceNode) { ++scores; // Count number of Score() calls. // Update our bound. const double bestDistance = CalculateBound(queryNode); // Use the traversal info to see if a parent-child or parent-parent prune is // possible. This is a looser bound than we could make, but it might be // sufficient. const double queryParentDist = queryNode.ParentDistance(); const double queryDescDist = queryNode.FurthestDescendantDistance(); const double refParentDist = referenceNode.ParentDistance(); const double refDescDist = referenceNode.FurthestDescendantDistance(); const double score = traversalInfo.LastScore(); double adjustedScore; // We want to set adjustedScore to be the distance between the centroid of the // last query node and last reference node. We will do this by adjusting the // last score. In some cases, we can just use the last base case. if (tree::TreeTraits<TreeType>::FirstPointIsCentroid) { adjustedScore = traversalInfo.LastBaseCase(); } else if (score == 0.0) // Nothing we can do here. { adjustedScore = 0.0; } else { // The last score is equal to the distance between the centroids minus the // radii of the query and reference bounds along the axis of the line // between the two centroids. In the best case, these radii are the // furthest descendant distances, but that is not always true. It would // take too long to calculate the exact radii, so we are forced to use // MinimumBoundDistance() as a lower-bound approximation. const double lastQueryDescDist = traversalInfo.LastQueryNode()->MinimumBoundDistance(); const double lastRefDescDist = traversalInfo.LastReferenceNode()->MinimumBoundDistance(); adjustedScore = SortPolicy::CombineWorst(score, lastQueryDescDist); adjustedScore = SortPolicy::CombineWorst(score, lastRefDescDist); } // Assemble an adjusted score. For nearest neighbor search, this adjusted // score is a lower bound on MinDistance(queryNode, referenceNode) that is // assembled without actually calculating MinDistance(). For furthest // neighbor search, it is an upper bound on // MaxDistance(queryNode, referenceNode). If the traversalInfo isn't usable // then the node should not be pruned by this. if (traversalInfo.LastQueryNode() == queryNode.Parent()) { const double queryAdjust = queryParentDist + queryDescDist; adjustedScore = SortPolicy::CombineBest(adjustedScore, queryAdjust); } else if (traversalInfo.LastQueryNode() == &queryNode) { adjustedScore = SortPolicy::CombineBest(adjustedScore, queryDescDist); } else { // The last query node wasn't this query node or its parent. So we force // the adjustedScore to be such that this combination can't be pruned here, // because we don't really know anything about it. // It would be possible to modify this section to try and make a prune based // on the query descendant distance and the distance between the query node // and last traversal query node, but this case doesn't actually happen for // kd-trees or cover trees. adjustedScore = SortPolicy::BestDistance(); } if (traversalInfo.LastReferenceNode() == referenceNode.Parent()) { const double refAdjust = refParentDist + refDescDist; adjustedScore = SortPolicy::CombineBest(adjustedScore, refAdjust); } else if (traversalInfo.LastReferenceNode() == &referenceNode) { adjustedScore = SortPolicy::CombineBest(adjustedScore, refDescDist); } else { // The last reference node wasn't this reference node or its parent. So we // force the adjustedScore to be such that this combination can't be pruned // here, because we don't really know anything about it. // It would be possible to modify this section to try and make a prune based // on the reference descendant distance and the distance between the // reference node and last traversal reference node, but this case doesn't // actually happen for kd-trees or cover trees. adjustedScore = SortPolicy::BestDistance(); } // Can we prune? if (SortPolicy::IsBetter(bestDistance, adjustedScore)) { if (!(tree::TreeTraits<TreeType>::FirstPointIsCentroid && score == 0.0)) { // There isn't any need to set the traversal information because no // descendant combinations will be visited, and those are the only // combinations that would depend on the traversal information. return DBL_MAX; } } double distance; if (tree::TreeTraits<TreeType>::FirstPointIsCentroid) { // The first point in the node is the centroid, so we can calculate the // distance between the two points using BaseCase() and then find the // bounds. This is potentially loose for non-ball bounds. double baseCase = -1.0; if (tree::TreeTraits<TreeType>::HasSelfChildren && (traversalInfo.LastQueryNode()->Point(0) == queryNode.Point(0)) && (traversalInfo.LastReferenceNode()->Point(0) == referenceNode.Point(0))) { // We already calculated it. baseCase = traversalInfo.LastBaseCase(); } else { baseCase = BaseCase(queryNode.Point(0), referenceNode.Point(0)); } distance = SortPolicy::CombineBest(baseCase, queryNode.FurthestDescendantDistance() + referenceNode.FurthestDescendantDistance()); lastQueryIndex = queryNode.Point(0); lastReferenceIndex = referenceNode.Point(0); lastBaseCase = baseCase; traversalInfo.LastBaseCase() = baseCase; } else { distance = SortPolicy::BestNodeToNodeDistance(&queryNode, &referenceNode); } if (SortPolicy::IsBetter(distance, bestDistance)) { // Set traversal information. traversalInfo.LastQueryNode() = &queryNode; traversalInfo.LastReferenceNode() = &referenceNode; traversalInfo.LastScore() = distance; return distance; } else { // There isn't any need to set the traversal information because no // descendant combinations will be visited, and those are the only // combinations that would depend on the traversal information. return DBL_MAX; } }