-
Notifications
You must be signed in to change notification settings - Fork 0
/
DecisionTree.cpp
114 lines (96 loc) · 2.91 KB
/
DecisionTree.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
#include "DecisionTree.h"
#include <algorithm>
#include <cmath>
namespace dt {
using std::sort;
float evalTree(DecisionTree* tree, const FeatureVector& features) {
// assert(tree);
if (!tree->left) {
return tree->value;
}
if (features[tree->splitFeature] < tree->splitValue) {
return evalTree(tree->left.get(), features);
}
return evalTree(tree->right.get(), features);
}
float entropy(size_t p, size_t n) {
// Avoid taking logarithm of zero when either p or n are zero
++p;
++n;
return - p * log((float) p / (p + n)) - n * log((float) n / (p + n));
}
float infoGain(size_t p, size_t n, size_t p1, size_t n1) {
return entropy(p, n) - entropy(p1, n1) - entropy(p - p1, n - n1);
}
unique_ptr<DecisionTree> trainTree(ExampleIt first,
ExampleIt last,
size_t maxDepth,
float minGain) {
if (first >= last) return nullptr;
if (maxDepth == 0) return nullptr;
size_t numFeatures = first->features.size();
size_t p = 0, n = 0;
for (auto it = first; it != last; ++it) {
if (it->label) {
++p;
} else {
++n;
}
}
float bestGain = -1;
size_t splitIndex;
size_t splitFeature;
float splitValue;
for (size_t i = 0; i < numFeatures; ++i) {
// Sort by feature i
sort(first, last, [=](const Example& a, const Example& b) {
return a.features[i] < b.features[i];
});
size_t p1 = 0, n1 = 0;
for (ExampleIt it = first; it != last; ++it) {
if (it != first && it->features[i] != (it - 1)->features[i]) {
float gain = infoGain(p, n, p1, n1);
if (gain > bestGain) {
bestGain = gain;
splitIndex = it - first;
splitFeature = i;
splitValue = it->features[i];
}
}
if (it->label) {
++p1;
} else {
++n1;
}
}
}
unique_ptr<DecisionTree> tree(new DecisionTree);
tree->value = (float) p / (p + n);
if (bestGain > minGain) {
tree->splitFeature = splitFeature;
tree->splitValue = splitValue;
sort(first, last, [=](const Example& a, const Example& b) {
return a.features[splitFeature] < b.features[splitFeature];
});
tree->left = std::move(trainTree(first,
first + splitIndex,
maxDepth - 1,
minGain));
tree->right = std::move(trainTree(first + splitIndex,
last,
maxDepth - 1,
minGain));
}
return tree;
}
float validateTree(DecisionTree* tree, ExampleIt first, ExampleIt last) {
size_t correct = 0;
for (auto it = first; it != last; ++it) {
float value = evalTree(tree, it->features);
if ((value >= 0.5) == it->label) {
++correct;
}
}
return (float) correct / (last - first);
}
}