void GreedySingleTreeTraverser<TreeType, RuleType>::Traverse(
    const size_t queryIndex,
    TreeType& referenceNode)
{
  // Run the base case as necessary for all the points in the reference node.
  for (size_t i = 0; i < referenceNode.NumPoints(); ++i)
    rule.BaseCase(queryIndex, referenceNode.Point(i));

  size_t bestChild = rule.GetBestChild(queryIndex, referenceNode);
  size_t numDescendants;

  // Check that referencenode is not a leaf node while calculating number of
  // descendants of it's best child.
  if (!referenceNode.IsLeaf())
    numDescendants = referenceNode.Child(bestChild).NumDescendants();
  else
    numDescendants = referenceNode.NumPoints();

  // If number of descendants are more than minBaseCases than we can go along
  // with best child otherwise we need to traverse for each descendant to
  // ensure that we calculate at least minBaseCases number of base cases.
  if (!referenceNode.IsLeaf())
  {
    if (numDescendants > minBaseCases)
    {
      // We are prunning all but one child.
      numPrunes += referenceNode.NumChildren() - 1;
      // Recurse the best child.
      Traverse(queryIndex, referenceNode.Child(bestChild));
    }
    else
    {
      // Run the base case over first minBaseCases number of descendants.
      for (size_t i = 0; i <= minBaseCases; ++i)
        rule.BaseCase(queryIndex, referenceNode.Descendant(i));
    }
  }
}
Esempio n. 2
0
int GetMinLevel(const TreeType& tree)
{
  int min = 1;
  if (!tree.IsLeaf())
  {
    int m = INT_MAX;
    for (size_t i = 0; i < tree.NumChildren(); i++)
    {
      int n = GetMinLevel(*tree.Children()[i]);
      if (n < m)
        m = n;
    }
    min += m;
  }

  return min;
}
Esempio n. 3
0
int GetMaxLevel(const TreeType& tree)
{
  int max = 1;
  if (!tree.IsLeaf())
  {
    int m = 0;
    for (size_t i = 0; i < tree.NumChildren(); i++)
    {
      int n = GetMaxLevel(*tree.Children()[i]);
      if (n > m)
        m = n;
    }
    max += m;
  }

  return max;
}
Esempio n. 4
0
void CheckFills(const TreeType& tree)
{
  if (tree.IsLeaf())
  {
    BOOST_REQUIRE(tree.Count() >= tree.MinLeafSize() || tree.Parent() == NULL);
    BOOST_REQUIRE(tree.Count() <= tree.MaxLeafSize());
  }
  else
  {
    for (size_t i = 0; i < tree.NumChildren(); i++)
    {
      BOOST_REQUIRE(tree.NumChildren() >= tree.MinNumChildren() ||
                    tree.Parent() == NULL);
      BOOST_REQUIRE(tree.NumChildren() <= tree.MaxNumChildren());
      CheckFills(*tree.Children()[i]);
    }
  }
}
Esempio n. 5
0
void CheckSplit(const TreeType& tree)
{
  typedef typename TreeType::ElemType ElemType;
  typedef typename std::conditional<sizeof(ElemType) * CHAR_BIT <= 32,
                                    uint32_t,
                                    uint64_t>::type AddressElemType;

  if (tree.IsLeaf())
    return;

  arma::Col<AddressElemType> lo(tree.Bound().Dim());
  arma::Col<AddressElemType> hi(tree.Bound().Dim());

  lo.fill(std::numeric_limits<AddressElemType>::max());
  hi.fill(0);

  arma::Col<AddressElemType> address(tree.Bound().Dim());

  // Find the highest address of the left node.
  for (size_t i = 0; i < tree.Left()->NumDescendants(); i++)
  {
    addr::PointToAddress(address,
        tree.Dataset().col(tree.Left()->Descendant(i)));

    if (addr::CompareAddresses(address, hi) > 0)
      hi = address;
  }

  // Find the lowest address of the right node.
  for (size_t i = 0; i < tree.Right()->NumDescendants(); i++)
  {
    addr::PointToAddress(address,
        tree.Dataset().col(tree.Right()->Descendant(i)));

    if (addr::CompareAddresses(address, lo) < 0)
      lo = address;
  }

  // Addresses in the left node should be less than addresses in the right node.
  BOOST_REQUIRE_LE(addr::CompareAddresses(hi, lo), 0);

  CheckSplit(*tree.Left());
  CheckSplit(*tree.Right());
}
Esempio n. 6
0
void CheckSync(const TreeType& tree)
{
  if (tree.IsLeaf())
  {
    for (size_t i = 0; i < tree.Count(); i++)
    {
      for (size_t j = 0; j < tree.LocalDataset().n_rows; j++)
      {
        BOOST_REQUIRE_EQUAL(tree.LocalDataset().col(i)[j],
                            tree.Dataset().col(tree.Points()[i])[j]);
      }
    }
  }
  else
  {
    for (size_t i = 0; i < tree.NumChildren(); i++)
      CheckSync(*tree.Children()[i]);
  }
}
Esempio n. 7
0
void CheckBound(const TreeType& tree)
{
  typedef typename TreeType::ElemType ElemType;
  for (size_t i = 0; i < tree.NumDescendants(); i++)
  {
    arma::Col<ElemType> point = tree.Dataset().col(tree.Descendant(i));

    // Check that the point is contained in the bound.
    BOOST_REQUIRE_EQUAL(true, tree.Bound().Contains(point));

    const arma::Mat<ElemType>& loBound = tree.Bound().LoBound();
    const arma::Mat<ElemType>& hiBound = tree.Bound().HiBound();

    // Ensure that there is a hyperrectangle that contains the point.
    bool success = false;
    for (size_t j = 0; j < tree.Bound().NumBounds(); j++)
    {
      success = true;
      for (size_t k = 0; k < loBound.n_rows; k++)
      {
        if (point[k] < loBound(k, j) - 1e-14 * std::fabs(loBound(k, j)) ||
            point[k] > hiBound(k, j) + 1e-14 * std::fabs(hiBound(k, j)))
        {
          success = false;
          break;
        }
      }
      if (success)
        break;
    }

    BOOST_REQUIRE_EQUAL(success, true);
  }

  if (!tree.IsLeaf())
  {
    CheckBound(*tree.Left());
    CheckBound(*tree.Right());
  }
}
Esempio n. 8
0
void CheckDistance(TreeType& tree, TreeType* node = NULL)
{
  typedef typename TreeType::ElemType ElemType;
  if (node == NULL)
  {
    node = &tree;

    while (node->Parent() != NULL)
      node = node->Parent();

    CheckDistance<TreeType, MetricType>(tree, node);

    for (size_t j = 0; j < tree.Dataset().n_cols; j++)
    {
      const arma::Col<ElemType>& point = tree.  Dataset().col(j);
      ElemType maxDist = 0;
      ElemType minDist = std::numeric_limits<ElemType>::max();
      for (size_t i = 0; i < tree.NumDescendants(); i++)
      {
        ElemType dist = MetricType::Evaluate(
            tree.Dataset().col(tree.Descendant(i)),
            tree.Dataset().col(j));

        if (dist > maxDist)
          maxDist = dist;
        if (dist < minDist)
          minDist = dist;
      }

      BOOST_REQUIRE_LE(tree.Bound().MinDistance(point), minDist *
          (1.0 + 10 * std::numeric_limits<ElemType>::epsilon()));
      BOOST_REQUIRE_LE(maxDist, tree.Bound().MaxDistance(point) *
          (1.0 + 10 * std::numeric_limits<ElemType>::epsilon()));

      math::RangeType<ElemType> r = tree.Bound().RangeDistance(point);

      BOOST_REQUIRE_LE(r.Lo(), minDist *
          (1.0 + 10 * std::numeric_limits<ElemType>::epsilon()));
      BOOST_REQUIRE_LE(maxDist, r.Hi() *
          (1.0 + 10 * std::numeric_limits<ElemType>::epsilon()));
    }
      
    if (!tree.IsLeaf())
    {
      CheckDistance<TreeType, MetricType>(*tree.Left());
      CheckDistance<TreeType, MetricType>(*tree.Right());
    }
  }
  else
  {
    if (&tree != node)
    {
      ElemType maxDist = 0;
      ElemType minDist = std::numeric_limits<ElemType>::max();
      for (size_t i = 0; i < tree.NumDescendants(); i++)
        for (size_t j = 0; j < node->NumDescendants(); j++)
        {
          ElemType dist = MetricType::Evaluate(
              tree.Dataset().col(tree.Descendant(i)),
              node->Dataset().col(node->Descendant(j)));

          if (dist > maxDist)
            maxDist = dist;
          if (dist < minDist)
            minDist = dist;
        }

      BOOST_REQUIRE_LE(tree.Bound().MinDistance(node->Bound()), minDist *
          (1.0 + 10 * std::numeric_limits<ElemType>::epsilon()));
      BOOST_REQUIRE_LE(maxDist, tree.Bound().MaxDistance(node->Bound()) *
          (1.0 + 10 * std::numeric_limits<ElemType>::epsilon()));

      math::RangeType<ElemType> r = tree.Bound().RangeDistance(node->Bound());

      BOOST_REQUIRE_LE(r.Lo(), minDist *
          (1.0 + 10 * std::numeric_limits<ElemType>::epsilon()));
      BOOST_REQUIRE_LE(maxDist, r.Hi() *
          (1.0 + 10 * std::numeric_limits<ElemType>::epsilon()));
    }
    if (!node->IsLeaf())
    {
      CheckDistance<TreeType, MetricType>(tree, node->Left());
      CheckDistance<TreeType, MetricType>(tree, node->Right());
    }
  }
}