Ejemplo n.º 1
0
void
CTD(int n, double *y[], double *x, int nclass,
		int edge, double *improve, double *split, int *csplit,
		double myrisk, double *wt, double *treatment, int minsize, 
		double alpha, int bucketnum, int bucketMax, double train_to_est_ratio)
{
	int i, j;
	double temp;
	double left_sum, right_sum;
	double left_tr_sum, right_tr_sum;
	double left_tr_sqr_sum, right_tr_sqr_sum;
	double left_sqr_sum, right_sqr_sum;
	double tr_var, con_var;
	double left_tr_var, left_con_var, right_tr_var, right_con_var;
	double left_tr, right_tr;
	double left_wt, right_wt;
	int left_n, right_n;
	double best;
	int direction = LEFT;
	int where = 0;
	double node_effect, left_effect, right_effect;
	double left_temp, right_temp;
	int min_node_size = minsize;
	int bucketTmp;
	double trsum = 0.;
	int Numbuckets;

	double *cum_wt, *tmp_wt, *fake_x;
	double tr_wt_sum, con_wt_sum, con_cum_wt, tr_cum_wt;
	
	// for overlap:
	double tr_min, tr_max, con_min, con_max;
	double left_bd, right_bd;
	double cut_point;

	right_wt = 0;
	right_tr = 0;
	right_sum = 0;
	right_tr_sum = 0;
	right_sqr_sum = 0;
	right_tr_sqr_sum = 0;
	right_n = n;
	for (i = 0; i < n; i++) {
		right_wt += wt[i];
		right_tr += wt[i] * treatment[i];
		right_sum += *y[i] * wt[i];
		right_tr_sum += *y[i] * wt[i] * treatment[i];
		right_sqr_sum += (*y[i]) * (*y[i]) * wt[i];
		right_tr_sqr_sum += (*y[i]) * (*y[i]) * wt[i] * treatment[i];
		trsum += treatment[i];
	}

	temp = right_tr_sum / right_tr - (right_sum - right_tr_sum) / (right_wt - right_tr);
	tr_var = right_tr_sqr_sum / right_tr - right_tr_sum * right_tr_sum / (right_tr * right_tr);
	con_var = (right_sqr_sum - right_tr_sqr_sum) / (right_wt - right_tr)
		- (right_sum - right_tr_sum) * (right_sum - right_tr_sum) 
		/ ((right_wt - right_tr) * (right_wt - right_tr));
	node_effect = alpha * temp * temp * right_wt - (1 - alpha) * (1 + train_to_est_ratio) 
		* right_wt * (tr_var / right_tr  + con_var / (right_wt - right_tr));

	if (nclass == 0) {
		/* continuous predictor */
		cum_wt = (double *) ALLOC(n, sizeof(double));
		tmp_wt = (double *) ALLOC(n, sizeof(double));
		fake_x = (double *) ALLOC(n, sizeof(double));

		tr_wt_sum = 0.;
		con_wt_sum = 0.;
		con_cum_wt = 0.;
		tr_cum_wt = 0.;  
		
		// find the abs max and min of x:
		double max_abs_tmp = fabs(x[0]);
		for (i = 0; i < n; i++) {
		    if (max_abs_tmp < fabs(x[i])) {
		        max_abs_tmp = fabs(x[i]);
		    }
		}
		
		// set tr_min, con_min, tr_max, con_max to a large/small value
		tr_min = max_abs_tmp;
		tr_max = -max_abs_tmp;
		con_min = max_abs_tmp;
		con_max = -max_abs_tmp;
		
		for (i = 0; i < n; i++) {
			if (treatment[i] == 0) {
				con_wt_sum += wt[i];
			    if (con_min > x[i]) {
			        con_min = x[i];
			    }
			    if (con_max < x[i]) {
			        con_max = x[i];
			    }
			} else {
				tr_wt_sum += wt[i];
			    if (tr_min > x[i]) {
			        tr_min = x[i];
			    }
			    if (tr_max < x[i]) {
			        tr_max = x[i];
			    }
			}
			cum_wt[i] = 0.;
			tmp_wt[i] = 0.;
			fake_x[i] = 0.;
		}
		
		// compute the left bound and right bound
		left_bd = max(tr_min, con_min);
		right_bd = min(tr_max, con_max);
		
		bucketTmp = min(round(trsum / (double)bucketnum), round(((double)n - trsum) / (double)bucketnum));
		Numbuckets = max(minsize, min(bucketTmp, bucketMax));

		for (i = 0; i < n; i++) {
			if (treatment[i] == 0) {
				tmp_wt[i] = wt[i] / con_wt_sum;     
				con_cum_wt += tmp_wt[i];
				cum_wt[i] = con_cum_wt;
				fake_x[i] = (int)floor(Numbuckets * cum_wt[i]);
			} else {
				tmp_wt[i] = wt[i] / tr_wt_sum;
				tr_cum_wt += tmp_wt[i];
				cum_wt[i] = tr_cum_wt;
				fake_x[i] = (int)floor(Numbuckets * cum_wt[i]);
			}
		}
        
		n_bucket = (int *) ALLOC(Numbuckets + 1,  sizeof(int));
		n_tr_bucket = (int *) ALLOC(Numbuckets + 1, sizeof(int));
		n_con_bucket = (int *) ALLOC(Numbuckets + 1, sizeof(int));
		wts_bucket = (double *) ALLOC(Numbuckets + 1, sizeof(double));
		trs_bucket = (double *) ALLOC(Numbuckets + 1, sizeof(double));
		wtsums_bucket = (double *) ALLOC(Numbuckets + 1, sizeof(double));
		trsums_bucket = (double *) ALLOC(Numbuckets + 1, sizeof(double));
		wtsqrsums_bucket = (double *) ALLOC(Numbuckets + 1, sizeof(double));
		trsqrsums_bucket = (double *) ALLOC(Numbuckets + 1, sizeof(double));
		tr_end_bucket = (double *) ALLOC(Numbuckets + 1, sizeof(double));
		con_end_bucket = (double *) ALLOC (Numbuckets + 1, sizeof(double));
		

		for (j = 0; j < Numbuckets + 1; j++) {
			n_bucket[j] = 0;
			n_tr_bucket[j] = 0;
			n_con_bucket[j] = 0;
			wts_bucket[j] = 0.;
			trs_bucket[j] = 0.;
			wtsums_bucket[j] = 0.;
			trsums_bucket[j] = 0.;
			wtsqrsums_bucket[j] = 0.;
			trsqrsums_bucket[j] = 0.;
		}

		for (i = 0; i < n; i++) {
			j = fake_x[i];
			n_bucket[j]++;
			wts_bucket[j] += wt[i];
			trs_bucket[j] += wt[i] * treatment[i];
			wtsums_bucket[j] += *y[i] * wt[i];
			trsums_bucket[j] += *y[i] * wt[i] * treatment[i];
			wtsqrsums_bucket[j] += (*y[i]) * (*y[i]) * wt[i];
			trsqrsums_bucket[j] += (*y[i]) * (*y[i]) * wt[i] * treatment[i];
			if (treatment[i] == 1) {
				tr_end_bucket[j] = x[i];
			} else {
				con_end_bucket[j] = x[i];
			}
		}

		left_wt = 0;
		left_tr = 0;
		left_n = 0;
		left_sum = 0;
		left_tr_sum = 0;
		left_sqr_sum = 0;
		left_tr_sqr_sum = 0;
		left_temp = 0.;
		right_temp = 0.;

		best = 0;

		for (j = 0; j < Numbuckets; j++) {
			left_n += n_bucket[j];
			right_n -= n_bucket[j];
			left_wt += wts_bucket[j];
			right_wt -= wts_bucket[j];
			left_tr += trs_bucket[j]; 
			right_tr -= trs_bucket[j];

			left_sum += wtsums_bucket[j];
			right_sum -= wtsums_bucket[j];

			left_tr_sum += trsums_bucket[j];
			right_tr_sum -= trsums_bucket[j];

			left_sqr_sum += wtsqrsums_bucket[j];
			right_sqr_sum -= wtsqrsums_bucket[j];

			left_tr_sqr_sum += trsqrsums_bucket[j];
			right_tr_sqr_sum -= trsqrsums_bucket[j];
            
            cut_point = (tr_end_bucket[j] + con_end_bucket[j]) / 2.0;
			if (left_n >= edge && right_n >= edge &&
					(int) left_tr >= min_node_size &&
					(int) left_wt - (int) left_tr >= min_node_size &&
					(int) right_tr >= min_node_size &&
					(int) right_wt - (int) right_tr >= min_node_size &&
					cut_point < right_bd && cut_point > left_bd) {
				left_temp = left_tr_sum / left_tr - 
					(left_sum - left_tr_sum) / (left_wt - left_tr);
				left_tr_var = left_tr_sqr_sum / left_tr - 
					left_tr_sum  * left_tr_sum / (left_tr * left_tr);
				left_con_var = (left_sqr_sum - left_tr_sqr_sum) / (left_wt - left_tr)  
					- (left_sum - left_tr_sum) * (left_sum - left_tr_sum)
					/ ((left_wt - left_tr) * (left_wt - left_tr));        
				left_effect = alpha * left_temp * left_temp * left_wt
					- (1 - alpha) * (1 + train_to_est_ratio) * left_wt 
					* (left_tr_var / left_tr + left_con_var / (left_wt - left_tr));

				right_temp = right_tr_sum / right_tr -
					(right_sum - right_tr_sum) / (right_wt - right_tr);
				right_tr_var = right_tr_sqr_sum / right_tr -
					right_tr_sum * right_tr_sum / (right_tr * right_tr);
				right_con_var = (right_sqr_sum - right_tr_sqr_sum) / (right_wt - right_tr)
					- (right_sum - right_tr_sum) * (right_sum - right_tr_sum) 
					/ ((right_wt - right_tr) * (right_wt - right_tr));
				right_effect = alpha * right_temp * right_temp * right_wt
					- (1 - alpha) * (1 + train_to_est_ratio) * right_wt 
					* (right_tr_var / right_tr + right_con_var / (right_wt - right_tr));
				temp = left_effect + right_effect - node_effect;
				if (temp > best) {
					best = temp;                  
					where = j; 
					if (left_temp < right_temp)
						direction = LEFT;
					else
						direction = RIGHT;
				}
			}
		}

		*improve = best;
		if (best > 0) {         /* found something */
			csplit[0] = direction;
			*split = (tr_end_bucket[where] + con_end_bucket[where]) / 2.0;
		}

	} else {
		/*
		 * Categorical predictor
		 */
		for (i = 0; i < nclass; i++) {
			countn[i] = 0;
			wts[i] = 0;
			trs[i] = 0;
			sums[i] = 0;
			wtsums[i] = 0;
			trsums[i] = 0;
			wtsqrsums[i] = 0;
			wttrsqrsums[i] = 0;
		}


		for (i = 0; i < n; i++) {
			j = (int) x[i] - 1;
			countn[j]++;
			wts[j] += wt[i];
			trs[j] += wt[i] * treatment[i];
			sums[j] += *y[i];
			wtsums[j] += *y[i] * wt[i];
			trsums[j] += *y[i] * wt[i] * treatment[i];
			wtsqrsums[j] += (*y[i]) * (*y[i]) * wt[i];
			wttrsqrsums[j] += (*y[i]) * (*y[i]) * wt[i] * treatment[i];
		}

		for (i = 0; i < nclass; i++) {
			if (countn[i] > 0) {
				tsplit[i] = RIGHT;
				treatment_effect[i] = trsums[j] / trs[j] - (wtsums[j] - trsums[j]) / (wts[j] - trs[j]);
			} else
				tsplit[i] = 0;
		}
		graycode_init2(nclass, countn, treatment_effect);

		/*
		 * Now find the split that we want
		 */
		left_wt = 0;
		left_tr = 0;
		left_n = 0;
		left_sum = 0;
		left_tr_sum = 0;
		left_sqr_sum = 0;
		left_tr_sqr_sum = 0;

		best = 0;
		where = 0;
		while ((j = graycode()) < nclass) {
			tsplit[j] = LEFT;
			left_n += countn[j];
			right_n -= countn[j];

			left_wt += wts[j];
			right_wt -= wts[j];

			left_tr += trs[j];
			right_tr -= trs[j];

			left_sum += wtsums[j];
			right_sum -= wtsums[j];

			left_tr_sum += trsums[j];
			right_tr_sum -= trsums[j];

			left_sqr_sum += wtsqrsums[j];
			right_sqr_sum -= wtsqrsums[j];

			left_tr_sqr_sum += wttrsqrsums[j];
			right_tr_sqr_sum -= wttrsqrsums[j];


			if (left_n >= edge && right_n >= edge &&
					(int) left_tr >= min_node_size &&
					(int) left_wt - (int) left_tr >= min_node_size &&
					(int) right_tr >= min_node_size &&
					(int) right_wt - (int) right_tr >= min_node_size) {

				left_temp = left_tr_sum / left_tr - (left_sum - left_tr_sum) / (left_wt - left_tr);
				left_tr_var = left_tr_sqr_sum / left_tr - left_tr_sum  * left_tr_sum / (left_tr * left_tr);
				left_con_var = (left_sqr_sum - left_tr_sqr_sum) / (left_wt - left_tr) 
					- (left_sum - left_tr_sum) * (left_sum - left_tr_sum)
					/ ((left_wt - left_tr) * (left_wt - left_tr));        
				left_effect = alpha * left_temp * left_temp * left_wt
					- (1 - alpha) * (1 + train_to_est_ratio) * left_wt 
					* (left_tr_var / left_tr + left_con_var / (left_wt - left_tr));

				right_temp = right_tr_sum / right_tr - (right_sum - right_tr_sum) / (right_wt - right_tr);
				right_tr_var = right_tr_sqr_sum / right_tr - 
					right_tr_sum * right_tr_sum / (right_tr * right_tr);
				right_con_var = (right_sqr_sum - right_tr_sqr_sum) / (right_wt - right_tr)
					- (right_sum - right_tr_sum) * (right_sum - right_tr_sum) 
					/ ((right_wt - right_tr) * (right_wt - right_tr));
				right_effect = alpha * right_temp * right_temp * right_wt
					- (1 - alpha) * (1 + train_to_est_ratio) * right_wt 
					* (right_tr_var / right_tr + right_con_var / (right_wt - right_tr)); 

				temp = left_effect + right_effect - node_effect;

				if (temp > best) {
					best = temp;
					if (left_temp > right_temp)
						for (i = 0; i < nclass; i++) csplit[i] = -tsplit[i];
					else
						for (i = 0; i < nclass; i++) csplit[i] = tsplit[i];
				}
			}
		}
		*improve = best;
	}
}
Ejemplo n.º 2
0
void
tstats(int n, double *y[], double *x, int nclass,
        int edge, double *improve, double *split, int *csplit,
        double myrisk, double *wt, double *treatment, int minsize, double alpha, 
        double train_to_est_ratio)
{
    int i, j;
    double temp;
    double left_sum, right_sum;
    double left_tr_sum, right_tr_sum;
    double left_tr_sqr_sum, right_tr_sqr_sum;
    double left_sqr_sum, right_sqr_sum;
    double tr_var, con_var;
    double left_tr_var, left_con_var, right_tr_var, right_con_var;
    double left_tr, right_tr;
    double left_wt, right_wt;
    int left_n, right_n;
    double best;
    int direction = LEFT;
    int where = 0;
    double node_effect, left_effect, right_effect;
    double left_var, right_var, sd;
    double left_temp, right_temp;
    int min_node_size = minsize;
    double improve_temp, improve_best;


    right_wt = 0.;
    right_tr = 0.;
    right_sum = 0.;
    right_tr_sum = 0.;
    right_sqr_sum = 0.;
    right_tr_sqr_sum = 0.;
    right_n = n;
    improve_temp = 0.;
    improve_best = 0.;
    for (i = 0; i < n; i++) {
        right_wt += wt[i];
        right_tr += wt[i] * treatment[i];
        right_sum += *y[i] * wt[i];
        right_tr_sum += *y[i] * wt[i] * treatment[i];
        right_sqr_sum += (*y[i]) * (*y[i]) * wt[i];
        right_tr_sqr_sum += (*y[i]) * (*y[i]) * wt[i] * treatment[i];
    }

    temp = right_tr_sum / right_tr - (right_sum - right_tr_sum) / (right_wt - right_tr);
    tr_var = right_tr_sqr_sum / right_tr - right_tr_sum * right_tr_sum 
        / (right_tr * right_tr);
    con_var = (right_sqr_sum - right_tr_sqr_sum) / (right_wt - right_tr)
        - (right_sum - right_tr_sum) * (right_sum - right_tr_sum) 
        / ((right_wt - right_tr) * (right_wt - right_tr));

    node_effect = alpha * temp * temp * n - (1 - alpha) * (1 + train_to_est_ratio) 
        * n * (tr_var / right_tr  + con_var / (right_wt - right_tr));

   
    if (nclass == 0) {
        /* continuous predictor */
        left_wt = 0.;
        left_tr = 0.;
        left_n = 0;
        left_sum = 0.;
        left_tr_sum = 0.;
        left_sqr_sum = 0.;
        left_tr_sqr_sum = 0.;

        best = 0.;

        for (i = 0; right_n > edge; i++) {
            left_wt += wt[i];
            right_wt -= wt[i];

            left_tr += wt[i] * treatment[i];
            right_tr -= wt[i] * treatment[i];

            left_n++;
            right_n--;

            temp = *y[i] * wt[i] * treatment[i];
            left_tr_sum += temp;
            right_tr_sum -= temp;

            left_sum += *y[i] * wt[i];
            right_sum -= *y[i] * wt[i];

            temp = (*y[i]) * (*y[i]) * wt[i] * treatment[i];
            left_tr_sqr_sum += temp;
            right_tr_sqr_sum -= temp;

            temp = (*y[i]) * (*y[i]) * wt[i];
            left_sqr_sum += temp;
            right_sqr_sum -= temp;

            if (x[i + 1] != x[i] && left_n >= edge &&
                    (int) left_tr >= min_node_size &&
                    (int) left_wt - (int) left_tr >= min_node_size &&
                    (int) right_tr >= min_node_size &&
                    (int) right_wt - (int) right_tr >= min_node_size) {

                left_temp = left_tr_sum / left_tr - (left_sum - left_tr_sum) / (left_wt - left_tr);
                left_tr_var = left_tr_sqr_sum / left_tr - left_tr_sum  * left_tr_sum 
                    / (left_tr * left_tr);
                left_con_var = (left_sqr_sum - left_tr_sqr_sum) / (left_wt - left_tr) 
                    - (left_sum - left_tr_sum) * (left_sum - left_tr_sum)
                    / ((left_wt - left_tr) * (left_wt - left_tr));        

                left_var = left_tr_var / left_tr + left_con_var / (left_wt - left_tr);
                left_effect = alpha * left_temp * left_temp * left_n
                  - (1 - alpha) * (1 + train_to_est_ratio) * left_n 
                    * (left_tr_var / left_tr + left_con_var / (left_wt - left_tr));

                right_temp = right_tr_sum / right_tr - (right_sum - right_tr_sum) / (right_wt - right_tr);
                right_tr_var = right_tr_sqr_sum / right_tr - right_tr_sum * right_tr_sum / (right_tr * right_tr);
                right_con_var = (right_sqr_sum - right_tr_sqr_sum) / (right_wt - right_tr)
                    - (right_sum - right_tr_sum) * (right_sum - right_tr_sum) / ((right_wt - right_tr) * (right_wt - right_tr));
                right_var = right_tr_var / right_tr + right_con_var / (right_wt - right_tr);
                right_effect = alpha * right_temp * right_temp * right_n
                    - (1 - alpha) * (1 + train_to_est_ratio) * right_n 
                    * (right_tr_var / right_tr + right_con_var / (right_wt - right_tr));    

                sd = sqrt(left_var / left_wt  + right_var / right_wt);
                temp = fabs(left_temp - right_temp) / sd;
                improve_temp = left_effect + right_effect - node_effect;

                if (temp > best) {
                    best = temp;
                    where = i;
                    improve_best = improve_temp;
                    if (left_temp < right_temp)
                        direction = LEFT;
                    else
                        direction = RIGHT;
                }             
            }
        }
       
        *improve = improve_best;
               
        
        if (improve_best > 0) {         /* found something */
            csplit[0] = direction;
            *split = (x[where] + x[where + 1]) / 2;
        }
    }

    /*
     * Categorical predictor
     */

    else {
        for (i = 0; i < nclass; i++) {
            countn[i] = 0;
            wts[i] = 0;
            trs[i] = 0;
            sums[i] = 0;
            wtsums[i] = 0;
            trsums[i] = 0;
            wtsqrsums[i] = 0;
            wttrsqrsums[i] = 0;
        }

        /* rank the classes by their mean y value */
        /* RANK THE CLASSES BY THEI */
        for (i = 0; i < n; i++) {
            j = (int) x[i] - 1;
            countn[j]++;
            wts[j] += wt[i];
            trs[j] += wt[i] * treatment[i];
            sums[j] += *y[i];
            wtsums[j] += *y[i] * wt[i];
            trsums[j] += *y[i] * wt[i] * treatment[i];
            wtsqrsums[j] += (*y[i]) * (*y[i]) * wt[i];
            wttrsqrsums[j] += (*y[i]) * (*y[i]) * wt[i] * treatment[i];
        }

        for (i = 0; i < nclass; i++) {
            if (countn[i] > 0) {
                tsplit[i] = RIGHT;
                mean[i] = sums[i] / wts[i];
                // mean[i] = sums[i] / countn[i];
                //Rprintf("countn[%d] = %d, mean[%d] = %f\n", i, countn[i], i, mean[i]);
            } else
                tsplit[i] = 0;
        }
        graycode_init2(nclass, countn, mean);

        /*
         * Now find the split that we want
         */

        left_wt = 0;
        left_tr = 0;
        left_n = 0;
        left_sum = 0;
        left_tr_sum = 0;
        left_sqr_sum = 0;
        left_tr_sqr_sum = 0;

        best = 0;
        where = 0;
        while ((j = graycode()) < nclass) {
            tsplit[j] = LEFT;
            left_n += countn[j];
            right_n -= countn[j];

            left_wt += wts[j];
            right_wt -= wts[j];

            left_tr += trs[j];
            right_tr -= trs[j];

            left_sum += wtsums[j];
            right_sum -= wtsums[j];

            left_tr_sum += trsums[j];
            right_tr_sum -= trsums[j];

            left_sqr_sum += wtsqrsums[j];
            right_sqr_sum -= wtsqrsums[j];

            left_tr_sqr_sum += wttrsqrsums[j];
            right_tr_sqr_sum -= wttrsqrsums[j];

            if (left_n >= edge && right_n >= edge &&
                    (int) left_tr >= min_node_size &&
                    (int) left_wt - (int) left_tr >= min_node_size &&
                    (int) right_tr >= min_node_size &&
                    (int) right_wt - (int) right_tr >= min_node_size) {

                left_temp = left_tr_sum / left_tr - (left_sum - left_tr_sum) / (left_wt - left_tr);
                left_tr_var = left_tr_sqr_sum / left_tr - left_tr_sum  * left_tr_sum / (left_tr * left_tr);
                left_con_var = (left_sqr_sum - left_tr_sqr_sum) / (left_wt - left_tr) 
                    - (left_sum - left_tr_sum) * (left_sum - left_tr_sum)/ ((left_wt - left_tr) * (left_wt - left_tr));   
                left_var = left_tr_var / left_tr + left_con_var / (left_wt - left_tr);
                left_effect = alpha * left_temp * left_temp * left_wt
                  - (1 - alpha) * (1 + train_to_est_ratio) * left_wt
                    * (left_tr_var / left_tr + left_con_var / (left_wt - left_tr));
           
                right_temp = right_tr_sum / right_tr - (right_sum - right_tr_sum) / (right_wt - right_tr);
                right_tr_var = right_tr_sqr_sum / right_tr - right_tr_sum * right_tr_sum / (right_tr * right_tr);
                right_con_var = (right_sqr_sum - right_tr_sqr_sum) / (right_wt - right_tr)
                    - (right_sum - right_tr_sum) * (right_sum - right_tr_sum) / ((right_wt - right_tr) * (right_wt - right_tr));
                right_var = right_tr_var / right_tr + right_con_var / (right_wt - right_tr);
                right_effect = alpha * right_temp * right_temp * right_wt
                - (1 - alpha) * (1 + train_to_est_ratio) 
                    * right_wt * (right_tr_var / right_tr + right_con_var / (right_wt - right_tr));

                sd = sqrt(left_var / left_wt  + right_var / right_wt);
                temp = fabs(left_temp - right_temp) / sd; 
                improve_temp = left_effect + right_effect - node_effect;
                
                if (temp > best) {
                    best = temp;
                    improve_best = improve_temp;
                }
            }
        }
        *improve = best;
        if (improve_best > 0) {
            if (left_temp > right_temp)
                for (i = 0; i < nclass; i++) csplit[i] = -tsplit[i];
            else
                for (i = 0; i < nclass; i++) csplit[i] = tsplit[i];
        }
    }
}
Ejemplo n.º 3
0
Archivo: dist.c Proyecto: cran/mvpart
void dist(int n,    double *y[],  FLOAT *x,     int nclass, 
       int edge, double *improve, FLOAT *split, int *csplit, double myrisk, double *wt)
    {
    int i, j, k, kj;
    double temp, sumdiffs_sq;
    double left_sum, right_sum;
/*    double left_wt, right_wt;  */
    int left_n, right_n;
    double best, total;
    int direction = LEFT;
    int where = 0;

    right_n = n;
        
    if (nclass==0) {
    
    left_n=0;
    best=0;
    
    total=0;
    for (k=1; k<n; k++)
    for (j=0; j<k; j++) 
    total += *y[rp.n*j-j*(j+1)/2+k-j-1];
    total = total/n;    

    for (i=0; right_n>edge; i++) {
	    temp=0; sumdiffs_sq=0; left_n++;  right_n--;
		right_sum=0; left_sum=0;

		if (i==0) left_sum=0;
		else {
			for (k=1; k<=i; k++)
			for (j=0; j<k; j++)  
			left_sum += *y[rp.n*j-j*(j+1)/2+k-j-1];
			left_sum = left_sum/(i+1);
		}
		
		if (i==(n-1)) right_sum=0;
		else {
			for (k=i+2; k<n; k++) 
			for (j=i+1; j<k; j++) 
			right_sum += *y[rp.n*j-j*(j+1)/2+k-j-1];
			right_sum = right_sum/(n-i-1);
		}

        if (x[i+1] !=x[i] &&  left_n>=edge) {
	        temp = total-left_sum-right_sum;

        if (temp > best) {
            best = temp;
            where = i;
            if (left_sum > right_sum) direction = LEFT;
                      else    direction = RIGHT;
            }
        }
    }

    *improve =  best/ myrisk;
    if (best>0) {   /* found something */
        csplit[0] = direction;
        *split = (x[where] + x[where+1]) /2;
        }
    }

    else {
    
/*
**  Do the easy coding for now - gd
**  Take countn and dsts as square matrices and fold them over
**  Fix it up later !!! 
*/
    for (i=0; i<nclass; i++) {
        count[i] =0;
    for (j=0; j<nclass; j++) {
        countn[i+nclass*j] =0;
        dsts[i+nclass*j] =0;
    }
    }

    k = x[0]-1;
    count[k]++;

    for (i=1; i<n; i++) {
    k = x[i]-1;
    count[k]++;
    for (j=0; j<i; j++) {
        kj = x[j]-1;    
        countn[k+nclass*kj]++;
            dsts[k+nclass*kj] += *y[rp.n*j-j*(j+1)/2+i-j-1];       
    }
    }

    for (i=0; i<nclass; i++) 
    for (j=0; j<=i; j++) {
    if (i!=j) {
        countn[i+nclass*j]=countn[i+nclass*j]+countn[j+nclass*i];
            dsts[i+nclass*j]=dsts[i+nclass*j]+dsts[j+nclass*i];    
    }
    }

        for (i=0; i<nclass; i++) {
        if (count[i]==0) tsplit[i] = 0;
        else tsplit[i] = RIGHT;
    }

    total = 0;
    for (k=0; k<nclass; k++) 
        if (tsplit[k]!=0) {
        for (j=0; j<=k; j++) 
            if (tsplit[j]!=0) 
        total += dsts[k+nclass*j];
    }

    /*
    ** Now find the split that we want
    */

    best = 0;
    /*
    ** Insert gray code bit here
    */

/*  if (numclass==2) graycode_init2(nclass, count, rate);
**              else graycode_init1(nclass, count);
**
**     Just use graycode_init1 here -- gd
*/

    graycode_init1(nclass, count);

    while((i=graycode()) < nclass) {

/* item i changes groups */

    left_n =0;  right_n = 0;
    left_sum = 0; right_sum = 0; 

    if (tsplit[i]==LEFT)  tsplit[i]=RIGHT;
    else tsplit[i]=LEFT;
        
    for (k=0; k<nclass; k++) 
        if (tsplit[k]==LEFT) {
        for (j=0; j<=k; j++) 
            if (tsplit[j]==LEFT)   {        
            left_n += countn[k+nclass*j];
            left_sum += dsts[k+nclass*j]; 
            }
        }
        else if (tsplit[k]==RIGHT) {
        for (j=0; j<=k; j++) 
            if (tsplit[j]==RIGHT)   {       
            right_n += countn[k+nclass*j];
            right_sum += dsts[k+nclass*j];   
                }
        }

    left_n = (int) (sqrt(2*left_n+0.25)+0.5);   
    right_n = (int) (sqrt(2*right_n+0.25)+0.5); 
    
    if (left_n>=edge  &&  right_n>=edge) {
    temp = total/n - left_sum/left_n - right_sum/right_n;

        if (temp > best) {
                best=temp;
                if (left_sum > right_sum)
                for (j=0; j<nclass; j++) csplit[j] = tsplit[j];
            else
                for (j=0; j<nclass; j++) csplit[j] = -tsplit[j];
                }
        }
        }
    }
    *improve = best / myrisk;      /* % improvement */

  }
