inclusive_detached_dilepton/train_bdt.cpp

596 lines
19 KiB
C++
Raw Normal View History

2023-11-15 16:49:50 +01:00
#include "TH1D.h"
#include "TH2D.h"
#include "THStack.h"
#include "TGraph.h"
#include "TTree.h"
#include "TChain.h"
#include "TFile.h"
#include "TCanvas.h"
#include "TROOT.h"
#include "TStyle.h"
#include "TColor.h"
#include "TLorentzVector.h"
#include "TRandom3.h"
#include "TLorentzVector.h"
#include "TMVA/Factory.h"
#include "TMVA/DataLoader.h"
#include "TMVA/Reader.h"
#include "RooDataHist.h"
#include "RooRealVar.h"
#include "RooPlot.h"
#include "RooGaussian.h"
#include "RooExponential.h"
#include "RooRealConstant.h"
#include "RooAddPdf.h"
#include "RooFitResult.h"
#include "RooProduct.h"
#include "RooCrystalBall.h"
#include <string>
#include <iostream>
#include <cmath>
const bool TRAIN_TMVA = false;
const bool EVALUATE_TMVA = true;
2023-11-17 13:18:07 +01:00
const int N_BINS = 130;
2023-11-15 16:49:50 +01:00
const double B_PLUS_MASS = 5279.;
const double J_PSI_MASS = 3096.;
RooPlot *CreateRooFit(TH1D *hist);
bool inRange(double value, double center, double low_intvl, double up_intvl) {
return center - low_intvl < value && value < center + up_intvl;
}
bool inRange(double value, double center, double intvl) {
return inRange(value, center, intvl, intvl);
}
class FourVector {
private:
Float_t px;
Float_t py;
Float_t pz;
Float_t energy;
public:
TLorentzVector LV() {
return TLorentzVector(px, py, pz, energy);
}
void Print() {
std::cout << TString::Format("(PX: %f, PY: %f, PZ: %f, E: %f)", px, py, pz, energy) << std::endl;
}
void Connect(TChain* chain, std::string name) {
chain->SetBranchAddress(TString::Format("%s_PX", name.c_str()), &px);
chain->SetBranchAddress(TString::Format("%s_PY", name.c_str()), &py);
chain->SetBranchAddress(TString::Format("%s_PZ", name.c_str()), &pz);
chain->SetBranchAddress(TString::Format("%s_ENERGY", name.c_str()), &energy);
}
};
enum TVT
{
Double,
Float,
Int
};
class TV
{
private:
std::string data_name;
std::string mc_name;
std::string train_name;
TVT type;
Double_t mc_double_value;
Float_t mc_float_value;
Double_t data_double_value;
Float_t data_float_value;
TV(std::string data_name, std::string mc_name, std::string train_name, TVT type)
: data_name{data_name}, mc_name{mc_name}, train_name{train_name}, type{type}
{
}
public:
static TV Float(std::string data_name, std::string mc_name, std::string train_name)
{
return TV(data_name, mc_name, train_name, TVT::Float);
}
static TV Float(std::string data_name, std::string mc_name)
{
return TV(data_name, mc_name, data_name, TVT::Float);
}
static TV Float(std::string data_name)
{
return TV(data_name, data_name, data_name, TVT::Float);
}
static TV Double(std::string data_name, std::string mc_name, std::string train_name)
{
return TV(data_name, mc_name, train_name, TVT::Double);
}
static TV Double(std::string data_name, std::string mc_name)
{
return TV(data_name, mc_name, data_name, TVT::Double);
}
static TV Double(std::string data_name)
{
return TV(data_name, data_name, data_name, TVT::Double);
}
const char *GetDataName()
{
return data_name.c_str();
}
const char *GetMCName()
{
return mc_name.c_str();
}
const char *GetTrainName()
{
return train_name.c_str();
}
Double_t *GetMCDoubleRef()
{
return &mc_double_value;
}
Float_t *GetMCFloatRef()
{
return &mc_float_value;
}
Double_t *GetDataDoubleRef()
{
return &data_double_value;
}
Float_t *GetDataFloatRef()
{
return &data_float_value;
}
Double_t GetDataDouble()
{
return data_double_value;
}
Float_t GetDataFloat()
{
return data_float_value;
}
void PrintDataValue(int entry)
{
std::cout << data_name << " (" << entry << "): ";
if (IsDouble())
{
std::cout << data_double_value;
}
else if (IsFloat())
{
std::cout << data_float_value;
}
std::cout << std::endl;
}
void PrintMCValue(int entry)
{
std::cout << mc_name << " (" << entry << "): ";
if (IsDouble())
{
std::cout << mc_double_value;
}
else if (IsFloat())
{
std::cout << mc_float_value;
}
std::cout << std::endl;
}
bool IsDataFinite()
{
if (IsDouble())
{
return std::isfinite(data_double_value);
}
else if (IsFloat())
{
return std::isfinite(data_float_value);
}
return false;
}
bool IsMCFinite()
{
if (IsDouble())
{
return std::isfinite(mc_double_value);
}
else if (IsFloat())
{
return std::isfinite(mc_float_value);
}
return false;
}
bool IsDouble()
{
return type == TVT::Double;
}
bool IsFloat()
{
return type == TVT::Float;
}
};
int train_bdt()
{
std::cout << TString::Format("Starting Up With TRAIN_TMVA=%d AND EVALUATE_TMVA=%d.", TRAIN_TMVA, EVALUATE_TMVA) << std::endl;
// files to load
std::vector<std::string> data_filenames =
{
"/auto/data/pfeiffer/inclusive_detached_dilepton/data_samples/BuToKpMuMu_Collision23_Beam6800GeV-VeloClosed-MagDown-Excl-UT_RealData_Sprucing23r1_90000000_RD.root"};
TChain *data_chain = new TChain("SpruceRD_BuToKpMuMu/DecayTree");
for (unsigned int i = 0; i < data_filenames.size(); i++)
{
data_chain->Add(data_filenames.at(i).c_str());
}
// files to load
std::vector<std::string> mc_filenames =
{
2023-11-17 13:18:07 +01:00
"/auto/data/pfeiffer/inclusive_detached_dilepton/MC/BuToKpMuMu_rd_btoxll_simulation_12143001_MagDown_v0r0p6316365_FULLSTREAM.root"};
2023-11-15 16:49:50 +01:00
2023-11-17 13:18:07 +01:00
TChain *mc_chain = new TChain("BuToKpMuMu_noPID/DecayTree");
2023-11-15 16:49:50 +01:00
for (unsigned int i = 0; i < mc_filenames.size(); i++)
{
mc_chain->Add(mc_filenames.at(i).c_str());
}
std::vector<TV> vars{
TV::Float("Bplus_PT", "B_PT"),
TV::Float("Bplus_BPVFDCHI2", "B_BPVFDCHI2"),
2023-11-17 13:18:07 +01:00
TV::Float("Bplus_BPVDIRA", "B_BPVDIRA"),
2023-11-15 16:49:50 +01:00
TV::Float("Jpsi_BPVIPCHI2", "Jpsi_BPVIPCHI2"),
2023-11-17 13:18:07 +01:00
TV::Float("Jpsi_BPVDIRA", "Jpsi_BPVDIRA"),
2023-11-15 16:49:50 +01:00
TV::Float("Jpsi_PT", "Jpsi_PT"),
TV::Float("Kplus_BPVIPCHI2", "K_BPVIPCHI2"),
TV::Float("Kplus_PT", "K_PT"),
// TV::Double("Kplus_PID_K", "K_PID_K"),
TV::Double("Kplus_PROBNN_K", "K_PROBNN_K"),
TV::Float("muplus_BPVIPCHI2", "L1_BPVIPCHI2"),
TV::Float("muminus_BPVIPCHI2", "L2_BPVIPCHI2"),
};
TTree *sig_tree = new TTree("TreeS", "tree containing signal data");
TTree *bkg_tree = new TTree("TreeB", "tree containing background data");
Double_t Bplus_M, B_M, K_PID_K;
data_chain->SetBranchAddress("Bplus_M", &Bplus_M);
mc_chain->SetBranchAddress("B_M", &B_M);
mc_chain->SetBranchAddress("K_PID_K", &K_PID_K);
FourVector muplus_4v, muminus_4v;
muplus_4v.Connect(mc_chain, "L1");
muminus_4v.Connect(mc_chain, "L2");
for (size_t i = 0; i < vars.size(); i++)
{
if (vars[i].IsDouble())
{
data_chain->SetBranchAddress(vars[i].GetDataName(), vars[i].GetDataDoubleRef());
mc_chain->SetBranchAddress(vars[i].GetMCName(), vars[i].GetMCDoubleRef());
sig_tree->Branch(vars[i].GetTrainName(), vars[i].GetMCDoubleRef(), TString::Format("%s/D", vars[i].GetTrainName()));
bkg_tree->Branch(vars[i].GetTrainName(), vars[i].GetDataDoubleRef(), TString::Format("%s/D", vars[i].GetTrainName()));
}
else if (vars[i].IsFloat())
{
data_chain->SetBranchAddress(vars[i].GetDataName(), vars[i].GetDataFloatRef());
mc_chain->SetBranchAddress(vars[i].GetMCName(), vars[i].GetMCFloatRef());
sig_tree->Branch(vars[i].GetTrainName(), vars[i].GetMCFloatRef(), TString::Format("%s/F", vars[i].GetTrainName()));
bkg_tree->Branch(vars[i].GetTrainName(), vars[i].GetDataFloatRef(), TString::Format("%s/F", vars[i].GetTrainName()));
}
}
unsigned int data_entries = data_chain->GetEntries();
unsigned int mc_entries = mc_chain->GetEntries();
if (TRAIN_TMVA)
{
std::cout << "----- Start TMVA Setup -----" << std::endl;
std::cout << "----- Setting Up Signal and Background Tree -----" << std::endl;
unsigned int added_bck_entries = 0;
unsigned int added_sig_entries = 0;
std::cout << "----- Processing Data -----" << std::endl;
for (unsigned int i = 0; i < data_entries; i++)
{
data_chain->GetEntry(i);
bool skip = false;
for (size_t j = 0; j < vars.size(); j++)
{
if (!vars[j].IsDataFinite())
{
vars[j].PrintDataValue(i);
skip = true;
}
}
if (skip)
{
continue;
}
if (Bplus_M > 5500.)
{
bkg_tree->Fill();
added_bck_entries++;
}
}
std::cout << "----- Processing Simulation -----" << std::endl;
for (unsigned int i = 0; i < mc_entries; i++)
{
mc_chain->GetEntry(i);
bool skip = false;
for (size_t j = 0; j < vars.size(); j++)
{
if (!vars[j].IsMCFinite())
{
vars[j].PrintMCValue(i);
skip = true;
}
}
Double_t mumu_inv_mass = (muplus_4v.LV() + muminus_4v.LV()).M();
if (skip || !inRange(B_M, B_PLUS_MASS, 100.) || !inRange(mumu_inv_mass, J_PSI_MASS, 100.) || K_PID_K < 0.)
{
continue;
}
sig_tree->Fill();
added_sig_entries++;
2023-11-17 13:18:07 +01:00
if (added_sig_entries >= added_bck_entries * 2) {
2023-11-15 16:49:50 +01:00
break;
}
}
std::cout << "----- Start TMVA Training -----" << std::endl;
std::cout << TString::Format("With %d Signal and %d Background Events.", added_sig_entries, added_bck_entries) << std::endl;
TString outfile_name("tmva_butokpmumu_out.root");
TFile *output_file = TFile::Open(outfile_name, "RECREATE");
TString factory_options("V:Silent:Color:DrawProgressBar:Transformations=I;D;P;G,D:AnalysisType=Auto");
TMVA::Factory *factory = new TMVA::Factory("tmva_butokpmumu", output_file, factory_options);
TMVA::DataLoader *data_loader = new TMVA::DataLoader("dataloader");
for (int i = 0; i < vars.size(); i++)
{
std::cout << "Adding Branch: " << vars[i].GetTrainName() << std::endl;
if (vars[i].IsDouble())
{
data_loader->AddVariable(vars[i].GetTrainName(), 'D');
}
else if (vars[i].IsFloat())
{
data_loader->AddVariable(vars[i].GetTrainName(), 'F');
}
}
Double_t signal_weight = 1.0, background_weight = 1.0;
data_loader->AddSignalTree(sig_tree, signal_weight);
data_loader->AddBackgroundTree(bkg_tree, background_weight);
data_loader->PrepareTrainingAndTestTree("", "", "nTrain_Signal=0:nTrain_Background=0:SplitMode=Random:NormMode=NumEvents:!V");
2023-11-17 13:18:07 +01:00
factory->BookMethod(data_loader, TMVA::Types::kBDT, "BDT", "!H:!V:NTrees=600:MinNodeSize=2.5%:CreateMVAPdfs:MaxDepth=3:BoostType=AdaBoost:AdaBoostBeta=0.5:UseBaggedBoost:BaggedSampleFraction=0.5:SeparationType=GiniIndex:nCuts=20");
2023-11-15 16:49:50 +01:00
factory->TrainAllMethods();
factory->TestAllMethods();
factory->EvaluateAllMethods();
output_file->Close();
std::cout << "----- Finished TMVA Setup & Training -----" << std::endl;
}
else
{
std::cout << "----- Skipped TMVA Setup & Training -----" << std::endl;
}
if (EVALUATE_TMVA)
{
std::cout << "----- Start TMVA Evaluation of Data -----" << std::endl;
TMVA::Reader *reader = new TMVA::Reader("!Color:!Silent");
Float_t *train_vars = new Float_t[vars.size()];
for (size_t i = 0; i < vars.size(); i++)
{
reader->AddVariable(vars[i].GetTrainName(), &train_vars[i]);
}
reader->BookMVA("BDT", "./dataloader/weights/tmva_butokpmumu_BDT.weights.xml");
TH1D *h1_probs = new TH1D("h1_probs", "BDT Probabilities", N_BINS, -1, 1);
TH1D *h1_Bplus_M = new TH1D("h1_Bplus_M", "B^{+} Mass", N_BINS, 4700., 6500.);
for (unsigned int i = 0; i < data_entries; i++)
{
data_chain->GetEntry(i);
bool skip = false;
for (size_t j = 0; j < vars.size(); j++)
{
if (!vars[j].IsDataFinite())
{
vars[j].PrintDataValue(i);
skip = true;
break;
}
if (vars[j].IsDouble())
{
train_vars[j] = vars[j].GetDataDouble();
}
else if (vars[j].IsFloat())
{
train_vars[j] = vars[j].GetDataFloat();
}
}
if (skip)
{
continue;
}
double mva_response = reader->EvaluateMVA("BDT");
h1_probs->Fill(mva_response);
2023-11-17 13:18:07 +01:00
const double mva_cut_value = 0.09; // -0.02;
2023-11-15 16:49:50 +01:00
if (mva_response > mva_cut_value)
{
h1_Bplus_M->Fill(Bplus_M);
}
}
std::cout << "----- Finished TMVA Evaluation of Data -----" << std::endl;
TCanvas *c1 = new TCanvas("c1", "c1", 0, 0, 1200, 800);
c1->Divide(2, 1);
c1->cd(1);
h1_probs->Draw();
c1->cd(2);
h1_Bplus_M->Draw();
c1->Draw();
2023-11-17 13:18:07 +01:00
TCanvas *c3 = new TCanvas("c3", "Canvas 3", 0, 0, 1000, 600);
2023-11-15 16:49:50 +01:00
2023-11-17 13:18:07 +01:00
auto fitFrame = CreateRooFit(h1_Bplus_M);
fitFrame->Draw();
2023-11-15 16:49:50 +01:00
2023-11-17 13:18:07 +01:00
c3->Draw();
2023-11-15 16:49:50 +01:00
}
else
{
std::cout << "----- Skipped TMVA Evaluation of Data -----" << std::endl;
}
return 0;
}
RooPlot *CreateRooFit(TH1D *hist)
{
RooRealVar roo_var_mass("roo_var_mass", "B+ Mass Variable", 4700., 6500.);
2023-11-17 13:18:07 +01:00
roo_var_mass.setRange("fitting_range", 4800., 5800.);
2023-11-15 16:49:50 +01:00
roo_var_mass.setRange("plot_range", 4700., 6500.);
TString hist_name = "roohist_bplus_M";
RooDataHist roohist_bplus_M(hist_name, "B Plus Mass Histogram", roo_var_mass, RooFit::Import(*hist));
2023-11-17 13:18:07 +01:00
RooRealVar roo_sig_gauss_mean("roo_sig_gauss_mean", "Mass Gauss Mean", 5250., 5100., 5400.);
RooRealVar roo_sig_gauss_sigma("roo_sig_gauss_sigma", "Mass Gauss Sigma", 30., 20., 40.);
2023-11-15 16:49:50 +01:00
2023-11-17 13:18:07 +01:00
RooGaussian roo_sig_gauss("roo_sig_gauss", "B+ Mass Signal Gaussian", roo_var_mass, roo_sig_gauss_mean, roo_sig_gauss_sigma);
2023-11-15 16:49:50 +01:00
// Crystal Ball for Signal
2023-11-17 13:18:07 +01:00
// RooRealVar roo_sig_cb_x0("roo_sig_cry_x0", "Location", B_PLUS_MASS, 5100., 5400.);
// RooRealVar roo_sig_cb_sigmaL("roo_sig_cry_sigmaL", "Sigma L", 30., 0., 60.);
// RooRealVar roo_sig_cb_sigmaR("roo_sig_cry_sigmaR", "Sigma R", 30., 0., 60.);
2023-11-15 16:49:50 +01:00
2023-11-17 13:18:07 +01:00
// RooRealVar roo_sig_cb_alphaL("roo_sig_cry_alphaL", "Alpha L", 15., 0., 30.);
// RooRealVar roo_sig_cb_nL("roo_sig_cry_nL", "Exponent L", 0., -40., 40.);
2023-11-15 16:49:50 +01:00
2023-11-17 13:18:07 +01:00
// RooRealVar roo_sig_cb_alphaR("roo_sig_cry_alphaR", "Alpha R", 15., 0., 30.);
// RooRealVar roo_sig_cb_nR("roo_sig_cry_nR", "Exponent R", 0., -40., 40.);
2023-11-15 16:49:50 +01:00
2023-11-17 13:18:07 +01:00
// RooCrystalBall roo_sig_cb("roo_sig_cb", "Signal Crystal Ball", roo_var_mass, roo_sig_cb_x0, roo_sig_cb_sigmaL, roo_sig_cb_sigmaR, roo_sig_cb_alphaL, roo_sig_cb_nL, roo_sig_cb_alphaR, roo_sig_cb_nR);
2023-11-15 16:49:50 +01:00
2023-11-17 13:18:07 +01:00
// Double Gauss Signal
// RooRealVar roo_sig_gauss1_mean("roo_sig_gauss1_mean", "Mass Gauss 1 Mean", 5250., 5200., 5300.);
// RooRealVar roo_sig_gauss1_sigma("roo_sig_gauss1_sigma", "Mass Gauss 1 Sigma", 60., 0., 150.);
// RooGaussian roo_sig_gauss1("roo_sig_gauss1", "B+ Mass Signal 1 Gaussian", roo_var_mass, roo_sig_gauss1_mean, roo_sig_gauss1_sigma);
// RooRealVar roo_sig_gauss2_mean("roo_sig_gauss2_mean", "Mass Gauss 2 Mean", 5250., 5200., 5300.);
// RooRealVar roo_sig_gauss2_sigma("roo_sig_gauss2_sigma", "Mass Gauss 2 Sigma", 60., 0., 150.);
// RooGaussian roo_sig_gauss2("roo_sig_gauss2", "B+ Mass Signal 2 Gaussian", roo_var_mass, roo_sig_gauss2_mean, roo_sig_gauss2_sigma);
// RooRealVar roo_sig_double_gauss_frac("roo_sig_double_gauss_frac", "B+ Mass Signal Double Gauss Frac", 0.5);
// RooAddPdf roo_sig_double_gauss("roo_sig_double_gauss", "B+ Mass Signal Double Gauss", roo_sig_gauss1, roo_sig_gauss2, roo_sig_double_gauss_frac);
RooRealVar roo_bkg_exp_c("roo_bkg_exp_c", "Background C", -0.000693147, -0.002, 0.);
2023-11-15 16:49:50 +01:00
RooExponential roo_bkg_exp("roo_bkg_exp", "B+ Mass Background Exp", roo_var_mass, roo_bkg_exp_c);
RooRealVar roo_var_mass_nsig("roo_var_mass_nsig", "B+ Mass N Signal", 0., hist->GetEntries());
RooRealVar roo_var_mass_nbkg("roo_var_mass_nbkg", "B+ Mass N Background", 0., hist->GetEntries());
TString pdf_name = "roo_pdf_sig_plus_bkg";
RooAddPdf roo_pdf_sig_plus_bkg(pdf_name, "Sig + Bkg PDF",
2023-11-17 13:18:07 +01:00
RooArgList(roo_sig_gauss, roo_bkg_exp),
2023-11-15 16:49:50 +01:00
RooArgList(roo_var_mass_nsig, roo_var_mass_nbkg));
RooPlot *roo_frame_mass = roo_var_mass.frame();
2023-11-17 13:18:07 +01:00
roohist_bplus_M.plotOn(roo_frame_mass, RooFit::Binning(N_BINS), RooFit::Name(hist_name), RooFit::MarkerColor(15), RooFit::Range("plot_range"));
2023-11-15 16:49:50 +01:00
RooFitResult *fitres = roo_pdf_sig_plus_bkg.fitTo(roohist_bplus_M, RooFit::Save(), RooFit::PrintLevel(1), RooFit::Range("fitting_range"));
roo_pdf_sig_plus_bkg.plotOn(roo_frame_mass, RooFit::VisualizeError(*fitres, 1), RooFit::Range("plot_range"), RooFit::FillColor(kOrange + 1), RooFit::FillStyle(3144));
roo_pdf_sig_plus_bkg.plotOn(roo_frame_mass, RooFit::LineColor(kRed), RooFit::LineStyle(kSolid), RooFit::Range("plot_range"), RooFit::Name(pdf_name));
roo_pdf_sig_plus_bkg.plotOn(roo_frame_mass, RooFit::Components(RooArgSet(roo_bkg_exp)), RooFit::LineColor(kBlue), RooFit::LineStyle(kDashed), RooFit::Range("plot_range"));
// roo_sig_cb.plotOn(roo_frame_mass, RooFit::LineColor(kAlpine), RooFit::LineStyle(kDashed), RooFit::Range("fitting_range"));
// roo_bkg_exp.plotOn(roo_frame_mass, RooFit::LineColor(kOrange), RooFit::LineStyle(kDashed), RooFit::Range("fitting_range"));
2023-11-17 13:18:07 +01:00
roo_sig_gauss.paramOn(roo_frame_mass, RooFit::Layout(0.60, 0.99, 0.90));
2023-11-15 16:49:50 +01:00
roo_frame_mass->getAttText()->SetTextSize(0.027);
double ymax = roo_frame_mass->GetMaximum();
TText *txt_n_sig = new TText(5800., ymax * 0.22, TString::Format("Signal Yield: %.2f +/- %.2f (Paper: %s)", roo_var_mass_nsig.getVal(), roo_var_mass_nsig.getError(), ""));
txt_n_sig->SetTextSize(0.027);
txt_n_sig->SetTextColor(kBlue + 4);
TText *txt_n_bkg = new TText(5800., ymax * 0.18, TString::Format("Background Yield: %.2f +/- %.2f", roo_var_mass_nbkg.getVal(), roo_var_mass_nbkg.getError()));
txt_n_bkg->SetTextSize(0.027);
txt_n_bkg->SetTextColor(kBlue + 4);
roo_frame_mass->addObject(txt_n_sig);
roo_frame_mass->addObject(txt_n_bkg);
return roo_frame_mass;
}