예제 #1
0
void CleanTree(TreeType& node)
{
  node.Stat().LastDistance() = 0.0;

  for (size_t i = 0; i < node.NumChildren(); ++i)
    CleanTree(node.Child(i));
}
예제 #2
0
void CARTTrainer::pruneTree(TreeType & tree){

	//Calculate g of all the nodes
	measureStrength(tree, 0, 0);

	//Find the lowest g of the internal nodes
	double g = std::numeric_limits<double>::max();
	for(std::size_t i = 0; i != tree.size(); i++){
		if(tree[i].leftNodeId > 0 && tree[i].g < g){
			//Update g
			g = tree[i].g;
		}
	}
	//Prune the nodes with lowest g and make them terminal
	for(std::size_t i=0; i != tree.size(); i++){
		//Make the internal nodes with the smallest g terminal nodes and prune their children!
		if( tree[i].leftNodeId > 0 && tree[i].g == g){
			// pruneNode(tree, tree[i].leftNodeId);
			// pruneNode(tree, tree[i].rightNodeId);
			// //Make the node terminal
			tree[i].leftNodeId = 0;
			tree[i].rightNodeId = 0;
		}
	}
}
예제 #3
0
 DTBStat(const TreeType& node) :
     maxNeighborDistance(DBL_MAX),
     minNeighborDistance(DBL_MAX),
     bound(DBL_MAX),
     componentMembership(
         ((node.NumPoints() == 1) && (node.NumChildren() == 0)) ?
           node.Point(0) : -1) { }
예제 #4
0
void CheckHierarchy(const TreeType& tree)
{
  for (size_t i = 0; i < tree.NumChildren(); i++)
  {
    BOOST_REQUIRE_EQUAL(&tree, tree.Child(i).Parent());
    CheckHierarchy(tree.Child(i));
  }
}
예제 #5
0
double FastMKSRules<KernelType, TreeType>::Rescore(TreeType& queryNode,
        TreeType& /*referenceNode*/,
        const double oldScore) const
{
    queryNode.Stat().Bound() = CalculateBound(queryNode);
    const double bestKernel = queryNode.Stat().Bound();

    return ((1.0 / oldScore) > bestKernel) ? oldScore : DBL_MAX;
}
예제 #6
0
  FastMKSStat(const TreeType& node) :
      bound(-DBL_MAX),
      lastKernel(0.0),
      lastKernelNode(NULL)
  {
    // Do we have to calculate the centroid?
    if (tree::TreeTraits<TreeType>::FirstPointIsCentroid)
    {
      // If this type of tree has self-children, then maybe the evaluation is
      // already done.  These statistics are built bottom-up, so the child stat
      // should already be done.
      if ((tree::TreeTraits<TreeType>::HasSelfChildren) &&
          (node.NumChildren() > 0) &&
          (node.Point(0) == node.Child(0).Point(0)))
      {
        selfKernel = node.Child(0).Stat().SelfKernel();
      }
      else
      {
        selfKernel = sqrt(node.Metric().Kernel().Evaluate(
            node.Dataset().col(node.Point(0)),
            node.Dataset().col(node.Point(0))));
      }
    }
    else
    {
      // Calculate the centroid.
      arma::vec center;
      node.Center(center);

      selfKernel = sqrt(node.Metric().Kernel().Evaluate(center, center));
    }
  }
예제 #7
0
  static int GetSplitPolicy(const TreeType& child,
                            const size_t axis,
                            const typename TreeType::ElemType cut)
  {
    if (child.Bound()[axis].Hi() <= cut)
      return AssignToFirstTree;
    else if (child.Bound()[axis].Lo() >= cut)
      return AssignToSecondTree;

    return SplitRequired;
  }
