コード例 #1
0
  DualTreeKMeansStatistic(TreeType& node) :
      neighbor::NeighborSearchStat<neighbor::NearestNeighborSort>(),
      upperBound(DBL_MAX),
      lowerBound(DBL_MAX),
      owner(size_t(-1)),
      pruned(size_t(-1)),
      staticPruned(false),
      staticUpperBoundMovement(0.0),
      staticLowerBoundMovement(0.0),
      trueParent(node.Parent())
  {
    // Empirically calculate the centroid.
    centroid.zeros(node.Dataset().n_rows);
    for (size_t i = 0; i < node.NumPoints(); ++i)
    {
      // Correct handling of cover tree: don't double-count the point which
      // appears in the children.
      if (tree::TreeTraits<TreeType>::HasSelfChildren && i == 0 &&
          node.NumChildren() > 0)
        continue;
      centroid += node.Dataset().col(node.Point(i));
    }

    for (size_t i = 0; i < node.NumChildren(); ++i)
      centroid += node.Child(i).NumDescendants() *
          node.Child(i).Stat().Centroid();

    centroid /= node.NumDescendants();

    // Set the true children correctly.
    trueChildren.resize(node.NumChildren());
    for (size_t i = 0; i < node.NumChildren(); ++i)
      trueChildren[i] = &node.Child(i);
  }
コード例 #2
0
void CleanTree(TreeType& node)
{
  node.Stat().LastDistance() = 0.0;

  for (size_t i = 0; i < node.NumChildren(); ++i)
    CleanTree(node.Child(i));
}
コード例 #3
0
void NeighborSearchRules<
    SortPolicy,
    MetricType,
    TreeType>::
UpdateAfterRecursion(TreeType& queryNode, TreeType& /* referenceNode */)
{
  // Find the worst distance that the children found (including any points), and
  // update the bound accordingly.
  double worstDistance = SortPolicy::BestDistance();

  // First look through children nodes.
  for (size_t i = 0; i < queryNode.NumChildren(); ++i)
  {
    if (SortPolicy::IsBetter(worstDistance, queryNode.Child(i).Stat().Bound()))
      worstDistance = queryNode.Child(i).Stat().Bound();
  }

  // Now look through children points.
  for (size_t i = 0; i < queryNode.NumPoints(); ++i)
  {
    if (SortPolicy::IsBetter(worstDistance,
        distances(distances.n_rows - 1, queryNode.Point(i))))
      worstDistance = distances(distances.n_rows - 1, queryNode.Point(i));
  }

  // Take the worst distance from all of these, and update our bound to reflect
  // that.
  queryNode.Stat().Bound() = worstDistance;
}
コード例 #4
0
ファイル: fastmks_stat.hpp プロジェクト: YaweiZhao/mlpack
  FastMKSStat(const TreeType& node) :
      bound(-DBL_MAX),
      lastKernel(0.0),
      lastKernelNode(NULL)
  {
    // Do we have to calculate the centroid?
    if (tree::TreeTraits<TreeType>::FirstPointIsCentroid)
    {
      // If this type of tree has self-children, then maybe the evaluation is
      // already done.  These statistics are built bottom-up, so the child stat
      // should already be done.
      if ((tree::TreeTraits<TreeType>::HasSelfChildren) &&
          (node.NumChildren() > 0) &&
          (node.Point(0) == node.Child(0).Point(0)))
      {
        selfKernel = node.Child(0).Stat().SelfKernel();
      }
      else
      {
        selfKernel = sqrt(node.Metric().Kernel().Evaluate(
            node.Dataset().col(node.Point(0)),
            node.Dataset().col(node.Point(0))));
      }
    }
    else
    {
      // Calculate the centroid.
      arma::vec center;
      node.Center(center);

      selfKernel = sqrt(node.Metric().Kernel().Evaluate(center, center));
    }
  }
コード例 #5
0
ファイル: dtb_stat.hpp プロジェクト: shenzebang/mlpack
 DTBStat(const TreeType& node) :
     maxNeighborDistance(DBL_MAX),
     minNeighborDistance(DBL_MAX),
     bound(DBL_MAX),
     componentMembership(
         ((node.NumPoints() == 1) && (node.NumChildren() == 0)) ?
           node.Point(0) : -1) { }
