You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

262 lines
8.6 KiB

10 months ago
  1. # flake8: noqaq
  2. import os
  3. import argparse
  4. import ROOT
  5. from ROOT import TMVA, TList, TTree
  6. def res_train_matching_ghost_mlp(
  7. input_file: str = "data/ghost_data_B.root",
  8. tree_name: str = "PrMatchNN.PrMCDebugMatchToolNN/MVAInputAndOutput",
  9. exclude_electrons: bool = False,
  10. only_electrons: bool = True,
  11. residuals: str = "None",
  12. n_train_signal: int = 2e3, # 50e3
  13. n_train_bkg: int = 5e3, # 500e3
  14. n_test_signal: int = 1e3,
  15. n_test_bkg: int = 2e3,
  16. prepare_data: bool = True,
  17. outdir: str = "nn_electron_training",
  18. ):
  19. """Trains an MLP to classify the match between Velo and Seed track.
  20. Args:
  21. input_file (str, optional): Defaults to "data/ghost_data.root".
  22. tree_name (str, optional): Defaults to "PrMatchNN.PrMCDebugMatchToolNN/Tuple".
  23. exclude_electrons (bool, optional): Defaults to False.
  24. only_electrons (bool, optional): Signal only of electrons, but bkg of all particles. Defaults to True.
  25. residuals (bool, optional): Signal only of mlp<0.215. Defaults to False.
  26. n_train_signal (int, optional): Number of true matches to train on. Defaults to 200e3.
  27. n_train_bkg (int, optional): Number of fake matches to train on. Defaults to 200e3.
  28. n_test_signal (int, optional): Number of true matches to test on. Defaults to 75e3.
  29. n_test_bkg (int, optional): Number of fake matches to test on. Defaults to 75e3.
  30. prepare_data (bool, optional): Split data into signal and background file. Defaults to False.
  31. """
  32. if prepare_data and not residuals == "None":
  33. resrdf = ROOT.RDataFrame(residuals, input_file)
  34. rdf = ROOT.RDataFrame(tree_name, input_file)
  35. if exclude_electrons:
  36. rdf_signal = rdf.Filter(
  37. "quality == 1 && mlp<0.215", # -1 elec, 0 ghost, 1 all part wo elec
  38. "Signal is defined as one label (excluding electrons)",
  39. )
  40. rdf_bkg = rdf.Filter(
  41. "quality == 0 && mlp<0.215",
  42. "Ghosts are defined as zero label",
  43. )
  44. resrdf_signal = resrdf.Filter(
  45. "quality == 1", # -1 elec, 0 ghost, 1 all part wo elec
  46. "Signal is defined as one label (excluding electrons)",
  47. )
  48. resrdf_bkg = resrdf.Filter(
  49. "quality == 0",
  50. "Ghosts are defined as zero label",
  51. )
  52. else:
  53. if only_electrons:
  54. rdf_signal = rdf.Filter(
  55. "quality == -1 && mlp<0.215", # electron that is true match but mlp said no match
  56. "Signal is defined as one label (only electrons)",
  57. )
  58. resrdf_signal = resrdf.Filter(
  59. "quality == -1", # electron that is true match but mlp said no match
  60. "Signal is defined as one label (only electrons)",
  61. )
  62. else:
  63. rdf_signal = rdf.Filter(
  64. "abs(quality) > 0 && mlp<0.215",
  65. "Signal is defined as non-zero label",
  66. )
  67. resrdf_signal = resrdf.Filter(
  68. "abs(quality) > 0",
  69. "Signal is defined as non-zero label",
  70. )
  71. rdf_bkg = rdf.Filter(
  72. "quality == 0 && mlp<0.215",
  73. "Ghosts are defined as zero label",
  74. )
  75. resrdf_bkg = resrdf.Filter(
  76. "quality == 0",
  77. "Ghosts are defined as zero label",
  78. )
  79. rdf_signal.Snapshot(
  80. "Signal",
  81. outdir + "/" + input_file.strip(".root") + "_mlp_matching_signal.root",
  82. )
  83. rdf_bkg.Snapshot(
  84. "Bkg",
  85. outdir + "/" + input_file.strip(".root") + "_mlp_matching_bkg.root",
  86. )
  87. resrdf_signal.Snapshot(
  88. "Signal",
  89. outdir + "/" + input_file.strip(".root") + "_res_matching_signal.root",
  90. )
  91. resrdf_bkg.Snapshot(
  92. "Bkg",
  93. outdir + "/" + input_file.strip(".root") + "_res_matching_bkg.root",
  94. )
  95. mlp_signal_file = ROOT.TFile.Open(
  96. outdir + "/" + input_file.strip(".root") + "_mlp_matching_signal.root",
  97. "READ",
  98. )
  99. mlp_signal_tree = mlp_signal_file.Get("Signal")
  100. mlp_bkg_file = ROOT.TFile.Open(
  101. outdir + "/" + input_file.strip(".root") + "_mlp_matching_bkg.root",
  102. "READ",
  103. )
  104. mlp_bkg_tree = mlp_bkg_file.Get("Bkg")
  105. ####
  106. res_signal_file = ROOT.TFile.Open(
  107. outdir + "/" + input_file.strip(".root") + "_res_matching_signal.root",
  108. "READ",
  109. )
  110. res_signal_tree = res_signal_file.Get("Signal")
  111. outputsignalFile = ROOT.TFile(
  112. outdir + "/" + input_file.strip(".root") + "_merged_matching_signal.root",
  113. "RECREATE",
  114. )
  115. signaltreeList = TList()
  116. signaltreeList.Add(mlp_signal_tree)
  117. signaltreeList.Add(res_signal_tree)
  118. mergedsignalTree = TTree.MergeTrees(signaltreeList)
  119. mergedsignalTree.Write()
  120. outputsignalFile.Close()
  121. res_bkg_file = ROOT.TFile.Open(
  122. outdir + "/" + input_file.strip(".root") + "_res_matching_bkg.root",
  123. "READ",
  124. )
  125. res_bkg_tree = res_bkg_file.Get("Bkg")
  126. outputbkgFile = ROOT.TFile(
  127. outdir + "/" + input_file.strip(".root") + "_merged_matching_bkg.root",
  128. "RECREATE",
  129. )
  130. bkgtreeList = TList()
  131. bkgtreeList.Add(mlp_bkg_tree)
  132. bkgtreeList.Add(res_bkg_tree)
  133. mergedbkgTree = TTree.MergeTrees(bkgtreeList)
  134. mergedbkgTree.Write()
  135. outputbkgFile.Close()
  136. #####
  137. signal_file = ROOT.TFile.Open(
  138. outdir + "/" + input_file.strip(".root") + "_merged_matching_signal.root",
  139. "READ",
  140. )
  141. signal_tree = signal_file.Get("Signal")
  142. bkg_file = ROOT.TFile.Open(
  143. outdir + "/" + input_file.strip(".root") + "_merged_matching_bkg.root"
  144. )
  145. bkg_tree = bkg_file.Get("Bkg")
  146. ###
  147. os.chdir(outdir + "/result")
  148. output = ROOT.TFile(
  149. "matching_ghost_mlp_training.root",
  150. "RECREATE",
  151. )
  152. factory = TMVA.Factory(
  153. "TMVAClassification",
  154. output,
  155. "V:!Silent:Color:DrawProgressBar:AnalysisType=Classification",
  156. )
  157. factory.SetVerbose(True)
  158. dataloader = TMVA.DataLoader("MatchNNDataSet")
  159. dataloader.AddVariable("chi2", "F")
  160. dataloader.AddVariable("teta2", "F")
  161. dataloader.AddVariable("distX", "F")
  162. dataloader.AddVariable("distY", "F")
  163. dataloader.AddVariable("dSlope", "F")
  164. dataloader.AddVariable("dSlopeY", "F")
  165. dataloader.AddSignalTree(signal_tree, 1.0)
  166. dataloader.AddBackgroundTree(bkg_tree, 1.0)
  167. # these cuts are also applied in the algorithm
  168. preselectionCuts = ROOT.TCut(
  169. # "chi2<30 && distX<500 && distY<500 && dSlope<2.0 && dSlopeY<0.15", #### ganz raus für elektronen
  170. )
  171. dataloader.PrepareTrainingAndTestTree(
  172. preselectionCuts,
  173. f"SplitMode=random:V:nTrain_Signal={n_train_signal}:nTrain_Background={n_train_bkg}:nTest_Signal={n_test_signal}:nTest_Background={n_test_bkg}",
  174. # normmode default is EqualNumEvents
  175. )
  176. factory.BookMethod(
  177. dataloader,
  178. TMVA.Types.kMLP,
  179. "matching_mlp",
  180. "!H:V:TrainingMethod=BP:NeuronType=ReLU:EstimatorType=CE:VarTransform=Norm:NCycles=700:HiddenLayers=N+2,N:TestRate=50:Sampling=1.0:SamplingImportance=1.0:LearningRate=0.02:DecayRate=0.01:!UseRegulator",
  181. )
  182. factory.TrainAllMethods()
  183. factory.TestAllMethods()
  184. factory.EvaluateAllMethods()
  185. output.Close()
  186. if __name__ == "__main__":
  187. parser = argparse.ArgumentParser()
  188. parser.add_argument(
  189. "--input-file",
  190. type=str,
  191. help="Path to the input file",
  192. required=False,
  193. )
  194. parser.add_argument(
  195. "--exclude_electrons",
  196. action="store_true",
  197. help="Excludes electrons from training.",
  198. required=False,
  199. )
  200. parser.add_argument(
  201. "--only_electrons",
  202. action="store_true",
  203. help="Only electrons for signal training.",
  204. required=False,
  205. )
  206. parser.add_argument(
  207. "--n-train-signal",
  208. type=int,
  209. help="Number of training tracks for signal.",
  210. required=False,
  211. )
  212. parser.add_argument(
  213. "--n-train-bkg",
  214. type=int,
  215. help="Number of training tracks for bkg.",
  216. required=False,
  217. )
  218. parser.add_argument(
  219. "--n-test-signal",
  220. type=int,
  221. help="Number of testing tracks for signal.",
  222. required=False,
  223. )
  224. parser.add_argument(
  225. "--n-test-bkg",
  226. type=int,
  227. help="Number of testing tracks for bkg.",
  228. required=False,
  229. )
  230. args = parser.parse_args()
  231. args_dict = {arg: val for arg, val in vars(args).items() if val is not None}
  232. res_train_matching_ghost_mlp(**args_dict)