EWP-BplusToKstMuMu-AngAna/Code/Selection/MVA.cpp

348 lines
14 KiB
C++
Raw Normal View History

//training and testing of BDT for B2Kstplusmumu
//different BDTs are created for the subdecays
//1) Kst -> Kplus pi0
//2) Kst -> Kshort pi+
//later one is further split into DD and LL tracks
//David Gerick
//Renata Kopecna
#include "GlobalFunctions.hh"
#include "Paths.hpp"
#include "MVAclass.hpp"
#include "PlotTMVA.cpp"
#include "TMVA/MethodBase.h"
#include "TMVA/MethodCategory.h"
//#include "../TMVA/TMVA-v4.2.0/TMVA/Data.h"
#include "TMVA/TMVAGui.h"
#if not defined(__CINT__) || defined(__MAKECINT__)
#include "TMVA/Factory.h"
#include "TMVA/Tools.h"
#endif
//////////////////////////////////////////////////////
/// MVA_b2kmm()
/// training and testing of the BDT for final signal selection
/// the variables used for the BDT are defined from line 142 on.
/// one can seperate the years 2011 & 2012 for individual
/// training and testing or use the combined data-set. The Kshort
/// channel may be split up into DD and LL tracks
/// One can include the B+ -> Jpsi K*+ MC (reference channel) into
/// the signal Factory. Don't know why... just to try it out.
/// The results of the BDT training are saved in an .XML file and
/// read-in by TMVAClassificationApplication_b2kmm.cc
///
///
/// RunMVA()
/// Function to perform the complete MVA training and testing
/// with the given configurations in the functions defined in the
/// lines above. Can choose between Kshort and PizeroResolved
/// subdecays.
///
Int_t MVA_b2kmm() {
//limit the number of signal events; default = 0 : no limit
Int_t nSignal = 0;
bool use2Dweight = true;
//use also the reference channel MC (B -> K*+ J/Psi)
bool IncludeRefMC = false;
//Should we cut away outliners in the training?
bool cutOutliners = false;
//Should we cut away multiple candidates?
bool cutMultipleCandidates = false;
//separate dataset if the kshort decays within the velo or not
//LL tracks: 1
//DD tracks: 0
//Check MVAcongif
if(MVAconfig.Run != 1 && MVAconfig.Run != 2 && MVAconfig.Run != 12){
coutERROR("Invaliad Run ID chosen: >> " + to_string(MVAconfig.Run) + " << . Exit!");
return 0;
}
if(MVAconfig.Run == 12)coutInfo("Evaluate both Run 1 & 2 in one GO (train together Run 1 and Run 2)!");
if(MVAconfig.years.size() == 0 && MVAconfig.SplitYears){
coutERROR("No year given for the MVA configuration! Please fix before executing the MVA_b2kmm() function.");
return 0;
}
if (MVAconfig.customTMbranch == "") MVAconfig.customTMbranch = TMbranch;
TMVA::Tools::Instance();
TString signalTree= "DecayTreeTruthMatched";
TString bkgTree="DecayTree";
TString targetFile = GetBDTConfigFile(MVAconfig.SplitYears,MVAconfig.years.at(0),MVAconfig.Run,MVAconfig.KShortDecaysInVelo,MVAconfig.nConfiguration,MVAconfig.UseLowQ2Range, MVAconfig.customTMbranch, MVAconfig.gammaTM);
TFile* outputFile = TFile::Open(targetFile,"RECREATE" );
TChain* signal = new TChain(signalTree);
TChain* background= new TChain(bkgTree);
// Factory
TString factoryOptions = "!V:!Silent:Color:Transformations=I;N:AnalysisType=Classification:DrawProgressBar";
TMVA::Factory *factory;
TString factoryName = (MVAconfig.SplitYears ? (to_string(MVAconfig.years.at(0)) + "_") : "") + "B2Kstmumu_" + TheDecay + (SplitDDandLL ? (MVAconfig.KShortDecaysInVelo ? "_LL" : "_DD") : "")
+ (!MVAconfig.SplitYears ? ("_Run" + to_string(MVAconfig.Run)) : "") + (MVAconfig.SplitInQ2Range ? (MVAconfig.UseLowQ2Range ? "_lowQ2" : "_highQ2") : "");
factory = new TMVA::Factory( factoryName, outputFile , factoryOptions);
//add variables used in training
string DL = "";
if(Kst2Kspiplus && SplitDDandLL) DL = MVAconfig.KShortDecaysInVelo ? "LL" : "DD";
//Read MVA variables from file
MVA_variables InputVariables(DL);
InputVariables.print();
//and feed them to the reader
for (vector<MVA_def>::iterator tracksIter1 = InputVariables.AllVariables.begin(); tracksIter1 !=InputVariables.AllVariables.end();++tracksIter1){
factory->AddVariable( (*tracksIter1).ReaderName,(*tracksIter1).LaTeXName,(*tracksIter1).Unit,(*tracksIter1).DataType);
}
//set MVAconfig.years accordingly to the chosen Run
if(!MVAconfig.SplitYears){
MVAconfig.years.clear();
MVAconfig.years = yearsVectorInt(false, false, false, MVAconfig.Run);
coutInfo("Load files for Run ");
std::cout << (MVAconfig.Run == 12 ? "1 & 2" : std::to_string(MVAconfig.Run).c_str()) << ": Years ";
for(UInt_t y = 0; y < MVAconfig.years.size(); y++){
if(y == MVAconfig.years.size() - 1) std::cout << "and " << MVAconfig.years.at(y) << "." << std::endl;
else std::cout << MVAconfig.years.at(y) << ", ";
}
}
if(MVAconfig.SplitYears && MVAconfig.years.size() > 1){
coutERROR("Vector with years cannot be larger 1 for the SplitYears configuration!");
return 0;
}
//load data to trees
for(UInt_t y = 0; y < MVAconfig.years.size(); y++){
if (MVAconfig.years.at(y) != 2015 && Kst2Kpluspi0Resolved) signal->Add(GetBDTinputFile(MVAconfig.years.at(y),true,false,false,MVAconfig.KShortDecaysInVelo).c_str()); //I feel dirty by doing this, hardcoded, disgusting
if(IncludeRefMC) signal->Add(GetBDTinputFile(MVAconfig.years.at(y),true,true,false,MVAconfig.KShortDecaysInVelo).c_str());
background->Add(GetBDTinputFile(MVAconfig.years.at(y),false,false,false,MVAconfig.KShortDecaysInVelo).c_str());
}
// check the files
for(UInt_t y = 0; y < MVAconfig.years.size(); y++){
if (MVAconfig.years.at(y) != 2015 && Kst2Kpluspi0Resolved) coutDebug("Opening signal file " + GetBDTinputFile(MVAconfig.years.at(y),true,false,false,MVAconfig.KShortDecaysInVelo));
if(IncludeRefMC) coutDebug("Opening signal file " + GetBDTinputFile(MVAconfig.years.at(y),true,true,false,MVAconfig.KShortDecaysInVelo));
coutDebug("Opening background file " + GetBDTinputFile(MVAconfig.years.at(y),false,false,false,MVAconfig.KShortDecaysInVelo));
}
factory->AddSignalTree(signal,1.);
factory->AddBackgroundTree(background,1.);
string weightName = getWeightName(MVAconfig.customTMbranch,MVAconfig.gammaTM);
if (use2Dweight) factory->SetSignalWeightExpression(weightName.c_str()); //2D weights
else factory->SetSignalWeightExpression(weightName.c_str()); //1D weights
TCut cutsS;
TCut cutsB;
string sVariable = UseDTF ? "B_plus_M_DTF" : "B_plus_M";
if(Kst2Kspiplus){
//mass range:
cutsS = Form("%s < 5379 && %s > 5179", sVariable.c_str(), sVariable.c_str()); //cut +/- 100MeV on signal MC
cutsB = Form("%s > 5400", sVariable.c_str()); //upper mass sideband of data
//DD and LL split?
if(SplitDDandLL){
cutsS += Form("KshortDecayInVeLo == %i", MVAconfig.KShortDecaysInVelo);
cutsB += Form("KshortDecayInVeLo == %i", MVAconfig.KShortDecaysInVelo);
}
//Q2 range:
if(MVAconfig.SplitInQ2Range){
if(MVAconfig.UseLowQ2Range){
cutsS += "Q2 < 8.68e6";
cutsB += "Q2 < 8.68e6";
}
else{
cutsS += "(Q2 > 10.09e6 && Q2 < 12.9e6) || Q2 > 14.4e6)";
cutsB += "(Q2 > 10.09e6 && Q2 < 12.9e6) || Q2 > 14.4e6)";
}
}
else{
cutsS += "(Q2 < 8.68e6 || (Q2 > 10.09e6 && Q2 < 12.9e6) || Q2 > 14.4e6)";
cutsB += "(Q2 < 8.68e6 || (Q2 > 10.09e6 && Q2 < 12.9e6) || Q2 > 14.4e6)";
}
cutsS += TCut(getTMcut(true,false,MVAconfig.customTMbranch,MVAconfig.gammaTM).c_str());
}
else{ //pi0 channel
cutsS = Form("%s < 5379 && %s > 5179", sVariable.c_str(), sVariable.c_str()); // for signal use MC data from B mass window only (+-100MeV)
cutsB = Form("%s > 5700", sVariable.c_str()); // for background far upper sideband of data
//Q2 range: //TODO: check
if(MVAconfig.SplitInQ2Range){
if(MVAconfig.UseLowQ2Range){
cutsS += "Q2 < 8.68e6";
cutsB += "Q2 < 8.68e6";
}
else{
cutsS += "((Q2 > 10.09e6 && Q2 < 12.9e6) || Q2 > 14.4e6)";
cutsB += "((Q2 > 10.09e6 && Q2 < 12.9e6) || Q2 > 14.4e6)";
}
}
else{
//cutsS += "(Q2 < 8.68e6 || (Q2 > 10.09e6 && Q2 < 12.9e6) || Q2 > 14.4e6)";
//cutsB += "(Q2 < 8.68e6 || (Q2 > 10.09e6 && Q2 < 12.9e6) || Q2 > 14.4e6)";
cutsS += TCut(getMuMucut().c_str()); //TODO: check
cutsB += TCut(getMuMucut().c_str()); //TODO: check
}
//cut outliners
if (cutOutliners){
TCut outliners = "";
if (MVAconfig.Run == 1){
outliners += "B_plus_IP_OWNPV<0.1";
outliners += "TMath::Abs(pi_zero_resolved_ETA_DTF-K_plus_ETA_DTF)<1";
outliners += "TMath::Log(B_plus_PT_DTF)<10";
outliners += "TMath::Log(1.0-B_plus_DIRA_OWNPV)>-18";
outliners += "TMath::Log(K_plus_PT_DTF)<9.5";
outliners += "TMath::Log(mu_minus_IPCHI2_OWNPV)<11";
outliners += "gamma1_PT_DTF<3000";
outliners += "gamma2_PT_DTF<4000";
}
else if (MVAconfig.Run == 2){
outliners += "TMath::Max(TMath::Log(gamma1_PT_DTF),TMath::Log(gamma2_PT_DTF))<8.5";
outliners += "B_plus_IP_OWNPV<0.1";
outliners += "TMath::Abs(pi_zero_resolved_ETA_DTF-K_plus_ETA_DTF)<1";
outliners += "TMath::Log(B_plus_PT_DTF)<10.25";
outliners += "TMath::Log(K_plus_PT_DTF)<9.0";
outliners += "TMath::Log(K_plus_PT_DTF)>6.0";
outliners += "TMath::Log(1.0-B_plus_DIRA_OWNPV)>-18";
outliners += "TMath::Log(mu_minus_IPCHI2_OWNPV)<10";
outliners += "gamma1_PT_DTF<3000";
outliners += "gamma2_PT_DTF<4000";
}
cutsS+= outliners;
cutsB+= outliners;
}
if (cutMultipleCandidates){
cutsS += "totCandidates == 1";
cutsB += "totCandidates == 1";
}
cutsS += TCut(getTMcut(true,false,MVAconfig.customTMbranch,MVAconfig.gammaTM).c_str());
coutDebug("Using cut ");
if (verboseLevel < 3) cutsS.Print();
}
coutInfo("nEntries in signal tree: " + to_string(signal->GetEntries()));
coutInfo("nEntries in bckgnd tree: " + to_string(background->GetEntries()));
// Tell the factory how to use the training and testing events
factory->PrepareTrainingAndTestTree(cutsS, cutsB, Form("nTrain_Signal=%i:nTrain_Background=0:nTest_Signal=%i:nTest_Background=0:SplitMode=Random:NormMode=NumEvents:SplitSeed=500:!V", nSignal, nSignal));
// MVA methods: book BDT as default
if (Kst2Kspiplus){
factory->BookMethod( TMVA::Types::kBDT, "BDT", "!H:!V:NTrees=300:MaxDepth=3:BoostType=AdaBoost:AdaBoostBeta=0.5:SeparationType=GiniIndex:nCuts=12:PruneMethod=NoPruning");
factory->BookMethod( TMVA::Types::kBDT, "BDTG", "!H:!V:NTrees=450:MaxDepth=2:MinNodeSize=1.5%:BoostType=Grad:Shrinkage=0.10:UseBaggedBoost:BaggedSampleFraction=0.5:nCuts=14" );
}
else if (Kst2Kpluspi0Resolved){
factory->BookMethod( TMVA::Types::kBDT, "BDT", "!H:!V:NTrees=300:MaxDepth=2:BoostType=AdaBoost:AdaBoostBeta=0.26:SeparationType=GiniIndex:nCuts=12:PruneMethod=NoPruning");
factory->BookMethod( TMVA::Types::kBDT, "BDTG", "!H:!V:NTrees=450:MaxDepth=2:MinNodeSize=1.5%:BoostType=Grad:Shrinkage=0.10:UseBaggedBoost:BaggedSampleFraction=0.5:nCuts=14" );
factory->BookMethod( TMVA::Types::kMLP, "MLP", "H:!V:VarTransform=N:NCycles=750:HiddenLayers=N+5:TestRate=10:!UseRegulator" );//IgnoreNegWeightsInTraining=True
}
// Train and test MVAs
factory->TrainAllMethods();
factory->TestAllMethods();
// Evaluate performances
factory->EvaluateAllMethods();
outputFile->Close();
delete factory;
// Launch the GUI for the root macros
//if (!gROOT->IsBatch()) TMVA::TMVAGui( targetFile );
//Plot everything TODO: fix year and add gammaTM
SaveAllFromOneFile(2011,MVAconfig.Run,MVAconfig.SplitYears,MVAconfig.KShortDecaysInVelo,MVAconfig.nConfiguration,MVAconfig.UseLowQ2Range, MVAconfig.customTMbranch, MVAconfig.gammaTM);
coutInfo("MVA training done!");
return 1;
}
//Ks only
Int_t RunMore(Int_t Run = 1){
MVAconfig.Run = Run;
MVAconfig.SplitYears = false;
MVAconfig.SplitInQ2Range = false;
MVAconfig.nConfiguration = 1;
MVAconfig.KShortDecaysInVelo = 1;
if (MVA_b2kmm() == 0) return 0;
MVAconfig.KShortDecaysInVelo = 0;
if (MVA_b2kmm() == 0) return 0;
MVAconfig.SplitInQ2Range = true;
MVAconfig.UseLowQ2Range = true;
MVAconfig.nConfiguration = 2;
MVAconfig.KShortDecaysInVelo = 1;
if (MVA_b2kmm() == 0) return 0;
MVAconfig.KShortDecaysInVelo = 0;
if (MVA_b2kmm() == 0) return 0;
MVAconfig.UseLowQ2Range = false;
MVAconfig.nConfiguration = 3;
MVAconfig.KShortDecaysInVelo = 1;
if (MVA_b2kmm() == 0) return 0;
MVAconfig.KShortDecaysInVelo = 0;
if (MVA_b2kmm() == 0) return 0;
return 1;
}
Int_t RunDDandLLKshort(Int_t Run = 1){
MVAconfig.Run = Run;
MVAconfig.SplitYears = false;
MVAconfig.SplitInQ2Range = false;
MVAconfig.nConfiguration = 1;
MVAconfig.KShortDecaysInVelo = 1;
if (MVA_b2kmm() == 0) return 0;
MVAconfig.KShortDecaysInVelo = 0;
if (MVA_b2kmm() == 0) return 0;
return 1;
}
Int_t RunKplusPizeroResolved(Int_t Run = 1, int config=0, string customTMbranch ="", bool gammaTM = false){
MVAconfig.Run = Run;
MVAconfig.nConfiguration = config;
MVAconfig.KShortDecaysInVelo = false;
MVAconfig.SplitYears = false;
MVAconfig.SplitInQ2Range = false;
MVAconfig.UseLowQ2Range = false;
MVAconfig.customTMbranch = customTMbranch;
MVAconfig.gammaTM = gammaTM;
if (MVA_b2kmm() == 0) return 0;
return 1;
}
Int_t RunMVA(Int_t Run = 1){
if(Kst2Kspiplus)return RunDDandLLKshort(Run);
if(Kst2Kpluspi0Resolved){
//RunKplusPizeroResolved(Run,1,"TMedBKGCAT",false);
RunKplusPizeroResolved(Run,0,"TMed",false);
//RunKplusPizeroResolved(Run,2,"TMed",true);
}
return 0;
}