コード例 #6
0
void CheckHierarchy(const TreeType& tree)
{
  for (size_t i = 0; i < tree.NumChildren(); i++)
  {
    BOOST_REQUIRE_EQUAL(&tree, tree.Child(i).Parent());
    CheckHierarchy(tree.Child(i));
  }
}
コード例 #7
0
void CheckFills(const TreeType& tree)
{
  if (tree.IsLeaf())
  {
    BOOST_REQUIRE(tree.Count() >= tree.MinLeafSize() || tree.Parent() == NULL);
    BOOST_REQUIRE(tree.Count() <= tree.MaxLeafSize());
  }
  else
  {
    for (size_t i = 0; i < tree.NumChildren(); i++)
    {
      BOOST_REQUIRE(tree.NumChildren() >= tree.MinNumChildren() ||
                    tree.Parent() == NULL);
      BOOST_REQUIRE(tree.NumChildren() <= tree.MaxNumChildren());
      CheckFills(*tree.Children()[i]);
    }
  }
}
コード例 #8
0
void CheckContainment(const TreeType& tree)
{
  if (tree.NumChildren() == 0)
  {
    for (size_t i = 0; i < tree.Count(); i++)
      BOOST_REQUIRE(tree.Bound().Contains(
          tree.Dataset().unsafe_col(tree.Points()[i])));
  }
  else
  {
    for (size_t i = 0; i < tree.NumChildren(); i++)
    {
      for (size_t j = 0; j < tree.Bound().Dim(); j++)
        BOOST_REQUIRE(tree.Bound()[j].Contains(tree.Children()[i]->Bound()[j]));

      CheckContainment(*(tree.Children()[i]));
    }
  }
}
コード例 #9
0
void CheckExactContainment(const TreeType& tree)
{
  if (tree.NumChildren() == 0)
  {
    for (size_t i = 0; i < tree.Bound().Dim(); i++)
    {
      double min = DBL_MAX;
      double max = -1.0 * DBL_MAX;
      for(size_t j = 0; j < tree.Count(); j++)
      {
        if (tree.LocalDataset().col(j)[i] < min)
          min = tree.LocalDataset().col(j)[i];
        if (tree.LocalDataset().col(j)[i] > max)
          max = tree.LocalDataset().col(j)[i];
      }
      BOOST_REQUIRE_EQUAL(max, tree.Bound()[i].Hi());
      BOOST_REQUIRE_EQUAL(min, tree.Bound()[i].Lo());
    }
  }
  else
  {
    for (size_t i = 0; i < tree.Bound().Dim(); i++)
    {
      double min = DBL_MAX;
      double max = -1.0 * DBL_MAX;
      for (size_t j = 0; j < tree.NumChildren(); j++)
      {
        if(tree.Child(j).Bound()[i].Lo() < min)
          min = tree.Child(j).Bound()[i].Lo();
        if(tree.Child(j).Bound()[i].Hi() > max)
          max = tree.Child(j).Bound()[i].Hi();
      }

      BOOST_REQUIRE_EQUAL(max, tree.Bound()[i].Hi());
      BOOST_REQUIRE_EQUAL(min, tree.Bound()[i].Lo());
    }

    for (size_t i = 0; i < tree.NumChildren(); i++)
      CheckExactContainment(tree.Child(i));
  }
}
コード例 #10
0
std::vector<arma::vec*> GetAllPointsInTree(const TreeType& tree)
{
  std::vector<arma::vec*> vec;
  if (tree.NumChildren() > 0)
  {
    for (size_t i = 0; i < tree.NumChildren(); i++)
    {
      std::vector<arma::vec*> tmp = GetAllPointsInTree(*(tree.Children()[i]));
      vec.insert(vec.begin(), tmp.begin(), tmp.end());
    }
  }
  else
  {
    for (size_t i = 0; i < tree.Count(); i++)
    {
      arma::vec* c = new arma::vec(tree.Dataset().col(tree.Points()[i]));
      vec.push_back(c);
    }
  }
  return vec;
}
コード例 #11
0
inline double DTBRules<MetricType, TreeType>::CalculateBound(
    TreeType& queryNode) const
{
  double worstPointBound = -DBL_MAX;
  double bestPointBound = DBL_MAX;

  double worstChildBound = -DBL_MAX;
  double bestChildBound = DBL_MAX;

  // Now, find the best and worst point bounds.
  for (size_t i = 0; i < queryNode.NumPoints(); ++i)
  {
    const size_t pointComponent = connections.Find(queryNode.Point(i));
    const double bound = neighborsDistances[pointComponent];

    if (bound > worstPointBound)
      worstPointBound = bound;
    if (bound < bestPointBound)
      bestPointBound = bound;
  }

  // Find the best and worst child bounds.
  for (size_t i = 0; i < queryNode.NumChildren(); ++i)
  {
    const double maxBound = queryNode.Child(i).Stat().MaxNeighborDistance();
    if (maxBound > worstChildBound)
      worstChildBound = maxBound;

    const double minBound = queryNode.Child(i).Stat().MinNeighborDistance();
    if (minBound < bestChildBound)
      bestChildBound = minBound;
  }

  // Now calculate the actual bounds.
  const double worstBound = std::max(worstPointBound, worstChildBound);
  const double bestBound = std::min(bestPointBound, bestChildBound);
  // We must check that bestBound != DBL_MAX; otherwise, we risk overflow.
  const double bestAdjustedBound = (bestBound == DBL_MAX) ? DBL_MAX :
      bestBound + 2 * queryNode.FurthestDescendantDistance();

  // Update the relevant quantities in the node.
  queryNode.Stat().MaxNeighborDistance() = worstBound;
  queryNode.Stat().MinNeighborDistance() = bestBound;
  queryNode.Stat().Bound() = std::min(worstBound, bestAdjustedBound);

  return queryNode.Stat().Bound();
}
コード例 #12
0
int GetMinLevel(const TreeType& tree)
{
  int min = 1;
  if (!tree.IsLeaf())
  {
    int m = INT_MAX;
    for (size_t i = 0; i < tree.NumChildren(); i++)
    {
      int n = GetMinLevel(*tree.Children()[i]);
      if (n < m)
        m = n;
    }
    min += m;
  }

  return min;
}
コード例 #13
0
int GetMaxLevel(const TreeType& tree)
{
  int max = 1;
  if (!tree.IsLeaf())
  {
    int m = 0;
    for (size_t i = 0; i < tree.NumChildren(); i++)
    {
      int n = GetMaxLevel(*tree.Children()[i]);
      if (n > m)
        m = n;
    }
    max += m;
  }

  return max;
}
コード例 #14
0
  DualTreeKMeansStatistic(TreeType& node) :
      closestQueryNode(NULL),
      minQueryNodeDistance(DBL_MAX),
      maxQueryNodeDistance(DBL_MAX),
      clustersPruned(0),
      iteration(size_t() - 1)
  {
    // Empirically calculate the centroid.
    centroid.zeros(node.Dataset().n_rows);
    for (size_t i = 0; i < node.NumPoints(); ++i)
      centroid += node.Dataset().col(node.Point(i));

    for (size_t i = 0; i < node.NumChildren(); ++i)
      centroid += node.Child(i).NumDescendants() *
          node.Child(i).Stat().Centroid();

    centroid /= node.NumDescendants();
  }
