bool KfoldTimeSeriesData::spiltDataIntoKFolds(const GRT::UINT K) { kFoldValue = K; //K can not be zero if( K == 0 ){ std::cout << "spiltDataIntoKFolds(UINT K) - K can not be zero!" << std::endl; return false; } //K can not be larger than the number of examples if( K > inputDataset.getNumSamples()){ std::cout << "spiltDataIntoKFolds(UINT K,bool useStratifiedSampling) - K can not be 0!" << std::endl; return false; } //K can not be larger than the number of examples in a specific class if the stratified sampling option is true for(UINT c=0; c < inputDataset.getNumClasses(); c++) { if( K > classTracker[c].counter ){ cout << "spiltDataIntoKFolds(UINT K,bool useStratifiedSampling) - K can not be larger than the number of samples in any given class!" << std::endl; return false; } } //Setup the dataset for k-fold cross validation kFoldValue = K; vector< UINT > indexs( inputDataset.getNumSamples() ); //Work out how many samples are in each fold, the last fold might have more samples than the others UINT numSamplesPerFold = (UINT) floor( inputDataset.getNumSamples() / double(K) ); //Create the random partion indexs Random random; UINT randomIndex = 0; //Break the data into seperate classes vector< vector< UINT > > classData( inputDataset.getNumClasses() ); //Add the indexs to their respective classes for(UINT i = 0; i < inputDataset.getNumSamples(); i++) { classData[ inputDataset.getClassLabelIndexValue( inputDataset[i].getClassLabel() ) ].push_back( i ); } //Randomize the order of the indexs in each of the class index buffers for(UINT c = 0; c < inputDataset.getNumClasses(); c++) { UINT numSamples = (UINT)classData[c].size(); for(UINT x = 0; x < numSamples; x++) { //Pick a random index randomIndex = random.getRandomNumberInt(0, numSamples); //Swap the indexs SWAP( classData[c][ x ] , classData[c][ randomIndex ] ); } } //Resize the cross validation indexs buffer crossValidationIndexs.resize( K ); for (UINT k = 0; k < K; k++) { crossValidationIndexs[k].resize(inputDataset.getNumClasses()); } //Loop over each of the classes and add the data equally to each of the k folds until there is no data left vector< UINT >::iterator iter; for(UINT c = 0; c < inputDataset.getNumClasses(); c++){ iter = classData[ c ].begin(); UINT k = 0; while( iter != classData[c].end() ){ crossValidationIndexs[ k ][c].push_back( *iter ); iter++; k = ++k % K; } } crossValidationSetup = true; return true; }