Ejemplo n.º 4
0
void
tstatsD(int n, double *y[], double *x, int nclass,
        int edge, double *improve, double *split, int *csplit,
        double myrisk, double *wt, double *treatment, int minsize, double alpha, 
        int bucketnum, int bucketMax, double train_to_est_ratio)
{
    int i, j;
    double temp;
    double left_sum, right_sum;
    double left_tr_sum, right_tr_sum;
    double left_tr_sqr_sum, right_tr_sqr_sum;
    double left_sqr_sum, right_sqr_sum;
    double tr_var, con_var;
    double left_tr_var, left_con_var, right_tr_var, right_con_var;
    double left_tr, right_tr;
    double left_wt, right_wt;
    int left_n, right_n;
    double best;
    int direction = LEFT;
    int where = 0;
    double node_effect, left_effect, right_effect;
    double left_var, right_var, sd;
    double left_temp, right_temp;
    int min_node_size = minsize;
    double improve_temp, improve_best;
    int bucketTmp;
    double trsum = 0.;
    int Numbuckets;
    
    double *cum_wt, *tmp_wt, *fake_x;
    double tr_wt_sum, con_wt_sum, con_cum_wt, tr_cum_wt;
    
    right_wt = 0.;
    right_tr = 0.;
    right_sum = 0.;
    right_tr_sum = 0.;
    right_sqr_sum = 0.;
    right_tr_sqr_sum = 0.;
    right_n = n;
    
    improve_temp = 0.;
    improve_best = 0.;
    for (i = 0; i < n; i++) {
        right_wt += wt[i];
        right_tr += wt[i] * treatment[i];
        right_sum += *y[i] * wt[i];
        right_tr_sum += *y[i] * wt[i] * treatment[i];
        right_sqr_sum += (*y[i]) * (*y[i]) * wt[i];
        right_tr_sqr_sum += (*y[i]) * (*y[i]) * wt[i] * treatment[i];
        trsum += treatment[i];
    }
    
    temp = right_tr_sum / right_tr - (right_sum - right_tr_sum) / (right_wt - right_tr);
    tr_var = right_tr_sqr_sum / right_tr - right_tr_sum * right_tr_sum / (right_tr * right_tr);
    con_var = (right_sqr_sum - right_tr_sqr_sum) / (right_wt - right_tr)
        - (right_sum - right_tr_sum) * (right_sum - right_tr_sum) / ((right_wt - right_tr) * (right_wt - right_tr));
    //now we use t-statistic
    node_effect = alpha * temp * temp * right_wt
        - (1 - alpha) * (1 + train_to_est_ratio) * right_wt 
        * (tr_var / right_tr  + con_var / (right_wt - right_tr));

   
    if (nclass == 0) {
        /* continuous predictor */
        
        cum_wt = (double *) ALLOC(n, sizeof(double));
        tmp_wt = (double *) ALLOC(n, sizeof(double));
        fake_x = (double *) ALLOC(n, sizeof(double));
        
        tr_wt_sum = 0.;
        con_wt_sum = 0.;
        con_cum_wt = 0.;
        tr_cum_wt = 0.;  
        
        for (i = 0; i < n; i ++) {
            if (treatment[i] == 0) {
                con_wt_sum += wt[i];
            } else {
                tr_wt_sum += wt[i];
            }
            cum_wt[i] = 0.;
            tmp_wt[i] = 0.;
            fake_x[i] = 0.;
        }
        bucketTmp = min(round(trsum / (double)bucketnum), round(((double)n - trsum) / (double)bucketnum));
        Numbuckets = max(minsize, min(bucketTmp, bucketMax));
        
        for (i = 0; i < n; i++) {
            if (treatment[i] == 0) {
                tmp_wt[i] = wt[i] / con_wt_sum;     
                con_cum_wt += tmp_wt[i];
                cum_wt[i] = con_cum_wt;
                fake_x[i] = (int)floor(Numbuckets * cum_wt[i]);
            } else {
                tmp_wt[i] = wt[i] / tr_wt_sum;
                tr_cum_wt += tmp_wt[i];
                cum_wt[i] = tr_cum_wt;
                fake_x[i] = (int)floor(Numbuckets * cum_wt[i]);
            }
        }
        
        n_bucket = (int *) ALLOC(Numbuckets,  sizeof(int));
        n_tr_bucket = (int *) ALLOC(Numbuckets, sizeof(int));
        n_con_bucket = (int *) ALLOC(Numbuckets, sizeof(int));
        wts_bucket = (double *) ALLOC(Numbuckets, sizeof(double));
        trs_bucket = (double *) ALLOC(Numbuckets, sizeof(double));
        wtsums_bucket = (double *) ALLOC(Numbuckets, sizeof(double));
        trsums_bucket = (double *) ALLOC(Numbuckets, sizeof(double));
        wtsqrsums_bucket = (double *) ALLOC(Numbuckets, sizeof(double));
        trsqrsums_bucket = (double *) ALLOC(Numbuckets, sizeof(double));
        tr_end_bucket = (double *) ALLOC(Numbuckets, sizeof(double));
        con_end_bucket = (double *) ALLOC (Numbuckets, sizeof(double));
        
        for (j = 0; j < Numbuckets; j++) {
            n_bucket[j] = 0;
            n_tr_bucket[j] = 0;
            n_con_bucket[j] = 0;
            wts_bucket[j] = 0.;
            trs_bucket[j] = 0.;
            wtsums_bucket[j] = 0.;
            trsums_bucket[j] = 0.;
            wtsqrsums_bucket[j] = 0.;
            trsqrsums_bucket[j] = 0.;
        }
        
        for (i = 0; i < n; i++) {
            j = fake_x[i];
            n_bucket[j]++;
            wts_bucket[j] += wt[i];
            trs_bucket[j] += wt[i] * treatment[i];
            wtsums_bucket[j] += *y[i] * wt[i];
            trsums_bucket[j] += *y[i] * wt[i] * treatment[i];
            wtsqrsums_bucket[j] += (*y[i]) * (*y[i]) * wt[i];
            trsqrsums_bucket[j] += (*y[i]) * (*y[i]) * wt[i] * treatment[i];
            if (treatment[i] == 1) {
                tr_end_bucket[j] = x[i];
            } else {
                con_end_bucket[j] = x[i];
            }
        }
        
        left_wt = 0.;
        left_tr = 0.;
        left_n = 0;
        left_sum = 0.;
        left_tr_sum = 0.;
        left_sqr_sum = 0.;
        left_tr_sqr_sum = 0.;
        left_temp = 0.;
        right_temp = 0.;
        

        best = 0.;
        for (j = 0; j < Numbuckets; j++) {
            left_n += n_bucket[j];
            right_n -= n_bucket[j];
            left_wt += wts_bucket[j];
            right_wt -= wts_bucket[j];
            left_tr += trs_bucket[j]; 
            right_tr -= trs_bucket[j];
            
            left_sum += wtsums_bucket[j];
            right_sum -= wtsums_bucket[j];
            
            left_tr_sum += trsums_bucket[j];
            right_tr_sum -= trsums_bucket[j];
            
            left_sqr_sum += wtsqrsums_bucket[j];
            right_sqr_sum -= wtsqrsums_bucket[j];
            
            left_tr_sqr_sum += trsqrsums_bucket[j];
            right_tr_sqr_sum -= trsqrsums_bucket[j];
            
            if (left_n >= edge && right_n >= edge &&
                (int) left_tr >= min_node_size &&
                (int) left_wt - (int) left_tr >= min_node_size &&
                (int) right_tr >= min_node_size &&
                (int) right_wt - (int) right_tr >= min_node_size) {
                
                left_temp = left_tr_sum / left_tr 
                    - (left_sum - left_tr_sum) / (left_wt - left_tr);
                left_tr_var = left_tr_sqr_sum / left_tr 
                    - left_tr_sum  * left_tr_sum / (left_tr * left_tr);
                left_con_var = (left_sqr_sum - left_tr_sqr_sum) / (left_wt - left_tr) 
                    - (left_sum - left_tr_sum) * (left_sum - left_tr_sum)
                    / ((left_wt - left_tr) * (left_wt - left_tr));        
                
                left_var = left_tr_var / left_tr + left_con_var / (left_wt - left_tr);
                left_effect = alpha * left_temp * left_temp * left_wt
                    - (1 - alpha) * (1 + train_to_est_ratio) * left_wt 
                    * (left_tr_var / left_tr + left_con_var / (left_wt - left_tr));
                
                right_temp = right_tr_sum / right_tr - (right_sum - right_tr_sum) / (right_wt - right_tr);
                right_tr_var = right_tr_sqr_sum / right_tr
                    - right_tr_sum * right_tr_sum / (right_tr * right_tr);
                right_con_var = (right_sqr_sum - right_tr_sqr_sum) / (right_wt - right_tr)
                    - (right_sum - right_tr_sum) * (right_sum - right_tr_sum) 
                    / ((right_wt - right_tr) * (right_wt - right_tr));
                right_var = right_tr_var / right_tr + right_con_var / (right_wt - right_tr);
                right_effect = alpha * right_temp * right_temp * right_wt
                    - (1 - alpha) * (1 + train_to_est_ratio) * right_wt 
                    * (right_tr_var / right_tr + right_con_var / (right_wt - right_tr));    
                
                sd = sqrt(left_var / left_wt  + right_var / right_wt);
                temp = fabs(left_temp - right_temp) / sd;
                improve_temp = left_effect + right_effect - node_effect;
                if (temp > best) {
                    best = temp;
                    where = j;
                    improve_best = improve_temp;
                    if (left_temp < right_temp)
                        direction = LEFT;
                    else
                        direction = RIGHT;
                }
                
            }
        }
        
        *improve = improve_best;
        if (improve_best > 0) {
            csplit[0] = direction;
            *split = (tr_end_bucket[where] + con_end_bucket[where]) / 2;
        }
    } else {
        /*
         * Categorical predictor
         */
        for (i = 0; i < nclass; i++) {
            countn[i] = 0;
            wts[i] = 0;
            trs[i] = 0;
            sums[i] = 0;
            wtsums[i] = 0;
            trsums[i] = 0;
            wtsqrsums[i] = 0;
            wttrsqrsums[i] = 0;
        }

        /* rank the classes by their mean y value */
        /* RANK THE CLASSES BY THEI */
        for (i = 0; i < n; i++) {
            j = (int) x[i] - 1;
            countn[j]++;
            wts[j] += wt[i];
            trs[j] += wt[i] * treatment[i];
            sums[j] += *y[i];
            wtsums[j] += *y[i] * wt[i];
            trsums[j] += *y[i] * wt[i] * treatment[i];
            wtsqrsums[j] += (*y[i]) * (*y[i]) * wt[i];
            wttrsqrsums[j] += (*y[i]) * (*y[i]) * wt[i] * treatment[i];
        }

        for (i = 0; i < nclass; i++) {
            if (countn[i] > 0) {
                tsplit[i] = RIGHT;
                mean[i] = sums[i] / wts[i];
            } else
                tsplit[i] = 0;
        }
        graycode_init2(nclass, countn, mean);

        /*
         * Now find the split that we want
         */

        left_wt = 0;
        left_tr = 0;
        left_n = 0;
        left_sum = 0;
        left_tr_sum = 0;
        left_sqr_sum = 0;
        left_tr_sqr_sum = 0;

        best = 0;
        where = 0;
        while ((j = graycode()) < nclass) {
            tsplit[j] = LEFT;
            left_n += countn[j];
            right_n -= countn[j];

            left_wt += wts[j];
            right_wt -= wts[j];

            left_tr += trs[j];
            right_tr -= trs[j];

            left_sum += wtsums[j];
            right_sum -= wtsums[j];

            left_tr_sum += trsums[j];
            right_tr_sum -= trsums[j];

            left_sqr_sum += wtsqrsums[j];
            right_sqr_sum -= wtsqrsums[j];

            left_tr_sqr_sum += wttrsqrsums[j];
            right_tr_sqr_sum -= wttrsqrsums[j];

            if (left_n >= edge && right_n >= edge &&
                    (int) left_tr >= min_node_size &&
                    (int) left_wt - (int) left_tr >= min_node_size &&
                    (int) right_tr >= min_node_size &&
                    (int) right_wt - (int) right_tr >= min_node_size) {

                left_temp = left_tr_sum / left_tr - (left_sum - left_tr_sum) / (left_wt - left_tr);
                left_tr_var = left_tr_sqr_sum / left_tr - left_tr_sum  * left_tr_sum / (left_tr * left_tr);
                left_con_var = (left_sqr_sum - left_tr_sqr_sum) / (left_wt - left_tr) 
                    - (left_sum - left_tr_sum) * (left_sum - left_tr_sum)/ ((left_wt - left_tr) * (left_wt - left_tr));   
                left_var = left_tr_var / left_tr + left_con_var / (left_wt - left_tr);
                left_effect = alpha * left_temp * left_temp * left_wt
                  - (1 - alpha) * (1 + train_to_est_ratio) * left_wt
                    * (left_tr_var / left_tr + left_con_var / (left_wt - left_tr));

                //Rprintf("left_sum = %f, left_wt_sum = %f, left_wt = %f, left_n = %d\n", left_sum, left_wt_sum, left_wt, left_n);             
                right_temp = right_tr_sum / right_tr - (right_sum - right_tr_sum) / (right_wt - right_tr);
                right_tr_var = right_tr_sqr_sum / right_tr - right_tr_sum * right_tr_sum / (right_tr * right_tr);
                right_con_var = (right_sqr_sum - right_tr_sqr_sum) / (right_wt - right_tr)
                    - (right_sum - right_tr_sum) * (right_sum - right_tr_sum) / ((right_wt - right_tr) * (right_wt - right_tr));
                right_var = right_tr_var / right_tr + right_con_var / (right_wt - right_tr);
                right_effect = alpha * right_temp * right_temp * right_wt
                - (1 - alpha) * (1 + train_to_est_ratio) * right_wt 
                    * (right_tr_var / right_tr + right_con_var / (right_wt - right_tr)); 

                sd = sqrt(left_var / left_wt  + right_var / right_wt);
                temp = fabs(left_temp - right_temp) / sd; 
                improve_temp = left_effect + right_effect - node_effect;
                if (temp > best) {
                    best = temp;
                    improve_best = improve_temp;
                }
            }
        }
        *improve = best;
        if (improve_best > 0) {
            if (left_temp > right_temp)
                for (i = 0; i < nclass; i++) csplit[i] = -tsplit[i];
            else
                for (i = 0; i < nclass; i++) csplit[i] = tsplit[i];
        }
    }
}
Ejemplo n.º 5
0
void totD(int n, double *y[], double *x, int nclass, int edge, double *improve, 
         double *split, int *csplit, double myrisk, double *wt, double *treatment, 
         double propensity, int minsize, int bucketnum, int bucketMax)
{
    int i, j;
    double temp;
    double left_sum, right_sum;
    double left_mean, right_mean;
    double left_wt, right_wt;
    int left_n, right_n;
    double left_tr, right_tr;
    double grandmean, best;
    int direction = LEFT;
    int where = 0;
    double ystar;
    int min_node_size = minsize;

    int bucketTmp;
    double trsum = 0.;
    int Numbuckets;
    
    double *cum_wt, *tmp_wt, *fake_x;
    double tr_wt_sum, con_wt_sum, con_cum_wt, tr_cum_wt;
    
    // for overlap:
    double tr_min, tr_max, con_min, con_max;
    double left_bd, right_bd;
    double cut_point;
    
    right_wt = 0.;
    right_n = n;
    right_tr = 0.;
    right_sum = 0.;
    trsum = 0.;
    for (i = 0; i < n; i++) {
        ystar = *y[i] * (treatment[i] - propensity) / (propensity * (1 - propensity));
        right_sum += ystar * wt[i];
        right_wt += wt[i];
        right_tr += treatment[i] * wt[i];
        trsum += treatment[i];
    }
    grandmean = right_sum / right_wt;
    
    
    if(nclass == 0) {
      Rprintf("totd: inside cont. split\n");
        cum_wt = (double *) ALLOC(n, sizeof(double));
        tmp_wt = (double *) ALLOC(n, sizeof(double));
        fake_x = (double *) ALLOC(n, sizeof(double));
        
        tr_wt_sum = 0.;
        con_wt_sum = 0.;
        con_cum_wt = 0.;
        tr_cum_wt = 0.;
        
        // find the abs max and min of x:
        double max_abs_tmp = fabs(x[0]);
        for (i = 0; i < n; i++) {
            if (max_abs_tmp < fabs(x[i])) {
                max_abs_tmp = fabs(x[i]);
            }
        }
        
        // set tr_min, con_min, tr_max, con_max to a large/small value
        tr_min = max_abs_tmp;
        tr_max = -max_abs_tmp;
        con_min = max_abs_tmp;
        con_max = -max_abs_tmp;
        
        for (i = 0; i < n; i++) {
            if (treatment[i] == 0) {
                con_wt_sum += wt[i];
                if (con_min > x[i]) {
                    con_min = x[i];
                }
                if (con_max < x[i]) {
                    con_max = x[i];
                }
            } else {
                tr_wt_sum += wt[i];
                if (tr_min > x[i]) {
                    tr_min = x[i];
                }
                if (tr_max < x[i]) {
                    tr_max = x[i];
                }
            }
            cum_wt[i] = 0.;
            tmp_wt[i] = 0.;
            fake_x[i] = 0.;
        }
        
        // compute the left bound and right bound
        left_bd = max(tr_min, con_min);
        right_bd = min(tr_max, con_max);
        
        int test1 = round(trsum / (double)bucketnum);
        int test2 = round(((double)n - trsum) / (double)bucketnum);
        bucketTmp = min(test1, test2);
        Numbuckets = max(minsize, min(bucketTmp, bucketMax));
        
        
        n_bucket = (int *) ALLOC(Numbuckets + 1,  sizeof(int));
        wts_bucket = (double *) ALLOC(Numbuckets + 1, sizeof(double));
        trs_bucket = (double *) ALLOC(Numbuckets + 1, sizeof(double));
        tr_end_bucket = (double *) ALLOC(Numbuckets + 1, sizeof(double));
        con_end_bucket = (double *) ALLOC (Numbuckets + 1, sizeof(double));
        wtsums_bucket = (double *) ALLOC (Numbuckets + 1, sizeof(double));
        
        
        for (i = 0; i < n; i++) {
            if (treatment[i] == 0) {
                tmp_wt[i] = wt[i] / con_wt_sum;     
                con_cum_wt += tmp_wt[i];
                cum_wt[i] = con_cum_wt;
                fake_x[i] = (int)floor(Numbuckets * cum_wt[i]);
            } else {
                tmp_wt[i] = wt[i] / tr_wt_sum;
                tr_cum_wt += tmp_wt[i];
                cum_wt[i] = tr_cum_wt;
                fake_x[i] = (int)floor(Numbuckets * cum_wt[i]);
            }
        }
        
        for (j = 0; j < Numbuckets; j++) {
            n_bucket[j] = 0;
            wts_bucket[j] = 0.;
            trs_bucket[j] = 0.;
            wtsums_bucket[j]  = 0.;
        }
        
        for (i = 0; i < n; i++) {
            j = fake_x[i];
            n_bucket[j]++;
            wts_bucket[j] += wt[i];
            trs_bucket[j] += wt[i] * treatment[i];
            ystar = *y[i] * (treatment[i] - propensity) / (propensity * (1 - propensity));
            wtsums_bucket[j] += ystar * wt[i];
            if (treatment[i] == 1) {
                tr_end_bucket[j] = x[i];
            } else {
                con_end_bucket[j] = x[i];
            }
        }
        
        left_sum = 0;
        left_wt = 0;
        left_n = 0;
        left_tr = 0.;
        best = 0;
        
        for (j = 0; j < Numbuckets; j++) {
            
            left_n += n_bucket[j];
            right_n -= n_bucket[j];
            left_wt += wts_bucket[j];
            right_wt -= wts_bucket[j];
            left_tr += trs_bucket[j]; 
            right_tr -= trs_bucket[j];
            
            left_sum += wtsums_bucket[j];
            right_sum -= wtsums_bucket[j];
            
            cut_point = (tr_end_bucket[j] + con_end_bucket[j]) / 2.0;
            
            if (left_n >= edge && right_n >= edge &&
                (int) left_tr >= min_node_size &&
                (int) left_wt - (int) left_tr >= min_node_size &&
                (int) right_tr >= min_node_size &&
                (int) right_wt - (int) right_tr >= min_node_size &&
                cut_point < right_bd && cut_point > left_bd) {
                
                left_mean = left_sum / left_wt;
                right_mean = right_sum / right_wt;
                temp = left_wt * (grandmean - left_mean) * (grandmean - left_mean) + 
                       right_wt * (grandmean - right_mean) * (grandmean - right_mean);  
                
                if (temp > best) {
                    best = temp;
                    where = j;
                    if (left_sum < right_sum)
                        direction = LEFT;
                    else
                        direction = RIGHT;
                }
                
            }
        }
        
        *improve = best / myrisk;
        if (best > 0) {
            csplit[0] = direction;
            *split = (tr_end_bucket[where] + con_end_bucket[where]) / 2;
        }
    } else {
        /*
         * Categorical Predictor
         */
        Rprintf("totd: inside factor split!\n");
      Rprintf("nclass:%d\n",nclass);
      
        for (i = 0; i < nclass; i++) {
            countn[i] = 0;
            wts[i] = 0;
            trs[i] = 0;
            sums[i] = 0;
            wtsums[i] = 0;
            trsums[i] = 0;
            wtsqrsums[i] = 0;
            wttrsqrsums[i] = 0;
        }
        
        /* rank the classes by treatment effect */
        for (i = 0; i < n; i++) {
            j = (int) x[i] - 1;
            countn[j]++;
            wts[j] += wt[i];
            trs[j] += wt[i] * treatment[i];
            sums[j] += *y[i];
            wtsums[j] += *y[i] * wt[i];
            trsums[j] += *y[i] * wt[i] * treatment[i];
            wtsqrsums[j] += (*y[i]) * (*y[i]) * wt[i];
            wttrsqrsums[j] += (*y[i]) * (*y[i]) * wt[i] * treatment[i];
        }
        
        for (i = 0; i < nclass; i++) {
            if (countn[i] > 0) {
                tsplit[i] = RIGHT;
                treatment_effect[i] = trsums[j] / trs[j] - (wtsums[j] - trsums[j]) / (wts[j] - trs[j]);
            } else
                tsplit[i] = 0;
        }
        graycode_init2(nclass, countn, treatment_effect);
  
        
        /*
         * Now find the split that we want
         */
        left_wt = 0;
        left_sum = 0;
        right_sum = 0;
        left_n = 0;
        best = 0;
        where = 0;
        while ((j = graycode()) < nclass) {
            tsplit[j] = LEFT;
            left_n += countn[j];
            right_n -= countn[j];
            left_wt += wts[j];
            right_wt -= wts[j];
            left_sum += sums[j];
            right_sum -= sums[j];
            Rprintf("j=%d,sums[j]=%f\n",j,sums[j]);
            Rprintf("left_sum=%f,right_sum=%f\n",left_sum,right_sum);
            //if (left_n >= edge && right_n >= edge &&
            //  (int) left_tr >= min_node_size &&
            //   (int) left_wt - (int) left_tr >= min_node_size &&
            //   (int) right_tr >= min_node_size &&
            //   (int) right_wt - (int) right_tr >= min_node_size)
            if (left_n >= edge && right_n >= edge) {
                temp = left_sum * left_sum / left_wt +
                    right_sum * right_sum / right_wt;
              Rprintf("temp=%f\n",temp);
              Rprintf("best=%f\n",best);
              Rprintf("left_sum_fin=%f,left_wt=%f,left_tr=%f,right_sum_fin=%f,right_wt=%f,right_tr=%f,min_node_size=%d\n",left_sum,left_wt,left_tr,right_sum,right_wt,right_tr,min_node_size);
                if (temp > best) {
                    best = temp;
                  Rprintf("tot factor best:%f\n",best);
                    if ((left_sum / left_wt) > (right_sum / right_wt)) {
                        for (i = 0; i < nclass; i++) csplit[i] = -tsplit[i];
                    } else {
                        for (i = 0; i < nclass; i++) csplit[i] = tsplit[i];
                    }
                }
            }
        }
        
        *improve = best / myrisk;  /* improvement */
    }
}
Ejemplo n.º 6
0
Archivo: gini.c Proyecto: csilles/cxxr
/*
 * The gini splitting function.  Find that split point in x such that
 *  the rss within the two groups is decreased as much
 *  as possible.
 */
