RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType,
              AuxiliaryInformationType>::
RectangleTree(MatType&& data,
              const size_t maxLeafSize,
              const size_t minLeafSize,
              const size_t maxNumChildren,
              const size_t minNumChildren,
              const size_t firstDataIndex) :
    maxNumChildren(maxNumChildren),
    minNumChildren(minNumChildren),
    numChildren(0),
    children(maxNumChildren + 1), // Add one to make splitting the node simpler.
    parent(NULL),
    begin(0),
    count(0),
    numDescendants(0),
    maxLeafSize(maxLeafSize),
    minLeafSize(minLeafSize),
    bound(data.n_rows),
    parentDistance(0),
    dataset(new MatType(std::move(data))),
    ownsDataset(true),
    points(maxLeafSize + 1), // Add one to make splitting the node simpler.
    auxiliaryInfo(this)
{
  stat = StatisticType(*this);

  // For now, just insert the points in order.
  RectangleTree* root = this;

  for (size_t i = firstDataIndex; i < dataset->n_cols; i++)
    root->InsertPoint(i);
}
bool RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType,
                   AuxiliaryInformationType>::
    RemoveNode(const RectangleTree* node, std::vector<bool>& relevels)
{
  for (size_t i = 0; i < numChildren; i++)
  {
    if (children[i] == node)
    {
      if (!auxiliaryInfo.HandleNodeRemoval(this, i))
      {
        children[i] = children[--numChildren]; // Decrement numChildren.
      }
      RectangleTree* tree = this;
      while (tree != NULL)
      {
        tree->numDescendants -= node->numDescendants;
        tree = tree->Parent();
      }
      CondenseTree(arma::vec(), relevels, false);
      return true;
    }

    bool contains = true;
    for (size_t j = 0; j < node->Bound().Dim(); j++)
      contains &= Child(i).Bound()[j].Contains(node->Bound()[j]);

    if (contains)
      if (children[i]->RemoveNode(node, relevels))
        return true;
  }

  return false;
}
Esempio n. 3
0
bool RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
    DeletePoint(const size_t point)
{
  // It is possible that this will cause a reinsertion, so we need to handle the
  // levels properly.
  RectangleTree* root = this;
  while (root->Parent() != NULL)
    root = root->Parent();

  std::vector<bool> lvls(root->TreeDepth());
  for (size_t i = 0; i < lvls.size(); i++)
    lvls[i] = true;

  if (numChildren == 0)
  {
    for (size_t i = 0; i < count; i++)
    {
      if (points[i] == point)
      {
        localDataset->col(i) = localDataset->col(--count); // Decrement count.
        points[i] = points[count];
        // This function wil ensure that minFill is satisfied.
        CondenseTree(dataset->col(point), lvls, true);
        return true;
      }
    }
  }

  for (size_t i = 0; i < numChildren; i++)
    if (children[i]->Bound().Contains(dataset->col(point)))
      if (children[i]->DeletePoint(point, lvls))
        return true;

  return false;
}
bool RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType,
                   AuxiliaryInformationType>::
    DeletePoint(const size_t point, std::vector<bool>& relevels)
{
  if (numChildren == 0)
  {
    for (size_t i = 0; i < count; i++)
    {
      if (points[i] == point)
      {
        if (!auxiliaryInfo.HandlePointDeletion(this, i))
          points[i] = points[--count];

        RectangleTree* tree = this;
        while (tree != NULL)
        {
          tree->numDescendants--;
          tree = tree->Parent();
        }
        // This function will ensure that minFill is satisfied.
        CondenseTree(dataset->col(point), relevels, true);
        return true;
      }
    }
  }

  for (size_t i = 0; i < numChildren; i++)
    if (children[i]->Bound().Contains(dataset->col(point)))
      if (children[i]->DeletePoint(point, relevels))
        return true;

  return false;
}
Esempio n. 5
0
RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
RectangleTree(const MatType& data,
              const size_t maxLeafSize,
              const size_t minLeafSize,
              const size_t maxNumChildren,
              const size_t minNumChildren,
              const size_t firstDataIndex) :
    maxNumChildren(maxNumChildren),
    minNumChildren(minNumChildren),
    numChildren(0),
    children(maxNumChildren + 1), // Add one to make splitting the node simpler.
    parent(NULL),
    begin(0),
    count(0),
    maxLeafSize(maxLeafSize),
    minLeafSize(minLeafSize),
    bound(data.n_rows),
    splitHistory(bound.Dim()),
    parentDistance(0),
    dataset(new MatType(data)),
    ownsDataset(true),
    points(maxLeafSize + 1), // Add one to make splitting the node simpler.
    localDataset(new MatType(arma::zeros<MatType>(data.n_rows,
                                                  maxLeafSize + 1)))
{
  stat = StatisticType(*this);

  // For now, just insert the points in order.
  RectangleTree* root = this;

  for (size_t i = firstDataIndex; i < data.n_cols; i++)
    root->InsertPoint(i);
}
size_t RectangleTree<MetricType, StatisticType, MatType, SplitType,
                     DescentType, AuxiliaryInformationType>::TreeDepth() const
{
  int n = 1;
  RectangleTree* currentNode = const_cast<RectangleTree*> (this);

  while (!currentNode->IsLeaf())
  {
    currentNode = currentNode->children[0];
    n++;
  }

  return n;
}
bool RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType,
                   AuxiliaryInformationType>::
    DeletePoint(const size_t point)
{
  // It is possible that this will cause a reinsertion, so we need to handle the
  // levels properly.
  RectangleTree* root = this;
  while (root->Parent() != NULL)
    root = root->Parent();

  std::vector<bool> lvls(root->TreeDepth());
  for (size_t i = 0; i < lvls.size(); i++)
    lvls[i] = true;

  if (numChildren == 0)
  {
    for (size_t i = 0; i < count; i++)
    {
      if (points[i] == point)
      {
        if (!auxiliaryInfo.HandlePointDeletion(this, i))
          points[i] = points[--count];

        RectangleTree* tree = this;
        while (tree != NULL)
        {
          tree->numDescendants--;
          tree = tree->Parent();
        }
        // This function wil ensure that minFill is satisfied.
        CondenseTree(dataset->col(point), lvls, true);
        return true;
      }
    }
  }

  for (size_t i = 0; i < numChildren; i++)
    if (children[i]->Bound().Contains(dataset->col(point)))
      if (children[i]->DeletePoint(point, lvls))
        return true;

  return false;
}
void RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType,
                   AuxiliaryInformationType>::
    CondenseTree(const arma::vec& point,
                 std::vector<bool>& relevels,
                 const bool usePoint)
{
  // First delete the node if we need to.  There's no point in shrinking the
  // bound first.
  if (IsLeaf() && count < minLeafSize && parent != NULL)
  {
    // We can't delete the root node.
    for (size_t i = 0; i < parent->NumChildren(); i++)
    {
      if (parent->children[i] == this)
      {
        // Decrement numChildren.
        if (!auxiliaryInfo.HandleNodeRemoval(parent, i))
        {
          parent->children[i] = parent->children[--parent->NumChildren()];
        }

        // We find the root and shrink bounds at the same time.
        bool stillShrinking = true;
        RectangleTree* root = parent;
        while (root->Parent() != NULL)
        {
          if (stillShrinking)
            stillShrinking = root->ShrinkBoundForBound(bound);
          root = root->Parent();
        }
        if (stillShrinking)
          stillShrinking = root->ShrinkBoundForBound(bound);

        root = parent;
        while (root != NULL)
        {
          root->numDescendants -= numDescendants;
          root = root->Parent();
        }

        stillShrinking = true;
        root = parent;
        while (root->Parent() != NULL)
        {
          if (stillShrinking)
            stillShrinking = root->AuxiliaryInfo().UpdateAuxiliaryInfo(root);
          root = root->Parent();
        }
        if (stillShrinking)
          stillShrinking = root->AuxiliaryInfo().UpdateAuxiliaryInfo(root);

       // Reinsert the points at the root node.
        for (size_t j = 0; j < count; j++)
          root->InsertPoint(points[j], relevels);

        // This will check the minFill of the parent.
        parent->CondenseTree(point, relevels, usePoint);
        // Now it should be safe to delete this node.
        SoftDelete();

        return;
      }
    }
    // Control should never reach here.
    assert(false);
  }
  else if (!IsLeaf() && numChildren < minNumChildren)
  {
    if (parent != NULL)
    {
      // The normal case.  We need to be careful with the root.
      for (size_t j = 0; j < parent->NumChildren(); j++)
      {
        if (parent->children[j] == this)
        {
          // Decrement numChildren.
          if (!auxiliaryInfo.HandleNodeRemoval(parent,j))
          {
            parent->children[j] = parent->children[--parent->NumChildren()];
          }
          size_t level = TreeDepth();

          // We find the root and shrink bounds at the same time.
          bool stillShrinking = true;
          RectangleTree* root = parent;
          while (root->Parent() != NULL)
          {
            if (stillShrinking)
              stillShrinking = root->ShrinkBoundForBound(bound);
            root = root->Parent();
          }
          if (stillShrinking)
            stillShrinking = root->ShrinkBoundForBound(bound);

          root = parent;
          while (root != NULL)
          {
            root->numDescendants -= numDescendants;
            root = root->Parent();
          }

          stillShrinking = true;
          root = parent;
          while (root->Parent() != NULL)
          {
            if (stillShrinking)
              stillShrinking = root->AuxiliaryInfo().UpdateAuxiliaryInfo(root);
            root = root->Parent();
          }
          if (stillShrinking)
            stillShrinking = root->AuxiliaryInfo().UpdateAuxiliaryInfo(root);

          // Reinsert the nodes at the root node.
          for (size_t i = 0; i < numChildren; i++)
            root->InsertNode(children[i], level, relevels);

          // This will check the minFill of the point.
          parent->CondenseTree(point, relevels, usePoint);
          // Now it should be safe to delete this node.
          SoftDelete();

          return;
        }
      }
    }
    else if (numChildren == 1)
    {
      // If there are multiple children, we can't do anything to the root.
      RectangleTree* child = children[0];

      // Required for the X tree.
      if (child->NumChildren() > maxNumChildren)
      {
        maxNumChildren = child->MaxNumChildren();
        children.resize(maxNumChildren+1);
      }

      for (size_t i = 0; i < child->NumChildren(); i++) {
        children[i] = child->children[i];
        children[i]->Parent() = this;
      }

      numChildren = child->NumChildren();

      for (size_t i = 0; i < child->Count(); i++)
      {
        // In case the tree has a height of two.
        points[i] = child->Point(i);
      }

      auxiliaryInfo = child->AuxiliaryInfo();

      count = child->Count();
      child->SoftDelete();
      return;
    }
  }

  // If we didn't delete it, shrink the bound if we need to.
  if (usePoint &&
      (ShrinkBoundForPoint(point) || auxiliaryInfo.UpdateAuxiliaryInfo(this)) &&
      parent != NULL)
    parent->CondenseTree(point, relevels, usePoint);
  else if (!usePoint &&
           (ShrinkBoundForBound(bound) || auxiliaryInfo.UpdateAuxiliaryInfo(this)) &&
           parent != NULL)
    parent->CondenseTree(point, relevels, usePoint);
}
RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType,
              AuxiliaryInformationType>::