예제 #8
0
/*
	Removes branch with root node id nodeId, incl. the node itself
*/
void CARTTrainer::pruneNode(TreeType & tree, std::size_t nodeId){
	std::size_t i = findNode(tree,nodeId);

	if(tree[i].leftNodeId>0){
		//Prune left branch
		pruneNode(tree, tree[i].leftNodeId);
		//Prune right branch
		pruneNode(tree, tree[i].rightNodeId);
	}
	//Remove node
	tree.erase(tree.begin()+i);
}
double RangeSearchRules<MetricType, TreeType>::Score(const size_t queryIndex,
                                                     TreeType& referenceNode)
{
  // We must get the minimum and maximum distances and store them in this
  // object.
  math::Range distances;

  if (tree::TreeTraits<TreeType>::FirstPointIsCentroid)
  {
    // In this situation, we calculate the base case.  So we should check to be
    // sure we haven't already done that.
    double baseCase;
    if (tree::TreeTraits<TreeType>::HasSelfChildren &&
        (referenceNode.Parent() != NULL) &&
        (referenceNode.Point(0) == referenceNode.Parent()->Point(0)))
    {
      // If the tree has self-children and this is a self-child, the base case
      // was already calculated.
      baseCase = referenceNode.Parent()->Stat().LastDistance();
      lastQueryIndex = queryIndex;
      lastReferenceIndex = referenceNode.Point(0);
    }
    else
    {
      // We must calculate the base case by hand.
      baseCase = BaseCase(queryIndex, referenceNode.Point(0));
    }

    // This may be possibly loose for non-ball bound trees.
    distances.Lo() = baseCase - referenceNode.FurthestDescendantDistance();
    distances.Hi() = baseCase + referenceNode.FurthestDescendantDistance();

    // Update last distance calculation.
    referenceNode.Stat().LastDistance() = baseCase;
  }
  else
  {
    distances = referenceNode.RangeDistance(querySet.unsafe_col(queryIndex));
  }

  // If the ranges do not overlap, prune this node.
  if (!distances.Contains(range))
    return DBL_MAX;

  // In this case, all of the points in the reference node will be part of the
  // results.
  if ((distances.Lo() >= range.Lo()) && (distances.Hi() <= range.Hi()))
  {
    AddResult(queryIndex, referenceNode);
    return DBL_MAX; // We don't need to go any deeper.
  }

  // Otherwise the score doesn't matter.  Recursion order is irrelevant in
  // range search.
  return 0.0;
}
bool HilbertRTreeAuxiliaryInformation<TreeType, HilbertValueType>::
UpdateAuxiliaryInfo(TreeType* node)
{
  if (node->IsLeaf())  //  Should already be updated
    return true;

  TreeType* child = node->Children()[node->NumChildren() - 1];
  if (hilbertValue.CompareWith(child->AuxiliaryInfo().hilbertValue()) < 0)
  {
    hilbertValue.Copy(node,child);
    return true;
  }
  return false;
}
예제 #11
0
void SearchTree(TreeType<T>& tree)
{
	T input;
	cout << "Enter item to search: ";
	cin >> input;
	tree.SearchItem(input);
}
예제 #12
0
void DeleteFromTree(TreeType<T>& tree)
{
	T input;
	cout << "Enter item to delete: ";
	cin >> input;
	tree.DeleteItem(input);
}
예제 #13
0
void InsertItem(TreeType<T>& tree)
{
	cout << "Enter number: ";
	T num;
	cin >> num;
	tree.InsertItem(num);
}
예제 #14
0
int GetMaxLevel(const TreeType& tree)
{
  int max = 1;
  if (!tree.IsLeaf())
  {
    int m = 0;
    for (size_t i = 0; i < tree.NumChildren(); i++)
    {
      int n = GetMaxLevel(*tree.Children()[i]);
      if (n > m)
        m = n;
    }
    max += m;
  }

  return max;
}
예제 #15
0
int GetMinLevel(const TreeType& tree)
{
  int min = 1;
  if (!tree.IsLeaf())
  {
    int m = INT_MAX;
    for (size_t i = 0; i < tree.NumChildren(); i++)
    {
      int n = GetMinLevel(*tree.Children()[i]);
      if (n < m)
        m = n;
    }
    min += m;
  }

  return min;
}
void NeighborSearchRules<
    SortPolicy,
    MetricType,
    TreeType>::
