示例#1
0
void TMVACrossValidation()
{
   // This loads the library
   TMVA::Tools::Instance();

   // Load data
   TFile *input(0);
   TString fname = "./tmva_class_example.root";
   if (!gSystem->AccessPathName( fname )) {
      input = TFile::Open( fname ); // check if file in local directory exists
   }
   else {
      TFile::SetCacheFileDir(".");
      input = TFile::Open("http://root.cern.ch/files/tmva_class_example.root", "CACHEREAD");
   }
   if (!input) {
      std::cout << "ERROR: could not open data file" << std::endl;
      exit(1);
   }

   TTree* signalTree = (TTree*)input->Get("TreeS");
   TTree* background = (TTree*)input->Get("TreeB");

   // Setup dataloader
   TMVA::DataLoader* dataloader = new TMVA::DataLoader("dataset");

   dataloader->AddSignalTree(signalTree);
   dataloader->AddBackgroundTree(background);

   dataloader->AddVariable("var1");
   dataloader->AddVariable("var2");
   dataloader->AddVariable("var3");
   dataloader->AddVariable("var4");

   dataloader->PrepareTrainingAndTestTree("", "SplitMode=Random:NormMode=NumEvents:!V");

   // Setup cross-validation with Fisher method
   TMVA::CrossValidation cv(dataloader);
   cv.BookMethod(TMVA::Types::kFisher, "Fisher", "!H:!V:Fisher");

   // Run cross-validation and print results
   cv.Evaluate();
   TMVA::CrossValidationResult results = cv.GetResults();
   results.Print();
}
示例#2
0
void TMVACrossValidation()
{
   // This loads the library
   TMVA::Tools::Instance();

   // Load data
   TString fname = "./tmva_class_example.root";
   if (gSystem->AccessPathName(fname))
      gSystem->Exec("curl -O http://root.cern.ch/files/tmva_class_example.root");
   TFile *input = TFile::Open(fname);

   TTree* signalTree = (TTree*)input->Get("TreeS");
   TTree* background = (TTree*)input->Get("TreeB");

   // Setup dataloader
   TMVA::DataLoader* dataloader = new TMVA::DataLoader("dataset");

   dataloader->AddSignalTree(signalTree);
   dataloader->AddBackgroundTree(background);

   dataloader->AddVariable("var1");
   dataloader->AddVariable("var2");
   dataloader->AddVariable("var3");
   dataloader->AddVariable("var4");

   dataloader->PrepareTrainingAndTestTree("", "SplitMode=Random:NormMode=NumEvents:!V");

   // Setup cross-validation with Fisher method
   TMVA::CrossValidation cv(dataloader);
   cv.BookMethod(TMVA::Types::kFisher, "Fisher", "!H:!V:Fisher");

   // Run cross-validation and print results
   cv.Evaluate();
   TMVA::CrossValidationResult results = cv.GetResults();
   results.Print();
}
示例#3
0
int testPyKerasRegression(){
   // Get data file
   std::cout << "Get test data..." << std::endl;
   TString fname = "./tmva_reg_example.root";
   if (gSystem->AccessPathName(fname))  // file does not exist in local directory
      gSystem->Exec("curl -O http://root.cern.ch/files/tmva_reg_example.root");
   TFile *input = TFile::Open(fname);

   // Build model from python file
   std::cout << "Generate keras model..." << std::endl;
   UInt_t ret;
   ret = gSystem->Exec("echo '"+pythonSrc+"' > generateKerasModelRegression.py");
   if(ret!=0){
       std::cout << "[ERROR] Failed to write python code to file" << std::endl;
       return 1;
   }
   ret = gSystem->Exec("python generateKerasModelRegression.py");
   if(ret!=0){
       std::cout << "[ERROR] Failed to generate model using python" << std::endl;
       return 1;
   }

   // Setup PyMVA and factory
   std::cout << "Setup TMVA..." << std::endl;
   TMVA::PyMethodBase::PyInitialize();
   TFile* outputFile = TFile::Open("ResultsTestPyKerasRegression.root", "RECREATE");
   TMVA::Factory *factory = new TMVA::Factory("testPyKerasRegression", outputFile,
      "!V:Silent:Color:!DrawProgressBar:AnalysisType=Regression");

   // Load data
   TMVA::DataLoader *dataloader = new TMVA::DataLoader("datasetTestPyKerasRegression");

   TTree *tree = (TTree*)input->Get("TreeR");
   dataloader->AddRegressionTree(tree);

   dataloader->AddVariable("var1");
   dataloader->AddVariable("var2");
   dataloader->AddTarget("fvalue");

   dataloader->PrepareTrainingAndTestTree("",
      "SplitMode=Random:NormMode=NumEvents:!V");

   // Book and train method
   factory->BookMethod(dataloader, TMVA::Types::kPyKeras, "PyKeras",
      "!H:!V:VarTransform=D,G:FilenameModel=kerasModelRegression.h5:FilenameTrainedModel=trainedKerasModelRegression.h5:NumEpochs=10:BatchSize=32:SaveBestOnly=false:Verbose=0");
   std::cout << "Train model..." << std::endl;
   factory->TrainAllMethods();

   // Clean-up
   delete factory;
   delete dataloader;
   delete outputFile;

   // Setup reader
   UInt_t numEvents = 100;
   std::cout << "Run reader and estimate target of " << numEvents << " events..." << std::endl;
   TMVA::Reader *reader = new TMVA::Reader("!Color:Silent");
   Float_t vars[3];
   reader->AddVariable("var1", vars+0);
   reader->AddVariable("var2", vars+1);
   reader->BookMVA("PyKeras", "datasetTestPyKerasRegression/weights/testPyKerasRegression_PyKeras.weights.xml");

   // Get mean squared error on events
   tree->SetBranchAddress("var1", vars+0);
   tree->SetBranchAddress("var2", vars+1);
   tree->SetBranchAddress("fvalue", vars+2);

   Float_t meanMvaError = 0;
   for(UInt_t i=0; i<numEvents; i++){
      tree->GetEntry(i);
      meanMvaError += std::pow(vars[2]-reader->EvaluateMVA("PyKeras"),2);
   }
   meanMvaError = meanMvaError/float(numEvents);

   // Check whether the response is obviously better than guessing
   std::cout << "Mean squared error: " << meanMvaError << std::endl;
   if(meanMvaError > 30.0){
      std::cout << "[ERROR] Mean squared error is " << meanMvaError << " (>30.0)" << std::endl;
      return 1;
   }

   return 0;
}
示例#4
0
int testPyGTBMulticlass(){
   // Get data file
   std::cout << "Get test data..." << std::endl;
   TString fname = "./tmva_example_multiple_background.root";
   if (gSystem->AccessPathName(fname)){  // file does not exist in local directory
      std::cout << "Create multiclass test data..." << std::endl;
      TString createDataMacro = TString(gROOT->GetTutorialsDir()) + "/tmva/createData.C";
      gROOT->ProcessLine(TString::Format(".L %s",createDataMacro.Data()));
      gROOT->ProcessLine("create_MultipleBackground(200)");
      std::cout << "Created " << fname << " for tests of the multiclass features" << std::endl;
   }
   TFile *input = TFile::Open(fname);

   // Setup PyMVA and factory
   std::cout << "Setup TMVA..." << std::endl;
   TMVA::PyMethodBase::PyInitialize();
   TFile* outputFile = TFile::Open("ResultsTestPyGTBMulticlass.root", "RECREATE");
   TMVA::Factory *factory = new TMVA::Factory("testPyGTBMulticlass", outputFile,
      "!V:Silent:Color:!DrawProgressBar:AnalysisType=multiclass");

   // Load data
   TMVA::DataLoader *dataloader = new TMVA::DataLoader("datasetTestPyGTBMulticlass");

   TTree *signal = (TTree*)input->Get("TreeS");
   TTree *background0 = (TTree*)input->Get("TreeB0");
   TTree *background1 = (TTree*)input->Get("TreeB1");
   TTree *background2 = (TTree*)input->Get("TreeB2");
   dataloader->AddTree(signal, "Signal");
   dataloader->AddTree(background0, "Background_0");
   dataloader->AddTree(background1, "Background_1");
   dataloader->AddTree(background2, "Background_2");

   dataloader->AddVariable("var1");
   dataloader->AddVariable("var2");
   dataloader->AddVariable("var3");
   dataloader->AddVariable("var4");

   dataloader->PrepareTrainingAndTestTree("",
      "SplitMode=Random:NormMode=NumEvents:!V");

   // Book and train method
   factory->BookMethod(dataloader, TMVA::Types::kPyGTB, "PyGTB",
      "!H:!V:VarTransform=None:NEstimators=100:Verbose=0");
   std::cout << "Train classifier..." << std::endl;
   factory->TrainAllMethods();

   // Clean-up
   delete factory;
   delete dataloader;
   delete outputFile;

   // Setup reader
   UInt_t numEvents = 100;
   std::cout << "Run reader and classify " << numEvents << " events..." << std::endl;
   TMVA::Reader *reader = new TMVA::Reader("!Color:Silent");
   Float_t vars[4];
   reader->AddVariable("var1", vars+0);
   reader->AddVariable("var2", vars+1);
   reader->AddVariable("var3", vars+2);
   reader->AddVariable("var4", vars+3);
   reader->BookMVA("PyGTB", "datasetTestPyGTBMulticlass/weights/testPyGTBMulticlass_PyGTB.weights.xml");

   // Get mean response of method on signal and background events
   signal->SetBranchAddress("var1", vars+0);
   signal->SetBranchAddress("var2", vars+1);
   signal->SetBranchAddress("var3", vars+2);
   signal->SetBranchAddress("var4", vars+3);

   background0->SetBranchAddress("var1", vars+0);
   background0->SetBranchAddress("var2", vars+1);
   background0->SetBranchAddress("var3", vars+2);
   background0->SetBranchAddress("var4", vars+3);

   background1->SetBranchAddress("var1", vars+0);
   background1->SetBranchAddress("var2", vars+1);
   background1->SetBranchAddress("var3", vars+2);
   background1->SetBranchAddress("var4", vars+3);

   background2->SetBranchAddress("var1", vars+0);
   background2->SetBranchAddress("var2", vars+1);
   background2->SetBranchAddress("var3", vars+2);
   background2->SetBranchAddress("var4", vars+3);

   Float_t meanMvaSignal = 0;
   Float_t meanMvaBackground0 = 0;
   Float_t meanMvaBackground1 = 0;
   Float_t meanMvaBackground2 = 0;
   for(UInt_t i=0; i<numEvents; i++){
      signal->GetEntry(i);
      meanMvaSignal += reader->EvaluateMulticlass("PyGTB")[0];
      background0->GetEntry(i);
      meanMvaBackground0 += reader->EvaluateMulticlass("PyGTB")[1];
      background1->GetEntry(i);
      meanMvaBackground1 += reader->EvaluateMulticlass("PyGTB")[2];
      background2->GetEntry(i);
      meanMvaBackground2 += reader->EvaluateMulticlass("PyGTB")[3];
   }
   meanMvaSignal = meanMvaSignal/float(numEvents);
   meanMvaBackground0 = meanMvaBackground0/float(numEvents);
   meanMvaBackground1 = meanMvaBackground1/float(numEvents);
   meanMvaBackground2 = meanMvaBackground2/float(numEvents);

   // Check whether the response is obviously better than guessing
   std::cout << "Mean MVA response on signal: " << meanMvaSignal << std::endl;
   if(meanMvaSignal < 0.3){
      std::cout << "[ERROR] Mean response on signal is " << meanMvaSignal << " (<0.3)" << std::endl;
      return 1;
   }
   std::cout << "Mean MVA response on background 0: " << meanMvaBackground0 << std::endl;
   if(meanMvaBackground0 < 0.3){
      std::cout << "[ERROR] Mean response on background 0 is " << meanMvaBackground0 << " (<0.3)" << std::endl;
      return 1;
   }
   std::cout << "Mean MVA response on background 1: " << meanMvaBackground1 << std::endl;
   if(meanMvaBackground0 < 0.3){
      std::cout << "[ERROR] Mean response on background 1 is " << meanMvaBackground1 << " (<0.3)" << std::endl;
      return 1;
   }
   std::cout << "Mean MVA response on background 2: " << meanMvaBackground2 << std::endl;
   if(meanMvaBackground0 < 0.3){
      std::cout << "[ERROR] Mean response on background 2 is " << meanMvaBackground2 << " (<0.3)" << std::endl;
      return 1;
   }

   return 0;
}