RectangleTree(
    const RectangleTree& other,
    const bool deepCopy,
    RectangleTree* newParent) :
    maxNumChildren(other.MaxNumChildren()),
    minNumChildren(other.MinNumChildren()),
    numChildren(other.NumChildren()),
    children(maxNumChildren + 1, NULL),
    parent(deepCopy ? newParent : other.Parent()),
    begin(other.Begin()),
    count(other.Count()),
    numDescendants(other.numDescendants),
    maxLeafSize(other.MaxLeafSize()),
    minLeafSize(other.MinLeafSize()),
    bound(other.bound),
    parentDistance(other.ParentDistance()),
    dataset(deepCopy ?
        (parent ? parent->dataset : new MatType(*other.dataset)) :
        &other.Dataset()),
    ownsDataset(deepCopy && (!parent)),
    points(other.points),
    auxiliaryInfo(other.auxiliaryInfo, this, deepCopy)
{
  if (deepCopy)
  {
    if (numChildren > 0)
    {
      for (size_t i = 0; i < numChildren; i++)
        children[i] = new RectangleTree(other.Child(i), true, this);
    }
  }
  else
    children = other.children;
}
Esempio n. 10
0
RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType>::
RectangleTree(
    const RectangleTree& other,
    const bool deepCopy) :
    maxNumChildren(other.MaxNumChildren()),
    minNumChildren(other.MinNumChildren()),
    numChildren(other.NumChildren()),
    children(maxNumChildren + 1),
    parent(other.Parent()),
    begin(other.Begin()),
    count(other.Count()),
    maxLeafSize(other.MaxLeafSize()),
    minLeafSize(other.MinLeafSize()),
    bound(other.bound),
    splitHistory(other.SplitHistory()),
    parentDistance(other.ParentDistance()),
    dataset(new MatType(*other.dataset)),
    ownsDataset(true),
    points(other.Points()),
    localDataset(NULL)
{
  if (deepCopy)
  {
    if (numChildren > 0)
    {
      for (size_t i = 0; i < numChildren; i++)
      {
        children[i] = new RectangleTree(*(other.Children()[i]));
      }
    }
    else
    {
      localDataset = new MatType(other.LocalDataset());
    }
  }
  else
  {
    children = other.Children();
    arma::mat& otherData = const_cast<arma::mat&>(other.LocalDataset());
    localDataset = &otherData;
  }
}
void RectangleTree<MetricType, StatisticType, MatType, SplitType, DescentType,
                   AuxiliaryInformationType>::