UpdateAfterRecursion(TreeType& queryNode, TreeType& /* referenceNode */)
{
  // Find the worst distance that the children found (including any points), and
  // update the bound accordingly.
  double worstDistance = SortPolicy::BestDistance();

  // First look through children nodes.
  for (size_t i = 0; i < queryNode.NumChildren(); ++i)
  {
    if (SortPolicy::IsBetter(worstDistance, queryNode.Child(i).Stat().Bound()))
      worstDistance = queryNode.Child(i).Stat().Bound();
  }

  // Now look through children points.
  for (size_t i = 0; i < queryNode.NumPoints(); ++i)
  {
    if (SortPolicy::IsBetter(worstDistance,
        distances(distances.n_rows - 1, queryNode.Point(i))))
      worstDistance = distances(distances.n_rows - 1, queryNode.Point(i));
  }

  // Take the worst distance from all of these, and update our bound to reflect
  // that.
  queryNode.Stat().Bound() = worstDistance;
}
예제 #17
0
double DTBRules<MetricType, TreeType>::Score(TreeType& queryNode,
                                             TreeType& referenceNode)
{
  // If all the queries belong to the same component as all the references
  // then we prune.
  if ((queryNode.Stat().ComponentMembership() >= 0) &&
      (queryNode.Stat().ComponentMembership() ==
           referenceNode.Stat().ComponentMembership()))
    return DBL_MAX;

  ++scores;
  const double distance = queryNode.MinDistance(&referenceNode);
  const double bound = CalculateBound(queryNode);

  // If all the points in the reference node are farther than the candidate
  // nearest neighbor for all queries in the node, we prune.
  return (bound < distance) ? DBL_MAX : distance;
}
예제 #18
0
void CheckContainment(const TreeType& tree)
{
  if (tree.NumChildren() == 0)
  {
    for (size_t i = 0; i < tree.Count(); i++)
      BOOST_REQUIRE(tree.Bound().Contains(
          tree.Dataset().unsafe_col(tree.Points()[i])));
  }
  else
  {
    for (size_t i = 0; i < tree.NumChildren(); i++)
    {
      for (size_t j = 0; j < tree.Bound().Dim(); j++)
        BOOST_REQUIRE(tree.Bound()[j].Contains(tree.Children()[i]->Bound()[j]));

      CheckContainment(*(tree.Children()[i]));
    }
  }
}
예제 #19
0
inline double KDERules<MetricType, KernelType, TreeType>::
Score(const size_t queryIndex, TreeType& referenceNode)
{
  double score, maxKernel, minKernel, bound;
  const arma::vec& queryPoint = querySet.unsafe_col(queryIndex);
  const double minDistance = referenceNode.MinDistance(queryPoint);
  bool newCalculations = true;

  if (tree::TreeTraits<TreeType>::FirstPointIsCentroid &&
      lastQueryIndex == queryIndex &&
      traversalInfo.LastReferenceNode() != NULL &&
      traversalInfo.LastReferenceNode()->Point(0) == referenceNode.Point(0))
  {
    // Don't duplicate calculations.
    newCalculations = false;
    lastQueryIndex = queryIndex;
    lastReferenceIndex = referenceNode.Point(0);
  }
  else
  {
    // Calculations are new.
    maxKernel = kernel.Evaluate(minDistance);
    minKernel = kernel.Evaluate(referenceNode.MaxDistance(queryPoint));
    bound = maxKernel - minKernel;
  }

  if (newCalculations &&
      bound <= (absError + relError * minKernel) / referenceSet.n_cols)
  {
    // Estimate values.
    double kernelValue;

    // Calculate kernel value based on reference node centroid.
    if (tree::TreeTraits<TreeType>::FirstPointIsCentroid)
    {
      kernelValue = EvaluateKernel(queryIndex, referenceNode.Point(0));
    }
    else
    {
      kde::KDEStat& referenceStat = referenceNode.Stat();
      kernelValue = EvaluateKernel(queryPoint, referenceStat.Centroid());
    }

    densities(queryIndex) += referenceNode.NumDescendants() * kernelValue;

    // Don't explore this tree branch.
    score = DBL_MAX;
  }
  else
  {
    score = minDistance;
  }

  ++scores;
  traversalInfo.LastReferenceNode() = &referenceNode;
  traversalInfo.LastScore() = score;
  return score;
}
size_t nodeDepth(const TreeType& tree, const NodeType& node)
{
    const auto* currentNode = &node;
    size_t depth = 0;
    while (!currentNode->isRoot()) {
        currentNode = &tree.nodeAt(currentNode->parentNodeIndex);
        depth++;
    }
    return depth;
}
inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Score(
    TreeType& queryNode,
    TreeType& referenceNode) const
{
  const double distance = SortPolicy::BestNodeToNodeDistance(&queryNode,
      &referenceNode);
  const double bestDistance = queryNode.Stat().Bound();

  return (SortPolicy::IsBetter(distance, bestDistance)) ? distance : DBL_MAX;
}
예제 #22
0
inline double DTBRules<MetricType, TreeType>::CalculateBound(
    TreeType& queryNode) const
{
  double worstPointBound = -DBL_MAX;
  double bestPointBound = DBL_MAX;

  double worstChildBound = -DBL_MAX;
  double bestChildBound = DBL_MAX;

  // Now, find the best and worst point bounds.
  for (size_t i = 0; i < queryNode.NumPoints(); ++i)
  {
    const size_t pointComponent = connections.Find(queryNode.Point(i));
    const double bound = neighborsDistances[pointComponent];

    if (bound > worstPointBound)
      worstPointBound = bound;
    if (bound < bestPointBound)
      bestPointBound = bound;
  }

  // Find the best and worst child bounds.
  for (size_t i = 0; i < queryNode.NumChildren(); ++i)
  {
    const double maxBound = queryNode.Child(i).Stat().MaxNeighborDistance();
    if (maxBound > worstChildBound)
      worstChildBound = maxBound;

    const double minBound = queryNode.Child(i).Stat().MinNeighborDistance();
    if (minBound < bestChildBound)
      bestChildBound = minBound;
  }

  // Now calculate the actual bounds.
  const double worstBound = std::max(worstPointBound, worstChildBound);
  const double bestBound = std::min(bestPointBound, bestChildBound);
  // We must check that bestBound != DBL_MAX; otherwise, we risk overflow.
  const double bestAdjustedBound = (bestBound == DBL_MAX) ? DBL_MAX :
      bestBound + 2 * queryNode.FurthestDescendantDistance();

  // Update the relevant quantities in the node.
  queryNode.Stat().MaxNeighborDistance() = worstBound;
  queryNode.Stat().MinNeighborDistance() = bestBound;
  queryNode.Stat().Bound() = std::min(worstBound, bestAdjustedBound);

  return queryNode.Stat().Bound();
}
예제 #23
0
double DTBRules<MetricType, TreeType>::Score(const size_t queryIndex,
                                             TreeType& referenceNode)
{
  size_t queryComponentIndex = connections.Find(queryIndex);

  // If the query belongs to the same component as all of the references,
  // then prune.  The cast is to stop a warning about comparing unsigned to
  // signed values.
  if (queryComponentIndex ==
      (size_t) referenceNode.Stat().ComponentMembership())
    return DBL_MAX;

  const arma::vec queryPoint = dataSet.unsafe_col(queryIndex);
  const double distance = referenceNode.MinDistance(queryPoint);

  // If all the points in the reference node are farther than the candidate
  // nearest neighbor for the query's component, we prune.
  return neighborsDistances[queryComponentIndex] < distance
      ? DBL_MAX : distance;
}
예제 #24
0
/**
 * @brief Return a list clades, each itself containing a list of edge indices of that clade.
 *
 * The function takes a list of clades with their taxa as input, and a reference tree.
 * It then inspects all clades and findes the edges of the tree that belong into a clade.
 * Furthermore, a clade "basal_branches" is added for those edges of the tree that do not
 * belong to any clade.
 *
 * The edges of a clade are determined by finding the smalles subtree (split) of the tree that
 * contains all nodes of the clade. That means, the clades should be monophyletic in order for this
 * algorithm to work properly.
 */
