Exemplo n.º 1
0
int   RandomSampleJobList::DetermineCompressedImageCount (FeatureVectorListPtr       trainData,
                                                          TrainingConfigurationPtr   config
                                                         )
{
  FileDescPtr  fileDesc = trainData->FileDesc ();
  FeatureVectorListPtr  srcImages = trainData->DuplicateListAndContents ();
  FeatureVectorListPtr  imagesToTrain = new FeatureVectorList (fileDesc, false, log, 10000);

  MLClassListPtr mlClasses = srcImages->ExtractListOfClasses ();

  {
    MLClassList::const_iterator  idx;

    for  (idx = mlClasses->begin ();  idx != mlClasses->end ();  idx++)
    {
      MLClassPtr  mlClass = *idx;
      FeatureVectorListPtr  imagesThisClass = srcImages->ExtractImagesForAGivenClass (mlClass);
      imagesToTrain->AddQueue (*imagesThisClass);
      delete  imagesThisClass;
    }
  }

  NormalizationParms  normParms (config, *imagesToTrain, log);
  normParms.NormalizeImages (imagesToTrain);

  ClassAssignments  classAssignments (*mlClasses, log);
  FeatureVectorListPtr  compressedImageList = new FeatureVectorList (fileDesc, true, log, 10000);

  BitReduction br (config->SVMparamREF (), fileDesc, trainData->AllFeatures ());

  CompressionStats compressionStats 
              = br.compress (*imagesToTrain, 
                             compressedImageList, 
                             classAssignments
                            );


  int  compressedImageCount = compressionStats.num_images_after;

  log.Level (10) << "DetermineCompressedImageCount  compressedImageCount[" << compressedImageCount << "]" << endl;

  delete  compressedImageList;  compressedImageList = NULL;
  delete  mlClasses;         mlClasses        = NULL;
  delete  imagesToTrain;        imagesToTrain       = NULL;
  delete  srcImages;            srcImages           = NULL;

  return compressedImageCount;
}  /* DetermineCompressedImageCount */
Exemplo n.º 2
0
void  RandomSampleJob::EvaluteNode (FeatureVectorListPtr  validationData,
                                    MLClassListPtr     classes
                                   )
{
  log.Level (9) << "  " << endl;
  log.Level (9) << "  " << endl;
  log.Level (9) << "RandomSampleJob::EvaluteNode JobId[" << jobId << "] Ordering[" << orderingNum << "]" << endl;

  status = rjStarted;

  config->CompressionMethod (BRnoCompression);
  config->KernalType        (kernelType);
  config->EncodingMethod    (encodingMethod);
  config->C_Param           (c);
  config->Gamma             (gamma);

  FileDescPtr fileDesc = config->FileDesc ();


  const FeatureVectorListPtr  srcExamples = orderings->Ordering (orderingNum);

  if  (numExamplesToKeep > srcExamples->QueueSize ())
  {
    log.Level (-1) << endl << endl << endl
                   << "RandomSampleJob::EvaluteNode     *** ERROR ***    RandomExamples to large" << endl
                   << endl
                   << "                     RandomExamples > num in Training set." << endl
                   << endl;
    osWaitForEnter ();
    exit (-1);
  }



  FeatureVectorListPtr  trainingData = new FeatureVectorList (srcExamples->FileDesc (), false, log, 10000);
  for  (int x = 0;  x < numExamplesToKeep;  x++)
  {
    trainingData->PushOnBack (srcExamples->IdxToPtr (x));
  }

  bool  allClassesRepresented = true;
  {
    MLClassListPtr  classesInRandomSample = trainingData->ExtractListOfClasses ();
    if  (*classesInRandomSample != (*classes))
    {
      log.Level (-1) << endl << endl
                     << "RandomSampling    *** ERROR ***" << endl
                     << endl
                     << "                  Missing Classes From Random Sample." << endl
                     << endl
                     << "MLClasses[" << classes->ToCommaDelimitedStr               () << "]" << endl
                     << "Found       [" << classesInRandomSample->ToCommaDelimitedStr () << "]" << endl
                     << endl;

       allClassesRepresented = false;

    }

    delete  classesInRandomSample;  classesInRandomSample = NULL;
  }


  //if  (!allClassesRepresented)
  //{
  //  accuracy  = 0.0;
  //  trainTime = 0.0;
  //  testTime  = 0.0;
  //}
  //else
  {
    delete  crossValidation;  crossValidation = NULL;

    compMethod = config->CompressionMethod ();

    bool  cancelFlag = false;

    crossValidation = new CrossValidation 
                              (config,
                               trainingData,
                               classes,
                               10,
                               false,   //  False = Features are not normalized already.
                               trainingData->FileDesc (),
                               log,
                               cancelFlag
                              );

    crossValidation->RunValidationOnly (validationData, NULL);

    accuracy  = crossValidation->Accuracy ();
    trainTime = crossValidation->TrainTimeMean ();
    testTime  = crossValidation->TestTimeMean ();
    supportVectors = crossValidation->SupportPointsMean ();
  }

  delete  trainingData;

  status = rjDone;
}  /* EvaluteNode */