void GreedySingleTreeTraverser<TreeType, RuleType>::Traverse( const size_t queryIndex, TreeType& referenceNode) { // Run the base case as necessary for all the points in the reference node. for (size_t i = 0; i < referenceNode.NumPoints(); ++i) rule.BaseCase(queryIndex, referenceNode.Point(i)); size_t bestChild = rule.GetBestChild(queryIndex, referenceNode); size_t numDescendants; // Check that referencenode is not a leaf node while calculating number of // descendants of it's best child. if (!referenceNode.IsLeaf()) numDescendants = referenceNode.Child(bestChild).NumDescendants(); else numDescendants = referenceNode.NumPoints(); // If number of descendants are more than minBaseCases than we can go along // with best child otherwise we need to traverse for each descendant to // ensure that we calculate at least minBaseCases number of base cases. if (!referenceNode.IsLeaf()) { if (numDescendants > minBaseCases) { // We are prunning all but one child. numPrunes += referenceNode.NumChildren() - 1; // Recurse the best child. Traverse(queryIndex, referenceNode.Child(bestChild)); } else { // Run the base case over first minBaseCases number of descendants. for (size_t i = 0; i <= minBaseCases; ++i) rule.BaseCase(queryIndex, referenceNode.Descendant(i)); } } }
int GetMinLevel(const TreeType& tree) { int min = 1; if (!tree.IsLeaf()) { int m = INT_MAX; for (size_t i = 0; i < tree.NumChildren(); i++) { int n = GetMinLevel(*tree.Children()[i]); if (n < m) m = n; } min += m; } return min; }
int GetMaxLevel(const TreeType& tree) { int max = 1; if (!tree.IsLeaf()) { int m = 0; for (size_t i = 0; i < tree.NumChildren(); i++) { int n = GetMaxLevel(*tree.Children()[i]); if (n > m) m = n; } max += m; } return max; }
void CheckFills(const TreeType& tree) { if (tree.IsLeaf()) { BOOST_REQUIRE(tree.Count() >= tree.MinLeafSize() || tree.Parent() == NULL); BOOST_REQUIRE(tree.Count() <= tree.MaxLeafSize()); } else { for (size_t i = 0; i < tree.NumChildren(); i++) { BOOST_REQUIRE(tree.NumChildren() >= tree.MinNumChildren() || tree.Parent() == NULL); BOOST_REQUIRE(tree.NumChildren() <= tree.MaxNumChildren()); CheckFills(*tree.Children()[i]); } } }
void CheckSplit(const TreeType& tree) { typedef typename TreeType::ElemType ElemType; typedef typename std::conditional<sizeof(ElemType) * CHAR_BIT <= 32, uint32_t, uint64_t>::type AddressElemType; if (tree.IsLeaf()) return; arma::Col<AddressElemType> lo(tree.Bound().Dim()); arma::Col<AddressElemType> hi(tree.Bound().Dim()); lo.fill(std::numeric_limits<AddressElemType>::max()); hi.fill(0); arma::Col<AddressElemType> address(tree.Bound().Dim()); // Find the highest address of the left node. for (size_t i = 0; i < tree.Left()->NumDescendants(); i++) { addr::PointToAddress(address, tree.Dataset().col(tree.Left()->Descendant(i))); if (addr::CompareAddresses(address, hi) > 0) hi = address; } // Find the lowest address of the right node. for (size_t i = 0; i < tree.Right()->NumDescendants(); i++) { addr::PointToAddress(address, tree.Dataset().col(tree.Right()->Descendant(i))); if (addr::CompareAddresses(address, lo) < 0) lo = address; } // Addresses in the left node should be less than addresses in the right node. BOOST_REQUIRE_LE(addr::CompareAddresses(hi, lo), 0); CheckSplit(*tree.Left()); CheckSplit(*tree.Right()); }
void CheckSync(const TreeType& tree) { if (tree.IsLeaf()) { for (size_t i = 0; i < tree.Count(); i++) { for (size_t j = 0; j < tree.LocalDataset().n_rows; j++) { BOOST_REQUIRE_EQUAL(tree.LocalDataset().col(i)[j], tree.Dataset().col(tree.Points()[i])[j]); } } } else { for (size_t i = 0; i < tree.NumChildren(); i++) CheckSync(*tree.Children()[i]); } }
void CheckBound(const TreeType& tree) { typedef typename TreeType::ElemType ElemType; for (size_t i = 0; i < tree.NumDescendants(); i++) { arma::Col<ElemType> point = tree.Dataset().col(tree.Descendant(i)); // Check that the point is contained in the bound. BOOST_REQUIRE_EQUAL(true, tree.Bound().Contains(point)); const arma::Mat<ElemType>& loBound = tree.Bound().LoBound(); const arma::Mat<ElemType>& hiBound = tree.Bound().HiBound(); // Ensure that there is a hyperrectangle that contains the point. bool success = false; for (size_t j = 0; j < tree.Bound().NumBounds(); j++) { success = true; for (size_t k = 0; k < loBound.n_rows; k++) { if (point[k] < loBound(k, j) - 1e-14 * std::fabs(loBound(k, j)) || point[k] > hiBound(k, j) + 1e-14 * std::fabs(hiBound(k, j))) { success = false; break; } } if (success) break; } BOOST_REQUIRE_EQUAL(success, true); } if (!tree.IsLeaf()) { CheckBound(*tree.Left()); CheckBound(*tree.Right()); } }
void CheckDistance(TreeType& tree, TreeType* node = NULL) { typedef typename TreeType::ElemType ElemType; if (node == NULL) { node = &tree; while (node->Parent() != NULL) node = node->Parent(); CheckDistance<TreeType, MetricType>(tree, node); for (size_t j = 0; j < tree.Dataset().n_cols; j++) { const arma::Col<ElemType>& point = tree. Dataset().col(j); ElemType maxDist = 0; ElemType minDist = std::numeric_limits<ElemType>::max(); for (size_t i = 0; i < tree.NumDescendants(); i++) { ElemType dist = MetricType::Evaluate( tree.Dataset().col(tree.Descendant(i)), tree.Dataset().col(j)); if (dist > maxDist) maxDist = dist; if (dist < minDist) minDist = dist; } BOOST_REQUIRE_LE(tree.Bound().MinDistance(point), minDist * (1.0 + 10 * std::numeric_limits<ElemType>::epsilon())); BOOST_REQUIRE_LE(maxDist, tree.Bound().MaxDistance(point) * (1.0 + 10 * std::numeric_limits<ElemType>::epsilon())); math::RangeType<ElemType> r = tree.Bound().RangeDistance(point); BOOST_REQUIRE_LE(r.Lo(), minDist * (1.0 + 10 * std::numeric_limits<ElemType>::epsilon())); BOOST_REQUIRE_LE(maxDist, r.Hi() * (1.0 + 10 * std::numeric_limits<ElemType>::epsilon())); } if (!tree.IsLeaf()) { CheckDistance<TreeType, MetricType>(*tree.Left()); CheckDistance<TreeType, MetricType>(*tree.Right()); } } else { if (&tree != node) { ElemType maxDist = 0; ElemType minDist = std::numeric_limits<ElemType>::max(); for (size_t i = 0; i < tree.NumDescendants(); i++) for (size_t j = 0; j < node->NumDescendants(); j++) { ElemType dist = MetricType::Evaluate( tree.Dataset().col(tree.Descendant(i)), node->Dataset().col(node->Descendant(j))); if (dist > maxDist) maxDist = dist; if (dist < minDist) minDist = dist; } BOOST_REQUIRE_LE(tree.Bound().MinDistance(node->Bound()), minDist * (1.0 + 10 * std::numeric_limits<ElemType>::epsilon())); BOOST_REQUIRE_LE(maxDist, tree.Bound().MaxDistance(node->Bound()) * (1.0 + 10 * std::numeric_limits<ElemType>::epsilon())); math::RangeType<ElemType> r = tree.Bound().RangeDistance(node->Bound()); BOOST_REQUIRE_LE(r.Lo(), minDist * (1.0 + 10 * std::numeric_limits<ElemType>::epsilon())); BOOST_REQUIRE_LE(maxDist, r.Hi() * (1.0 + 10 * std::numeric_limits<ElemType>::epsilon())); } if (!node->IsLeaf()) { CheckDistance<TreeType, MetricType>(tree, node->Left()); CheckDistance<TreeType, MetricType>(tree, node->Right()); } } }