コード例 #1
0
  static int GetSplitPolicy(const TreeType& child,
                            const size_t axis,
                            const typename TreeType::ElemType cut)
  {
    if (child.Bound()[axis].Hi() <= cut)
      return AssignToFirstTree;
    else if (child.Bound()[axis].Lo() >= cut)
      return AssignToSecondTree;

    return SplitRequired;
  }
コード例 #2
0
void CheckSplit(const TreeType& tree)
{
  typedef typename TreeType::ElemType ElemType;
  typedef typename std::conditional<sizeof(ElemType) * CHAR_BIT <= 32,
                                    uint32_t,
                                    uint64_t>::type AddressElemType;

  if (tree.IsLeaf())
    return;

  arma::Col<AddressElemType> lo(tree.Bound().Dim());
  arma::Col<AddressElemType> hi(tree.Bound().Dim());

  lo.fill(std::numeric_limits<AddressElemType>::max());
  hi.fill(0);

  arma::Col<AddressElemType> address(tree.Bound().Dim());

  // Find the highest address of the left node.
  for (size_t i = 0; i < tree.Left()->NumDescendants(); i++)
  {
    addr::PointToAddress(address,
        tree.Dataset().col(tree.Left()->Descendant(i)));

    if (addr::CompareAddresses(address, hi) > 0)
      hi = address;
  }

  // Find the lowest address of the right node.
  for (size_t i = 0; i < tree.Right()->NumDescendants(); i++)
  {
    addr::PointToAddress(address,
        tree.Dataset().col(tree.Right()->Descendant(i)));

    if (addr::CompareAddresses(address, lo) < 0)
      lo = address;
  }

  // Addresses in the left node should be less than addresses in the right node.
  BOOST_REQUIRE_LE(addr::CompareAddresses(hi, lo), 0);

  CheckSplit(*tree.Left());
  CheckSplit(*tree.Right());
}
コード例 #3
0
void CheckContainment(const TreeType& tree)
{
  if (tree.NumChildren() == 0)
  {
    for (size_t i = 0; i < tree.Count(); i++)
      BOOST_REQUIRE(tree.Bound().Contains(
          tree.Dataset().unsafe_col(tree.Points()[i])));
  }
  else
  {
    for (size_t i = 0; i < tree.NumChildren(); i++)
    {
      for (size_t j = 0; j < tree.Bound().Dim(); j++)
        BOOST_REQUIRE(tree.Bound()[j].Contains(tree.Children()[i]->Bound()[j]));

      CheckContainment(*(tree.Children()[i]));
    }
  }
}
コード例 #4
0
void CheckExactContainment(const TreeType& tree)
{
  if (tree.NumChildren() == 0)
  {
    for (size_t i = 0; i < tree.Bound().Dim(); i++)
    {
      double min = DBL_MAX;
      double max = -1.0 * DBL_MAX;
      for(size_t j = 0; j < tree.Count(); j++)
      {
        if (tree.LocalDataset().col(j)[i] < min)
          min = tree.LocalDataset().col(j)[i];
        if (tree.LocalDataset().col(j)[i] > max)
          max = tree.LocalDataset().col(j)[i];
      }
      BOOST_REQUIRE_EQUAL(max, tree.Bound()[i].Hi());
      BOOST_REQUIRE_EQUAL(min, tree.Bound()[i].Lo());
    }
  }
  else
  {
    for (size_t i = 0; i < tree.Bound().Dim(); i++)
    {
      double min = DBL_MAX;
      double max = -1.0 * DBL_MAX;
      for (size_t j = 0; j < tree.NumChildren(); j++)
      {
        if(tree.Child(j).Bound()[i].Lo() < min)
          min = tree.Child(j).Bound()[i].Lo();
        if(tree.Child(j).Bound()[i].Hi() > max)
          max = tree.Child(j).Bound()[i].Hi();
      }

      BOOST_REQUIRE_EQUAL(max, tree.Bound()[i].Hi());
      BOOST_REQUIRE_EQUAL(min, tree.Bound()[i].Lo());
    }

    for (size_t i = 0; i < tree.NumChildren(); i++)
      CheckExactContainment(tree.Child(i));
  }
}
コード例 #5
0
void CheckBound(const TreeType& tree)
{
  typedef typename TreeType::ElemType ElemType;
  for (size_t i = 0; i < tree.NumDescendants(); i++)
  {
    arma::Col<ElemType> point = tree.Dataset().col(tree.Descendant(i));

    // Check that the point is contained in the bound.
    BOOST_REQUIRE_EQUAL(true, tree.Bound().Contains(point));

    const arma::Mat<ElemType>& loBound = tree.Bound().LoBound();
    const arma::Mat<ElemType>& hiBound = tree.Bound().HiBound();

    // Ensure that there is a hyperrectangle that contains the point.
    bool success = false;
    for (size_t j = 0; j < tree.Bound().NumBounds(); j++)
    {
      success = true;
      for (size_t k = 0; k < loBound.n_rows; k++)
      {
        if (point[k] < loBound(k, j) - 1e-14 * std::fabs(loBound(k, j)) ||
            point[k] > hiBound(k, j) + 1e-14 * std::fabs(hiBound(k, j)))
        {
          success = false;
          break;
        }
      }
      if (success)
        break;
    }

    BOOST_REQUIRE_EQUAL(success, true);
  }

  if (!tree.IsLeaf())
  {
    CheckBound(*tree.Left());
    CheckBound(*tree.Right());
  }
}
コード例 #6
0
double PellegMooreKMeansRules<MetricType, TreeType>::Score(
    const size_t /* queryIndex */,
    TreeType& referenceNode)
{
  // Obtain the parent's blacklist.  If this is the root node, we'll start with
  // an empty blacklist.  This means that after each iteration, we don't need to
  // reset any statistics.
  if (referenceNode.Parent() == NULL ||
      referenceNode.Parent()->Stat().Blacklist().n_elem == 0)
    referenceNode.Stat().Blacklist().zeros(centroids.n_cols);
  else
    referenceNode.Stat().Blacklist() =
        referenceNode.Parent()->Stat().Blacklist();

  // The query index is a fake index that we won't use, and the reference node
  // holds all of the points in the dataset.  Our goal is to determine whether
  // or not this node is dominated by a single cluster.
  const size_t whitelisted = centroids.n_cols -
      arma::accu(referenceNode.Stat().Blacklist());

  distanceCalculations += whitelisted;

  // Which cluster has minimum distance to the node?
  size_t closestCluster = centroids.n_cols;
  double minMinDistance = DBL_MAX;
  for (size_t i = 0; i < centroids.n_cols; ++i)
  {
    if (referenceNode.Stat().Blacklist()[i] == 0)
    {
      const double minDistance = referenceNode.MinDistance(centroids.col(i));
      if (minDistance < minMinDistance)
      {
        minMinDistance = minDistance;
        closestCluster = i;
      }
    }
  }

  // Now, for every other whitelisted cluster, determine if the closest cluster
  // owns the point.  This calculation is specific to hyperrectangle trees (but,
  // this implementation is specific to kd-trees, so that's okay).  For
  // circular-bound trees, the condition should be simpler and can probably be
  // expressed as a comparison between minimum and maximum distances.
  size_t newBlacklisted = 0;
  for (size_t c = 0; c < centroids.n_cols; ++c)
  {
    if (referenceNode.Stat().Blacklist()[c] == 1 || c == closestCluster)
      continue;

    // This algorithm comes from the proof of Lemma 4 in the extended version
    // of the Pelleg-Moore paper (the CMU tech report, that is).  It has been
    // adapted for speed.
    arma::vec cornerPoint(centroids.n_rows);
    for (size_t d = 0; d < referenceNode.Bound().Dim(); ++d)
    {
      if (centroids(d, c) > centroids(d, closestCluster))
        cornerPoint(d) = referenceNode.Bound()[d].Hi();
      else
        cornerPoint(d) = referenceNode.Bound()[d].Lo();
    }

    const double closestDist = metric.Evaluate(cornerPoint,
        centroids.col(closestCluster));
    const double otherDist = metric.Evaluate(cornerPoint, centroids.col(c));

    distanceCalculations += 3; // One for cornerPoint, then two distances.

    if (closestDist < otherDist)
    {
      // The closest cluster dominates the node with respect to the cluster c.
      // So we can blacklist c.
      referenceNode.Stat().Blacklist()[c] = 1;
      ++newBlacklisted;
    }
  }

  if (whitelisted - newBlacklisted == 1)
  {
    // This node is dominated by the closest cluster.
    counts[closestCluster] += referenceNode.NumDescendants();
    newCentroids.col(closestCluster) += referenceNode.NumDescendants() *
        referenceNode.Stat().Centroid();

    return DBL_MAX;
  }

  // Perform the base case here.
  for (size_t i = 0; i < referenceNode.NumPoints(); ++i)
  {
    size_t bestCluster = centroids.n_cols;
    double bestDistance = DBL_MAX;
    for (size_t c = 0; c < centroids.n_cols; ++c)
    {
      if (referenceNode.Stat().Blacklist()[c] == 1)
        continue;

      ++distanceCalculations;

      // The reference index is the index of the data point.
      const double distance = metric.Evaluate(centroids.col(c),
          dataset.col(referenceNode.Point(i)));

      if (distance < bestDistance)
      {
        bestDistance = distance;
        bestCluster = c;
      }
    }

    // Add to resulting centroid.
    newCentroids.col(bestCluster) += dataset.col(referenceNode.Point(i));
    ++counts(bestCluster);
  }

  // Otherwise, we're not sure, so we can't prune.  Recursion order doesn't make
  // a difference, so we'll just return a score of 0.
  return 0.0;
}
コード例 #7
0
void CheckDistance(TreeType& tree, TreeType* node = NULL)
{
  typedef typename TreeType::ElemType ElemType;
  if (node == NULL)
  {
    node = &tree;

    while (node->Parent() != NULL)
      node = node->Parent();

    CheckDistance<TreeType, MetricType>(tree, node);

    for (size_t j = 0; j < tree.Dataset().n_cols; j++)
    {
      const arma::Col<ElemType>& point = tree.  Dataset().col(j);
      ElemType maxDist = 0;
      ElemType minDist = std::numeric_limits<ElemType>::max();
      for (size_t i = 0; i < tree.NumDescendants(); i++)
      {
        ElemType dist = MetricType::Evaluate(
            tree.Dataset().col(tree.Descendant(i)),
            tree.Dataset().col(j));

        if (dist > maxDist)
          maxDist = dist;
        if (dist < minDist)
          minDist = dist;
      }

      BOOST_REQUIRE_LE(tree.Bound().MinDistance(point), minDist *
          (1.0 + 10 * std::numeric_limits<ElemType>::epsilon()));
      BOOST_REQUIRE_LE(maxDist, tree.Bound().MaxDistance(point) *
          (1.0 + 10 * std::numeric_limits<ElemType>::epsilon()));

      math::RangeType<ElemType> r = tree.Bound().RangeDistance(point);

      BOOST_REQUIRE_LE(r.Lo(), minDist *
          (1.0 + 10 * std::numeric_limits<ElemType>::epsilon()));
      BOOST_REQUIRE_LE(maxDist, r.Hi() *
          (1.0 + 10 * std::numeric_limits<ElemType>::epsilon()));
    }
      
    if (!tree.IsLeaf())
    {
      CheckDistance<TreeType, MetricType>(*tree.Left());
      CheckDistance<TreeType, MetricType>(*tree.Right());
    }
  }
  else
  {
    if (&tree != node)
    {
      ElemType maxDist = 0;
      ElemType minDist = std::numeric_limits<ElemType>::max();
      for (size_t i = 0; i < tree.NumDescendants(); i++)
        for (size_t j = 0; j < node->NumDescendants(); j++)
        {
          ElemType dist = MetricType::Evaluate(
              tree.Dataset().col(tree.Descendant(i)),
              node->Dataset().col(node->Descendant(j)));

          if (dist > maxDist)
            maxDist = dist;
          if (dist < minDist)
            minDist = dist;
        }

      BOOST_REQUIRE_LE(tree.Bound().MinDistance(node->Bound()), minDist *
          (1.0 + 10 * std::numeric_limits<ElemType>::epsilon()));
      BOOST_REQUIRE_LE(maxDist, tree.Bound().MaxDistance(node->Bound()) *
          (1.0 + 10 * std::numeric_limits<ElemType>::epsilon()));

      math::RangeType<ElemType> r = tree.Bound().RangeDistance(node->Bound());

      BOOST_REQUIRE_LE(r.Lo(), minDist *
          (1.0 + 10 * std::numeric_limits<ElemType>::epsilon()));
      BOOST_REQUIRE_LE(maxDist, r.Hi() *
          (1.0 + 10 * std::numeric_limits<ElemType>::epsilon()));
    }
    if (!node->IsLeaf())
    {
      CheckDistance<TreeType, MetricType>(tree, node->Left());
      CheckDistance<TreeType, MetricType>(tree, node->Right());
    }
  }
}