Exemplo n.º 1
0
bool KfoldTimeSeriesData::spiltDataIntoKFolds(const GRT::UINT K) {

	kFoldValue = K;

    //K can not be zero
    if( K == 0 ){
        std::cout << "spiltDataIntoKFolds(UINT K) - K can not be zero!" << std::endl;
        return false;
    }

    //K can not be larger than the number of examples
    if( K  > inputDataset.getNumSamples()){
        std::cout << "spiltDataIntoKFolds(UINT K,bool useStratifiedSampling) - K can not be 0!" << std::endl;
        return false;
    }

    //K can not be larger than the number of examples in a specific class if the stratified sampling option is true
	for(UINT c=0; c < inputDataset.getNumClasses(); c++) {
		if( K > classTracker[c].counter ){
			cout << "spiltDataIntoKFolds(UINT K,bool useStratifiedSampling) - K can not be larger than the number of samples in any given class!" << std::endl;
			return false;
		}
	}


    //Setup the dataset for k-fold cross validation
    kFoldValue = K;
    vector< UINT > indexs( inputDataset.getNumSamples() );

    //Work out how many samples are in each fold, the last fold might have more samples than the others
    UINT numSamplesPerFold = (UINT) floor( inputDataset.getNumSamples() / double(K) );

    //Create the random partion indexs
    Random random;
    UINT randomIndex = 0;

	//Break the data into seperate classes
	vector< vector< UINT > > classData( inputDataset.getNumClasses() );

	//Add the indexs to their respective classes
	for(UINT i = 0; i < inputDataset.getNumSamples(); i++) {
		classData[ inputDataset.getClassLabelIndexValue(
				inputDataset[i].getClassLabel() ) ].push_back( i );
	}

	//Randomize the order of the indexs in each of the class index buffers
	for(UINT c = 0; c < inputDataset.getNumClasses(); c++) {
		UINT numSamples = (UINT)classData[c].size();
		for(UINT x = 0; x < numSamples; x++) {
			//Pick a random index
			randomIndex = random.getRandomNumberInt(0, numSamples);

			//Swap the indexs
			SWAP( classData[c][ x ] , classData[c][ randomIndex ] );
		}
	}


    //Resize the cross validation indexs buffer
    crossValidationIndexs.resize( K );
    for (UINT k = 0; k < K; k++) {
    	crossValidationIndexs[k].resize(inputDataset.getNumClasses());
    }

    //Loop over each of the classes and add the data equally to each of the k folds until there is no data left
    vector< UINT >::iterator iter;
    for(UINT c = 0; c < inputDataset.getNumClasses(); c++){
        iter = classData[ c ].begin();
        UINT k = 0;
        while( iter != classData[c].end() ){
            crossValidationIndexs[ k ][c].push_back( *iter );
            iter++;
            k = ++k % K;
        }
    }

    crossValidationSetup = true;
    return true;

}