コード例 #1
0
ファイル: Node.cpp プロジェクト: JiehuaChen/BART
void Node::CopyTree(Node *copy)
{

	int i;
	
	copy->Top = Top;
	copy->Bot = Bot;
	copy->Nog = Nog;

	for(i=1;i<=NumX;i++) copy->VarAvail[i] = VarAvail[i];

	if(!Bot) {
		CopyRule(&rule,&(copy->rule));
		Node *Left,*Right;
		Left = new Node;
		Right = new Node;
		copy->LeftC = Left;
		copy->RightC = Right;
		LeftC->CopyTree(Left);
		RightC->CopyTree(Right);
		Left->Parent = copy;
		Right->Parent = copy;
	}
	if (Top) {
		copy->SetData();
	}
}
コード例 #2
0
ファイル: BirthDeath.cpp プロジェクト: cran/BayesTree
double BirthDeath(Node *top,int *BD,int *Done)
//does either a birth or death step
//top: top of tree
//BD: on exit, 1 if birth , 0 if death
//Done: on exit, 1 if step taken , 0 otherwise
{

	double PGn,Pbot,PBx,PGl,PGr,PDy,Pnog;
	double PDx,PBy;
	double temprob;


	int VarI;
	int LeftEx,RightEx;
	double alpha1,alpha2,alpha;


	double Ly,Lx;

	Rule *rule=new Rule;

	Node *n,*tempnode;

	PBx=PBirth(top,&n,&Pbot);

	
	
	
	if(Bern(PBx)) {
	
		*BD=1;

		
		PGn = PGrow(n);	
	
		Lx = LogLT(n,top);

		VarI = DrPriVar(n);//draw variable
		DrPriRule(VarI,n,&LeftEx,&RightEx);//draw rule
		SpawnChildren(n,LeftEx,RightEx); //create children

		// this omitted because we will implement it at the metrop level
		/*if ((((n->LeftC)->DataList).length < 5) || (((n->RightC)->DataList).length < 5)){
			// back out if we have less than 5 obs in either new kid; clean up first
			KillChildren(n);
			return -1;
		}*/

		PGl = PGrow(n->LeftC);
		PGr = PGrow(n->RightC);
		Ly = LogLT(n,top);
		
		Pnog = 1.0/((double)(top->NumNogNodes()));
	
		PDy = 1.0-PBirth(top,&tempnode,&temprob);

		alpha1 = (PGn*(1.0-PGl)*(1.0-PGr)*PDy*Pnog)/((1.0-PGn)*PBx*Pbot);
		alpha2 = alpha1*exp(Ly-Lx);
		alpha = min(1.0,alpha2);
			
		if(Bern(alpha)) {		
			*Done=1;
		} else {
			
			KillChildren(n);
			*Done=0;
		}
	} else {
			
			*BD=0;
			PDx=1-PBx;
			Pnog = DrNogNode(top,&n);
			PGl = PGrow(n->LeftC);
			PGr = PGrow(n->RightC);
			
			Lx = LogLT(n,top);

			CopyRule(&(n->rule),rule);
			LeftEx=1-((n->LeftC)->VarAvail[(n->rule).Var]);
			RightEx=1-((n->RightC)->VarAvail[(n->rule).Var]);

			KillChildren(n);
		
			Ly = LogLT(n,top);
			PBy = PBirth(top,&tempnode,&temprob);
			PGn = PGrow(n);
			Pbot=PrBotNode(top,n);
			alpha1 =((1.0-PGn)*PBy*Pbot)/(PGn*(1.0-PGl)*(1.0-PGr)*PDx*Pnog);
			alpha2 = alpha1*exp(Ly-Lx);
			alpha = min(1,alpha2);
			
			if(Bern(alpha)) {
				
				*Done=1;
			} else {
				//put back rule and children
				
				CopyRule(rule,&(n->rule));
				SpawnChildren(n,LeftEx,RightEx);
				
				*Done=0;

			}
			
	}
	delete rule;

	return alpha;

	
}
コード例 #3
0
ファイル: ChangeRule.cpp プロジェクト: cran/BayesTree
double ChangeRule(Node *top,int *Done)
// step which tries changing the rule 
{

	int i,j;
	double XLogPi,XLogL,YLogPi,YLogL;
	int ruleI;
	
	
	double alpha;
	double u;
	int Nnotbot;
	NodeP *notbotvec;
	
	// get list of nodes with rule = nodes which are not bottom
	MakeNotBotVec(top,&notbotvec,&Nnotbot);
	if(Nnotbot==0) {
		delete [] notbotvec;
		return -1;
	}
	
	// randomly choose a notbot node = cnode
	//u=ran1(&idum);
	u= unif_rand();
	int NodeI =  (int)floor(u*Nnotbot)+1;
	Node *cnode = notbotvec[NodeI];

	//given the node, choose a new variable for the new rule
	int YVarI = DrPriVar(cnode);

	// if new var is CAT do one thing, if ORD another
	if(VarType[YVarI]==CAT) {
		
		// get the list of good cat rules given var choice
		int firstone;
		int NR = RuleNum[YVarI];
		int numr = (int)pow(2.0,NR-1)-1;
		int *RuleInd = new int [numr+1];
		FindGoodCatRules(cnode,YVarI,RuleInd,firstone);
		int sum = 0;
		for(i=1;i<=numr;i++) sum += RuleInd[i];
		
		//if there are any good cat rules
		if(sum) {
			
			// draw the rule from list of good ones
			//u=ran1(&idum);
                        u = unif_rand();
			ruleI = (int)floor(u*sum)+1;
			ruleI = GetSkipBadInd(numr,RuleInd,ruleI);

			//get logpri and logL from current tree (X)
			XLogPi = LogPriT(top);
			XLogL = LogLT(cnode,top);

			// copy old rule
			Rule rule;
			CopyRule(&(cnode->rule),&rule);
			

			// change rule at cnode to the new one
			int *sel = new int [NR-1+1];
			indtd(NR-1,ruleI-1,sel);
			(cnode->rule).Var = YVarI;
			delete [] (cnode->rule).CatRule;
			(cnode->rule).CatRule = new int [NR+1];
			for(j=1;j<firstone;j++) (cnode->rule).CatRule[j]=sel[j];
			(cnode->rule).CatRule[firstone]=1;
			for(j=(firstone+1);j<=NR;j++) (cnode->rule).CatRule[j] = sel[j-1];
			
			//fix data at nodes below cnode given new rule
			FixDataBelow(cnode);
			
			//  fix VarAvail
			UpDateVarAvail(cnode,YVarI);
			if(!(YVarI==rule.Var)) UpDateVarAvail(cnode,rule.Var);
			
			

			//get logpri and logL from candidate tree (Y)
			YLogPi = LogPriT(top);
			YLogL = LogLT(cnode,top);
			
			//draw go nogo
			alpha = min(1.0,exp(YLogPi+YLogL-XLogPi-XLogL));
			if(Bern(alpha)) {
				
				
				*Done=1;
				


			} else {

				// if nogo put rule, data, and VarAvail back
				CopyRule(&rule,&(cnode->rule));
				FixDataBelow(cnode);

				//  fix VarAvail
				UpDateVarAvail(cnode,YVarI);
				if(!(YVarI==rule.Var)) UpDateVarAvail(cnode,rule.Var);
				
				*Done=0;
			}

			
			delete [] sel;

		
		} else {

			// if no rules for that var abort step
			alpha = -1;
		}

		delete [] RuleInd;


	} else {

		//ORD variable
		
		// get the set of good rules = [l,r]
		int l,r;
		FindGoodOrdRules(cnode,YVarI,l,r);
		int numsplit = r-l+1;

		// if there are any rules
		if(numsplit>0) {

			//draw the rule
			//u=ran1(&idum);
                        u = unif_rand();
			ruleI = l+(int)floor(u*numsplit);

			//get logpri and logL from current tree (X)
			XLogPi = LogPriT(top);
			XLogL = LogLT(cnode,top);
			
			// copy old rule
			int XVarI = (cnode->rule).Var;
			int XOrdRule = (cnode->rule).OrdRule;
			
			// change rule at cnode to the new one
			(cnode->rule).Var = YVarI;
			(cnode->rule).OrdRule = ruleI;
			
			//fix data at nodes below cnode given new rule
			FixDataBelow(cnode);

			UpDateVarAvail(cnode,YVarI);
			if(!(YVarI==XVarI)) UpDateVarAvail(cnode,XVarI);

			//get logpri and logL from candidate tree (Y)
			YLogPi = LogPriT(top);
			YLogL = LogLT(cnode,top);
			
			//draw go nogo
			alpha = min(1.0,exp(YLogPi+YLogL-XLogPi-XLogL));
			if(Bern(alpha)) {	
				// if go fix VarAvail
				*Done=1;
				
			} else {
				// if nogo put rule and data back
				(cnode->rule).Var = XVarI;
				(cnode->rule).OrdRule = XOrdRule;

				FixDataBelow(cnode);

				UpDateVarAvail(cnode,YVarI);
				if(!(YVarI==XVarI)) UpDateVarAvail(cnode,XVarI);

				*Done=0;
			}

		} else {
			// if no rules for that var abort step
			alpha=-1;
		}

	}
	delete [] notbotvec;
	return alpha; // note -1 means backed out
}