void CheckTrees(TreeType& tree, TreeType& xmlTree, TreeType& textTree, TreeType& binaryTree) { const typename TreeType::Mat* dataset = &tree.Dataset(); // Make sure that the data matrices are the same. if (tree.Parent() == NULL) { CheckMatrices(*dataset, xmlTree.Dataset(), textTree.Dataset(), binaryTree.Dataset()); // Also ensure that the other parents are null too. BOOST_REQUIRE_EQUAL(xmlTree.Parent(), (TreeType*) NULL); BOOST_REQUIRE_EQUAL(textTree.Parent(), (TreeType*) NULL); BOOST_REQUIRE_EQUAL(binaryTree.Parent(), (TreeType*) NULL); } // Make sure the number of children is the same. BOOST_REQUIRE_EQUAL(tree.NumChildren(), xmlTree.NumChildren()); BOOST_REQUIRE_EQUAL(tree.NumChildren(), textTree.NumChildren()); BOOST_REQUIRE_EQUAL(tree.NumChildren(), binaryTree.NumChildren()); // Make sure the number of descendants is the same. BOOST_REQUIRE_EQUAL(tree.NumDescendants(), xmlTree.NumDescendants()); BOOST_REQUIRE_EQUAL(tree.NumDescendants(), textTree.NumDescendants()); BOOST_REQUIRE_EQUAL(tree.NumDescendants(), binaryTree.NumDescendants()); // Make sure the number of points is the same. BOOST_REQUIRE_EQUAL(tree.NumPoints(), xmlTree.NumPoints()); BOOST_REQUIRE_EQUAL(tree.NumPoints(), textTree.NumPoints()); BOOST_REQUIRE_EQUAL(tree.NumPoints(), binaryTree.NumPoints()); // Check that each point is the same. for (size_t i = 0; i < tree.NumPoints(); ++i) { BOOST_REQUIRE_EQUAL(tree.Point(i), xmlTree.Point(i)); BOOST_REQUIRE_EQUAL(tree.Point(i), textTree.Point(i)); BOOST_REQUIRE_EQUAL(tree.Point(i), binaryTree.Point(i)); } // Check that the parent distance is the same. BOOST_REQUIRE_CLOSE(tree.ParentDistance(), xmlTree.ParentDistance(), 1e-8); BOOST_REQUIRE_CLOSE(tree.ParentDistance(), textTree.ParentDistance(), 1e-8); BOOST_REQUIRE_CLOSE(tree.ParentDistance(), binaryTree.ParentDistance(), 1e-8); // Check that the furthest descendant distance is the same. BOOST_REQUIRE_CLOSE(tree.FurthestDescendantDistance(), xmlTree.FurthestDescendantDistance(), 1e-8); BOOST_REQUIRE_CLOSE(tree.FurthestDescendantDistance(), textTree.FurthestDescendantDistance(), 1e-8); BOOST_REQUIRE_CLOSE(tree.FurthestDescendantDistance(), binaryTree.FurthestDescendantDistance(), 1e-8); // Check that the minimum bound distance is the same. BOOST_REQUIRE_CLOSE(tree.MinimumBoundDistance(), xmlTree.MinimumBoundDistance(), 1e-8); BOOST_REQUIRE_CLOSE(tree.MinimumBoundDistance(), textTree.MinimumBoundDistance(), 1e-8); BOOST_REQUIRE_CLOSE(tree.MinimumBoundDistance(), binaryTree.MinimumBoundDistance(), 1e-8); // Recurse into the children. for (size_t i = 0; i < tree.NumChildren(); ++i) { // Check that the child dataset is the same. BOOST_REQUIRE_EQUAL(&xmlTree.Dataset(), &xmlTree.Child(i).Dataset()); BOOST_REQUIRE_EQUAL(&textTree.Dataset(), &textTree.Child(i).Dataset()); BOOST_REQUIRE_EQUAL(&binaryTree.Dataset(), &binaryTree.Child(i).Dataset()); // Make sure the parent link is right. BOOST_REQUIRE_EQUAL(xmlTree.Child(i).Parent(), &xmlTree); BOOST_REQUIRE_EQUAL(textTree.Child(i).Parent(), &textTree); BOOST_REQUIRE_EQUAL(binaryTree.Child(i).Parent(), &binaryTree); CheckTrees(tree.Child(i), xmlTree.Child(i), textTree.Child(i), binaryTree.Child(i)); } }
double FastMKSRules<KernelType, TreeType>::Score(TreeType& queryNode, TreeType& referenceNode) { // Update and get the query node's bound. queryNode.Stat().Bound() = CalculateBound(queryNode); const double bestKernel = queryNode.Stat().Bound(); // First, see if we can make a parent-child or parent-parent prune. These // four bounds on the maximum kernel value are looser than the bound normally // used, but they can prevent a base case from needing to be calculated. // Convenience caching so lines are shorter. const double queryParentDist = queryNode.ParentDistance(); const double queryDescDist = queryNode.FurthestDescendantDistance(); const double refParentDist = referenceNode.ParentDistance(); const double refDescDist = referenceNode.FurthestDescendantDistance(); double adjustedScore = traversalInfo.LastBaseCase(); const double queryDistBound = (queryParentDist + queryDescDist); const double refDistBound = (refParentDist + refDescDist); double dualQueryTerm; double dualRefTerm; // The parent-child and parent-parent prunes work by applying the same pruning // condition as when the parent node was used, except they are tighter because // queryDistBound < queryNode.Parent()->FurthestDescendantDistance() // and // refDistBound < referenceNode.Parent()->FurthestDescendantDistance() // so we construct the same bounds that were used when Score() was called with // the parents, except with the tighter distance bounds. Sometimes this // allows us to prune nodes without evaluating the base cases between them. if (traversalInfo.LastQueryNode() == queryNode.Parent()) { // We can assume that queryNode.Parent() != NULL, because at the root node // combination, the traversalInfo.LastQueryNode() pointer will _not_ be // NULL. We also should be guaranteed that // traversalInfo.LastReferenceNode() is either the reference node or the // parent of the reference node. adjustedScore += queryDistBound * traversalInfo.LastReferenceNode()->Stat().SelfKernel(); dualQueryTerm = queryDistBound; } else { // The query parent could be NULL, which does weird things and we have to // consider. if (traversalInfo.LastReferenceNode() != NULL) { adjustedScore += queryDescDist * traversalInfo.LastReferenceNode()->Stat().SelfKernel(); dualQueryTerm = queryDescDist; } else { // This makes it so a child-parent (or parent-parent) prune is not // possible. dualQueryTerm = 0.0; adjustedScore = bestKernel; } } if (traversalInfo.LastReferenceNode() == referenceNode.Parent()) { // We can assume that referenceNode.Parent() != NULL, because at the root // node combination, the traversalInfo.LastReferenceNode() pointer will // _not_ be NULL. adjustedScore += refDistBound * traversalInfo.LastQueryNode()->Stat().SelfKernel(); dualRefTerm = refDistBound; } else { // The reference parent could be NULL, which does weird things and we have // to consider. if (traversalInfo.LastQueryNode() != NULL) { adjustedScore += refDescDist * traversalInfo.LastQueryNode()->Stat().SelfKernel(); dualRefTerm = refDescDist; } else { // This makes it so a child-parent (or parent-parent) prune is not // possible. dualRefTerm = 0.0; adjustedScore = bestKernel; } } // Now add the dual term. adjustedScore += (dualQueryTerm * dualRefTerm); if (adjustedScore < bestKernel) { // It is not possible that this node combination can contain a point // combination with kernel value better than the minimum kernel value to // improve any of the results, so we can prune it. return DBL_MAX; } // We were unable to perform a parent-child or parent-parent prune, so now we // must calculate kernel evaluation, if necessary. double kernelEval = 0.0; if (tree::TreeTraits<TreeType>::FirstPointIsCentroid) { // For this type of tree, we may have already calculated the base case in // the parents. if ((traversalInfo.LastQueryNode() != NULL) && (traversalInfo.LastReferenceNode() != NULL) && (traversalInfo.LastQueryNode()->Point(0) == queryNode.Point(0)) && (traversalInfo.LastReferenceNode()->Point(0) == referenceNode.Point(0))) { // Base case already done. kernelEval = traversalInfo.LastBaseCase(); // When BaseCase() is called after Score(), these must be correct so that // another kernel evaluation is not performed. lastQueryIndex = queryNode.Point(0); lastReferenceIndex = referenceNode.Point(0); } else { // The kernel must be evaluated, but it is between points in the dataset, // so we can call BaseCase(). BaseCase() will set lastQueryIndex and // lastReferenceIndex correctly. kernelEval = BaseCase(queryNode.Point(0), referenceNode.Point(0)); } traversalInfo.LastBaseCase() = kernelEval; } else { // Calculate the maximum possible kernel value. arma::vec queryCentroid; arma::vec refCentroid; queryNode.Centroid(queryCentroid); referenceNode.Centroid(refCentroid); kernelEval = kernel.Evaluate(queryCentroid, refCentroid); traversalInfo.LastBaseCase() = kernelEval; } ++scores; double maxKernel; if (kernel::KernelTraits<KernelType>::IsNormalized) { // We have a tighter bound for normalized kernels. const double querySqDist = std::pow(queryDescDist, 2.0); const double refSqDist = std::pow(refDescDist, 2.0); const double bothSqDist = std::pow((queryDescDist + refDescDist), 2.0); if (kernelEval <= (1 - 0.5 * bothSqDist)) { const double queryDelta = (1 - 0.5 * querySqDist); const double queryGamma = queryDescDist * sqrt(1 - 0.25 * querySqDist); const double refDelta = (1 - 0.5 * refSqDist); const double refGamma = refDescDist * sqrt(1 - 0.25 * refSqDist); maxKernel = kernelEval * (queryDelta * refDelta - queryGamma * refGamma) + sqrt(1 - std::pow(kernelEval, 2.0)) * (queryGamma * refDelta + queryDelta * refGamma); } else { maxKernel = 1.0; } } else { // Use standard bound; kernel is not normalized. const double refKernelTerm = queryDescDist * referenceNode.Stat().SelfKernel(); const double queryKernelTerm = refDescDist * queryNode.Stat().SelfKernel(); maxKernel = kernelEval + refKernelTerm + queryKernelTerm + (queryDescDist * refDescDist); } // Store relevant information for parent-child pruning. traversalInfo.LastQueryNode() = &queryNode; traversalInfo.LastReferenceNode() = &referenceNode; // We return the inverse of the maximum kernel so that larger kernels are // recursed into first. return (maxKernel > bestKernel) ? (1.0 / maxKernel) : DBL_MAX; }
inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Score( TreeType& queryNode, TreeType& referenceNode) { ++scores; // Count number of Score() calls. // Update our bound. const double bestDistance = CalculateBound(queryNode); // Use the traversal info to see if a parent-child or parent-parent prune is // possible. This is a looser bound than we could make, but it might be // sufficient. const double queryParentDist = queryNode.ParentDistance(); const double queryDescDist = queryNode.FurthestDescendantDistance(); const double refParentDist = referenceNode.ParentDistance(); const double refDescDist = referenceNode.FurthestDescendantDistance(); const double score = traversalInfo.LastScore(); double adjustedScore; // We want to set adjustedScore to be the distance between the centroid of the // last query node and last reference node. We will do this by adjusting the // last score. In some cases, we can just use the last base case. if (tree::TreeTraits<TreeType>::FirstPointIsCentroid) { adjustedScore = traversalInfo.LastBaseCase(); } else if (score == 0.0) // Nothing we can do here. { adjustedScore = 0.0; } else { // The last score is equal to the distance between the centroids minus the // radii of the query and reference bounds along the axis of the line // between the two centroids. In the best case, these radii are the // furthest descendant distances, but that is not always true. It would // take too long to calculate the exact radii, so we are forced to use // MinimumBoundDistance() as a lower-bound approximation. const double lastQueryDescDist = traversalInfo.LastQueryNode()->MinimumBoundDistance(); const double lastRefDescDist = traversalInfo.LastReferenceNode()->MinimumBoundDistance(); adjustedScore = SortPolicy::CombineWorst(score, lastQueryDescDist); adjustedScore = SortPolicy::CombineWorst(score, lastRefDescDist); } // Assemble an adjusted score. For nearest neighbor search, this adjusted // score is a lower bound on MinDistance(queryNode, referenceNode) that is // assembled without actually calculating MinDistance(). For furthest // neighbor search, it is an upper bound on // MaxDistance(queryNode, referenceNode). If the traversalInfo isn't usable // then the node should not be pruned by this. if (traversalInfo.LastQueryNode() == queryNode.Parent()) { const double queryAdjust = queryParentDist + queryDescDist; adjustedScore = SortPolicy::CombineBest(adjustedScore, queryAdjust); } else if (traversalInfo.LastQueryNode() == &queryNode) { adjustedScore = SortPolicy::CombineBest(adjustedScore, queryDescDist); } else { // The last query node wasn't this query node or its parent. So we force // the adjustedScore to be such that this combination can't be pruned here, // because we don't really know anything about it. // It would be possible to modify this section to try and make a prune based // on the query descendant distance and the distance between the query node // and last traversal query node, but this case doesn't actually happen for // kd-trees or cover trees. adjustedScore = SortPolicy::BestDistance(); } if (traversalInfo.LastReferenceNode() == referenceNode.Parent()) { const double refAdjust = refParentDist + refDescDist; adjustedScore = SortPolicy::CombineBest(adjustedScore, refAdjust); } else if (traversalInfo.LastReferenceNode() == &referenceNode) { adjustedScore = SortPolicy::CombineBest(adjustedScore, refDescDist); } else { // The last reference node wasn't this reference node or its parent. So we // force the adjustedScore to be such that this combination can't be pruned // here, because we don't really know anything about it. // It would be possible to modify this section to try and make a prune based // on the reference descendant distance and the distance between the // reference node and last traversal reference node, but this case doesn't // actually happen for kd-trees or cover trees. adjustedScore = SortPolicy::BestDistance(); } // Can we prune? if (SortPolicy::IsBetter(bestDistance, adjustedScore)) { if (!(tree::TreeTraits<TreeType>::FirstPointIsCentroid && score == 0.0)) { // There isn't any need to set the traversal information because no // descendant combinations will be visited, and those are the only // combinations that would depend on the traversal information. return DBL_MAX; } } double distance; if (tree::TreeTraits<TreeType>::FirstPointIsCentroid) { // The first point in the node is the centroid, so we can calculate the // distance between the two points using BaseCase() and then find the // bounds. This is potentially loose for non-ball bounds. double baseCase = -1.0; if (tree::TreeTraits<TreeType>::HasSelfChildren && (traversalInfo.LastQueryNode()->Point(0) == queryNode.Point(0)) && (traversalInfo.LastReferenceNode()->Point(0) == referenceNode.Point(0))) { // We already calculated it. baseCase = traversalInfo.LastBaseCase(); } else { baseCase = BaseCase(queryNode.Point(0), referenceNode.Point(0)); } distance = SortPolicy::CombineBest(baseCase, queryNode.FurthestDescendantDistance() + referenceNode.FurthestDescendantDistance()); lastQueryIndex = queryNode.Point(0); lastReferenceIndex = referenceNode.Point(0); lastBaseCase = baseCase; traversalInfo.LastBaseCase() = baseCase; } else { distance = SortPolicy::BestNodeToNodeDistance(&queryNode, &referenceNode); } if (SortPolicy::IsBetter(distance, bestDistance)) { // Set traversal information. traversalInfo.LastQueryNode() = &queryNode; traversalInfo.LastReferenceNode() = &referenceNode; traversalInfo.LastScore() = distance; return distance; } else { // There isn't any need to set the traversal information because no // descendant combinations will be visited, and those are the only // combinations that would depend on the traversal information. return DBL_MAX; } }
double FastMKSRules<KernelType, TreeType>::Score(const size_t queryIndex, TreeType& referenceNode) { // Compare with the current best. const double bestKernel = products(products.n_rows - 1, queryIndex); // See if we can perform a parent-child prune. const double furthestDist = referenceNode.FurthestDescendantDistance(); if (referenceNode.Parent() != NULL) { double maxKernelBound; const double parentDist = referenceNode.ParentDistance(); const double combinedDistBound = parentDist + furthestDist; const double lastKernel = referenceNode.Parent()->Stat().LastKernel(); if (kernel::KernelTraits<KernelType>::IsNormalized) { const double squaredDist = std::pow(combinedDistBound, 2.0); const double delta = (1 - 0.5 * squaredDist); if (lastKernel <= delta) { const double gamma = combinedDistBound * sqrt(1 - 0.25 * squaredDist); maxKernelBound = lastKernel * delta + gamma * sqrt(1 - std::pow(lastKernel, 2.0)); } else { maxKernelBound = 1.0; } } else { maxKernelBound = lastKernel + combinedDistBound * queryKernels[queryIndex]; } if (maxKernelBound < bestKernel) return DBL_MAX; } // Calculate the maximum possible kernel value, either by calculating the // centroid or, if the centroid is a point, use that. ++scores; double kernelEval; if (tree::TreeTraits<TreeType>::FirstPointIsCentroid) { // Could it be that this kernel evaluation has already been calculated? if (tree::TreeTraits<TreeType>::HasSelfChildren && referenceNode.Parent() != NULL && referenceNode.Point(0) == referenceNode.Parent()->Point(0)) { kernelEval = referenceNode.Parent()->Stat().LastKernel(); } else { kernelEval = BaseCase(queryIndex, referenceNode.Point(0)); } } else { const arma::vec queryPoint = querySet.unsafe_col(queryIndex); arma::vec refCentroid; referenceNode.Centroid(refCentroid); kernelEval = kernel.Evaluate(queryPoint, refCentroid); } referenceNode.Stat().LastKernel() = kernelEval; double maxKernel; if (kernel::KernelTraits<KernelType>::IsNormalized) { const double squaredDist = std::pow(furthestDist, 2.0); const double delta = (1 - 0.5 * squaredDist); if (kernelEval <= delta) { const double gamma = furthestDist * sqrt(1 - 0.25 * squaredDist); maxKernel = kernelEval * delta + gamma * sqrt(1 - std::pow(kernelEval, 2.0)); } else { maxKernel = 1.0; } } else { maxKernel = kernelEval + furthestDist * queryKernels[queryIndex]; } // We return the inverse of the maximum kernel so that larger kernels are // recursed into first. return (maxKernel > bestKernel) ? (1.0 / maxKernel) : DBL_MAX; }