CladeEdgeList get_clade_edges( CladeTaxaList const& clades, TreeType& tree )
{
    // Prepare the result map.
    CladeEdgeList clade_edges;

    // Make a set of all edges that do not belong to any clade (the basal branches of the tree).
    // We first fill it with all edge indices, then remove the clade-edges later,
    // so that only the wanted ones remain.
    std::unordered_set<size_t> basal_branches;
    for( auto it = tree.begin_edges(); it != tree.end_edges(); ++it ) {
        basal_branches.insert( it->index() );
    }

    // Process all clades.
    for( auto const& clade : clades ) {

        // Find the edges that are part of the subtree of this clade.
        auto const subedges = get_clade_edges( tree, clade.second );

        // TODO for now, we convert to an unordered map here by hand.
        // this can be clean up!
        auto const subedge_map = std::unordered_set<size_t>(
            subedges.begin(), subedges.end()
        );

        // Add them to the clade edges list.
        clade_edges.push_back( std::make_pair( clade.first, subedge_map ));

        // Remove the edge indices of this clade from the basal branches (non-clade) edges list.
        for( auto const edge : subedges ) {
            basal_branches.erase( edge );
        }
    }

    // Now that we have processed all clades, also add the non-clade edges (basal branches)
    // to the list as a special clade "basal_branches". This way, all edges of the reference tree
    // are used by exaclty one clade.
    clade_edges.push_back( std::make_pair( "basal_branches", basal_branches ));

    return clade_edges;
}
inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::Rescore(
    TreeType& queryNode,
    TreeType& /* referenceNode */,
    const double oldScore) const
{
  if (oldScore == DBL_MAX)
    return oldScore;

  const double bestDistance = queryNode.Stat().Bound();

  return (SortPolicy::IsBetter(oldScore, bestDistance)) ? oldScore : DBL_MAX;
}
void TestRangeQuery(TreeType &a_tree, string &db_filename, string &key1, string &key2) {

    SequenceMap lower_bound(key1);
    SequenceMap upper_bound(key2);

    ifstream rebasefile(db_filename);
    string db_line;

    if (!rebasefile.is_open()) {   //Program exists if file is not opened or cannot be opened
        exit(EXIT_FAILURE);
    }
    
    while(rebasefile.good()) {
        
        (getline(rebasefile, db_line));
        
        size_t pos;
        string delimiter = "/";
        string a_reco_seq;
        string an_enz_acro;
        
        //Enzyme acronym is copied and deleted from db_line
        pos = db_line.find(delimiter);
        an_enz_acro = db_line.substr(0, pos);
        db_line.erase(0, pos + delimiter.length());
        
        while((pos = db_line.find(delimiter)) != string::npos) {
            a_reco_seq = db_line.substr(0, pos);
            if(a_reco_seq == "") {
                break;
            } 
            else {
                SequenceMap new_sequence_map(a_reco_seq, an_enz_acro);
                a_tree.insert(new_sequence_map);
                db_line.erase(0, pos + delimiter.length());
            }
        }
    }
    a_tree.print_range_elem(lower_bound, upper_bound);
} 
예제 #27
0
double DTBRules<MetricType, TreeType>::Score(const size_t queryIndex,
                                             TreeType& referenceNode,
                                             const double baseCaseResult)
{
  // I don't really understand the last argument here
  // It just gets passed in the distance call, otherwise this function
  // is the same as the one above.
  size_t queryComponentIndex = connections.Find(queryIndex);

  // If the query belongs to the same component as all of the references,
  // then prune.
  if (queryComponentIndex == referenceNode.Stat().ComponentMembership())
    return DBL_MAX;

  const arma::vec queryPoint = dataSet.unsafe_col(queryIndex);
  const double distance = referenceNode.MinDistance(queryPoint,
                                                    baseCaseResult);

  // If all the points in the reference node are farther than the candidate
  // nearest neighbor for the query's component, we prune.
  return (neighborsDistances[queryComponentIndex] < distance) ? DBL_MAX :
      distance;
}
예제 #28
0
void FillTree (std::string db_filename, TreeType &a_tree) {
    ifstream inStream(db_filename);
    std::string db_line, an_enz_acro, a_reco_seq, garbage_line;
    for (int i = 0; i < 10; i ++)                               //Skip over the header and begin reading on line 11. 
        getline(inStream, garbage_line);
    while (std::getline (inStream, db_line)) {                  //inStream has reached line 11 and will begin to parse data.
        if (db_line.empty()) continue;
        size_t first_slash = db_line.find("/");
        an_enz_acro = GetEnzymeAcronym(db_line, first_slash);
        while (GetNextRecognitionSequence(db_line, a_reco_seq, first_slash)) {
            SequenceMap new_sequence_map(a_reco_seq, an_enz_acro);
            a_tree.insert(new_sequence_map);
        }
    }
}
예제 #29
0
void CheckBound(const TreeType& tree)
{
  typedef typename TreeType::ElemType ElemType;
  for (size_t i = 0; i < tree.NumDescendants(); i++)
  {
    arma::Col<ElemType> point = tree.Dataset().col(tree.Descendant(i));

    // Check that the point is contained in the bound.
    BOOST_REQUIRE_EQUAL(true, tree.Bound().Contains(point));

    const arma::Mat<ElemType>& loBound = tree.Bound().LoBound();
    const arma::Mat<ElemType>& hiBound = tree.Bound().HiBound();

    // Ensure that there is a hyperrectangle that contains the point.
    bool success = false;
    for (size_t j = 0; j < tree.Bound().NumBounds(); j++)
    {
      success = true;
      for (size_t k = 0; k < loBound.n_rows; k++)
      {
        if (point[k] < loBound(k, j) - 1e-14 * std::fabs(loBound(k, j)) ||
            point[k] > hiBound(k, j) + 1e-14 * std::fabs(hiBound(k, j)))
        {
          success = false;
          break;
        }
      }
      if (success)
        break;
    }

    BOOST_REQUIRE_EQUAL(success, true);
  }

  if (!tree.IsLeaf())
  {
    CheckBound(*tree.Left());
    CheckBound(*tree.Right());
  }
}
  DualTreeKMeansStatistic(TreeType& node) :
      closestQueryNode(NULL),
      minQueryNodeDistance(DBL_MAX),
      maxQueryNodeDistance(DBL_MAX),
      clustersPruned(0),
      iteration(size_t() - 1)
  {
    // Empirically calculate the centroid.
    centroid.zeros(node.Dataset().n_rows);
    for (size_t i = 0; i < node.NumPoints(); ++i)
      centroid += node.Dataset().col(node.Point(i));

    for (size_t i = 0; i < node.NumChildren(); ++i)
      centroid += node.Child(i).NumDescendants() *
          node.Child(i).Stat().Centroid();

    centroid /= node.NumDescendants();
  }