ROOT Analysis for the Inclusive Detachted Dilepton Trigger Lines
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

268 lines
6.7 KiB

6 months ago
6 months ago
  1. #ifndef BDT_CLASSIFICATION
  2. #define BDT_CLASSIFICATION
  3. #include <string>
  4. #include <iostream>
  5. #include <cmath>
  6. #include <algorithm>
  7. #include <filesystem>
  8. #include <string_view>
  9. #include "RtypesCore.h"
  10. enum TVT
  11. {
  12. Double,
  13. Float,
  14. Int
  15. };
  16. class TV
  17. {
  18. private:
  19. std::string data_name;
  20. std::string mc_name;
  21. std::string train_name;
  22. TVT type;
  23. Double_t mc_double_value;
  24. Float_t mc_float_value;
  25. Double_t data_double_value;
  26. Float_t data_float_value;
  27. TV(std::string data_name, std::string mc_name, std::string train_name, TVT type)
  28. : data_name{data_name}, mc_name{mc_name}, train_name{train_name}, type{type}
  29. {
  30. }
  31. public:
  32. static TV *Float(std::string data_name, std::string mc_name, std::string train_name)
  33. {
  34. return new TV(data_name, mc_name, train_name, TVT::Float);
  35. }
  36. static TV *Float(std::string data_name, std::string mc_name)
  37. {
  38. return new TV(data_name, mc_name, data_name, TVT::Float);
  39. }
  40. static TV *Float(std::string data_name)
  41. {
  42. return new TV(data_name, data_name, data_name, TVT::Float);
  43. }
  44. static TV *Double(std::string data_name, std::string mc_name, std::string train_name)
  45. {
  46. return new TV(data_name, mc_name, train_name, TVT::Double);
  47. }
  48. static TV *Double(std::string data_name, std::string mc_name)
  49. {
  50. return new TV(data_name, mc_name, data_name, TVT::Double);
  51. }
  52. static TV *Double(std::string data_name)
  53. {
  54. return new TV(data_name, data_name, data_name, TVT::Double);
  55. }
  56. const char *GetDataName()
  57. {
  58. return data_name.c_str();
  59. }
  60. const char *GetMCName()
  61. {
  62. return mc_name.c_str();
  63. }
  64. const char *GetTrainName()
  65. {
  66. return train_name.c_str();
  67. }
  68. Double_t *GetMCDoubleRef()
  69. {
  70. return &mc_double_value;
  71. }
  72. Float_t *GetMCFloatRef()
  73. {
  74. return &mc_float_value;
  75. }
  76. Double_t *GetDataDoubleRef()
  77. {
  78. return &data_double_value;
  79. }
  80. Float_t *GetDataFloatRef()
  81. {
  82. return &data_float_value;
  83. }
  84. Double_t GetDataDouble()
  85. {
  86. return data_double_value;
  87. }
  88. Float_t GetDataFloat()
  89. {
  90. return data_float_value;
  91. }
  92. void PrintDataValue(int entry)
  93. {
  94. std::cout << data_name << " (" << entry << "): ";
  95. if (IsDouble())
  96. {
  97. std::cout << data_double_value;
  98. }
  99. else if (IsFloat())
  100. {
  101. std::cout << data_float_value;
  102. }
  103. std::cout << std::endl;
  104. }
  105. void PrintMCValue(int entry)
  106. {
  107. std::cout << mc_name << " (" << entry << "): ";
  108. if (IsDouble())
  109. {
  110. std::cout << mc_double_value;
  111. }
  112. else if (IsFloat())
  113. {
  114. std::cout << mc_float_value;
  115. }
  116. std::cout << std::endl;
  117. }
  118. bool IsDataFinite()
  119. {
  120. if (IsDouble())
  121. {
  122. return std::isfinite(data_double_value);
  123. }
  124. else if (IsFloat())
  125. {
  126. return std::isfinite(data_float_value);
  127. }
  128. return false;
  129. }
  130. bool IsMCFinite()
  131. {
  132. if (IsDouble())
  133. {
  134. return std::isfinite(mc_double_value);
  135. }
  136. else if (IsFloat())
  137. {
  138. return std::isfinite(mc_float_value);
  139. }
  140. return false;
  141. }
  142. bool IsDouble()
  143. {
  144. return type == TVT::Double;
  145. }
  146. bool IsFloat()
  147. {
  148. return type == TVT::Float;
  149. }
  150. };
  151. void ConnectVarsToData(std::vector<TV *> vars, TChain *data_chain, TChain *mc_chain, TTree *sig_tree, TTree *bkg_tree)
  152. {
  153. for (size_t i = 0; i < vars.size(); i++)
  154. {
  155. if (vars[i]->IsDouble())
  156. {
  157. data_chain->SetBranchAddress(vars[i]->GetDataName(), vars[i]->GetDataDoubleRef());
  158. mc_chain->SetBranchAddress(vars[i]->GetMCName(), vars[i]->GetMCDoubleRef());
  159. sig_tree->Branch(vars[i]->GetTrainName(), vars[i]->GetMCDoubleRef(), TString::Format("%s/D", vars[i]->GetTrainName()));
  160. bkg_tree->Branch(vars[i]->GetTrainName(), vars[i]->GetDataDoubleRef(), TString::Format("%s/D", vars[i]->GetTrainName()));
  161. }
  162. else if (vars[i]->IsFloat())
  163. {
  164. data_chain->SetBranchAddress(vars[i]->GetDataName(), vars[i]->GetDataFloatRef());
  165. mc_chain->SetBranchAddress(vars[i]->GetMCName(), vars[i]->GetMCFloatRef());
  166. sig_tree->Branch(vars[i]->GetTrainName(), vars[i]->GetMCFloatRef(), TString::Format("%s/F", vars[i]->GetTrainName()));
  167. bkg_tree->Branch(vars[i]->GetTrainName(), vars[i]->GetDataFloatRef(), TString::Format("%s/F", vars[i]->GetTrainName()));
  168. }
  169. }
  170. }
  171. void TrainBDT(std::vector<TV *> vars, const char* unique_id, TTree *sig_tree, TTree *bkg_tree)
  172. {
  173. TString outfile_name = TString::Format("%s_tmva_out.root", unique_id);
  174. TFile *output_file = TFile::Open(outfile_name, "RECREATE");
  175. TString factory_options("V:!Silent:Color:DrawProgressBar:Transformations=I;D;P;G,D:AnalysisType=Auto");
  176. TMVA::Factory *factory = new TMVA::Factory(TString::Format("%s_factory", unique_id), output_file, factory_options);
  177. TMVA::DataLoader *data_loader = new TMVA::DataLoader(TString::Format("%s_dataloader", unique_id));
  178. for (int i = 0; i < vars.size(); i++)
  179. {
  180. std::cout << "@TMVA: Adding Branch: " << vars[i]->GetTrainName() << std::endl;
  181. if (vars[i]->IsDouble())
  182. {
  183. data_loader->AddVariable(vars[i]->GetTrainName(), 'D');
  184. }
  185. else if (vars[i]->IsFloat())
  186. {
  187. data_loader->AddVariable(vars[i]->GetTrainName(), 'F');
  188. }
  189. }
  190. Double_t signal_weight = 1.0, background_weight = 1.0;
  191. data_loader->AddSignalTree(sig_tree, signal_weight);
  192. data_loader->AddBackgroundTree(bkg_tree, background_weight);
  193. data_loader->PrepareTrainingAndTestTree("", "", "nTrain_Signal=0:nTrain_Background=0:SplitMode=Random:NormMode=NumEvents:V");
  194. factory->BookMethod(data_loader, TMVA::Types::kBDT, "BDT", "!H:!V:NTrees=400:MinNodeSize=2.5%:CreateMVAPdfs:MaxDepth=3:BoostType=AdaBoost:AdaBoostBeta=0.5:UseBaggedBoost:BaggedSampleFraction=0.5:SeparationType=GiniIndex:nCuts=20");
  195. factory->TrainAllMethods();
  196. factory->TestAllMethods();
  197. factory->EvaluateAllMethods();
  198. output_file->Close();
  199. }
  200. TMVA::Reader* SetupReader(std::vector<TV *> vars, Float_t* train_vars, const char* unique_id) {
  201. TMVA::Reader *reader = new TMVA::Reader("!Color:!Silent");
  202. for (size_t i = 0; i < vars.size(); i++)
  203. {
  204. reader->AddVariable(vars[i]->GetTrainName(), &train_vars[i]);
  205. }
  206. reader->BookMVA("BDT", TString::Format("./%s_dataloader/weights/%s_factory_BDT.weights.xml", unique_id, unique_id));
  207. return reader;
  208. }
  209. void DrawBDTProbs(TH1D *histogram, const double cut_value, const char *folder)
  210. {
  211. std::filesystem::create_directory(TString::Format("output_files/analysis/%s", folder).Data());
  212. TString name = TString::Format("%s_canvas", histogram->GetName());
  213. TCanvas *c = new TCanvas(name, histogram->GetName(), 0, 0, 800, 600);
  214. histogram->SetStats(0);
  215. histogram->Draw();
  216. TLine* line = new TLine(cut_value, 0, cut_value, histogram->GetMaximum());
  217. line->SetLineColor(kRed);
  218. line->SetLineStyle(kDashed);
  219. line->Draw();
  220. c->Draw();
  221. c->SaveAs(TString::Format("output_files/analysis/%s/%s.pdf", folder, name.Data()).Data());
  222. }
  223. #endif