// Combinatorial PDFs
void makeCombinatorialPdf( RooWorkspace *w ) {
  // one in each cat
  RooCategory *cat = (RooCategory*)w->cat("DataCat");
  for ( int i=0; i < cat->numTypes(); i++ ) {
    cat->setIndex(i);
    w->factory( Form("Exponential::bkg_pdf_%s( B_s0_DTF_B_s0_M, bkg_exp_p1_%s[-0.002,-0.005,0.] )", cat->getLabel(), cat->getLabel() ) );
    defineParamSet( w, Form("bkg_pdf_%s",cat->getLabel()) );
  }
}
// total PDF
void makeTotalPdf( RooWorkspace *w ) {

  // constrain the bs->phikst / bd->phikst ratio
  w->factory( "yield_ratio_bs2phikst_o_bd2phikst[0.,1.]" );
  w->factory( "Gaussian::yield_ratio_bs2phikst_o_bd2phikst_constraint( yield_ratio_bs2phikst_o_bd2phikst, 0.113, 0.0287 )" );
  // constrain the bd->rhokst / bd->phikst ratio
  w->factory("yield_ratio_bd2rhokst_o_bd2phikst[0.,1.]" );
  w->factory( "Gaussian::yield_ratio_bd2rhokst_o_bd2phikst_constraint( yield_ratio_bd2rhokst_o_bd2phikst, 0.390, 0.130 )" ); // PDG err is 0.130 (relax this for eff)

  // make a yield for each category
  RooCategory *cat = (RooCategory*)w->cat("DataCat");
  for ( int i=0; i < cat->numTypes(); i++ ) {
    cat->setIndex(i);
    w->factory( Form("bkg_y_%s[200,400e3]",       cat->getLabel()));
    w->factory( Form("part_reco_y_%s[100,200e3]", cat->getLabel()));
    w->factory( Form("bs2kstkst_y_%s[0,20e3]", cat->getLabel()));
    w->factory( Form("bd2kstkst_y_%s[0,3000]", cat->getLabel()));
    w->factory( Form("bd2phikst_y_%s[10,5000]", cat->getLabel()));
    // add bs2phikst yield as constrained ratio
    w->factory( Form("prod::bs2phikst_y_%s( yield_ratio_bs2phikst_o_bd2phikst, bd2phikst_y_%s )", cat->getLabel(), cat->getLabel()) );
    //w->factory( Form("bs2phikst_y_%s[10,5000]", cat->getLabel()));
    // add bd2rhokst yield as constrained ratio
    w->factory( Form("prod::bd2rhokst_y_%s( yield_ratio_bd2rhokst_o_bd2phikst, bd2phikst_y_%s )", cat->getLabel(), cat->getLabel()) );
    //w->factory( Form("bd2rhokst_y_%s[5,250]", cat->getLabel()));
    w->factory( Form("lb2pkpipi_y_%s[0,4000]", cat->getLabel()));
    w->factory( Form("lb2ppipipi_y_%s[0,4000]", cat->getLabel()));
  }

  // construct the pdf for each category
  for ( int i=0; i < cat->numTypes(); i++ ) {
    cat->setIndex(i);

    RooArgList *yields = new RooArgList();
    yields->add(*w->var( Form("bkg_y_%s"      , cat->getLabel()) ));
    yields->add(*w->var( Form("part_reco_y_%s", cat->getLabel()) ));
    yields->add(*w->var( Form("bs2kstkst_y_%s", cat->getLabel()) ));
    yields->add(*w->var( Form("bd2kstkst_y_%s", cat->getLabel()) ));
    yields->add(*w->var( Form("bd2phikst_y_%s", cat->getLabel()) ));
    yields->add(*w->function( Form("bs2phikst_y_%s", cat->getLabel()) ));
    //yields->add(*w->var( Form("bs2phikst_y_%s", cat->getLabel()) ));
    yields->add(*w->function( Form("bd2rhokst_y_%s", cat->getLabel()) ));
    //yields->add(*w->var( Form("bd2rhokst_y_%s", cat->getLabel()) ));
    yields->add(*w->var( Form("lb2pkpipi_y_%s", cat->getLabel()) ));
    //yields->add(*w->var( Form("lb2ppipipi_y_%s", cat->getLabel()) )); // this guy we scrap

    RooArgList *pdfs   = new RooArgList();
    pdfs->add(*w->pdf( Form("bkg_pdf_%s", cat->getLabel()) ));
    pdfs->add(*w->pdf("part_reco_pdf" ));
    pdfs->add(*w->pdf("bs2kstkst_mc_pdf"  ));
    pdfs->add(*w->pdf("bd2kstkst_mc_pdf" ));
    pdfs->add(*w->pdf("bd2phikst_mc_pdf" ));
    pdfs->add(*w->pdf("bs2phikst_mc_pdf" ));
    pdfs->add(*w->pdf("bd2rhokst_mc_pdf" ));
    pdfs->add(*w->pdf("lb2pkpipi_mc_pdf" ));
    //pdfs->add(*w->pdf("lb2ppipipi_mc_pdf")); // this guy we scrap

    RooAddPdf *pdf = new RooAddPdf( Form("pdf_%s",cat->getLabel()), "pdf" , *pdfs, *yields);
    w->import(*pdf);
    delete pdf;

    // then make the constrained pdf
    RooArgSet *prodpdfs = new RooArgSet();
    prodpdfs->add( *w->pdf(Form("pdf_%s",cat->getLabel())) );
    prodpdfs->add( *w->pdf("yield_ratio_bs2phikst_o_bd2phikst_constraint") );
    prodpdfs->add( *w->pdf("yield_ratio_bd2rhokst_o_bd2phikst_constraint") );
    RooProdPdf *cpdf = new RooProdPdf( Form("constrained_pdf_%s",cat->getLabel()), "constrained_pdf", *prodpdfs );
    w->import(*cpdf);
    delete cpdf;

    w->defineSet(Form("pdf_%s_yield_params",cat->getLabel()), *yields);
    w->defineSet(Form("constrained_pdf_%s_yield_params",cat->getLabel()), *yields);
  }

  // now make simultaneous pdf
  RooSimultaneous *cpdf = new RooSimultaneous( "constrained_pdf", "constrained_pdf", *w->cat("DataCat") );
  RooSimultaneous *pdf = new RooSimultaneous( "pdf", "pdf", *w->cat("DataCat") );
  for ( int i=0; i < cat->numTypes(); i++ ) {
    cat->setIndex(i);
    cpdf->addPdf( *w->pdf( Form("constrained_pdf_%s", cat->getLabel() )), cat->getLabel() );
    pdf->addPdf( *w->pdf( Form("pdf_%s", cat->getLabel() )), cat->getLabel() );
  }
  w->import(*cpdf);
  w->import(*pdf);

  delete pdf;
  delete cpdf;

}
void StandardHistFactoryPlotsWithCategories(const char* infile = "",
                                            const char* workspaceName = "combined",
                                            const char* modelConfigName = "ModelConfig",
                                            const char* dataName = "obsData"){


   double nSigmaToVary=5.;
   double muVal=0;
   bool doFit=false;

   // -------------------------------------------------------
   // First part is just to access a user-defined file
   // or create the standard example file if it doesn't exist
   const char* filename = "";
   if (!strcmp(infile,"")) {
      filename = "results/example_combined_GaussExample_model.root";
      bool fileExist = !gSystem->AccessPathName(filename); // note opposite return code
                                                           // if file does not exists generate with histfactory
      if (!fileExist) {
#ifdef _WIN32
         cout << "HistFactory file cannot be generated on Windows - exit" << endl;
         return;
#endif
         // Normally this would be run on the command line
         cout <<"will run standard hist2workspace example"<<endl;
         gROOT->ProcessLine(".! prepareHistFactory .");
         gROOT->ProcessLine(".! hist2workspace config/example.xml");
         cout <<"\n\n---------------------"<<endl;
         cout <<"Done creating example input"<<endl;
         cout <<"---------------------\n\n"<<endl;
      }

   }
   else
      filename = infile;

   // Try to open the file
   TFile *file = TFile::Open(filename);

   // if input file was specified byt not found, quit
   if(!file ){
      cout <<"StandardRooStatsDemoMacro: Input file " << filename << " is not found" << endl;
      return;
   }

   // -------------------------------------------------------
   // Tutorial starts here
   // -------------------------------------------------------

   // get the workspace out of the file
   RooWorkspace* w = (RooWorkspace*) file->Get(workspaceName);
   if(!w){
      cout <<"workspace not found" << endl;
      return;
   }

   // get the modelConfig out of the file
   ModelConfig* mc = (ModelConfig*) w->obj(modelConfigName);

   // get the modelConfig out of the file
   RooAbsData* data = w->data(dataName);

   // make sure ingredients are found
   if(!data || !mc){
      w->Print();
      cout << "data or ModelConfig was not found" <<endl;
      return;
   }

   // -------------------------------------------------------
   // now use the profile inspector

   RooRealVar* obs = (RooRealVar*)mc->GetObservables()->first();
   TList* list = new TList();


   RooRealVar * firstPOI = dynamic_cast<RooRealVar*>(mc->GetParametersOfInterest()->first());

   firstPOI->setVal(muVal);
   //  firstPOI->setConstant();
   if(doFit){
      mc->GetPdf()->fitTo(*data);
   }

   // -------------------------------------------------------


   mc->GetNuisanceParameters()->Print("v");
   int  nPlotsMax = 1000;
   cout <<" check expectedData by category"<<endl;
   RooDataSet* simData=NULL;
   RooSimultaneous* simPdf = NULL;
   if(strcmp(mc->GetPdf()->ClassName(),"RooSimultaneous")==0){
      cout <<"Is a simultaneous PDF"<<endl;
      simPdf = (RooSimultaneous *)(mc->GetPdf());
   } else {
      cout <<"Is not a simultaneous PDF"<<endl;
   }



   if(doFit) {
      RooCategory* channelCat = (RooCategory*) (&simPdf->indexCat());
      TIterator* iter = channelCat->typeIterator() ;
      RooCatType* tt = NULL;
      tt=(RooCatType*) iter->Next();
      RooAbsPdf* pdftmp = ((RooSimultaneous*)mc->GetPdf())->getPdf(tt->GetName()) ;
      RooArgSet* obstmp = pdftmp->getObservables(*mc->GetObservables()) ;
      obs = ((RooRealVar*)obstmp->first());
      RooPlot* frame = obs->frame();
      cout <<Form("%s==%s::%s",channelCat->GetName(),channelCat->GetName(),tt->GetName())<<endl;
      cout << tt->GetName() << " " << channelCat->getLabel() <<endl;
      data->plotOn(frame,MarkerSize(1),Cut(Form("%s==%s::%s",channelCat->GetName(),channelCat->GetName(),tt->GetName())),DataError(RooAbsData::None));

      Double_t normCount = data->sumEntries(Form("%s==%s::%s",channelCat->GetName(),channelCat->GetName(),tt->GetName())) ;

      pdftmp->plotOn(frame,LineWidth(2.),Normalization(normCount,RooAbsReal::NumEvent)) ;
      frame->Draw();
      cout <<"expected events = " << mc->GetPdf()->expectedEvents(*data->get()) <<endl;
      return;
   }



   int nPlots=0;
   if(!simPdf){

      TIterator* it = mc->GetNuisanceParameters()->createIterator();
      RooRealVar* var = NULL;
      while( (var = (RooRealVar*) it->Next()) != NULL){
         RooPlot* frame = obs->frame();
         frame->SetYTitle(var->GetName());
         data->plotOn(frame,MarkerSize(1));
         var->setVal(0);
         mc->GetPdf()->plotOn(frame,LineWidth(1.));
         var->setVal(1);
         mc->GetPdf()->plotOn(frame,LineColor(kRed),LineStyle(kDashed),LineWidth(1));
         var->setVal(-1);
         mc->GetPdf()->plotOn(frame,LineColor(kGreen),LineStyle(kDashed),LineWidth(1));
         list->Add(frame);
         var->setVal(0);
      }


   } else {
      RooCategory* channelCat = (RooCategory*) (&simPdf->indexCat());
      //    TIterator* iter = simPdf->indexCat().typeIterator() ;
      TIterator* iter = channelCat->typeIterator() ;
      RooCatType* tt = NULL;
      while(nPlots<nPlotsMax && (tt=(RooCatType*) iter->Next())) {

         cout << "on type " << tt->GetName() << " " << endl;
         // Get pdf associated with state from simpdf
         RooAbsPdf* pdftmp = simPdf->getPdf(tt->GetName()) ;

         // Generate observables defined by the pdf associated with this state
         RooArgSet* obstmp = pdftmp->getObservables(*mc->GetObservables()) ;
         //      obstmp->Print();


         obs = ((RooRealVar*)obstmp->first());

         TIterator* it = mc->GetNuisanceParameters()->createIterator();
         RooRealVar* var = NULL;
         while(nPlots<nPlotsMax && (var = (RooRealVar*) it->Next())){
            TCanvas* c2 = new TCanvas("c2");
            RooPlot* frame = obs->frame();
            frame->SetName(Form("frame%d",nPlots));
            frame->SetYTitle(var->GetName());

            cout <<Form("%s==%s::%s",channelCat->GetName(),channelCat->GetName(),tt->GetName())<<endl;
            cout << tt->GetName() << " " << channelCat->getLabel() <<endl;
            data->plotOn(frame,MarkerSize(1),Cut(Form("%s==%s::%s",channelCat->GetName(),channelCat->GetName(),tt->GetName())),DataError(RooAbsData::None));

            Double_t normCount = data->sumEntries(Form("%s==%s::%s",channelCat->GetName(),channelCat->GetName(),tt->GetName())) ;

            if(strcmp(var->GetName(),"Lumi")==0){
               cout <<"working on lumi"<<endl;
               var->setVal(w->var("nominalLumi")->getVal());
               var->Print();
            } else{
               var->setVal(0);
            }
            // w->allVars().Print("v");
            // mc->GetNuisanceParameters()->Print("v");
            // pdftmp->plotOn(frame,LineWidth(2.));
            // mc->GetPdf()->plotOn(frame,LineWidth(2.),Slice(*channelCat,tt->GetName()),ProjWData(*data));
            //pdftmp->plotOn(frame,LineWidth(2.),Slice(*channelCat,tt->GetName()),ProjWData(*data));
            normCount = pdftmp->expectedEvents(*obs);
            pdftmp->plotOn(frame,LineWidth(2.),Normalization(normCount,RooAbsReal::NumEvent)) ;

            if(strcmp(var->GetName(),"Lumi")==0){
               cout <<"working on lumi"<<endl;
               var->setVal(w->var("nominalLumi")->getVal()+0.05);
               var->Print();
            } else{
               var->setVal(nSigmaToVary);
            }
            // pdftmp->plotOn(frame,LineColor(kRed),LineStyle(kDashed),LineWidth(2));
            // mc->GetPdf()->plotOn(frame,LineColor(kRed),LineStyle(kDashed),LineWidth(2.),Slice(*channelCat,tt->GetName()),ProjWData(*data));
            //pdftmp->plotOn(frame,LineColor(kRed),LineStyle(kDashed),LineWidth(2.),Slice(*channelCat,tt->GetName()),ProjWData(*data));
            normCount = pdftmp->expectedEvents(*obs);
            pdftmp->plotOn(frame,LineWidth(2.),LineColor(kRed),LineStyle(kDashed),Normalization(normCount,RooAbsReal::NumEvent)) ;

            if(strcmp(var->GetName(),"Lumi")==0){
               cout <<"working on lumi"<<endl;
               var->setVal(w->var("nominalLumi")->getVal()-0.05);
               var->Print();
            } else{
               var->setVal(-nSigmaToVary);
            }
            // pdftmp->plotOn(frame,LineColor(kGreen),LineStyle(kDashed),LineWidth(2));
            // mc->GetPdf()->plotOn(frame,LineColor(kGreen),LineStyle(kDashed),LineWidth(2),Slice(*channelCat,tt->GetName()),ProjWData(*data));
            //pdftmp->plotOn(frame,LineColor(kGreen),LineStyle(kDashed),LineWidth(2),Slice(*channelCat,tt->GetName()),ProjWData(*data));
            normCount = pdftmp->expectedEvents(*obs);
            pdftmp->plotOn(frame,LineWidth(2.),LineColor(kGreen),LineStyle(kDashed),Normalization(normCount,RooAbsReal::NumEvent)) ;



            // set them back to normal
            if(strcmp(var->GetName(),"Lumi")==0){
               cout <<"working on lumi"<<endl;
               var->setVal(w->var("nominalLumi")->getVal());
               var->Print();
            } else{
               var->setVal(0);
            }

            list->Add(frame);

            // quit making plots
            ++nPlots;

            frame->Draw();
            c2->SaveAs(Form("%s_%s_%s.pdf",tt->GetName(),obs->GetName(),var->GetName()));
            delete c2;
         }
      }
   }



   // -------------------------------------------------------


   // now make plots
   TCanvas* c1 = new TCanvas("c1","ProfileInspectorDemo",800,200);
   if(list->GetSize()>4){
      double n = list->GetSize();
      int nx = (int)sqrt(n) ;
      int ny = TMath::CeilNint(n/nx);
      nx = TMath::CeilNint( sqrt(n) );
      c1->Divide(ny,nx);
   } else
      c1->Divide(list->GetSize());
   for(int i=0; i<list->GetSize(); ++i){
      c1->cd(i+1);
      list->At(i)->Draw();
   }





}