void Tree::trainTree(const MatrixReal& featMat, const VectorInteger& labels) { // We work with a queue of nodes, initially containing only the root node. // We process the queue until it becomes empty. std::queue<int> toTrain; int size,numClasses,numVars,dims; size = labels.size(); dims = featMat.cols(); numClasses = labels.maxCoeff(); classWts = VectorReal::Zero(numClasses+1); for(int i = 0; i < labels.size(); ++i) classWts(labels(i)) += 1.0; classWts /= size; for(int i = 0; i < size; ++i) nodeix.push_back(i); std::cout<<"Training tree, dimensions set\n"; numVars = (int)((double)sqrt((double)dims)) + 1; int cur; // The relevant indices for the root node is the entire set of training data nodes[0].start = 0; nodes[0].end = size-1; // Initialise the queue with just the root node toTrain.push(0); // Stores the relevant features. VectorReal relFeat; // Resize our boolean array, more on this later. indices.resize(size); std::cout<<"Starting the queue\n"; int lpoints,rpoints; // While the queue isn't empty, continue processing. while(!toTrain.empty()) { int featNum; double threshold; lpoints = rpoints = 0; cur = toTrain.front(); // std::cout<<"In queue, node being processed is d :"<<cur.depth<<"\n"; // There are two ways for a node to get out of the queue trivially, // a) it doesn't have enough data to be a non-trivial split, or // b) it has hit the maximum permissible depth if((nodes[cur].end - nodes[cur].start < DATA_MIN) || (nodes[cur].depth == depth)) { // Tell ourselves that this is a leaf node, and remove the node // from the queue. // std::cout<<"Popping a leaf node\n"; nodes[cur].setType(true); // Initialize the histogram and set it to zero nodes[cur].hist = VectorReal::Zero(numClasses+1); // The below code should give the histogram of all the elements for(int i = nodes[cur].start; i <= nodes[cur].end; ++i) { nodes[cur].hist[labels(nodeix[i])] += 1.0; } for(int i = 0 ; i < classWts.size(); ++i) nodes[cur].hist[i] = nodes[cur].hist[i] / classWts[i]; toTrain.pop(); continue; } double infoGain(-100.0); relFeat.resize(size); // In case this isn't a trivial node, we need to process it. for(int i = 0; i < numVars; ++i) { // std::cout<<"Choosing a random variable\n"; // Randomly select a feature featNum = rand()%dims; // std::cout<<"Feat: "<<featNum<<std::endl; // Extract the relevant feature set from the training data relFeat = featMat.col(featNum); double tmax,tmin,curInfo; tmax = relFeat.maxCoeff(); tmin = relFeat.minCoeff(); // infoGain = -100; //std::cout<<"Min "<<tmin<<"Max: "<<tmax<<std::endl; // NUM_CHECKS is a macro defined at the start for(int j = 0; j < NUM_CHECKS; ++j) { // std::cout<<"Choosing a random threshold\n"; // Generate a random threshold threshold = ((rand()%100)/100.0)*(tmax - tmin) + tmin; //std::cout<<"Thresh: "<<threshold<<std::endl; for(int k = nodes[cur].start; k <= nodes[cur].end ; ++k) indices[k] = (relFeat(k) < threshold); // Check if we have enough information gain curInfo = informationGain(nodes[cur].start,nodes[cur].end, labels); // std::cout<<"Info gain : "<<curInfo<<"\n"; // curInfo = (double) ((rand()%10)/10.0); if(curInfo > infoGain) { infoGain = curInfo; nodes[cur].x = featNum; nodes[cur].threshold = threshold; } } } // We have selected a feature and a threshold for it that maximises the information gain. relFeat = featMat.col(nodes[cur].x); // We just set the indices depending on whether the features are greater or lesser. // Conventions followed : greater goes to the right. for(int k = nodes[cur].start; k <= nodes[cur].end; ++k) { // If relfeat is lesser, indices[k] will be true, which will put it in the // left side of the partition. indices[k] = relFeat(k) < nodes[cur].threshold; // indices[k] = (bool)(rand()%2); if(indices[k]) lpoints++; else rpoints++; } if( (lpoints < DATA_MIN) || (rpoints < DATA_MIN) ) { // Tell ourselves that this is a leaf node, and remove the node // from the queue. // std::cout<<"Popping a leaf node\n"; nodes[cur].setType(true); // Initialize the histogram and set it to zero nodes[cur].hist.resize(numClasses+1); nodes[cur].hist = VectorReal::Zero(numClasses+1); // The below code should give the histogram of all the elements for(int i = nodes[cur].start; i <= nodes[cur].end; ++i) { nodes[cur].hist[labels(nodeix[i])] += 1.0; } toTrain.pop(); continue; } int part; // Use the prebuilt function to linearly partition our data part = partition(nodes[cur].start,nodes[cur].end); Node right, left; // Increase the depth of the children right.depth = left.depth = nodes[cur].depth + 1; // Correctly assign the partitions left.start = nodes[cur].start; left.end = part -1; // Push back into the relevant places and also link the parent and the child nodes.push_back(left); nodes[cur].leftChild = nodes.size()-1; toTrain.push(nodes[cur].leftChild); // Ditto with the right node. right.start = part; right.end = nodes[cur].end; nodes.push_back(right); nodes[cur].rightChild = nodes.size()-1; toTrain.push(nodes[cur].rightChild); // Finally remove our node from the queue. toTrain.pop(); } }