-
Notifications
You must be signed in to change notification settings - Fork 2
/
forest.hpp
120 lines (98 loc) · 2.71 KB
/
forest.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
/*
* CRForest.h
*
* Created on: May 4, 2011
* Author: Matthias Dantone
*/
#ifndef CRFOREST_H_
#define CRFOREST_H_
#include "tree.hpp"
template<typename Sample>
class Forest {
public:
typedef typename Sample::Split Split;
typedef typename Sample::Leaf Leaf;
Forest() {
};
Forest(ForestParam tp) :
param(tp) {
};
Forest(const std::vector<Sample*> data, ForestParam tp, boost::mt19937* rng) {
for (int i = 0; i < tp.nTrees; i++) {
Tree<Sample>* tree = new Tree<Sample>(data, tp, rng);
trees.push_back(tree);
}
};
void addTree(Tree<Sample>* t) {
trees.push_back(t);
}
//sends the Sample down the tree
void evaluate(const Sample* f, std::vector<Leaf*>& leafs) const {
for (unsigned int i = 0; i < trees.size(); i++)
trees[i]->evaluate(f, trees[i]->root, leafs);
}
void evaluate_mt(const Sample* f, Leaf** leafs) const {
for (unsigned int i = 0; i < trees.size(); i++) {
trees[i]->evaluate_mt(f, trees[i]->root, leafs);
leafs++;
}
}
//stores the tree
void save(std::string url, int offset = 0) {
for (unsigned int i = 0; i < trees.size(); i++) {
char buffer[200];
sprintf(buffer, "%s%03d.txt", url.c_str(), i + offset);
std::string path = buffer;
trees[i]->save(buffer);
}
}
void load(std::string url, ForestParam tp, int max_trees = -1) {
param = tp;
if (max_trees == -1)
max_trees = tp.nTrees;
std::cout << tp.nTrees << " to load." << std::endl;
for (int i = 0; i < tp.nTrees; i++) {
if (static_cast<int>(trees.size()) > max_trees)
continue;
char buffer[200];
sprintf(buffer, "%s%03d.txt", url.c_str(), i);
std::string tree_path = buffer;
load_tree(tree_path, trees);
}
std::cout << trees.size() << " trees loaded" << std::endl;
}
static bool load_tree(std::string url, std::vector<Tree<Sample>*>& trees) {
Tree<Sample>* tree;
Tree<Sample>::load(&tree, url);
if (tree->isFinished()) {
trees.push_back(tree);
} else {
delete tree;
return false;
}
return true;
}
ForestParam getParam() const {
return param;
}
void setParam(ForestParam fp) {
param = fp;
}
std::vector<float> getClassWeights() {
return trees[0]->getClassWeights();
}
void getAllLeafs(std::vector<std::vector<Leaf*> >& leafs) {
leafs.resize(trees.size());
for (unsigned int i = 0; i < trees.size(); i++)
trees[i]->root->collectLeafs(leafs[i]);
}
std::vector<Tree<Sample>*> trees;
private:
ForestParam param;
friend class boost::serialization::access;
template<class Archive>
void serialize(Archive & ar, const unsigned int version) {
ar & trees;
}
};
#endif /* CRFOREST_H_ */