void
gini(int n, double *y[], double *x, int numcat,
     int edge, double *improve, double *split, int *csplit, double my_risk,
     double *wt)
{
    int i, j, k;
    double lwt, rwt;
    int rtot, ltot;
    int direction = LEFT, where = 0;
    double total_ss, best, temp, p;
    double lmean, rmean;        /* used to decide direction */

    for (i = 0; i < numclass; i++) {
	left[i] = 0;
	right[i] = 0;
    }
    lwt = 0;
    rwt = 0;
    rtot = 0;
    ltot = 0;
    for (i = 0; i < n; i++) {
	j = (int) *y[i] - 1;
	rwt += aprior[j] * wt[i];  /* altered weight = prior * case_weight */
	right[j] += wt[i];
	rtot++;
    }
    total_ss = 0;
    for (i = 0; i < numclass; i++) {
	temp = aprior[i] * right[i] / rwt;      /* p(class=i, given node A) */
	total_ss += rwt * (*impurity) (temp);   /* p(A) * I(A) */
    }
    best = total_ss;  /* total weight of right * impurity of right + 0 *0 */

    /*
     * at this point we split into 2 disjoint paths
     */
    if (numcat > 0)
	goto categorical;

    for (i = 0; rtot > edge; i++) {
	j = (int) *y[i] - 1;
	rwt -= aprior[j] * wt[i];
	lwt += aprior[j] * wt[i];
	rtot--;
	ltot++;
	right[j] -= wt[i];
	left[j] += wt[i];

	if (x[i + 1] != x[i] && (ltot >= edge)) {
	    temp = 0;
	    lmean = 0;
	    rmean = 0;
	    for (j = 0; j < numclass; j++) {
		p = aprior[j] * left[j] / lwt;  /* p(j | left) */
		temp += lwt * (*impurity) (p);  /* p(left) * I(left) */
		lmean += p * j;
		p = aprior[j] * right[j] / rwt; /* p(j | right) */
		temp += rwt * (*impurity) (p);  /* p(right) * I(right) */
		rmean += p * j;
	    }
	    if (temp < best) {
		best = temp;
		where = i;
		direction = lmean < rmean ? LEFT : RIGHT;
	    }
	}
    }

    *improve = total_ss - best;
    if (*improve > 0) {         /* found something */
	csplit[0] = direction;
	*split = (x[where] + x[where + 1]) / 2;
    }
    return;

categorical:;
    /*
     * First collapse the data into a numclass x numcat array
     *  ccnt[i][j] = number of class i obs, category j of the predictor
     */
    for (j = 0; j < numcat; j++) {
	awt[j] = 0;
	countn[j] = 0;
	for (i = 0; i < numclass; i++)
	    ccnt[i][j] = 0;
    }
    for (i = 0; i < n; i++) {
	j = (int) *y[i] - 1;
	k = (int) x[i] - 1;
	awt[k] += aprior[j] * wt[i];
	countn[k]++;
	ccnt[j][k] += wt[i];
    }

    for (i = 0; i < numcat; i++) {
	if (awt[i] == 0)
	    tsplit[i] = 0;
	else {
	    rate[i] = ccnt[0][i] / awt[i];      /* a scratch array */
	    tsplit[i] = RIGHT;
	}
    }

    if (numclass == 2)
	graycode_init2(numcat, countn, rate);
    else
	graycode_init1(numcat, countn);

    while ((i = graycode()) < numcat) {
       /* item i changes groups */
	if (tsplit[i] == LEFT) {
	    tsplit[i] = RIGHT;
	    rwt += awt[i];
	    lwt -= awt[i];
	    rtot += countn[i];
	    ltot -= countn[i];
	    for (j = 0; j < numclass; j++) {
		right[j] += ccnt[j][i];
		left[j] -= ccnt[j][i];
	    }
	} else {
	    tsplit[i] = LEFT;
	    rwt -= awt[i];
	    lwt += awt[i];
	    rtot -= countn[i];
	    ltot += countn[i];
	    for (j = 0; j < numclass; j++) {
		right[j] -= ccnt[j][i];
		left[j] += ccnt[j][i];
	    }
	}

	if (ltot >= edge && rtot >= edge) {
	    temp = 0;
	    lmean = 0;
	    rmean = 0;
	    for (j = 0; j < numclass; j++) {
		p = aprior[j] * left[j] / lwt;
		temp += lwt * (*impurity) (p);
		lmean += p * j;
		p = aprior[j] * right[j] / rwt; /* p(j | right) */
		temp += rwt * (*impurity) (p);  /* p(right) * I(right) */
		rmean += p * j;
	    }
	    if (temp < best) {
		best = temp;
		if (lmean < rmean)
		    for (j = 0; j < numcat; j++) csplit[j] = tsplit[j];
		else
		    for (j = 0; j < numcat; j++) csplit[j] = -tsplit[j];
	    }
	}
    }
    *improve = total_ss - best;
}