DualTreeTraverser<RuleType>::Traverse(RectangleTree& queryNode,
                                      RectangleTree& referenceNode)
{
  // Increment the visit counter.
  ++numVisited;

  // Store the current traversal info.
  traversalInfo = rule.TraversalInfo();

  // We now have four options.
  // 1)  Both nodes are leaf nodes.
  // 2)  Only the reference node is a leaf node.
  // 3)  Only the query node is a leaf node.
  // 4)  Niether node is a leaf node.
  // We go through those options in that order.

  if (queryNode.IsLeaf() && referenceNode.IsLeaf())
  {
    // Evaluate the base case.  Do the query points on the outside so we can
    // possibly prune the reference node for that particular point.
    for (size_t query = 0; query < queryNode.Count(); ++query)
    {
      // Restore the traversal information.
      rule.TraversalInfo() = traversalInfo;
      const double childScore = rule.Score(queryNode.Point(query),
          referenceNode);

      if (childScore == DBL_MAX)
        continue;  // We don't require a search in this reference node.

      for(size_t ref = 0; ref < referenceNode.Count(); ++ref)
        rule.BaseCase(queryNode.Point(query), referenceNode.Point(ref));

      numBaseCases += referenceNode.Count();
    }
  }
  else if (!queryNode.IsLeaf() && referenceNode.IsLeaf())
  {
    // We only need to traverse down the query node.  Order doesn't matter here.
    for (size_t i = 0; i < queryNode.NumChildren(); ++i)
    {
      // Before recursing, we have to set the traversal information correctly.
      rule.TraversalInfo() = traversalInfo;
      ++numScores;
      if (rule.Score(queryNode.Child(i), referenceNode) < DBL_MAX)
        Traverse(queryNode.Child(i), referenceNode);
      else
        numPrunes++;
    }
  }
  else if (queryNode.IsLeaf() && !referenceNode.IsLeaf())
  {
    // We only need to traverse down the reference node.  Order does matter
    // here.

    // We sort the children of the reference node by their scores.
    std::vector<NodeAndScore> nodesAndScores(referenceNode.NumChildren());
    for (size_t i = 0; i < referenceNode.NumChildren(); i++)
    {
      rule.TraversalInfo() = traversalInfo;
      nodesAndScores[i].node = &(referenceNode.Child(i));
      nodesAndScores[i].score = rule.Score(queryNode,
          *(nodesAndScores[i].node));
      nodesAndScores[i].travInfo = rule.TraversalInfo();
    }
    std::sort(nodesAndScores.begin(), nodesAndScores.end(), nodeComparator);
    numScores += nodesAndScores.size();

    for (size_t i = 0; i < nodesAndScores.size(); i++)
    {
      rule.TraversalInfo() = nodesAndScores[i].travInfo;
      if (rule.Rescore(queryNode, *(nodesAndScores[i].node),
          nodesAndScores[i].score) < DBL_MAX)
      {
        Traverse(queryNode, *(nodesAndScores[i].node));
      }
      else
      {
        numPrunes += nodesAndScores.size() - i;
        break;
      }
    }
  }
  else
  {
    // We need to traverse down both the reference and the query trees.
    // We loop through all of the query nodes, and for each of them, we
    // loop through the reference nodes to see where we need to descend.
    for (size_t j = 0; j < queryNode.NumChildren(); j++)
    {
      // We sort the children of the reference node by their scores.
      std::vector<NodeAndScore> nodesAndScores(referenceNode.NumChildren());
      for (size_t i = 0; i < referenceNode.NumChildren(); i++)
      {
        rule.TraversalInfo() = traversalInfo;
        nodesAndScores[i].node = &(referenceNode.Child(i));
        nodesAndScores[i].score = rule.Score(queryNode.Child(j),
            *nodesAndScores[i].node);
        nodesAndScores[i].travInfo = rule.TraversalInfo();
      }
      std::sort(nodesAndScores.begin(), nodesAndScores.end(), nodeComparator);
      numScores += nodesAndScores.size();

      for (size_t i = 0; i < nodesAndScores.size(); i++)
      {
        rule.TraversalInfo() = nodesAndScores[i].travInfo;
        if (rule.Rescore(queryNode.Child(j), *(nodesAndScores[i].node),
            nodesAndScores[i].score) < DBL_MAX)
        {
          Traverse(queryNode.Child(j), *(nodesAndScores[i].node));
        }
        else
        {
          numPrunes += nodesAndScores.size() - i;
          break;
        }
      }
    }
  }
}