Exemplo n.º 1
0
bool TreeProbability::splitNodeInternal(size_t nodeID, std::vector<size_t>& possible_split_varIDs) {

  // Check node size, stop if maximum reached
  if (sampleIDs[nodeID].size() <= min_node_size) {
    addToTerminalNodes(nodeID);
    return true;
  }

  // Check if node is pure and set split_value to estimate and stop if pure
  bool pure = true;
  double pure_value = 0;
  for (size_t i = 0; i < sampleIDs[nodeID].size(); ++i) {
    double value = data->get(sampleIDs[nodeID][i], dependent_varID);
    if (i != 0 && value != pure_value) {
      pure = false;
      break;
    }
    pure_value = value;
  }
  if (pure) {
    addToTerminalNodes(nodeID);
    return true;
  }

  // Find best split, stop if no decrease of impurity
  bool stop = findBestSplit(nodeID, possible_split_varIDs);
  if (stop) {
    addToTerminalNodes(nodeID);
    return true;
  }

  return false;
}
Exemplo n.º 2
0
bool TreeRegression::splitNodeInternal(size_t nodeID, std::vector<size_t>& possible_split_varIDs) {

// Check node size, stop if maximum reached
  if (sampleIDs[nodeID].size() <= min_node_size) {
    split_values[nodeID] = estimate(nodeID);
    return true;
  }

// Find best split, stop if no decrease of impurity
  bool stop = findBestSplit(nodeID, possible_split_varIDs);
  if (stop) {
    split_values[nodeID] = estimate(nodeID);
    return true;
  }

  return false;
}
Exemplo n.º 3
0
void KdTreeNode::computeChildren()
{
    int axis = depth % 3;

    if (terminate()) { // the node is a leaf
        return;
    }

    float limit;
    if (! findBestSplit(axis, &limit)) return;
    createChildren(axis, limit);
    splitObjects(axis, &objects, &(left->objects), &(right->objects));
    objects.clear(); // only leaves keep their objects

#pragma omp task
    left->computeChildren();

#pragma omp task
    right->computeChildren();

#pragma omp taskwait

}
Exemplo n.º 4
0
void regTree(double *x, double *y, int mdim, int nsample, int *lDaughter,
        int *rDaughter,
        double *upper, double *avnode, SMALL_INT *nodestatus, int nrnodes,
        int *treeSize, int nthsize, int mtry, int *mbest, int *cat,
        double *tgini, int *varUsed) {
    int i, j, k, m, ncur;
    static int *jdex, *nodestart, *nodepop;
    int ndstart, ndend, ndendl, nodecnt, jstat, msplit;
    double d, ss, av, decsplit, ubest, sumnode;
    
    if (in_regTree==-99){
        free(nodestart);
        free(jdex);
        free(nodepop);
//	PRINTF("giving up mem in in_regTree\n");
        return;
    }
    
    if (in_regTree==0){
        in_regTree=1;
        nodestart = (int *) calloc(nrnodes, sizeof(int));
        nodepop   = (int *) calloc(nrnodes, sizeof(int));
        jdex = (int *) calloc(nsample, sizeof(int));
    }
    
    /* initialize some arrays for the tree */
    zeroSMALLInt(nodestatus, nrnodes);
    zeroInt(nodestart, nrnodes);
    zeroInt(nodepop, nrnodes);
    zeroDouble(avnode, nrnodes);
    
    for (i = 1; i <= nsample; ++i) jdex[i-1] = i;
    
    ncur = 0;
    nodestart[0] = 0;
    nodepop[0] = nsample;
    nodestatus[0] = NODE_TOSPLIT;
    
    /* compute mean and sum of squares for Y */
    av = 0.0;
    ss = 0.0;
    for (i = 0; i < nsample; ++i) {
        d = y[jdex[i] - 1];
        ss += i * (av - d) * (av - d) / (i + 1);
        av = (i * av + d) / (i + 1);
    }
    avnode[0] = av;
    
    /* start main loop */
    for (k = 0; k < nrnodes - 2; ++k) {
        if (k > ncur || ncur >= nrnodes - 2) break;
        /* skip if the node is not to be split */
        if (nodestatus[k] != NODE_TOSPLIT) continue;
        
        /* initialize for next call to findbestsplit */
        ndstart = nodestart[k];
        ndend = ndstart + nodepop[k] - 1;
        nodecnt = nodepop[k];
        sumnode = nodecnt * avnode[k];
        jstat = 0;
        decsplit = 0.0;
        
        findBestSplit(x, jdex, y, mdim, nsample, ndstart, ndend, &msplit,
                &decsplit, &ubest, &ndendl, &jstat, mtry, sumnode,
                nodecnt, cat);
        if (jstat == 1) {
            /* Node is terminal: Mark it as such and move on to the next. */
            nodestatus[k] = NODE_TERMINAL;
            continue;
        }
        /* Found the best split. */
        mbest[k] = msplit;
        varUsed[msplit - 1] = 1;
        upper[k] = ubest;
        tgini[msplit - 1] += decsplit;
        nodestatus[k] = NODE_INTERIOR;
        
        /* leftnode no.= ncur+1, rightnode no. = ncur+2. */
        nodepop[ncur + 1] = ndendl - ndstart + 1;
        nodepop[ncur + 2] = ndend - ndendl;
        nodestart[ncur + 1] = ndstart;
        nodestart[ncur + 2] = ndendl + 1;
        
        /* compute mean and sum of squares for the left daughter node */
        av = 0.0;
        ss = 0.0;
        for (j = ndstart; j <= ndendl; ++j) {
            d = y[jdex[j]-1];
            m = j - ndstart;
            ss += m * (av - d) * (av - d) / (m + 1);
            av = (m * av + d) / (m+1);
        }
        avnode[ncur+1] = av;
        nodestatus[ncur+1] = NODE_TOSPLIT;
        if (nodepop[ncur + 1] <= nthsize) {
            nodestatus[ncur + 1] = NODE_TERMINAL;
        }
        
        /* compute mean and sum of squares for the right daughter node */
        av = 0.0;
        ss = 0.0;
        for (j = ndendl + 1; j <= ndend; ++j) {
            d = y[jdex[j]-1];
            m = j - (ndendl + 1);
            ss += m * (av - d) * (av - d) / (m + 1);
            av = (m * av + d) / (m + 1);
        }
        avnode[ncur + 2] = av;
        nodestatus[ncur + 2] = NODE_TOSPLIT;
        if (nodepop[ncur + 2] <= nthsize) {
            nodestatus[ncur + 2] = NODE_TERMINAL;
        }
        
        /* map the daughter nodes */
        lDaughter[k] = ncur + 1 + 1;
        rDaughter[k] = ncur + 2 + 1;
        /* Augment the tree by two nodes. */
        ncur += 2;
    }
    *treeSize = nrnodes;
    for (k = nrnodes - 1; k >= 0; --k) {
        if (nodestatus[k] == 0) (*treeSize)--;
        if (nodestatus[k] == NODE_TOSPLIT) {
            nodestatus[k] = NODE_TERMINAL;
        }
    }
    
}
Exemplo n.º 5
0
void regRF(double *x, double *y, int *xdim, int *sampsize,
        int *nthsize, int *nrnodes, int *nTree, int *mtry, int *imp,
        int *cat, int maxcat, int *jprint, int doProx, int oobprox,
        int biasCorr, double *yptr, double *errimp, double *impmat,
        double *impSD, double *prox, int *treeSize, SMALL_INT *nodestatus,
        int *lDaughter, int *rDaughter, double *avnode, int *mbest,
        double *upper, double *mse, const int *keepf, int *replace,
        int testdat, double *xts, int *nts, double *yts, int labelts,
        double *yTestPred, double *proxts, double *msets, double *coef,
        int *nout, int *inbag) {
    /*************************************************************************
     * Input:
     * mdim=number of variables in data set
     * nsample=number of cases
     *
     * nthsize=number of cases in a node below which the tree will not split,
     * setting nthsize=5 generally gives good results.
     *
     * nTree=number of trees in run.  200-500 gives pretty good results
     *
     * mtry=number of variables to pick to split on at each node.  mdim/3
     * seems to give genrally good performance, but it can be
     * altered up or down
     *
     * imp=1 turns on variable importance.  This is computed for the
     * mth variable as the percent rise in the test set mean sum-of-
     * squared errors when the mth variable is randomly permuted.
     *
     *************************************************************************/
    
    //PRINTF( "*jprint: %d\n", *jprint );
    //mexEvalString( "pause(0.0001)" );
    
    double errts = 0.0, averrb, meanY, meanYts, varY, varYts, r, xrand,
            errb = 0.0, resid=0.0, ooberr, ooberrperm, delta, *resOOB;
    
    double *yb, *xtmp, *xb, *ytr, *ytree = NULL, *tgini;
    
    int k, m, mr, n, nOOB, j, jout, idx, ntest, last, ktmp, nPerm,
            nsample, mdim, keepF, keepInbag;
    int *oobpair, varImp, localImp, *varUsed;
    
    int *in, *nind, *nodex, *nodexts = NULL;
    
    //Abhi:temp variable
    double tmp_d = 0;
    int tmp_i;
    SMALL_INT tmp_c;
    
    //Do initialization for COKUS's Random generator
    seedMT(2*rand()+1);  //works well with odd number so why don't use that
    
    nsample = xdim[0];
    mdim = xdim[1];
    ntest = *nts;
    varImp = imp[0];
    localImp = imp[1];
    nPerm = imp[2]; //PRINTF("nPerm %d\n",nPerm);
    keepF = keepf[0];
    keepInbag = keepf[1];
    
    if (*jprint == 0) *jprint = *nTree + 1;
    
    yb         = (double *) calloc(*sampsize, sizeof(double));
    xb         = (double *) calloc(mdim * *sampsize, sizeof(double));
    ytr        = (double *) calloc(nsample, sizeof(double));
    xtmp       = (double *) calloc(nsample, sizeof(double));
    resOOB     = (double *) calloc(nsample, sizeof(double));
    
    in        = (int *) calloc(nsample, sizeof(int));
    nodex      = (int *) calloc(nsample, sizeof(int));
    varUsed    = (int *) calloc(mdim, sizeof(int));
    nind = *replace ? NULL : (int *) calloc(nsample, sizeof(int));
    
    if (testdat) {
        ytree      = (double *) calloc(ntest, sizeof(double));
        nodexts    = (int *) calloc(ntest, sizeof(int));
    }
    oobpair = (doProx && oobprox) ?
        (int *) calloc(nsample * nsample, sizeof(int)) : NULL;
        
        /* If variable importance is requested, tgini points to the second
       "column" of errimp, otherwise it's just the same as errimp. */
        tgini = varImp ? errimp + mdim : errimp;
        
        averrb = 0.0;
        meanY = 0.0;
        varY = 0.0;
        
        zeroDouble(yptr, nsample);
        zeroInt(nout, nsample);
        for (n = 0; n < nsample; ++n) {
            varY += n * (y[n] - meanY)*(y[n] - meanY) / (n + 1);
            meanY = (n * meanY + y[n]) / (n + 1);
        }
        varY /= nsample;
        
        varYts = 0.0;
        meanYts = 0.0;
        if (testdat) {
            for (n = 0; n < ntest; ++n) {
                varYts += n * (yts[n] - meanYts)*(yts[n] - meanYts) / (n + 1);
                meanYts = (n * meanYts + yts[n]) / (n + 1);
            }
            varYts /= ntest;
        }
        
        if (doProx) {
            zeroDouble(prox, nsample * nsample);
            if (testdat) zeroDouble(proxts, ntest * (nsample + ntest));
        }
        
        if (varImp) {
            zeroDouble(errimp, mdim * 2);
            if (localImp) zeroDouble(impmat, nsample * mdim);
        } else {
            zeroDouble(errimp, mdim);
        }
        if (labelts) zeroDouble(yTestPred, ntest);
        
        /* print header for running output */
        if (*jprint <= *nTree) {
            PRINTF("     |      Out-of-bag   ");
            if (testdat) PRINTF("|       Test set    ");
            PRINTF("|\n");
            PRINTF("Tree |      MSE  %%Var(y) ");
            if (testdat) PRINTF("|      MSE  %%Var(y) ");
            PRINTF("|\n");
            // mexEvalString( "pause(0.001)" );
        }
        GetRNGstate();
        /*************************************
         * Start the loop over trees.
         *************************************/
        for (j = 0; j < *nTree; ++j) {
            //PRINTF("tree num %d\n",j);fflush(stdout);
            //PRINTF("1. maxcat %d, jprint %d, doProx %d, oobProx %d, biasCorr %d\n", *maxcat, *jprint, doProx, oobprox, biasCorr);
            
            idx = keepF ? j * *nrnodes : 0;
            zeroInt(in, nsample);
            zeroInt(varUsed, mdim);
            /* Draw a random sample for growing a tree. */
//		PRINTF("1.8. maxcat %d, jprint %d, doProx %d, oobProx %d, biasCorr %d testdat %d\n", maxcat, *jprint, doProx, oobprox, biasCorr,testdat);
            
            if (*replace) { /* sampling with replacement */
                for (n = 0; n < *sampsize; ++n) {
                    xrand = unif_rand();
                    k = (int)(xrand * nsample);
                    in[k] = 1;
                    yb[n] = y[k];
                    for(m = 0; m < mdim; ++m) {
                        xb[m + n * mdim] = x[m + k * mdim];
                    }
                }
            } else { /* sampling w/o replacement */
                for (n = 0; n < nsample; ++n) nind[n] = n;
                last = nsample - 1;
                for (n = 0; n < *sampsize; ++n) {
                    ktmp = (int) (unif_rand() * (last+1));
                    k = nind[ktmp];
                    swapInt(nind[ktmp], nind[last]);
                    last--;
                    in[k] = 1;
                    yb[n] = y[k];
                    for(m = 0; m < mdim; ++m) {
                        xb[m + n * mdim] = x[m + k * mdim];
                    }
                }
            }
            if (keepInbag) {
                for (n = 0; n < nsample; ++n) inbag[n + j * nsample] = in[n];
            }
//		PRINTF("1.9. maxcat %d, jprint %d, doProx %d, oobProx %d, biasCorr %d testdat %d\n", maxcat, *jprint, doProx, oobprox, biasCorr,testdat);
            
            /* grow the regression tree */
            regTree(xb, yb, mdim, *sampsize, lDaughter + idx, rDaughter + idx,
                    upper + idx, avnode + idx, nodestatus + idx, *nrnodes,
                    treeSize + j, *nthsize, *mtry, mbest + idx, cat, tgini,
                    varUsed);
            /* predict the OOB data with the current tree */
            /* ytr is the prediction on OOB data by the current tree */
            
//		PRINTF("2. maxcat %d, jprint %d, doProx %d, oobProx %d, biasCorr %d testdat %d\n", maxcat, *jprint, doProx, oobprox, biasCorr,testdat);
            
            predictRegTree(x, nsample, mdim, lDaughter + idx,
                    rDaughter + idx, nodestatus + idx, ytr, upper + idx,
                    avnode + idx, mbest + idx, treeSize[j], cat, maxcat,
                    nodex);
            /* yptr is the aggregated prediction by all trees grown so far */
            errb = 0.0;
            ooberr = 0.0;
            jout = 0; /* jout is the number of cases that has been OOB so far */
            nOOB = 0; /* nOOB is the number of OOB samples for this tree */
            for (n = 0; n < nsample; ++n) {
                if (in[n] == 0) {
                    nout[n]++;
                    nOOB++;
                    yptr[n] = ((nout[n]-1) * yptr[n] + ytr[n]) / nout[n];
                    resOOB[n] = ytr[n] - y[n];
                    ooberr += resOOB[n] * resOOB[n];
                }
                if (nout[n]) {
                    jout++;
                    errb += (y[n] - yptr[n]) * (y[n] - yptr[n]);
                }
            }
            errb /= jout;
            /* Do simple linear regression of y on yhat for bias correction. */
            if (biasCorr) simpleLinReg(nsample, yptr, y, coef, &errb, nout);
//PRINTF("2.5.maxcat %d, jprint %d, doProx %d, oobProx %d, biasCorr %d\n", maxcat, *jprint, doProx, oobprox, biasCorr);
            
            /* predict testset data with the current tree */
            if (testdat) {
                predictRegTree(xts, ntest, mdim, lDaughter + idx,
                        rDaughter + idx, nodestatus + idx, ytree,
                        upper + idx, avnode + idx,
                        mbest + idx, treeSize[j], cat, maxcat, nodexts);
                /* ytree is the prediction for test data by the current tree */
                /* yTestPred is the average prediction by all trees grown so far */
                errts = 0.0;
                for (n = 0; n < ntest; ++n) {
                    yTestPred[n] = (j * yTestPred[n] + ytree[n]) / (j + 1);
                }
                /* compute testset MSE */
                if (labelts) {
                    for (n = 0; n < ntest; ++n) {
                        resid = biasCorr ?
                            yts[n] - (coef[0] + coef[1]*yTestPred[n]) :
                            yts[n] - yTestPred[n];
                            errts += resid * resid;
                    }
                    errts /= ntest;
                }
            }
//PRINTF("2.6.maxcat %d, jprint %d, doProx %d, oobProx %d, biasCorr %d, testdat %d\n", maxcat, *jprint, doProx, oobprox, biasCorr,testdat);
            
            /* Print running output. */
            if ((j + 1) % *jprint == 0) {
                PRINTF("%4d |", j + 1);
                PRINTF(" %8.4g %8.2f ", errb, 100 * errb / varY);
                if(labelts == 1) PRINTF("| %8.4g %8.2f ",
                        errts, 100.0 * errts / varYts);
                PRINTF("|\n");
                fflush(stdout);
                // mexEvalString("pause(.001);"); // to dump string.
            }
            
//PRINTF("2.7.maxcat %d, jprint %d, doProx %d, oobProx %d, biasCorr %d, testdat %d\n", maxcat, *jprint, doProx, oobprox, biasCorr,testdat);
            
            mse[j] = errb;
            if (labelts) msets[j] = errts;
//PRINTF("2.701  j %d, nTree %d, errts %f errb %f \n", j, *nTree, errts,errb);
//PRINTF("2.71.maxcat %d, jprint %d, doProx %d, oobProx %d, biasCorr %d, testdat %d\n", maxcat, *jprint, doProx, oobprox, biasCorr,testdat);
            
            /*  DO PROXIMITIES */
            if (doProx) {
                computeProximity(prox, oobprox, nodex, in, oobpair, nsample);
                /* proximity for test data */
                if (testdat) {
                    /* In the next call, in and oobpair are not used. */
                    computeProximity(proxts, 0, nodexts, in, oobpair, ntest);
                    for (n = 0; n < ntest; ++n) {
                        for (k = 0; k < nsample; ++k) {
                            if (nodexts[n] == nodex[k]) {
                                proxts[n + ntest * (k+ntest)] += 1.0;
                            }
                        }
                    }
                }
            }
//PRINTF("2.8.maxcat %d, jprint %d, doProx %d, oobProx %d, biasCorr %d, testdat %d\n", maxcat, *jprint, doProx, oobprox, biasCorr,testdat);
            
            /* Variable importance */
            if (varImp) {
                for (mr = 0; mr < mdim; ++mr) {
                    if (varUsed[mr]) { /* Go ahead if the variable is used */
                        /* make a copy of the m-th variable into xtmp */
                        for (n = 0; n < nsample; ++n)
                            xtmp[n] = x[mr + n * mdim];
                        ooberrperm = 0.0;
                        for (k = 0; k < nPerm; ++k) {
                            permuteOOB(mr, x, in, nsample, mdim);
                            predictRegTree(x, nsample, mdim, lDaughter + idx,
                                    rDaughter + idx, nodestatus + idx, ytr,
                                    upper + idx, avnode + idx, mbest + idx,
                                    treeSize[j], cat, maxcat, nodex);
                            for (n = 0; n < nsample; ++n) {
                                if (in[n] == 0) {
                                    r = ytr[n] - y[n];
                                    ooberrperm += r * r;
                                    if (localImp) {
                                        impmat[mr + n * mdim] +=
                                                (r*r - resOOB[n]*resOOB[n]) / nPerm;
                                    }
                                }
                            }
                        }
                        delta = (ooberrperm / nPerm - ooberr) / nOOB;
                        errimp[mr] += delta;
                        impSD[mr] += delta * delta;
                        /* copy original data back */
                        for (n = 0; n < nsample; ++n)
                            x[mr + n * mdim] = xtmp[n];
                    }
                    
                }
                
            }
//	PRINTF("3. maxcat %d, jprint %d, doProx %d, oobProx %d, biasCorr %d testdat %d\n", maxcat, *jprint, doProx, oobprox, biasCorr,testdat);
            
        }
        PutRNGstate();
        /* end of tree iterations=======================================*/
        
        if (biasCorr) {  /* bias correction for predicted values */
            for (n = 0; n < nsample; ++n) {
                if (nout[n]) yptr[n] = coef[0] + coef[1] * yptr[n];
            }
            if (testdat) {
                for (n = 0; n < ntest; ++n) {
                    yTestPred[n] = coef[0] + coef[1] * yTestPred[n];
                }
            }
        }
        
        if (doProx) {
            for (n = 0; n < nsample; ++n) {
                for (k = n + 1; k < nsample; ++k) {
                    prox[nsample*k + n] /= oobprox ?
                        (oobpair[nsample*k + n] > 0 ? oobpair[nsample*k + n] : 1) :
                            *nTree;
                            prox[nsample * n + k] = prox[nsample * k + n];
                }
                prox[nsample * n + n] = 1.0;
            }
            if (testdat) {
                for (n = 0; n < ntest; ++n)
                    for (k = 0; k < ntest + nsample; ++k)
                        proxts[ntest*k + n] /= *nTree;
            }
        }
        
        if (varImp) {
            for (m = 0; m < mdim; ++m) {
                errimp[m] = errimp[m] / *nTree;
                impSD[m] = sqrt( ((impSD[m] / *nTree) -
                        (errimp[m] * errimp[m])) / *nTree );
                if (localImp) {
                    for (n = 0; n < nsample; ++n) {
                        impmat[m + n * mdim] /= nout[n];
                    }
                }
            }
        }
        for (m = 0; m < mdim; ++m) tgini[m] /= *nTree;
        
        
        //addition by abhi
        //in order to release the space stored by the variable in findBestSplit
        // call by setting
        in_findBestSplit=-99;
        findBestSplit(&tmp_d, &tmp_i, &tmp_d, tmp_i, tmp_i,
                tmp_i, tmp_i, &tmp_i, &tmp_d,
                &tmp_d, &tmp_i, &tmp_i, tmp_i,
                tmp_d, tmp_i, &tmp_i);
        
        //do the same freeing of space by calling with -99
        in_regTree=-99;
        regTree(&tmp_d, &tmp_d, tmp_i, tmp_i, &tmp_i,
                &tmp_i,
                &tmp_d, &tmp_d, &tmp_c, tmp_i,
                &tmp_i, tmp_i, tmp_i, &tmp_i, &tmp_i,
                &tmp_d, &tmp_i);
	
	
	free(yb);
        free(xb);
	free(ytr);
	free(xtmp);
	free(resOOB);
        free(in);
	free(nodex);
	free(varUsed);
    if (!(*replace)  )
        free(nind);
    
    if (testdat) {
		free(ytree);
		free(nodexts);
	}
	
	if (doProx && oobprox)
		free(oobpair) ;
}