예제 #1
0
int main()
{
    //Change the default position of files!!!
    string files_dir = "/home/shushman/Desktop/Codes/ml-project/files/";

    //Change the default position of images!
    string image_dir = "/home/shushman/Desktop/Codes/ml-project/test-images/";

    string vocabpath,indexpath,labelpath;
    keypath = files_dir+"keypoints";

    vocabpath = files_dir+"vocabulary.yml";
    FileStorage fs_vocab(vocabpath, FileStorage::READ );
    cout<<"vocabulary file loaded!"<<endl;

    indexpath = files_dir+"inverted_index.yml";
    FileStorage fs_inv(indexpath, FileStorage::READ );
    cout<<"index file loaded!"<<endl;

    labelpath = files_dir+"labels.yml";
    FileStorage fs_labels(labelpath, FileStorage::READ );
    cout<<"labels file loaded!"<<endl;

    Mat train_mat;
    Mat label_mat;
    fs_vocab["vocabulary"] >> vocabulary;//The training data matrix
    vocabulary.convertTo(vocabulary,CV_32F);
    fs_labels["labels"] >> labels;//The label data matrix

    for(int i=0;i<VOCAB_SIZE;i++)
    {
        char n[50];
        sprintf(n,"inverted_index_%d",i);
        fs_inv[n] >> inverted_index[i];
    }

    fs_labels.release();
    fs_inv.release();
    fs_vocab.release();

    bowDE.setVocabulary(vocabulary);

    char ch;
    string image_fname;

    do{
        cout<<"Enter the name of the image to test: "<<endl;
        cin>>image_fname;
        Mat img = imread(image_dir+image_fname,1);
        int result = test(img);
        if(result>0)
            cout<<"Result is "<<bill_values[result]<<endl;
        else
            cout<<"Can't say!"<<endl;
        cout<<"Do you wish to continue? (y/n)";
        cin>>ch;
    }while(ch!='n');

    return 0;
}
void loadVocabulary()
{
	FileStorage fs(DATABASE_FILENAME, FileStorage::READ);
	Mat vocabulary;
	fs["Vocabulary"] >> vocabulary;
	bowExtractor.setVocabulary(vocabulary);
	fs.release();
}
void TestIndividualImage(string imagePath, string categoryName, int database_type)
{
	//Load image 
	Mat testImage = imread(imagePath);

	//Load database
	vector<Mat> trainData, trainLabels;
	vector<float> videoFPS;
	vector<string> _listDataName = SetupModel(categoryName, trainData, trainLabels, videoFPS, database_type);

	//Load dictionary for BOW model
	if (database_type == 1)
		bowDE.setVocabulary(LoadBOWDictionaryFromFile("Data/Training_Data/" + categoryName + "_BOW_dictionary.xml"));

	clock_t t;

	t = clock();

	//---Extract feature based on database_type
	Mat queryFeature;
	if (database_type == 1)
		queryFeature = ExtractBOWFeature(bowDE, detector, testImage);
	else
		queryFeature = ExtractMPEGFeature(testImage);

	//For each test image, retrieve list of video and check if there's any correct video in candidate list
	vector<int> predictLabels = RetrieveVideo(_listDataName.size(), queryFeature, trainData);
	if (predictLabels.size() > 0)
	{
		cout << "This image can belong to: " << endl;
		for (int j = 0; j < 3; j++)
		{
			int predictLabel = predictLabels[j];
			cout << GetName(_listDataName[predictLabel]) << endl;

//			string videoName = "Data/Raws/" + categoryName + "/" + _listRawClass[predictLabel];
			RetrieveShot(videoFPS[predictLabel], trainData[predictLabel], trainLabels[predictLabel], queryFeature);
		}
		cout << endl;
	}
	else
	{
		cout << "No video match the query image" << endl << endl;
	}

	t = clock() - t;

	cout << "It took " << ((float)t) / CLOCKS_PER_SEC << " seconds to complete all tasks" << endl;
}
void TestShotRetrieval(string categoryName, int database_type, int num_retrieve)
{
	string resultFile = "Data/Result_" + categoryName + "_ShotRetrieval.txt";
	ofstream out(resultFile);

	//Load database
	vector<Mat> trainData, trainLabels;
	vector<float> videoFPS;
	vector<string> _listDataName = SetupModel(categoryName, trainData, trainLabels, videoFPS, database_type);
	
	//Load dictionary for BOW model
	if (database_type == 1)
		bowDE.setVocabulary(LoadBOWDictionaryFromFile("Data/Training_Data/" + categoryName + "_BOW_dictionary.xml"));

	//Get list of test frames
	string pathTest = "Data/Test_images/" + categoryName + "/";
	vector<string> _listTestClass = ReadFileList(pathTest);


	vector<vector<float>> _listWeightScore;
	for (int index = 0; index < _listTestClass.size(); index++)
	{
		vector<float> avgWeigthScore;
		avgWeigthScore.resize(num_retrieve);

		string pathClass = pathTest + _listTestClass[index] + "/";
		vector<string> _listTestImage = ReadFileList(pathClass);
		for (int i = 0; i < _listTestImage.size(); i++)
		{
			string testpath = pathClass + _listTestImage[i];
			Mat testImage = imread(testpath);

			//---Extract feature based on database_type
			Mat queryFeature;
			if (database_type == 1)
				queryFeature = ExtractBOWFeature(bowDE, detector, testImage);
			else
				queryFeature = ExtractMPEGFeature(testImage);

			//For each test image, retrieve list of shot and check if there's any correct shot in candidate list
			int trueShotID = IdentifyShotFromKeyFrame(_listTestImage[i]);

			cout << "Retrieve shot in video " << GetName(_listDataName[index]) << " using " << _listTestImage[i] << endl;
			vector<int> _listShotID = RetrieveShot(videoFPS[index], trainData[index], trainLabels[index], queryFeature);

			//Calculate precision and recall
			for (int numRetrieve = 0; numRetrieve < num_retrieve; numRetrieve++)
			{
				float weightScore = ShotRetrievalPerformance(_listShotID, trueShotID, numRetrieve+1);
				avgWeigthScore[numRetrieve] += weightScore;
			}

			cout << endl;
		}

		for (int i = 0; i < num_retrieve; i++)
		{
			avgWeigthScore[i] /= _listTestImage.size();
		}

		_listWeightScore.push_back(avgWeigthScore);
	}

	vector<float> avgWeightScorePerNumRetrieval;
	avgWeightScorePerNumRetrieval.resize(num_retrieve);
	for (int i = 0; i < _listWeightScore.size(); i++)
	{
		float avgPrecision = 0.0f;
		out << _listTestClass[i].c_str() << endl;
		for (int j = 0; j < num_retrieve; j++)
		{
			out << j+1 << ". " << _listWeightScore[i][j] << endl;
			avgWeightScorePerNumRetrieval[j] += _listWeightScore[i][j];
		}
		out << endl;
	}

	out << endl << "Average weight score: " << endl;
	for (int i = 0; i < num_retrieve; i++)
	{ 
		avgWeightScorePerNumRetrieval[i] /= (float)_listWeightScore.size();
		out << i + 1 << ". " << avgWeightScorePerNumRetrieval[i] << endl;
	}

	out.close();
}
void TestVideoRetrieval(string categoryName, int database_type)
{
	//Load database
	vector<Mat> trainData, trainLabels;
	vector<float> videoFPS;
	vector<string> _listDataName = SetupModel(categoryName, trainData, trainLabels, videoFPS, database_type);

	//Load dictionary for BOW model
	if (database_type == 1)
		bowDE.setVocabulary(LoadBOWDictionaryFromFile("Data/Training_Data/" + categoryName + "_BOW_dictionary.xml"));

	float avgReciprocalRank = 0.0f;
	float avgPrecision = 0.0f;
	float avgRecall = 0.0f;
	string pathTest = "Data/Test_images/" + categoryName + "/";
	vector<string> _listTestClass = ReadFileList(pathTest);

	int countTotal = 0;
	for (int index = 0; index < _listTestClass.size(); index++)
	{
		float avgClassRR = 0.0f;

		string pathClass = pathTest + _listTestClass[index] + "/";
		vector<string> _listTestImage = ReadFileList(pathClass);
		for (int i = 0; i < _listTestImage.size(); i++)
		{
			string testpath = pathClass + _listTestImage[i];
			Mat testImage = imread(testpath);

			//---Extract feature based on database_type
			Mat queryFeature;
			if (database_type == 1)
				queryFeature = ExtractBOWFeature(bowDE, detector, testImage);
			else
				queryFeature = ExtractMPEGFeature(testImage);

			//For each test image, retrieve list of video and check if there's any correct video in candidate list
			int countMatch = 0;
			int countMatchRank = 0;
			int countRetrievedMatch = 0;
			int numMatchRetrieve = 3;
			vector<int> predictLabels = RetrieveVideo(_listTestClass.size(), queryFeature, trainData);

			//Calculate precision and recall from retrieved list
			if (predictLabels.size() > 0)
			{
				cout << _listTestImage[i] << " in video " << _listTestClass[index] << " can belong to: " << endl;
				for (int j = 0; j < predictLabels.size(); j++)
				{
					int predictLabel = predictLabels[j];
					cout << _listTestClass[predictLabel] << endl;

					countMatchRank++;
					if (predictLabel == index)
					{
						if (j < numMatchRetrieve)
						{
							countRetrievedMatch++;
						}
						countMatch++;
						break;
					}
				}
				cout << endl;
			}
			else
			{
				cout << "No video match the query image" << endl << endl;
			}

			float reciprocalrank = 0.0f;
			float precision = 0.0f;
			float recall = 0.0f;
			if (predictLabels.size() > 0)
			{
				reciprocalrank = (float)countMatch / (float)countMatchRank;
				precision = (float)countRetrievedMatch / (float)numMatchRetrieve;
				recall = countRetrievedMatch;
			}
				
			avgClassRR += reciprocalrank;
			avgPrecision += precision;
			avgRecall += recall;
		}

		countTotal += _listTestImage.size();

		avgReciprocalRank += avgClassRR;
	}

	avgReciprocalRank = avgReciprocalRank / countTotal * 100.0f;
	avgPrecision = avgPrecision / countTotal * 100.0f;
	avgRecall = avgRecall / countTotal * 100.0f;

	cout << "MAP = " << avgReciprocalRank << endl;
	cout << "Average Precision = " << avgPrecision << endl;
	cout << "Average Recall = " << avgRecall << endl;
}
void clusterFeatures()
{
	Mat dictionary = bowTrainer.cluster();
	bowExtractor.setVocabulary(dictionary);
}