コード例 #15
0
void CheckSync(const TreeType& tree)
{
  if (tree.IsLeaf())
  {
    for (size_t i = 0; i < tree.Count(); i++)
    {
      for (size_t j = 0; j < tree.LocalDataset().n_rows; j++)
      {
        BOOST_REQUIRE_EQUAL(tree.LocalDataset().col(i)[j],
                            tree.Dataset().col(tree.Points()[i])[j]);
      }
    }
  }
  else
  {
    for (size_t i = 0; i < tree.NumChildren(); i++)
      CheckSync(*tree.Children()[i]);
  }
}
コード例 #16
0
    PellegMooreKMeansStatistic(TreeType& node)
    {
        centroid.zeros(node.Dataset().n_rows);

        // Hope it's a depth-first build procedure.  Also, this won't work right for
        // trees that have self-children or stuff like that.
        for (size_t i = 0; i < node.NumChildren(); ++i)
        {
            centroid += node.Child(i).NumDescendants() *
                        node.Child(i).Stat().Centroid();
        }

        for (size_t i = 0; i < node.NumPoints(); ++i)
        {
            centroid += node.Dataset().col(node.Point(i));
        }

        if (node.NumDescendants() > 0)
            centroid /= node.NumDescendants();
        else
            centroid.fill(DBL_MAX); // Invalid centroid.  What else can we do?
    }
