コード例 #1
0
ファイル: gbt_train.cpp プロジェクト: colinsongf/TreeExtra
int main(int argc, char* argv[])
{	
	try{

//1. Analyze input parameters
	//convert input parameters to string from char*
	stringv args(argc); 
	for(int argNo = 0; argNo < argc; argNo++)
		args[argNo] = string(argv[argNo]);
	
	//check that the number of arguments is even (flags + value pairs)
	if(argc % 2 == 0)
		throw INPUT_ERR;

#ifndef _WIN32
	int threadN = 6;	//number of threads
#endif

	TrainInfo ti; //model training parameters
	int topAttrN = 0;  //how many top attributes to output and keep in the cut data 
						//(0 = do not do feature selection)
						//(-1 = output all available features)

	//parse and save input parameters
	//indicators of presence of required flags in the input
	bool hasTrain = false;
	bool hasVal = false; 
	bool hasAttr = false; 

	int treeN = 100;
	double shrinkage = 0.01;
	double subsample = -1;

	for(int argNo = 1; argNo < argc; argNo += 2)
	{
		if(!args[argNo].compare("-t"))
		{
			ti.trainFName = args[argNo + 1];
			hasTrain = true;
		}
		else if(!args[argNo].compare("-v"))
		{
			ti.validFName = args[argNo + 1];
			hasVal = true;
		}
		else if(!args[argNo].compare("-r"))
		{
			ti.attrFName = args[argNo + 1];
			hasAttr = true;
		}
		else if(!args[argNo].compare("-a"))
			ti.alpha = atofExt(argv[argNo + 1]);
		else if(!args[argNo].compare("-n"))
			treeN = atoiExt(argv[argNo + 1]);
		else if(!args[argNo].compare("-i"))
			ti.seed = atoiExt(argv[argNo + 1]);
		else if(!args[argNo].compare("-k"))
			topAttrN = atoiExt(argv[argNo + 1]);
		else if(!args[argNo].compare("-sh"))
			shrinkage = atofExt(argv[argNo + 1]);
		else if(!args[argNo].compare("-sub"))
			subsample = atofExt(argv[argNo + 1]);
		else if(!args[argNo].compare("-c"))
		{
			if(!args[argNo + 1].compare("roc"))
				ti.rms = false;
			else if(!args[argNo + 1].compare("rms"))
				ti.rms = true;
			else
				throw INPUT_ERR;
		}
		else if(!args[argNo].compare("-h"))
#ifndef _WIN32 
			threadN = atoiExt(argv[argNo + 1]);
#else
			throw WIN_ERR;
#endif
		else
			throw INPUT_ERR;
	}//end for(int argNo = 1; argNo < argc; argNo += 2) //parse and save input parameters

	if(!(hasTrain && hasVal && hasAttr))
		throw INPUT_ERR;

	if((ti.alpha < 0) || (ti.alpha > 1))
		throw ALPHA_ERR;

//1.a) Set log file
	LogStream clog;
	LogStream::init(true);
	clog << "\n-----\ngbt_train ";
	for(int argNo = 1; argNo < argc; argNo++)
		clog << argv[argNo] << " ";
	clog << "\n\n";

//1.b) Initialize random number generator. 
	srand(ti.seed);

//2. Load data
	INDdata data(ti.trainFName.c_str(), ti.validFName.c_str(), ti.testFName.c_str(), 
				 ti.attrFName.c_str());
	CTree::setData(data);
	CTreeNode::setData(data);

//2.a) Start thread pool
#ifndef _WIN32
	TThreadPool pool(threadN);
	CTree::setPool(pool);
#endif

//------------------
	int attrN = data.getAttrN();
	if(topAttrN == -1)
		topAttrN = attrN;
	idpairv attrCounts;	//counts of attribute importance
	bool doFS = (topAttrN != 0);	//whether feature selection is requested
	if(doFS)
	{//initialize attrCounts
		attrCounts.resize(attrN);
		for(int attrNo = 0; attrNo < attrN; attrNo++)
		{
			attrCounts[attrNo].first = attrNo;	//number of attribute	
			attrCounts[attrNo].second = 0;		//counts
		}
	}

	fstream frmscurve("boosting_rms.txt", ios_base::out); //bagging curve (rms)
	frmscurve.close();
	fstream froccurve;
	if(!ti.rms)
	{
		froccurve.open("boosting_roc.txt", ios_base::out); //bagging curve (roc) 
		froccurve.close();
	}

	doublev validTar;
	int validN = data.getTargets(validTar, VALID);

	doublev trainTar;
	int trainN = data.getTargets(trainTar, TRAIN);

	int sampleN;
	if(subsample == -1)
		sampleN = trainN;
	else
		sampleN = (int) (trainN * subsample);
	
	doublev validPreds(validN, 0);
	doublev trainPreds(trainN, 0);
	
	for(int treeNo = 0; treeNo < treeN; treeNo++)
	{
		if(treeNo % 10 == 0)
			cout << "\titeration " << treeNo + 1 << " out of " << treeN << endl;

		if(subsample == -1)
			data.newBag();
		else
			data.newSample(sampleN);

		CTree tree(ti.alpha);
		tree.setRoot();
		tree.resetRoot(trainPreds);
		idpairv stub;
		tree.grow(doFS, attrCounts);

		//update predictions
		for(int itemNo = 0; itemNo < trainN; itemNo++)
			trainPreds[itemNo] += shrinkage * tree.predict(itemNo, TRAIN);
		for(int itemNo = 0; itemNo < validN; itemNo++)
			validPreds[itemNo] += shrinkage * tree.predict(itemNo, VALID);

		//output
		frmscurve.open("boosting_rms.txt", ios_base::out | ios_base::app); 
		frmscurve << rmse(validPreds, validTar) << endl;
		frmscurve.close();
		
		if(!ti.rms)
		{
			froccurve.open("boosting_roc.txt", ios_base::out | ios_base::app); 
			froccurve << roc(validPreds, validTar) << endl;
			froccurve.close();
		}

	}

	//output feature selection results
	if(doFS)
	{
		sort(attrCounts.begin(), attrCounts.end(), idGreater);
		if(topAttrN > attrN)
			topAttrN = attrN;

		fstream ffeatures("feature_scores.txt", ios_base::out);
		ffeatures << "Top " << topAttrN << " features\n";
		for(int attrNo = 0; attrNo < topAttrN; attrNo++)
			ffeatures << data.getAttrName(attrCounts[attrNo].first) << "\t"
			<< attrCounts[attrNo].second / ti.bagN / trainN << "\n";
		ffeatures << "\n\nColumn numbers (beginning with 1)\n";
		for(int attrNo = 0; attrNo < topAttrN; attrNo++)
			ffeatures << data.getColNo(attrCounts[attrNo].first) + 1 << " ";
		ffeatures << "\nLabel column number: " << data.getTarColNo() + 1;
		ffeatures.close();

		//output new attribute file
		for(int attrNo = topAttrN; attrNo < attrN; attrNo++)
			data.ignoreAttr(attrCounts[attrNo].first);
		data.outAttr(ti.attrFName);
	}

	//output predictions
	fstream fpreds;
	fpreds.open("preds.txt", ios_base::out);
	for(int itemNo = 0; itemNo < validN; itemNo++)
		fpreds << validPreds[itemNo] << endl;
	fpreds.close();

//------------------

	}catch(TE_ERROR err){
コード例 #2
0
ファイル: bt_train.cpp プロジェクト: Jinyoyyo/additive-groves
int main(int argc, char* argv[])
{	
	try{

//1. Analyze input parameters
	//convert input parameters to string from char*
	stringv args(argc); 
	for(int argNo = 0; argNo < argc; argNo++)
		args[argNo] = string(argv[argNo]);
	
	//check that the number of arguments is even (flags + value pairs)
	if(argc % 2 == 0)
		throw INPUT_ERR;

#ifndef _WIN32
	int threadN = 6;	//number of threads
#endif

	TrainInfo ti; //model training parameters
	string modelFName = "model.bin";	//name of the output file for the model
	int topAttrN = 0;  //how many top attributes to output and keep in the cut data 
							//(0 = do not do feature selection)
							//(-1 = output all available features)
	bool doOut = true; //whether to output log information to stdout

	//parse and save input parameters
	//indicators of presence of required flags in the input
	bool hasTrain = false;
	bool hasVal = false; 
	bool hasAttr = false; 

	for(int argNo = 1; argNo < argc; argNo += 2)
	{
		if(!args[argNo].compare("-t"))
		{
			ti.trainFName = args[argNo + 1];
			hasTrain = true;
		}
		else if(!args[argNo].compare("-v"))
		{
			ti.validFName = args[argNo + 1];
			hasVal = true;
		}
		else if(!args[argNo].compare("-r"))
		{
			ti.attrFName = args[argNo + 1];
			hasAttr = true;
		}
		else if(!args[argNo].compare("-a"))
			ti.alpha = atofExt(argv[argNo + 1]);
		else if(!args[argNo].compare("-b"))
			ti.bagN = atoiExt(argv[argNo + 1]);
		else if(!args[argNo].compare("-i"))
			ti.seed = atoiExt(argv[argNo + 1]);
		else if(!args[argNo].compare("-k"))
			topAttrN = atoiExt(argv[argNo + 1]);
		else if(!args[argNo].compare("-m"))
		{
			modelFName = args[argNo + 1];
			if(modelFName.empty())
				throw EMPTY_MODEL_NAME_ERR;
		}
		else if(!args[argNo].compare("-l"))
		{
			if(!args[argNo + 1].compare("log"))
				doOut = true;
			else if(!args[argNo + 1].compare("nolog"))
				doOut = false;
			else
				throw INPUT_ERR;
		}
		else if(!args[argNo].compare("-c"))
		{
			if(!args[argNo + 1].compare("roc"))
				ti.rms = false;
			else if(!args[argNo + 1].compare("rms"))
				ti.rms = true;
			else
				throw INPUT_ERR;
		}
		else if(!args[argNo].compare("-h"))
#ifndef _WIN32 
			threadN = atoiExt(argv[argNo + 1]);
#else
			throw WIN_ERR;
#endif
		else
			throw INPUT_ERR;
	}//end for(int argNo = 1; argNo < argc; argNo += 2) //parse and save input parameters

	if(!(hasTrain && hasVal && hasAttr))
		throw INPUT_ERR;

	if((ti.alpha < 0) || (ti.alpha > 1))
		throw ALPHA_ERR;
	
//1.a) Set log file
	LogStream clog;
	LogStream::init(doOut);
	clog << "\n-----\nbt_train ";
	for(int argNo = 1; argNo < argc; argNo++)
		clog << argv[argNo] << " ";
	clog << "\n\n";

//1.b) Initialize random number generator. 
	srand(ti.seed);

//2. Load data
	INDdata data(ti.trainFName.c_str(), ti.validFName.c_str(), ti.testFName.c_str(), 
				 ti.attrFName.c_str());
	CTree::setData(data);
	CTreeNode::setData(data);

//2.a) Start thread pool
#ifndef _WIN32
	TThreadPool pool(threadN);
	CTree::setPool(pool);
#endif

//3. Train models
	doublev validTar;
	int validN = data.getTargets(validTar, VALID);
	int itemN = data.getTrainN();

	//adjust minAlpha, if needed
	double newAlpha = adjustAlpha(ti.alpha, itemN);
	if(ti.alpha != newAlpha)
	{
		if(newAlpha == 0)
			clog << "Warning: due to small train set size value of alpha was changed to 0"; 
		else 
			clog << "Warning: alpha value was rounded to the closest valid value " << newAlpha;
		clog << ".\n\n";
		ti.alpha = newAlpha;	
	}
	clog << "Alpha = " << ti.alpha << "\n" 
		<< ti.bagN << " bagging iterations\n";

	doublev rmsV(ti.bagN, 0); 				//bagging curve of rms values for validation set
	doublev rocV;							 
	if(!ti.rms)
		rocV.resize(ti.bagN, 0);			//bagging curve of roc values for validation set
	doublev predsumsV(validN, 0); 			//sums of predictions for each data point

	int attrN = data.getAttrN();
	if(topAttrN == -1)
		topAttrN = attrN;
	idpairv attrCounts;	//counts of attribute importance
	bool doFS = (topAttrN != 0);	//whether feature selection is requested
	if(doFS)
	{//initialize attrCounts
		attrCounts.resize(attrN);
		for(int attrNo = 0; attrNo < attrN; attrNo++)
		{
			attrCounts[attrNo].first = attrNo;	//number of attribute	
			attrCounts[attrNo].second = 0;		//counts
		}
	}
	fstream fmodel(modelFName.c_str(), ios_base::binary | ios_base::out);
	//header for compatibility with Additive Groves model
	AG_TRAIN_MODE modeStub = SLOW;
	fmodel.write((char*) &modeStub, sizeof(enum AG_TRAIN_MODE));
	int tigNStub = 1;
	fmodel.write((char*) &tigNStub, sizeof(int));
	fmodel.write((char*) &ti.alpha, sizeof(double));
	fmodel.close();
	
	fstream fbagrms("bagging_rms.txt", ios_base::out); //bagging curve (rms)
	fbagrms.close();
	fstream fbagroc;
	if(!ti.rms)
	{
		fbagroc.open("bagging_roc.txt", ios_base::out); //bagging curve (roc) 
		fbagroc.close();
	}

	//make bags, build trees, collect predictions
	for(int bagNo = 0; bagNo < ti.bagN; bagNo++)
	{
		if(doOut)
			cout << "Iteration " << bagNo + 1 << " out of " << ti.bagN << endl;

		data.newBag();
		CTree tree(ti.alpha);
		tree.setRoot();
		tree.grow(doFS, attrCounts);
		tree.save(modelFName.c_str());

		//generate predictions for validation set
		doublev predictions(validN);
		for(int itemNo = 0; itemNo < validN; itemNo++)
		{
			predsumsV[itemNo] += tree.predict(itemNo, VALID);
			predictions[itemNo] = predsumsV[itemNo] / (bagNo + 1);
		}
		rmsV[bagNo] = rmse(predictions, validTar);
		if(!ti.rms)
			rocV[bagNo] = roc(predictions, validTar);

		//output an element of bagging curve 
		fbagrms.open("bagging_rms.txt", ios_base::out | ios_base::app); 
		fbagrms << rmsV[bagNo] << endl;
		fbagrms.close();

		//same for roc, if needed
		if(!ti.rms)
		{
			fbagroc.open("bagging_roc.txt", ios_base::out | ios_base::app); 
			fbagroc << rocV[bagNo] << endl;
			fbagroc.close();
		}
	}

	if(doFS)	//sort attributes by counts
		sort(attrCounts.begin(), attrCounts.end(), idGreater);
	
//4. Output
		
	//output results and recommendations
	if(ti.rms)
		clog << "RMSE on validation set = " << rmsV[ti.bagN - 1] << "\n";
	else
		clog << "ROC on validation set = " << rocV[ti.bagN - 1] << "\n";


	//analyze whether more bagging should be recommended based on the curve in the best point
	if(moreBag(rmsV))
	{
		int recBagN = ti.bagN + 100;
		clog << "\nRecommendation: a greater number of bagging iterations might produce a better model.\n"
			<< "Suggested action: bt_train -b " << recBagN << "\n";
	}
	else
		clog << "\nThe bagging curve shows good convergence. \n"; 
	clog << "\n";

	//standard output in case of turned off log output: final performance on validation set only
	if(!doOut)
		if(ti.rms)
			cout << rmsV[ti.bagN - 1] << endl;
		else
			cout << rocV[ti.bagN - 1] << endl;

	//output feature selection results
	if(doFS)
	{
		if(topAttrN > attrN)
			topAttrN = attrN;

		fstream ffeatures("feature_scores.txt", ios_base::out);	
		ffeatures << "Top " << topAttrN << " features\n";
		for(int attrNo = 0; attrNo < topAttrN; attrNo++)
			ffeatures << data.getAttrName(attrCounts[attrNo].first) << "\t" 
				<< attrCounts[attrNo].second / ti.bagN / itemN << "\n";
		ffeatures << "\n\nColumn numbers (beginning with 1)\n";
		for(int attrNo = 0; attrNo < topAttrN; attrNo++)
			ffeatures << data.getColNo(attrCounts[attrNo].first) + 1 << " ";
		ffeatures << "\nLabel column number: " << data.getTarColNo() + 1;
		ffeatures.close();

		//output new attribute file
		for(int attrNo = topAttrN; attrNo < attrN; attrNo++)
			data.ignoreAttr(attrCounts[attrNo].first);
		data.outAttr(ti.attrFName);
	}

	}catch(TE_ERROR err){