コード例 #1
0
ファイル: getneighbor_main.cpp プロジェクト: pswgoo/lshtc
int SaveNeighbor()
{
	int rtn = 0;
	LhtcDocumentSet lshtcTrainSet, lshtcTestSet;
	UniGramFeature uniGrams;
	string trainsetFile = "../data/loc_train.bin";
	string testsetFile = "../data/loc_test.bin";
	vector<Feature> lshtcTrainFeatureSet, lshtcTestFeatureSet;
	vector<int> lshtcTrainFeatureID, lshtcTestFeatureID;
	Feature tempFeature;
	lshtcTrainFeatureSet.clear();
	lshtcTestFeatureSet.clear();
	lshtcTrainFeatureID.clear();
	lshtcTestFeatureID.clear();

	clog << "Load Unigram Dictionary" << endl;
	rtn = uniGrams.Load("lshtc_unigram_dictionary_loctrain.bin");
	CHECK_RTN(rtn);
	clog << "Total " << uniGrams.mDictionary.size() << " unigrams" << endl;

	rtn = lshtcTrainSet.LoadBin(trainsetFile, FULL_LOG);
	CHECK_RTN(rtn);

	int trainSize = (int)lshtcTrainSet.Size();
	for (std::map<int, LhtcDocument>::iterator it = lshtcTrainSet.mLhtcDocuments.begin(); it != lshtcTrainSet.mLhtcDocuments.end(); ++it)
		lshtcTrainFeatureID.push_back(it->first);

	vector<LhtcDocument*> vecTrainDocument;
	vecTrainDocument.reserve(lshtcTrainSet.Size());
	for (map<int, LhtcDocument>::iterator it = lshtcTrainSet.mLhtcDocuments.begin(); it != lshtcTrainSet.mLhtcDocuments.end(); ++it)
		vecTrainDocument.push_back(&(it->second));

	clog << "Prepare for Extract Features" << endl;
	FeatureSet allTrainFeatures;
	allTrainFeatures.mFeatures.resize(vecTrainDocument.size());
#pragma omp parallel for schedule(dynamic)
	for (int i = 0; i < (int)vecTrainDocument.size(); i++)
	{
		uniGrams.ExtractLhtc(*vecTrainDocument[i], allTrainFeatures.mFeatures[i]);
		if (allTrainFeatures.mFeatures[i].size() == 0) printf("%d Warning!!\n", i);
	}
	allTrainFeatures.Normalize();//get traindata feature

	rtn = lshtcTestSet.LoadBin(testsetFile, FULL_LOG);
	CHECK_RTN(rtn);

	int testSize = (int)lshtcTestSet.Size();
	for (std::map<int, LhtcDocument>::iterator it = lshtcTestSet.mLhtcDocuments.begin(); it != lshtcTestSet.mLhtcDocuments.end(); ++it)
		lshtcTestFeatureID.push_back(it->first);

	vector<LhtcDocument*> vecTestDocument;
	vecTestDocument.reserve(lshtcTestSet.Size());
	for (map<int, LhtcDocument>::iterator it = lshtcTestSet.mLhtcDocuments.begin(); it != lshtcTestSet.mLhtcDocuments.end(); ++it)
		vecTestDocument.push_back(&(it->second));

	clog << "Prepare for Extract Features" << endl;
	FeatureSet allTestFeatures;
	allTestFeatures.mFeatures.resize(vecTestDocument.size());
#pragma omp parallel for schedule(dynamic)
	for (int i = 0; i < (int)vecTestDocument.size(); i++)
	{
		uniGrams.ExtractLhtc(*vecTestDocument[i], allTestFeatures.mFeatures[i]);
		if (allTestFeatures.mFeatures[i].size() == 0) printf("%d Warning!!\n", i);
	}
	allTestFeatures.Normalize();//get testdata feature

	int sigSize = allTestFeatures.Size() / 5;
	for (int i = 0; i < 5; ++i)
	{
		string filename = "../data/lshtc_neighbor" + intToString(i) + ".bin";
		if (FileExist(filename))
			continue;
		clog << i << "th, sigSize = " << sigSize << endl;
		FeatureSet locFeatures;
		vector<int> locIds;
		for (int j = sigSize*i; j < sigSize*(i + 1); ++j)
		{
			locFeatures.AddInstance(allTestFeatures[j]);
			locIds.push_back(lshtcTestFeatureID[j]);
		}
		FeatureNeighbor featureneighbor;
		rtn = featureneighbor.Build(allTrainFeatures.mFeatures, locFeatures.mFeatures, lshtcTrainFeatureID, locIds);
		CHECK_RTN(rtn);

		rtn = featureneighbor.SaveBin(filename, STATUS_ONLY);
		CHECK_RTN(rtn);
		clog << "Save bin completed" << endl;
	}

	return 0;
}