コード例 #17
0
void GreedySingleTreeTraverser<TreeType, RuleType>::Traverse(
    const size_t queryIndex,
    TreeType& referenceNode)
{
  // Run the base case as necessary for all the points in the reference node.
  for (size_t i = 0; i < referenceNode.NumPoints(); ++i)
    rule.BaseCase(queryIndex, referenceNode.Point(i));

  size_t bestChild = rule.GetBestChild(queryIndex, referenceNode);
  size_t numDescendants;

  // Check that referencenode is not a leaf node while calculating number of
  // descendants of it's best child.
  if (!referenceNode.IsLeaf())
    numDescendants = referenceNode.Child(bestChild).NumDescendants();
  else
    numDescendants = referenceNode.NumPoints();

  // If number of descendants are more than minBaseCases than we can go along
  // with best child otherwise we need to traverse for each descendant to
  // ensure that we calculate at least minBaseCases number of base cases.
  if (!referenceNode.IsLeaf())
  {
    if (numDescendants > minBaseCases)
    {
      // We are prunning all but one child.
      numPrunes += referenceNode.NumChildren() - 1;
      // Recurse the best child.
      Traverse(queryIndex, referenceNode.Child(bestChild));
    }
    else
    {
      // Run the base case over first minBaseCases number of descendants.
      for (size_t i = 0; i <= minBaseCases; ++i)
        rule.BaseCase(queryIndex, referenceNode.Descendant(i));
    }
  }
}
コード例 #18
0
inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::
    CalculateBound(TreeType& queryNode) const
{
  // We have five possible bounds, and we must take the best of them all.  We
  // don't use min/max here, but instead "best/worst", because this is general
  // to the nearest-neighbors/furthest-neighbors cases.  For nearest neighbors,
  // min = best, max = worst.
  //
  // (1) worst ( worst_{all points p in queryNode} D_p[k],
  //             worst_{all children c in queryNode} B(c) );
  // (2) best_{all points p in queryNode} D_p[k] + worst child distance +
  //        worst descendant distance;
  // (3) best_{all children c in queryNode} B(c) +
  //      2 ( worst descendant distance of queryNode -
  //          worst descendant distance of c );
  // (4) B_1(parent of queryNode)
  // (5) B_2(parent of queryNode);
  //
  // D_p[k] is the current k'th candidate distance for point p.
  // So we will loop over the points in queryNode and the children in queryNode
  // to calculate all five of these quantities.

  // Hm, can we populate our distances vector with estimates from the parent?
  // This is written specifically for the cover tree and assumes only one point
  // in a node.
//  if (queryNode.Parent() != NULL && queryNode.NumPoints() > 0)
//  {
//    size_t parentIndexStart = 0;
//    for (size_t i = 0; i < neighbors.n_rows; ++i)
//    {
//      const double pointDistance = distances(i, queryNode.Point(0));
//      if (pointDistance == DBL_MAX)
//      {
//      // Cool, can we take an estimate from the parent?
//        const double parentWorstBound = distances(distances.n_rows - 1,
//              queryNode.Parent()->Point(0));
//        if (parentWorstBound != DBL_MAX)
//        {
//          const double parentAdjustedDistance = parentWorstBound +
//              queryNode.ParentDistance();
//          distances(i, queryNode.Point(0)) = parentAdjustedDistance;
//        }
//      }
//    }
//  }

  double worstPointDistance = SortPolicy::BestDistance();
  double bestPointDistance = SortPolicy::WorstDistance();

  // Loop over all points in this node to find the best and worst distance
  // candidates (for (1) and (2)).
  for (size_t i = 0; i < queryNode.NumPoints(); ++i)
  {
    const double distance = distances(distances.n_rows - 1,
        queryNode.Point(i));
    if (SortPolicy::IsBetter(distance, bestPointDistance))
      bestPointDistance = distance;
    if (SortPolicy::IsBetter(worstPointDistance, distance))
      worstPointDistance = distance;
  }

  // Loop over all the children in this node to find the worst bound (for (1))
  // and the best bound with the correcting factor for descendant distances (for
  // (3)).
  double worstChildBound = SortPolicy::BestDistance();
  double bestAdjustedChildBound = SortPolicy::WorstDistance();
  const double queryMaxDescendantDistance =
      queryNode.FurthestDescendantDistance();

  for (size_t i = 0; i < queryNode.NumChildren(); ++i)
  {
    const double firstBound = queryNode.Child(i).Stat().FirstBound();
    const double secondBound = queryNode.Child(i).Stat().SecondBound();
    const double childMaxDescendantDistance =
        queryNode.Child(i).FurthestDescendantDistance();

    if (SortPolicy::IsBetter(worstChildBound, firstBound))
      worstChildBound = firstBound;

    // Now calculate adjustment for maximum descendant distances.
    const double adjustedBound = SortPolicy::CombineWorst(secondBound,
        2 * (queryMaxDescendantDistance - childMaxDescendantDistance));
    if (SortPolicy::IsBetter(adjustedBound, bestAdjustedChildBound))
      bestAdjustedChildBound = adjustedBound;
  }

  // This is bound (1).
  const double firstBound =
      (SortPolicy::IsBetter(worstPointDistance, worstChildBound)) ?
      worstChildBound : worstPointDistance;

  // This is bound (2).
  const double secondBound = SortPolicy::CombineWorst(
      SortPolicy::CombineWorst(bestPointDistance, queryMaxDescendantDistance),
      queryNode.FurthestPointDistance());

  // Bound (3) is bestAdjustedChildBound.

  // Bounds (4) and (5) are the parent bounds.
  const double fourthBound = (queryNode.Parent() != NULL) ?
      queryNode.Parent()->Stat().FirstBound() : SortPolicy::WorstDistance();
//  const double fifthBound = (queryNode.Parent() != NULL) ?
//      queryNode.Parent()->Stat().SecondBound() -
//      queryNode.Parent()->FurthestDescendantDistance() -
//      queryNode.Parent()->FurthestPointDistance() + queryMaxDescendantDistance +
//      queryNode.FurthestPointDistance() + queryNode.ParentDistance() :
//      SortPolicy::WorstDistance();

  // Now, we will take the best of these.  Unfortunately due to the way
  // IsBetter() is defined, this sort of has to be a little ugly.
  // The variable interA represents the first bound (B_1), which is the worst
  // candidate distance of any descendants of this node.
  // The variable interC represents the second bound (B_2), which is a bound on
  // the worst distance of any descendants of this node assembled using the best
  // descendant candidate distance modified using the furthest descendant
  // distance.
  const double interA = (SortPolicy::IsBetter(firstBound, fourthBound)) ?
      firstBound : fourthBound;
  const double interB =
      (SortPolicy::IsBetter(bestAdjustedChildBound, secondBound)) ?
      bestAdjustedChildBound : secondBound;
//  const double interC = (SortPolicy::IsBetter(interB, fifthBound)) ? interB :
//      fifthBound;

  // Update the first and second bounds of the node.
  queryNode.Stat().FirstBound() = interA;
  queryNode.Stat().SecondBound() = interB;

  // Update the actual bound of the node.
  queryNode.Stat().Bound() = (SortPolicy::IsBetter(interA, interB)) ? interB :
      interB;

  return queryNode.Stat().Bound();
}
コード例 #19
0
void CheckTrees(TreeType& tree,
                TreeType& xmlTree,
                TreeType& textTree,
                TreeType& binaryTree)
{
  const typename TreeType::Mat* dataset = &tree.Dataset();

  // Make sure that the data matrices are the same.
  if (tree.Parent() == NULL)
  {
    CheckMatrices(*dataset,
                  xmlTree.Dataset(),
                  textTree.Dataset(),
                  binaryTree.Dataset());

    // Also ensure that the other parents are null too.
    BOOST_REQUIRE_EQUAL(xmlTree.Parent(), (TreeType*) NULL);
    BOOST_REQUIRE_EQUAL(textTree.Parent(), (TreeType*) NULL);
    BOOST_REQUIRE_EQUAL(binaryTree.Parent(), (TreeType*) NULL);
  }

  // Make sure the number of children is the same.
  BOOST_REQUIRE_EQUAL(tree.NumChildren(), xmlTree.NumChildren());
  BOOST_REQUIRE_EQUAL(tree.NumChildren(), textTree.NumChildren());
  BOOST_REQUIRE_EQUAL(tree.NumChildren(), binaryTree.NumChildren());

  // Make sure the number of descendants is the same.
  BOOST_REQUIRE_EQUAL(tree.NumDescendants(), xmlTree.NumDescendants());
  BOOST_REQUIRE_EQUAL(tree.NumDescendants(), textTree.NumDescendants());
  BOOST_REQUIRE_EQUAL(tree.NumDescendants(), binaryTree.NumDescendants());

  // Make sure the number of points is the same.
  BOOST_REQUIRE_EQUAL(tree.NumPoints(), xmlTree.NumPoints());
  BOOST_REQUIRE_EQUAL(tree.NumPoints(), textTree.NumPoints());
  BOOST_REQUIRE_EQUAL(tree.NumPoints(), binaryTree.NumPoints());

  // Check that each point is the same.
  for (size_t i = 0; i < tree.NumPoints(); ++i)
  {
    BOOST_REQUIRE_EQUAL(tree.Point(i), xmlTree.Point(i));
    BOOST_REQUIRE_EQUAL(tree.Point(i), textTree.Point(i));
    BOOST_REQUIRE_EQUAL(tree.Point(i), binaryTree.Point(i));
  }

  // Check that the parent distance is the same.
  BOOST_REQUIRE_CLOSE(tree.ParentDistance(), xmlTree.ParentDistance(), 1e-8);
  BOOST_REQUIRE_CLOSE(tree.ParentDistance(), textTree.ParentDistance(), 1e-8);
  BOOST_REQUIRE_CLOSE(tree.ParentDistance(), binaryTree.ParentDistance(), 1e-8);

  // Check that the furthest descendant distance is the same.
  BOOST_REQUIRE_CLOSE(tree.FurthestDescendantDistance(),
      xmlTree.FurthestDescendantDistance(), 1e-8);
  BOOST_REQUIRE_CLOSE(tree.FurthestDescendantDistance(),
      textTree.FurthestDescendantDistance(), 1e-8);
  BOOST_REQUIRE_CLOSE(tree.FurthestDescendantDistance(),
      binaryTree.FurthestDescendantDistance(), 1e-8);

  // Check that the minimum bound distance is the same.
  BOOST_REQUIRE_CLOSE(tree.MinimumBoundDistance(),
      xmlTree.MinimumBoundDistance(), 1e-8);
  BOOST_REQUIRE_CLOSE(tree.MinimumBoundDistance(),
      textTree.MinimumBoundDistance(), 1e-8);
  BOOST_REQUIRE_CLOSE(tree.MinimumBoundDistance(),
      binaryTree.MinimumBoundDistance(), 1e-8);

  // Recurse into the children.
  for (size_t i = 0; i < tree.NumChildren(); ++i)
  {
    // Check that the child dataset is the same.
    BOOST_REQUIRE_EQUAL(&xmlTree.Dataset(), &xmlTree.Child(i).Dataset());
    BOOST_REQUIRE_EQUAL(&textTree.Dataset(), &textTree.Child(i).Dataset());
    BOOST_REQUIRE_EQUAL(&binaryTree.Dataset(), &binaryTree.Child(i).Dataset());

    // Make sure the parent link is right.
    BOOST_REQUIRE_EQUAL(xmlTree.Child(i).Parent(), &xmlTree);
    BOOST_REQUIRE_EQUAL(textTree.Child(i).Parent(), &textTree);
    BOOST_REQUIRE_EQUAL(binaryTree.Child(i).Parent(), &binaryTree);

    CheckTrees(tree.Child(i), xmlTree.Child(i), textTree.Child(i),
        binaryTree.Child(i));
  }
}
コード例 #20
0
bool HilbertRTreeSplit<splitOrder>::
SplitNonLeafNode(TreeType* tree, std::vector<bool>& relevels)
{
  // If we are splitting the root node, we need will do things differently so
  // that the constructor and other methods don't confuse the end user by giving
  // an address of another node.
  if (tree->Parent() == NULL)
  {
    // We actually want to copy this way.  Pointers and everything.
    TreeType* copy = new TreeType(*tree, false);
    // Only the root node owns this variable.
    copy->AuxiliaryInfo().HilbertValue().OwnsValueToInsert() = false;
    copy->Parent() = tree;
    tree->NumChildren() = 0;
    tree->NullifyData();
    tree->children[(tree->NumChildren())++] = copy;

    SplitNonLeafNode(copy, relevels);
    return true;
  }

  TreeType* parent = tree->Parent();

  size_t iTree = 0;
  for (iTree = 0; parent->children[iTree] != tree; iTree++);

  // Try to find splitOrder cooperating siblings in order to redistribute
  // children among them and avoid split.
  size_t firstSibling, lastSibling;
  if (FindCooperatingSiblings(parent, iTree, firstSibling, lastSibling))
  {
    RedistributeNodesEvenly(parent, firstSibling, lastSibling);
    return false;
  }

  // We can not find splitOrder cooperating siblings since they are all full.
  // We introduce new one instead.
  size_t iNewSibling = (iTree + splitOrder < parent->NumChildren() ?
                        iTree + splitOrder : parent->NumChildren());

  for (size_t i = parent->NumChildren(); i > iNewSibling ; i--)
    parent->children[i] = parent->children[i - 1];

  parent->NumChildren()++;

  parent->children[iNewSibling] = new TreeType(parent);

  lastSibling = (iTree + splitOrder < parent->NumChildren() ?
                 iTree + splitOrder : parent->NumChildren() - 1);
  firstSibling = (lastSibling > splitOrder ?
                  lastSibling - splitOrder : 0);

  assert(lastSibling - firstSibling <= splitOrder);
  assert(firstSibling >= 0);
  assert(lastSibling < parent->NumChildren());

  // Redistribute children among (splitOrder + 1) cooperating siblings evenly.
  RedistributeNodesEvenly(parent, firstSibling, lastSibling);

  if (parent->NumChildren() == parent->MaxNumChildren() + 1)
    SplitNonLeafNode(parent, relevels);
  return false;
}
コード例 #21
0
double FastMKSRules<KernelType, TreeType>::CalculateBound(TreeType& queryNode)
const
{
    // We have four possible bounds -- just like NeighborSearchRules, but they are
    // slightly different in this context.
    //
    // (1) min ( min_{all points p in queryNode} P_p[k],
    //           min_{all children c in queryNode} B(c) );
    // (2) max_{all points p in queryNode} P_p[k] + (worst child distance + worst
    //           descendant distance) sqrt(K(I_p[k], I_p[k]));
    // (3) max_{all children c in queryNode} B(c) + <-- not done yet.  ignored.
    // (4) B(parent of queryNode);
    double worstPointKernel = DBL_MAX;
    double bestAdjustedPointKernel = -DBL_MAX;

    const double queryDescendantDistance = queryNode.FurthestDescendantDistance();

    // Loop over all points in this node to find the best and worst.
    for (size_t i = 0; i < queryNode.NumPoints(); ++i)
    {
        const size_t point = queryNode.Point(i);
        if (products(products.n_rows - 1, point) < worstPointKernel)
            worstPointKernel = products(products.n_rows - 1, point);

        if (products(products.n_rows - 1, point) == -DBL_MAX)
            continue; // Avoid underflow.

        // This should be (queryDescendantDistance + centroidDistance) for any tree
        // but it works for cover trees since centroidDistance = 0 for cover trees.
        const double candidateKernel = products(products.n_rows - 1, point) -
                                       queryDescendantDistance *
                                       referenceKernels[indices(indices.n_rows - 1, point)];

        if (candidateKernel > bestAdjustedPointKernel)
            bestAdjustedPointKernel = candidateKernel;
    }

    // Loop over all the children in the node.
    double worstChildKernel = DBL_MAX;

    for (size_t i = 0; i < queryNode.NumChildren(); ++i)
    {
        if (queryNode.Child(i).Stat().Bound() < worstChildKernel)
            worstChildKernel = queryNode.Child(i).Stat().Bound();
    }

    // Now assemble bound (1).
    const double firstBound = (worstPointKernel < worstChildKernel) ?
                              worstPointKernel : worstChildKernel;

    // Bound (2) is bestAdjustedPointKernel.
    const double fourthBound = (queryNode.Parent() == NULL) ? -DBL_MAX :
                               queryNode.Parent()->Stat().Bound();

    // Pick the best of these bounds.
    const double interA = (firstBound > bestAdjustedPointKernel) ? firstBound :
                          bestAdjustedPointKernel;
//  const double interA = 0.0;
    const double interB = fourthBound;

    return (interA > interB) ? interA : interB;
}