コード例 #1
0
ファイル: tree.cpp プロジェクト: dongyu1990/gbdt
void RegressionTree::Fit(DataVector *data,
                         size_t len,
                         Node *node,
                         size_t depth,
                         double *gain) {
  size_t max_depth = g_conf.max_depth;

  if (g_conf.loss == SQUARED_ERROR) {
    node->pred = Average(*data, len);
  } else if (g_conf.loss == LOG_LIKELIHOOD) {
    node->pred = LogitOptimalValue(*data, len);
  }

  if (max_depth == depth
      || Same(*data, len)
      || len <= g_conf.min_leaf_size) {
    node->leaf = true;
    return;
  }

  double g = 0.0;
  if (!FindSplit(data, len, &(node->index), &(node->value), &g)) {
    node->leaf = true;
    return;
  }

  DataVector out[Node::CHILDSIZE];

  SplitData(*data, len, node->index, node->value, out);
  if (out[Node::LT].empty() || out[Node::GE].empty()) {
    node->leaf = true;
    return;
  }

  // update gain
  if (gain[node->index] < g) {
    gain[node->index] = g;
  }

  // increase feature cost if certain feature is used
  if (g_conf.feature_costs && g_conf.enable_feature_tunning) {
    g_conf.feature_costs[node->index] += 1.0e-4;
  }

  node->child[Node::LT] = new Node();
  node->child[Node::GE] = new Node();

  Fit(&out[Node::LT], node->child[Node::LT], depth+1, gain);
  Fit(&out[Node::GE], node->child[Node::GE], depth+1, gain);

  if (!out[Node::UNKNOWN].empty()) {
    node->child[Node::UNKNOWN] = new Node();
    Fit(&out[Node::UNKNOWN], node->child[Node::UNKNOWN], depth+1, gain);
  }
}
コード例 #2
0
ファイル: tree.cpp プロジェクト: nadult/Snail
void DBVH::Construct(const vector<ObjectInstance> &tElements) {
	elements = tElements;
	nodes.reserve(elements.size() * 2);
	nodes.clear();
	
	depth = 0;

	BBox bbox(elements[0].GetBBox());
	for(size_t n = 1; n < elements.size(); n++)
		bbox += BBox(elements[n].GetBBox());

	nodes.push_back(Node(bbox));
	FindSplit(0, 0, elements.size(), 0);

	ASSERT(depth <= maxDepth);
}
コード例 #3
0
ファイル: tree.cpp プロジェクト: nadult/Snail
void DBVH::FindSplit(int nNode, int first, int count, int sdepth) {
	BBox bbox = nodes[nNode].bbox;

	if(count <= 1) {
	LEAF_NODE:
		depth = Max(depth, sdepth);
		nodes[nNode].first = first | 0x80000000;
		nodes[nNode].count = count;
	}
	else {
		int minAxis = MaxAxis(bbox.Size());
		const float traverseCost = 0.0;
		const float intersectCost = 1.0;

		struct Bin {
			Bin() :count(0) {
				box.min = Vec3f(constant::inf, constant::inf, constant::inf);
				box.max = Vec3f(-constant::inf, -constant::inf, -constant::inf);
			}

			BBox box;
			int count;
		};

		int nBins = count < 8? 8 : 16;
		Bin bins[nBins];

		float mul = nBins * (1.0f - constant::epsilon) /
			((&bbox.max.x)[minAxis] - (&bbox.min.x)[minAxis]);
		float sub = (&bbox.min.x)[minAxis];

		for(int n = 0; n < count; n++) {
			const BBox &box = elements[first + n].GetBBox();
			float c = ((&box.max.x)[minAxis] + (&box.min.x)[minAxis]) * 0.5f;
			int nBin = int( (c - sub) * mul );
			bins[nBin].count++;
			bins[nBin].box += box;
		}

		BBox leftBoxes[nBins], rightBoxes[nBins];
		int leftCounts[nBins], rightCounts[nBins];

		rightBoxes[nBins - 1] = bins[nBins - 1].box;
		rightCounts[nBins - 1] = bins[nBins - 1].count;
		leftBoxes[0] = bins[0].box;
		leftCounts[0] = bins[0].count;

		for(size_t n = 1; n < nBins; n++) {
			leftBoxes[n] = leftBoxes[n - 1] + bins[n].box;
			leftCounts[n] = leftCounts[n - 1] + bins[n].count;
		}
		for(int n = nBins - 2; n >= 0; n--) {
			rightBoxes[n] = rightBoxes[n + 1] + bins[n].box;
			rightCounts[n] = rightCounts[n + 1] + bins[n].count;
		}
		
		float minCost = constant::inf;
		float noSplitCost = intersectCost * count * BoxSA(bbox);
		int minIdx = 1;

		for(size_t n = 1; n < nBins; n++) {
			float cost =
				(leftCounts[n - 1]?	BoxSA(leftBoxes[n - 1]) * leftCounts[n - 1] : 0) +
				(rightCounts[n]?	BoxSA(rightBoxes[n]) * rightCounts[n] : 0);

			if(cost < minCost) {
				minCost = cost;
				minIdx = n;
			}
		}

		minCost = traverseCost + intersectCost * minCost;
		if(noSplitCost < minCost)
			goto LEAF_NODE;
		
		ObjectInstance *it =
			std::partition(&elements[first], &elements[first + count], TestBoxes(minAxis, minIdx, sub, mul));

		BBox leftBox = leftBoxes[minIdx - 1];
		BBox rightBox = rightBoxes[minIdx];
		int leftCount = leftCounts[minIdx - 1];
		int rightCount = rightCounts[minIdx];

		if(leftCount == 0 || rightCount == 0) {
			minIdx = count / 2;
			leftBox = elements[first].GetBBox();
			rightBox = elements[first + count - 1].GetBBox();

			for(size_t n = 1; n < minIdx; n++)
				leftBox  += elements[first + n].GetBBox();
			for(size_t n = minIdx; n < count; n++)
				rightBox += elements[first + n].GetBBox();
			leftCount = minIdx;
			rightCount = count - leftCount;
		}
			
		int subNode = nodes.size();
		nodes[nNode].subNode = subNode;

		nodes[nNode].axis = minAxis;
		nodes[nNode].firstNode = leftBox.min[minAxis] > rightBox.min[minAxis]? 1 : 0;
		nodes[nNode].firstNode =
			leftBox.min[minAxis] == rightBox.min[minAxis]? leftBox.max[minAxis] < rightBox.max[minAxis]? 0 :1 : 0;

		nodes.push_back(Node(leftBox));
		nodes.push_back(Node(rightBox));

		FindSplit(subNode + 0, first, leftCount, sdepth + 1);
		FindSplit(subNode + 1, first + leftCount, rightCount, sdepth + 1);
	}
}