void NeighborSearchRules< SortPolicy, MetricType, TreeType>:: UpdateAfterRecursion(TreeType& queryNode, TreeType& /* referenceNode */) { // Find the worst distance that the children found (including any points), and // update the bound accordingly. double worstDistance = SortPolicy::BestDistance(); // First look through children nodes. for (size_t i = 0; i < queryNode.NumChildren(); ++i) { if (SortPolicy::IsBetter(worstDistance, queryNode.Child(i).Stat().Bound())) worstDistance = queryNode.Child(i).Stat().Bound(); } // Now look through children points. for (size_t i = 0; i < queryNode.NumPoints(); ++i) { if (SortPolicy::IsBetter(worstDistance, distances(distances.n_rows - 1, queryNode.Point(i)))) worstDistance = distances(distances.n_rows - 1, queryNode.Point(i)); } // Take the worst distance from all of these, and update our bound to reflect // that. queryNode.Stat().Bound() = worstDistance; }
FastMKSStat(const TreeType& node) : bound(-DBL_MAX), lastKernel(0.0), lastKernelNode(NULL) { // Do we have to calculate the centroid? if (tree::TreeTraits<TreeType>::FirstPointIsCentroid) { // If this type of tree has self-children, then maybe the evaluation is // already done. These statistics are built bottom-up, so the child stat // should already be done. if ((tree::TreeTraits<TreeType>::HasSelfChildren) && (node.NumChildren() > 0) && (node.Point(0) == node.Child(0).Point(0))) { selfKernel = node.Child(0).Stat().SelfKernel(); } else { selfKernel = sqrt(node.Metric().Kernel().Evaluate( node.Dataset().col(node.Point(0)), node.Dataset().col(node.Point(0)))); } } else { // Calculate the centroid. arma::vec center; node.Center(center); selfKernel = sqrt(node.Metric().Kernel().Evaluate(center, center)); } }
DualTreeKMeansStatistic(TreeType& node) : neighbor::NeighborSearchStat<neighbor::NearestNeighborSort>(), upperBound(DBL_MAX), lowerBound(DBL_MAX), owner(size_t(-1)), pruned(size_t(-1)), staticPruned(false), staticUpperBoundMovement(0.0), staticLowerBoundMovement(0.0), trueParent(node.Parent()) { // Empirically calculate the centroid. centroid.zeros(node.Dataset().n_rows); for (size_t i = 0; i < node.NumPoints(); ++i) { // Correct handling of cover tree: don't double-count the point which // appears in the children. if (tree::TreeTraits<TreeType>::HasSelfChildren && i == 0 && node.NumChildren() > 0) continue; centroid += node.Dataset().col(node.Point(i)); } for (size_t i = 0; i < node.NumChildren(); ++i) centroid += node.Child(i).NumDescendants() * node.Child(i).Stat().Centroid(); centroid /= node.NumDescendants(); // Set the true children correctly. trueChildren.resize(node.NumChildren()); for (size_t i = 0; i < node.NumChildren(); ++i) trueChildren[i] = &node.Child(i); }
void CheckHierarchy(const TreeType& tree) { for (size_t i = 0; i < tree.NumChildren(); i++) { BOOST_REQUIRE_EQUAL(&tree, tree.Child(i).Parent()); CheckHierarchy(tree.Child(i)); } }
void CleanTree(TreeType& node) { node.Stat().LastDistance() = 0.0; for (size_t i = 0; i < node.NumChildren(); ++i) CleanTree(node.Child(i)); }
inline double DTBRules<MetricType, TreeType>::CalculateBound( TreeType& queryNode) const { double worstPointBound = -DBL_MAX; double bestPointBound = DBL_MAX; double worstChildBound = -DBL_MAX; double bestChildBound = DBL_MAX; // Now, find the best and worst point bounds. for (size_t i = 0; i < queryNode.NumPoints(); ++i) { const size_t pointComponent = connections.Find(queryNode.Point(i)); const double bound = neighborsDistances[pointComponent]; if (bound > worstPointBound) worstPointBound = bound; if (bound < bestPointBound) bestPointBound = bound; } // Find the best and worst child bounds. for (size_t i = 0; i < queryNode.NumChildren(); ++i) { const double maxBound = queryNode.Child(i).Stat().MaxNeighborDistance(); if (maxBound > worstChildBound) worstChildBound = maxBound; const double minBound = queryNode.Child(i).Stat().MinNeighborDistance(); if (minBound < bestChildBound) bestChildBound = minBound; } // Now calculate the actual bounds. const double worstBound = std::max(worstPointBound, worstChildBound); const double bestBound = std::min(bestPointBound, bestChildBound); // We must check that bestBound != DBL_MAX; otherwise, we risk overflow. const double bestAdjustedBound = (bestBound == DBL_MAX) ? DBL_MAX : bestBound + 2 * queryNode.FurthestDescendantDistance(); // Update the relevant quantities in the node. queryNode.Stat().MaxNeighborDistance() = worstBound; queryNode.Stat().MinNeighborDistance() = bestBound; queryNode.Stat().Bound() = std::min(worstBound, bestAdjustedBound); return queryNode.Stat().Bound(); }
DualTreeKMeansStatistic(TreeType& node) : closestQueryNode(NULL), minQueryNodeDistance(DBL_MAX), maxQueryNodeDistance(DBL_MAX), clustersPruned(0), iteration(size_t() - 1) { // Empirically calculate the centroid. centroid.zeros(node.Dataset().n_rows); for (size_t i = 0; i < node.NumPoints(); ++i) centroid += node.Dataset().col(node.Point(i)); for (size_t i = 0; i < node.NumChildren(); ++i) centroid += node.Child(i).NumDescendants() * node.Child(i).Stat().Centroid(); centroid /= node.NumDescendants(); }
void CheckExactContainment(const TreeType& tree) { if (tree.NumChildren() == 0) { for (size_t i = 0; i < tree.Bound().Dim(); i++) { double min = DBL_MAX; double max = -1.0 * DBL_MAX; for(size_t j = 0; j < tree.Count(); j++) { if (tree.LocalDataset().col(j)[i] < min) min = tree.LocalDataset().col(j)[i]; if (tree.LocalDataset().col(j)[i] > max) max = tree.LocalDataset().col(j)[i]; } BOOST_REQUIRE_EQUAL(max, tree.Bound()[i].Hi()); BOOST_REQUIRE_EQUAL(min, tree.Bound()[i].Lo()); } } else { for (size_t i = 0; i < tree.Bound().Dim(); i++) { double min = DBL_MAX; double max = -1.0 * DBL_MAX; for (size_t j = 0; j < tree.NumChildren(); j++) { if(tree.Child(j).Bound()[i].Lo() < min) min = tree.Child(j).Bound()[i].Lo(); if(tree.Child(j).Bound()[i].Hi() > max) max = tree.Child(j).Bound()[i].Hi(); } BOOST_REQUIRE_EQUAL(max, tree.Bound()[i].Hi()); BOOST_REQUIRE_EQUAL(min, tree.Bound()[i].Lo()); } for (size_t i = 0; i < tree.NumChildren(); i++) CheckExactContainment(tree.Child(i)); } }
PellegMooreKMeansStatistic(TreeType& node) { centroid.zeros(node.Dataset().n_rows); // Hope it's a depth-first build procedure. Also, this won't work right for // trees that have self-children or stuff like that. for (size_t i = 0; i < node.NumChildren(); ++i) { centroid += node.Child(i).NumDescendants() * node.Child(i).Stat().Centroid(); } for (size_t i = 0; i < node.NumPoints(); ++i) { centroid += node.Dataset().col(node.Point(i)); } if (node.NumDescendants() > 0) centroid /= node.NumDescendants(); else centroid.fill(DBL_MAX); // Invalid centroid. What else can we do? }
void GreedySingleTreeTraverser<TreeType, RuleType>::Traverse( const size_t queryIndex, TreeType& referenceNode) { // Run the base case as necessary for all the points in the reference node. for (size_t i = 0; i < referenceNode.NumPoints(); ++i) rule.BaseCase(queryIndex, referenceNode.Point(i)); size_t bestChild = rule.GetBestChild(queryIndex, referenceNode); size_t numDescendants; // Check that referencenode is not a leaf node while calculating number of // descendants of it's best child. if (!referenceNode.IsLeaf()) numDescendants = referenceNode.Child(bestChild).NumDescendants(); else numDescendants = referenceNode.NumPoints(); // If number of descendants are more than minBaseCases than we can go along // with best child otherwise we need to traverse for each descendant to // ensure that we calculate at least minBaseCases number of base cases. if (!referenceNode.IsLeaf()) { if (numDescendants > minBaseCases) { // We are prunning all but one child. numPrunes += referenceNode.NumChildren() - 1; // Recurse the best child. Traverse(queryIndex, referenceNode.Child(bestChild)); } else { // Run the base case over first minBaseCases number of descendants. for (size_t i = 0; i <= minBaseCases; ++i) rule.BaseCase(queryIndex, referenceNode.Descendant(i)); } } }
void CheckTrees(TreeType& tree, TreeType& xmlTree, TreeType& textTree, TreeType& binaryTree) { const typename TreeType::Mat* dataset = &tree.Dataset(); // Make sure that the data matrices are the same. if (tree.Parent() == NULL) { CheckMatrices(*dataset, xmlTree.Dataset(), textTree.Dataset(), binaryTree.Dataset()); // Also ensure that the other parents are null too. BOOST_REQUIRE_EQUAL(xmlTree.Parent(), (TreeType*) NULL); BOOST_REQUIRE_EQUAL(textTree.Parent(), (TreeType*) NULL); BOOST_REQUIRE_EQUAL(binaryTree.Parent(), (TreeType*) NULL); } // Make sure the number of children is the same. BOOST_REQUIRE_EQUAL(tree.NumChildren(), xmlTree.NumChildren()); BOOST_REQUIRE_EQUAL(tree.NumChildren(), textTree.NumChildren()); BOOST_REQUIRE_EQUAL(tree.NumChildren(), binaryTree.NumChildren()); // Make sure the number of descendants is the same. BOOST_REQUIRE_EQUAL(tree.NumDescendants(), xmlTree.NumDescendants()); BOOST_REQUIRE_EQUAL(tree.NumDescendants(), textTree.NumDescendants()); BOOST_REQUIRE_EQUAL(tree.NumDescendants(), binaryTree.NumDescendants()); // Make sure the number of points is the same. BOOST_REQUIRE_EQUAL(tree.NumPoints(), xmlTree.NumPoints()); BOOST_REQUIRE_EQUAL(tree.NumPoints(), textTree.NumPoints()); BOOST_REQUIRE_EQUAL(tree.NumPoints(), binaryTree.NumPoints()); // Check that each point is the same. for (size_t i = 0; i < tree.NumPoints(); ++i) { BOOST_REQUIRE_EQUAL(tree.Point(i), xmlTree.Point(i)); BOOST_REQUIRE_EQUAL(tree.Point(i), textTree.Point(i)); BOOST_REQUIRE_EQUAL(tree.Point(i), binaryTree.Point(i)); } // Check that the parent distance is the same. BOOST_REQUIRE_CLOSE(tree.ParentDistance(), xmlTree.ParentDistance(), 1e-8); BOOST_REQUIRE_CLOSE(tree.ParentDistance(), textTree.ParentDistance(), 1e-8); BOOST_REQUIRE_CLOSE(tree.ParentDistance(), binaryTree.ParentDistance(), 1e-8); // Check that the furthest descendant distance is the same. BOOST_REQUIRE_CLOSE(tree.FurthestDescendantDistance(), xmlTree.FurthestDescendantDistance(), 1e-8); BOOST_REQUIRE_CLOSE(tree.FurthestDescendantDistance(), textTree.FurthestDescendantDistance(), 1e-8); BOOST_REQUIRE_CLOSE(tree.FurthestDescendantDistance(), binaryTree.FurthestDescendantDistance(), 1e-8); // Check that the minimum bound distance is the same. BOOST_REQUIRE_CLOSE(tree.MinimumBoundDistance(), xmlTree.MinimumBoundDistance(), 1e-8); BOOST_REQUIRE_CLOSE(tree.MinimumBoundDistance(), textTree.MinimumBoundDistance(), 1e-8); BOOST_REQUIRE_CLOSE(tree.MinimumBoundDistance(), binaryTree.MinimumBoundDistance(), 1e-8); // Recurse into the children. for (size_t i = 0; i < tree.NumChildren(); ++i) { // Check that the child dataset is the same. BOOST_REQUIRE_EQUAL(&xmlTree.Dataset(), &xmlTree.Child(i).Dataset()); BOOST_REQUIRE_EQUAL(&textTree.Dataset(), &textTree.Child(i).Dataset()); BOOST_REQUIRE_EQUAL(&binaryTree.Dataset(), &binaryTree.Child(i).Dataset()); // Make sure the parent link is right. BOOST_REQUIRE_EQUAL(xmlTree.Child(i).Parent(), &xmlTree); BOOST_REQUIRE_EQUAL(textTree.Child(i).Parent(), &textTree); BOOST_REQUIRE_EQUAL(binaryTree.Child(i).Parent(), &binaryTree); CheckTrees(tree.Child(i), xmlTree.Child(i), textTree.Child(i), binaryTree.Child(i)); } }
inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>:: CalculateBound(TreeType& queryNode) const { // We have five possible bounds, and we must take the best of them all. We // don't use min/max here, but instead "best/worst", because this is general // to the nearest-neighbors/furthest-neighbors cases. For nearest neighbors, // min = best, max = worst. // // (1) worst ( worst_{all points p in queryNode} D_p[k], // worst_{all children c in queryNode} B(c) ); // (2) best_{all points p in queryNode} D_p[k] + worst child distance + // worst descendant distance; // (3) best_{all children c in queryNode} B(c) + // 2 ( worst descendant distance of queryNode - // worst descendant distance of c ); // (4) B_1(parent of queryNode) // (5) B_2(parent of queryNode); // // D_p[k] is the current k'th candidate distance for point p. // So we will loop over the points in queryNode and the children in queryNode // to calculate all five of these quantities. // Hm, can we populate our distances vector with estimates from the parent? // This is written specifically for the cover tree and assumes only one point // in a node. // if (queryNode.Parent() != NULL && queryNode.NumPoints() > 0) // { // size_t parentIndexStart = 0; // for (size_t i = 0; i < neighbors.n_rows; ++i) // { // const double pointDistance = distances(i, queryNode.Point(0)); // if (pointDistance == DBL_MAX) // { // // Cool, can we take an estimate from the parent? // const double parentWorstBound = distances(distances.n_rows - 1, // queryNode.Parent()->Point(0)); // if (parentWorstBound != DBL_MAX) // { // const double parentAdjustedDistance = parentWorstBound + // queryNode.ParentDistance(); // distances(i, queryNode.Point(0)) = parentAdjustedDistance; // } // } // } // } double worstPointDistance = SortPolicy::BestDistance(); double bestPointDistance = SortPolicy::WorstDistance(); // Loop over all points in this node to find the best and worst distance // candidates (for (1) and (2)). for (size_t i = 0; i < queryNode.NumPoints(); ++i) { const double distance = distances(distances.n_rows - 1, queryNode.Point(i)); if (SortPolicy::IsBetter(distance, bestPointDistance)) bestPointDistance = distance; if (SortPolicy::IsBetter(worstPointDistance, distance)) worstPointDistance = distance; } // Loop over all the children in this node to find the worst bound (for (1)) // and the best bound with the correcting factor for descendant distances (for // (3)). double worstChildBound = SortPolicy::BestDistance(); double bestAdjustedChildBound = SortPolicy::WorstDistance(); const double queryMaxDescendantDistance = queryNode.FurthestDescendantDistance(); for (size_t i = 0; i < queryNode.NumChildren(); ++i) { const double firstBound = queryNode.Child(i).Stat().FirstBound(); const double secondBound = queryNode.Child(i).Stat().SecondBound(); const double childMaxDescendantDistance = queryNode.Child(i).FurthestDescendantDistance(); if (SortPolicy::IsBetter(worstChildBound, firstBound)) worstChildBound = firstBound; // Now calculate adjustment for maximum descendant distances. const double adjustedBound = SortPolicy::CombineWorst(secondBound, 2 * (queryMaxDescendantDistance - childMaxDescendantDistance)); if (SortPolicy::IsBetter(adjustedBound, bestAdjustedChildBound)) bestAdjustedChildBound = adjustedBound; } // This is bound (1). const double firstBound = (SortPolicy::IsBetter(worstPointDistance, worstChildBound)) ? worstChildBound : worstPointDistance; // This is bound (2). const double secondBound = SortPolicy::CombineWorst( SortPolicy::CombineWorst(bestPointDistance, queryMaxDescendantDistance), queryNode.FurthestPointDistance()); // Bound (3) is bestAdjustedChildBound. // Bounds (4) and (5) are the parent bounds. const double fourthBound = (queryNode.Parent() != NULL) ? queryNode.Parent()->Stat().FirstBound() : SortPolicy::WorstDistance(); // const double fifthBound = (queryNode.Parent() != NULL) ? // queryNode.Parent()->Stat().SecondBound() - // queryNode.Parent()->FurthestDescendantDistance() - // queryNode.Parent()->FurthestPointDistance() + queryMaxDescendantDistance + // queryNode.FurthestPointDistance() + queryNode.ParentDistance() : // SortPolicy::WorstDistance(); // Now, we will take the best of these. Unfortunately due to the way // IsBetter() is defined, this sort of has to be a little ugly. // The variable interA represents the first bound (B_1), which is the worst // candidate distance of any descendants of this node. // The variable interC represents the second bound (B_2), which is a bound on // the worst distance of any descendants of this node assembled using the best // descendant candidate distance modified using the furthest descendant // distance. const double interA = (SortPolicy::IsBetter(firstBound, fourthBound)) ? firstBound : fourthBound; const double interB = (SortPolicy::IsBetter(bestAdjustedChildBound, secondBound)) ? bestAdjustedChildBound : secondBound; // const double interC = (SortPolicy::IsBetter(interB, fifthBound)) ? interB : // fifthBound; // Update the first and second bounds of the node. queryNode.Stat().FirstBound() = interA; queryNode.Stat().SecondBound() = interB; // Update the actual bound of the node. queryNode.Stat().Bound() = (SortPolicy::IsBetter(interA, interB)) ? interB : interB; return queryNode.Stat().Bound(); }
double FastMKSRules<KernelType, TreeType>::CalculateBound(TreeType& queryNode) const { // We have four possible bounds -- just like NeighborSearchRules, but they are // slightly different in this context. // // (1) min ( min_{all points p in queryNode} P_p[k], // min_{all children c in queryNode} B(c) ); // (2) max_{all points p in queryNode} P_p[k] + (worst child distance + worst // descendant distance) sqrt(K(I_p[k], I_p[k])); // (3) max_{all children c in queryNode} B(c) + <-- not done yet. ignored. // (4) B(parent of queryNode); double worstPointKernel = DBL_MAX; double bestAdjustedPointKernel = -DBL_MAX; const double queryDescendantDistance = queryNode.FurthestDescendantDistance(); // Loop over all points in this node to find the best and worst. for (size_t i = 0; i < queryNode.NumPoints(); ++i) { const size_t point = queryNode.Point(i); if (products(products.n_rows - 1, point) < worstPointKernel) worstPointKernel = products(products.n_rows - 1, point); if (products(products.n_rows - 1, point) == -DBL_MAX) continue; // Avoid underflow. // This should be (queryDescendantDistance + centroidDistance) for any tree // but it works for cover trees since centroidDistance = 0 for cover trees. const double candidateKernel = products(products.n_rows - 1, point) - queryDescendantDistance * referenceKernels[indices(indices.n_rows - 1, point)]; if (candidateKernel > bestAdjustedPointKernel) bestAdjustedPointKernel = candidateKernel; } // Loop over all the children in the node. double worstChildKernel = DBL_MAX; for (size_t i = 0; i < queryNode.NumChildren(); ++i) { if (queryNode.Child(i).Stat().Bound() < worstChildKernel) worstChildKernel = queryNode.Child(i).Stat().Bound(); } // Now assemble bound (1). const double firstBound = (worstPointKernel < worstChildKernel) ? worstPointKernel : worstChildKernel; // Bound (2) is bestAdjustedPointKernel. const double fourthBound = (queryNode.Parent() == NULL) ? -DBL_MAX : queryNode.Parent()->Stat().Bound(); // Pick the best of these bounds. const double interA = (firstBound > bestAdjustedPointKernel) ? firstBound : bestAdjustedPointKernel; // const double interA = 0.0; const double interB = fourthBound; return (interA > interB) ? interA : interB; }