LabelledRegressionData LabelledRegressionData::getTrainingFoldData(const UINT foldIndex) const{ LabelledRegressionData trainingData; if( !crossValidationSetup ){ errorLog << "getTrainingFoldData(UINT foldIndex) - Cross Validation has not been setup! You need to call the spiltDataIntoKFolds(UINT K,bool useStratifiedSampling) function first before calling this function!" << endl; return trainingData; } if( foldIndex >= kFoldValue ) return trainingData; trainingData.setInputAndTargetDimensions(numInputDimensions, numTargetDimensions); //Add the data to the training set, this will consist of all the data that is NOT in the foldIndex UINT index = 0; for(UINT k=0; k<kFoldValue; k++){ if( k != foldIndex ){ for(UINT i=0; i<crossValidationIndexs[k].size(); i++){ index = crossValidationIndexs[k][i]; trainingData.addSample( data[ index ].getInputVector(), data[ index ].getTargetVector() ); } } } return trainingData; }
LabelledRegressionData LabelledClassificationData::reformatAsLabelledRegressionData() const{ //Turns the classification into a regression data to enable regression algorithms like the MLP to be used as a classifier //This sets the number of targets in the regression data equal to the number of classes in the classification data //The output of each regression training sample will then be all 0's, except for the index matching the classLabel, which will be 1 //For this to work, the labelled classification data cannot have any samples with a classLabel of 0! LabelledRegressionData regressionData; if( totalNumSamples == 0 ){ return regressionData; } const UINT numInputDimensions = numDimensions; const UINT numTargetDimensions = getNumClasses(); regressionData.setInputAndTargetDimensions(numInputDimensions, numTargetDimensions); for(UINT i=0; i<totalNumSamples; i++){ VectorDouble targetVector(numTargetDimensions,0); //Set the class index in the target vector to 1 and all other values in the target vector to 0 UINT classLabel = data[i].getClassLabel(); if( classLabel > 0 ){ targetVector[ classLabel-1 ] = 1; }else{ regressionData.clear(); return regressionData; } regressionData.addSample(data[i].getSample(),targetVector); } return regressionData; }
LabelledRegressionData LabelledRegressionData::getTestFoldData(const UINT foldIndex) const{ LabelledRegressionData testData; if( !crossValidationSetup ) return testData; if( foldIndex >= kFoldValue ) return testData; //Add the data to the training testData.setInputAndTargetDimensions(numInputDimensions, numTargetDimensions); UINT index = 0; for(UINT i=0; i<crossValidationIndexs[ foldIndex ].size(); i++){ index = crossValidationIndexs[ foldIndex ][i]; testData.addSample( data[ index ].getInputVector(), data[ index ].getTargetVector() ); } return testData; }