Exemplo n.º 1
0
void   JobValidation::EvaluateNode ()
{
  log.Level (9) << "  " << endl;
  log.Level (9) << "JobValidation::EvaluteNode JobId[" << jobId << "]" << endl;
  status = BinaryJobStatus::Started;

  bool  configFileFormatGood = true;
  
  TrainingConfiguration2Ptr  config = new TrainingConfiguration2 ();
  config->Load (configFileName, false, log);
  if  (!config->FormatGood ())
    configFileFormatGood;

  config->SetFeatureNums (features);
  config->C_Param (cParm);
  config->Gamma   (gammaParm);
  config->A_Param (aParm);
  config->SelectionMethod (processor->SelectionMethod ());
  
  switch  (processor->ResultType ())
  {
  case  FinalResultType::MfsFeaturesSel:
  case  FinalResultType::NoTuningAllFeatures:
  case  FinalResultType::MfsParmsTuned: 
  case  FinalResultType::MfsParmsTunedFeaturesSel: 
           config->MachineType (SVM_MachineType::OneVsOne);
           break;
    
  case  FinalResultType::BfsFeaturesSel:
  case  FinalResultType::BfsParmsTuned:
  case  FinalResultType::BfsFeaturesSelParmsTuned:
           config->MachineType (SVM_MachineType::BinaryCombos);
           break;
  }

  bool  cancelFlag = false;

  FeatureVectorListPtr  trainData       = processor->TrainingData ();
  FeatureVectorListPtr  validationData  = processor->ValidationData ();

  VectorDouble  trainDataMeans      = trainData->ExtractMeanFeatureValues ();
  VectorDouble  validationDataMeans = validationData->ExtractMeanFeatureValues ();


  CrossValidationPtr  crossValidation = new CrossValidation  
                                           (config,
                                            trainData,
                                            processor->MLClasses (),
                                            processor->NumOfFolds (),
                                            processor->AlreadyNormalized (),
                                            processor->FileDesc (),
                                            log,
                                            cancelFlag
                                           );

  delete  classedCorrectly;
  classedCorrectlySize = validationData->QueueSize ();
  classedCorrectly = new bool[classedCorrectlySize];

  crossValidation->RunValidationOnly (validationData, classedCorrectly, log);

  testAccuracy      = crossValidation->Accuracy ();
  testAccuracyNorm  = crossValidation->AccuracyNorm ();
  testAvgPredProb   = (float)crossValidation->AvgPredProb () * 100.0f;
  testFMeasure      = (float)crossValidation->ConfussionMatrix ()->FMeasure (processor->PositiveClass (), log);

  if  (processor->GradingMethod () == GradingMethodType::Accuracy)
    testGrade = testAccuracy;

  else if  (processor->GradingMethod () == GradingMethodType::AccuracyNorm)
    testGrade = testAccuracyNorm;

  else if  (processor->GradingMethod () == GradingMethodType::FMeasure)
    testGrade = testFMeasure;

  else
    testGrade = testAccuracy;

  testNumSVs  = crossValidation->NumOfSupportVectors ();

  {
    // Save results of this Split in Results file.
    processor->Block ();

    {
      uint  fn = 0;
      ofstream rl ("FinalResults.log", ios_base::app);
      rl << endl << endl
         << "ConfigFileName"          << "\t" << configFileName  << "\t" << "Format Good[" << (configFileFormatGood ? "Yes" : "No") << endl
         << "SummaryResultsFileName"  << "\t" << processor->SummaryResultsFileName () << endl
         << "Configuration CmdLine"   << "\t" << config->SVMparamREF (log).ToString ()   << endl
         << "ImagesPerClass"          << "\t" << config->ImagesPerClass ()            << endl
         << endl;

      rl << endl << endl
         << "Training Data Status" << endl
         << endl;
      trainData->PrintClassStatistics (rl);
      rl << endl << endl;


      rl << "TrainingDataMeans";
      for  (fn = 0;  fn < trainDataMeans.size ();  fn++)
        rl << "\t" << trainDataMeans[fn];
      rl << endl;

      rl << "ValidationDataMeans";
      for  (fn = 0;  fn < validationDataMeans.size ();  fn++)
        rl << "\t" << validationDataMeans[fn];
      rl << endl
         << endl;

      crossValidation->ConfussionMatrix ()->PrintConfusionMatrixTabDelimited (rl);
      rl << endl << endl << endl << endl;
      rl.close ();
    }

    {
      ofstream  f (processor->SummaryResultsFileName ().Str (), ios_base::app);
      ValidationResults r (processor->ResultType (), 
                           config, 
                           crossValidation,
                           trainData,
                           osGetHostName ().value_or ("*** unknown ***"),
                           classedCorrectlySize,
                           classedCorrectly,
                           this,
                           log
                          );
      r.Write (f);
      f.close ();
    }
    processor->EndBlock ();
  }

  delete  crossValidation;     crossValidation    = NULL;
  delete  config;              config = NULL;
  status = BinaryJobStatus::Done;
}  /* EvaluateNode */