Ejemplo n.º 1
0
tree::tree(data_set& train_set, int max_leafs, int max_depth) : max_leafs(max_leafs)
{
	std::set<int> features;
	for (size_t i = 0; i < train_set[0].features.size(); i++)
	{
		features.insert(i);
	}
	leafs = 1;
	int depth = 0;
	root = new node(0);
	root->data_begin = train_set.begin();
	root->data_end = train_set.end();
	root->calc_avg();
	root->node_mse = calc_mse(root->data_begin, root->data_end, root->output_value, root->size);
	std::vector<node*> layer;
	layer.push_back(root);
	layers.push_back(layer);
	while (leafs < max_leafs && depth < max_depth && !features.empty())
	{
		float min_error = INF;
		int best_feature = -1;
		make_layer(depth);
		for (std::set<int>::iterator cur_split_feature = features.begin(); cur_split_feature != features.end(); cur_split_feature++)
		//choose best split feature at current depth
		{
			float cur_error = 0;
			for (size_t i = 0; i < layers[depth].size(); i++)
			{
				cur_error += layers[depth][i]->split(*cur_split_feature);
			}
			if (cur_error < min_error)
			{
				min_error = cur_error;
				best_feature = *cur_split_feature;
			}
		}
		for (size_t i = 0; i < layers[depth].size(); i++)
		{
			layers[depth][i]->split(best_feature);
		}
		feature_id_at_depth.push_back(best_feature);
		features.erase(best_feature);
		depth++;
		//std::cout << "level " << depth << " created. training error: " << min_error << " best feat: " << best_feature << " split_val: "
			//<< root->split_value << std::endl;
	}
	for (size_t i = 0; i < layers.back().size(); i++)
	{
		layers.back()[i]->is_leaf = true;
	}
	//std::cout << "leafs before pruning: " << leafs << std::endl;
	//prune(root);
	//std::cout << "new tree! leafs after pruning: " << leafs << std::endl;
	while (layers.back().empty())
	{
		layers.pop_back();
	}
}
Ejemplo n.º 2
0
float tree::calculate_error(data_set& test_set)
{
	float error = 0;
	for (data_set::iterator cur_test = test_set.begin(); cur_test != test_set.end(); cur_test++)
	{
		float ans = calculate_anwser(*cur_test);
		error += ((ans - cur_test->anwser) * (ans - cur_test->anwser));
	}
	error /= (1.0 * test_set.size());
	return error;
}