FastMKSStat(const TreeType& node) : bound(-DBL_MAX), lastKernel(0.0), lastKernelNode(NULL) { // Do we have to calculate the centroid? if (tree::TreeTraits<TreeType>::FirstPointIsCentroid) { // If this type of tree has self-children, then maybe the evaluation is // already done. These statistics are built bottom-up, so the child stat // should already be done. if ((tree::TreeTraits<TreeType>::HasSelfChildren) && (node.NumChildren() > 0) && (node.Point(0) == node.Child(0).Point(0))) { selfKernel = node.Child(0).Stat().SelfKernel(); } else { selfKernel = sqrt(node.Metric().Kernel().Evaluate( node.Dataset().col(node.Point(0)), node.Dataset().col(node.Point(0)))); } } else { // Calculate the centroid. arma::vec centroid; node.Centroid(centroid); selfKernel = sqrt(node.Metric().Kernel().Evaluate(centroid, centroid)); } }
double FastMKSRules<KernelType, TreeType>::Score(TreeType& queryNode, TreeType& referenceNode) { // Update and get the query node's bound. queryNode.Stat().Bound() = CalculateBound(queryNode); const double bestKernel = queryNode.Stat().Bound(); // First, see if we can make a parent-child or parent-parent prune. These // four bounds on the maximum kernel value are looser than the bound normally // used, but they can prevent a base case from needing to be calculated. // Convenience caching so lines are shorter. const double queryParentDist = queryNode.ParentDistance(); const double queryDescDist = queryNode.FurthestDescendantDistance(); const double refParentDist = referenceNode.ParentDistance(); const double refDescDist = referenceNode.FurthestDescendantDistance(); double adjustedScore = traversalInfo.LastBaseCase(); const double queryDistBound = (queryParentDist + queryDescDist); const double refDistBound = (refParentDist + refDescDist); double dualQueryTerm; double dualRefTerm; // The parent-child and parent-parent prunes work by applying the same pruning // condition as when the parent node was used, except they are tighter because // queryDistBound < queryNode.Parent()->FurthestDescendantDistance() // and // refDistBound < referenceNode.Parent()->FurthestDescendantDistance() // so we construct the same bounds that were used when Score() was called with // the parents, except with the tighter distance bounds. Sometimes this // allows us to prune nodes without evaluating the base cases between them. if (traversalInfo.LastQueryNode() == queryNode.Parent()) { // We can assume that queryNode.Parent() != NULL, because at the root node // combination, the traversalInfo.LastQueryNode() pointer will _not_ be // NULL. We also should be guaranteed that // traversalInfo.LastReferenceNode() is either the reference node or the // parent of the reference node. adjustedScore += queryDistBound * traversalInfo.LastReferenceNode()->Stat().SelfKernel(); dualQueryTerm = queryDistBound; } else { // The query parent could be NULL, which does weird things and we have to // consider. if (traversalInfo.LastReferenceNode() != NULL) { adjustedScore += queryDescDist * traversalInfo.LastReferenceNode()->Stat().SelfKernel(); dualQueryTerm = queryDescDist; } else { // This makes it so a child-parent (or parent-parent) prune is not // possible. dualQueryTerm = 0.0; adjustedScore = bestKernel; } } if (traversalInfo.LastReferenceNode() == referenceNode.Parent()) { // We can assume that referenceNode.Parent() != NULL, because at the root // node combination, the traversalInfo.LastReferenceNode() pointer will // _not_ be NULL. adjustedScore += refDistBound * traversalInfo.LastQueryNode()->Stat().SelfKernel(); dualRefTerm = refDistBound; } else { // The reference parent could be NULL, which does weird things and we have // to consider. if (traversalInfo.LastQueryNode() != NULL) { adjustedScore += refDescDist * traversalInfo.LastQueryNode()->Stat().SelfKernel(); dualRefTerm = refDescDist; } else { // This makes it so a child-parent (or parent-parent) prune is not // possible. dualRefTerm = 0.0; adjustedScore = bestKernel; } } // Now add the dual term. adjustedScore += (dualQueryTerm * dualRefTerm); if (adjustedScore < bestKernel) { // It is not possible that this node combination can contain a point // combination with kernel value better than the minimum kernel value to // improve any of the results, so we can prune it. return DBL_MAX; } // We were unable to perform a parent-child or parent-parent prune, so now we // must calculate kernel evaluation, if necessary. double kernelEval = 0.0; if (tree::TreeTraits<TreeType>::FirstPointIsCentroid) { // For this type of tree, we may have already calculated the base case in // the parents. if ((traversalInfo.LastQueryNode() != NULL) && (traversalInfo.LastReferenceNode() != NULL) && (traversalInfo.LastQueryNode()->Point(0) == queryNode.Point(0)) && (traversalInfo.LastReferenceNode()->Point(0) == referenceNode.Point(0))) { // Base case already done. kernelEval = traversalInfo.LastBaseCase(); // When BaseCase() is called after Score(), these must be correct so that // another kernel evaluation is not performed. lastQueryIndex = queryNode.Point(0); lastReferenceIndex = referenceNode.Point(0); } else { // The kernel must be evaluated, but it is between points in the dataset, // so we can call BaseCase(). BaseCase() will set lastQueryIndex and // lastReferenceIndex correctly. kernelEval = BaseCase(queryNode.Point(0), referenceNode.Point(0)); } traversalInfo.LastBaseCase() = kernelEval; } else { // Calculate the maximum possible kernel value. arma::vec queryCentroid; arma::vec refCentroid; queryNode.Centroid(queryCentroid); referenceNode.Centroid(refCentroid); kernelEval = kernel.Evaluate(queryCentroid, refCentroid); traversalInfo.LastBaseCase() = kernelEval; } ++scores; double maxKernel; if (kernel::KernelTraits<KernelType>::IsNormalized) { // We have a tighter bound for normalized kernels. const double querySqDist = std::pow(queryDescDist, 2.0); const double refSqDist = std::pow(refDescDist, 2.0); const double bothSqDist = std::pow((queryDescDist + refDescDist), 2.0); if (kernelEval <= (1 - 0.5 * bothSqDist)) { const double queryDelta = (1 - 0.5 * querySqDist); const double queryGamma = queryDescDist * sqrt(1 - 0.25 * querySqDist); const double refDelta = (1 - 0.5 * refSqDist); const double refGamma = refDescDist * sqrt(1 - 0.25 * refSqDist); maxKernel = kernelEval * (queryDelta * refDelta - queryGamma * refGamma) + sqrt(1 - std::pow(kernelEval, 2.0)) * (queryGamma * refDelta + queryDelta * refGamma); } else { maxKernel = 1.0; } } else { // Use standard bound; kernel is not normalized. const double refKernelTerm = queryDescDist * referenceNode.Stat().SelfKernel(); const double queryKernelTerm = refDescDist * queryNode.Stat().SelfKernel(); maxKernel = kernelEval + refKernelTerm + queryKernelTerm + (queryDescDist * refDescDist); } // Store relevant information for parent-child pruning. traversalInfo.LastQueryNode() = &queryNode; traversalInfo.LastReferenceNode() = &referenceNode; // We return the inverse of the maximum kernel so that larger kernels are // recursed into first. return (maxKernel > bestKernel) ? (1.0 / maxKernel) : DBL_MAX; }
double FastMKSRules<KernelType, TreeType>::Score(const size_t queryIndex, TreeType& referenceNode) { // Compare with the current best. const double bestKernel = products(products.n_rows - 1, queryIndex); // See if we can perform a parent-child prune. const double furthestDist = referenceNode.FurthestDescendantDistance(); if (referenceNode.Parent() != NULL) { double maxKernelBound; const double parentDist = referenceNode.ParentDistance(); const double combinedDistBound = parentDist + furthestDist; const double lastKernel = referenceNode.Parent()->Stat().LastKernel(); if (kernel::KernelTraits<KernelType>::IsNormalized) { const double squaredDist = std::pow(combinedDistBound, 2.0); const double delta = (1 - 0.5 * squaredDist); if (lastKernel <= delta) { const double gamma = combinedDistBound * sqrt(1 - 0.25 * squaredDist); maxKernelBound = lastKernel * delta + gamma * sqrt(1 - std::pow(lastKernel, 2.0)); } else { maxKernelBound = 1.0; } } else { maxKernelBound = lastKernel + combinedDistBound * queryKernels[queryIndex]; } if (maxKernelBound < bestKernel) return DBL_MAX; } // Calculate the maximum possible kernel value, either by calculating the // centroid or, if the centroid is a point, use that. ++scores; double kernelEval; if (tree::TreeTraits<TreeType>::FirstPointIsCentroid) { // Could it be that this kernel evaluation has already been calculated? if (tree::TreeTraits<TreeType>::HasSelfChildren && referenceNode.Parent() != NULL && referenceNode.Point(0) == referenceNode.Parent()->Point(0)) { kernelEval = referenceNode.Parent()->Stat().LastKernel(); } else { kernelEval = BaseCase(queryIndex, referenceNode.Point(0)); } } else { const arma::vec queryPoint = querySet.unsafe_col(queryIndex); arma::vec refCentroid; referenceNode.Centroid(refCentroid); kernelEval = kernel.Evaluate(queryPoint, refCentroid); } referenceNode.Stat().LastKernel() = kernelEval; double maxKernel; if (kernel::KernelTraits<KernelType>::IsNormalized) { const double squaredDist = std::pow(furthestDist, 2.0); const double delta = (1 - 0.5 * squaredDist); if (kernelEval <= delta) { const double gamma = furthestDist * sqrt(1 - 0.25 * squaredDist); maxKernel = kernelEval * delta + gamma * sqrt(1 - std::pow(kernelEval, 2.0)); } else { maxKernel = 1.0; } } else { maxKernel = kernelEval + furthestDist * queryKernels[queryIndex]; } // We return the inverse of the maximum kernel so that larger kernels are // recursed into first. return (maxKernel > bestKernel) ? (1.0 / maxKernel) : DBL_MAX; }