You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

261 lines
10 KiB

10 months ago
  1. import os
  2. import argparse
  3. import ROOT
  4. from ROOT import TMVA
  5. def train_default_forward_ghost_mlp(
  6. input_file: str = "data/ghost_data.root",
  7. tree_name: str = "PrForwardTrackingVelo.PrMCDebugForwardTool/MVAInput",
  8. exclude_electrons: bool = False,
  9. n_train_signal: int = 300e3,
  10. n_train_bkg: int = 300e3,
  11. n_test_signal: int = 50e3,
  12. n_test_bkg: int = 50e3,
  13. prepare_data: bool = False,
  14. ):
  15. """Trains an MLP to classify track candidates from PrForwardTrackingVelo as ghost or Long Track.
  16. Args:
  17. input_file (str, optional): Defaults to "data/ghost_data.root".
  18. tree_name (str, optional): Defaults to "PrForwardTrackingVelo.PrMCDebugForwardTool/MVAInput".
  19. exclude_electrons (bool, optional): Defaults to False.
  20. n_train_signal (int, optional): Number of true tracks to train on. Defaults to 750e3.
  21. n_train_bkg (int, optional): Number of fake tracks to train on. Defaults to 750e3.
  22. n_test_signal (int, optional): Number of true tracks to test on. Defaults to 50e3.
  23. n_test_bkg (int, optional): umber of fake tracks to test on. Defaults to 50e3.
  24. prepare_data (bool, optional): Split data into signal and background file. Defaults to False.
  25. """
  26. if prepare_data:
  27. rdf = ROOT.RDataFrame(tree_name, input_file)
  28. if exclude_electrons:
  29. rdf_signal = rdf.Filter(
  30. "label == 1",
  31. "Signal is defined as one label (excluding electrons)",
  32. )
  33. rdf_bkg = rdf.Filter("label == 0", "Ghosts are defined as zero label")
  34. else:
  35. rdf_signal = rdf.Filter("label > 0", "Signal is defined as non-zero label")
  36. rdf_bkg = rdf.Filter("label == 0", "Ghosts are defined as zero label")
  37. rdf_signal.Snapshot(
  38. "Signal",
  39. input_file.strip(".root") + "_forward_signal.root",
  40. )
  41. rdf_bkg.Snapshot("Bkg", input_file.strip(".root") + "_forward_bkg.root")
  42. signal_file = ROOT.TFile.Open(
  43. input_file.strip(".root") + "_forward_signal.root",
  44. "READ",
  45. )
  46. signal_tree = signal_file.Get("Signal")
  47. bkg_file = ROOT.TFile.Open(input_file.strip(".root") + "_forward_bkg.root")
  48. bkg_tree = bkg_file.Get("Bkg")
  49. os.chdir("neural_net_training/result")
  50. output = ROOT.TFile(
  51. "default_forward_ghost_mlp_training.root",
  52. "RECREATE",
  53. )
  54. factory = TMVA.Factory(
  55. "TMVAClassification",
  56. output,
  57. "V:!Silent:Color:DrawProgressBar:AnalysisType=Classification",
  58. )
  59. factory.SetVerbose(True)
  60. dataloader = TMVA.DataLoader("GhostNNDataSet")
  61. dataloader.AddVariable("redChi2", "F")
  62. dataloader.AddVariable(
  63. "distXMatch := abs((x + ( zMagMatch - 770.0 ) * tx) - (xEndT + ( zMagMatch - 9410.0 ) * txEndT))",
  64. "F",
  65. )
  66. dataloader.AddVariable(
  67. "distY := abs(ySeedMatch - yEndT)",
  68. "F",
  69. )
  70. dataloader.AddVariable("abs(yParam0Final-yParam0Init)", "F")
  71. dataloader.AddVariable("abs(yParam1Final-yParam1Init)", "F")
  72. dataloader.AddVariable("abs(ty)", "F")
  73. dataloader.AddVariable("abs(qop)", "F")
  74. dataloader.AddVariable("abs(tx)", "F")
  75. dataloader.AddVariable("abs(xParam1Final-xParam1Init)", "F")
  76. dataloader.AddSignalTree(signal_tree, 1.0)
  77. dataloader.AddBackgroundTree(bkg_tree, 1.0)
  78. preselectionCuts = ROOT.TCut(
  79. "redChi2 < 8 && abs((x + ( zMagMatch - 770.0 ) * tx) - (xEndT + ( zMagMatch - 9410.0 ) * txEndT)) < 140 && abs(ySeedMatch - yEndT) < 500 && abs(yParam0Final-yParam0Init) < 140 && abs(yParam1Final-yParam1Init) < 0.055",
  80. )
  81. dataloader.PrepareTrainingAndTestTree(
  82. preselectionCuts,
  83. f"NormMode=NumEvents:SplitMode=random:V:nTrain_Signal={n_train_signal}:nTrain_Background={n_train_bkg}:nTest_Signal={n_test_signal}:nTest_Background={n_test_bkg}",
  84. )
  85. factory.BookMethod(
  86. dataloader,
  87. TMVA.Types.kMLP,
  88. "default_forward_ghost_mlp",
  89. "!H:V:TrainingMethod=BP:NeuronType=ReLU:EstimatorType=CE:VarTransform=Norm:NCycles=750:HiddenLayers=N+4,N+2:TestRate=50:Sampling=1.0:SamplingImportance=1.0:LearningRate=0.02:DecayRate=0.005:!UseRegulator",
  90. )
  91. factory.TrainAllMethods()
  92. factory.TestAllMethods()
  93. factory.EvaluateAllMethods()
  94. output.Close()
  95. def train_veloUT_forward_ghost_mlp(
  96. input_file: str = "data/ghost_data.root",
  97. tree_name: str = "PrForwardTracking.PrMCDebugForwardTool/MVAInput",
  98. exclude_electrons: bool = False,
  99. n_train_signal: int = 300e3,
  100. n_train_bkg: int = 300e3,
  101. n_test_signal: int = 50e3,
  102. n_test_bkg: int = 50e3,
  103. prepare_data: bool = False,
  104. ):
  105. """Trains an MLP to classify track candidates from PrForwardTracking as ghost or Long Track.
  106. Args:
  107. input_file (str, optional): Defaults to "data/ghost_data.root".
  108. tree_name (str, optional): Defaults to "PrForwardTracking.PrMCDebugForwardTool/MVAInput".
  109. exclude_electrons (bool, optional): Defaults to False.
  110. n_train_signal (int, optional): Number of true tracks to train on. Defaults to 750e3.
  111. n_train_bkg (int, optional): Number of fake tracks to train on. Defaults to 750e3.
  112. n_test_signal (int, optional): Number of true tracks to test on. Defaults to 50e3.
  113. n_test_bkg (int, optional): umber of fake tracks to test on. Defaults to 50e3.
  114. prepare_data (bool, optional): Split data into signal and background file. Defaults to False.
  115. """
  116. if prepare_data:
  117. rdf = ROOT.RDataFrame(tree_name, input_file)
  118. if exclude_electrons:
  119. rdf_signal = rdf.Filter(
  120. "label == 1",
  121. "Signal is defined as one label (excluding electrons)",
  122. )
  123. rdf_bkg = rdf.Filter("label == 0", "Ghosts are defined as zero label")
  124. else:
  125. rdf_signal = rdf.Filter("label > 0", "Signal is defined as non-zero label")
  126. rdf_bkg = rdf.Filter("label == 0", "Ghosts are defined as zero label")
  127. rdf_signal.Snapshot(
  128. "Signal",
  129. input_file.strip(".root") + "_forward_velout_signal.root",
  130. )
  131. rdf_bkg.Snapshot("Bkg", input_file.strip(".root") + "_forward_velout_bkg.root")
  132. signal_file = ROOT.TFile.Open(
  133. input_file.strip(".root") + "_forward_velout_signal.root",
  134. "READ",
  135. )
  136. signal_tree = signal_file.Get("Signal")
  137. bkg_file = ROOT.TFile.Open(input_file.strip(".root") + "_forward_velout_bkg.root")
  138. bkg_tree = bkg_file.Get("Bkg")
  139. os.chdir("neural_net_training/result")
  140. output = ROOT.TFile(
  141. "veloUT_forward_ghost_mlp_training.root",
  142. "RECREATE",
  143. )
  144. factory = TMVA.Factory(
  145. "TMVAClassification",
  146. output,
  147. "V:!Silent:Color:DrawProgressBar:AnalysisType=Classification",
  148. )
  149. factory.SetVerbose(True)
  150. dataloader = TMVA.DataLoader("GhostNNDataSet")
  151. dataloader.AddVariable("dMom := log(abs((1.0/qop) - (1.0/qopUT) ))", "F")
  152. dataloader.AddVariable("redChi2", "F")
  153. dataloader.AddVariable(
  154. "distXMatch := abs((x + ( zMagMatch - 770.0 ) * tx) - (xEndT + ( zMagMatch - 9410.0 ) * txEndT))",
  155. "F",
  156. )
  157. dataloader.AddVariable(
  158. "distY := abs(ySeedMatch - yEndT)",
  159. "F",
  160. )
  161. dataloader.AddVariable("abs(yParam0Final-yParam0Init)", "F")
  162. dataloader.AddVariable("abs(yParam1Final-yParam1Init)", "F")
  163. dataloader.AddVariable("abs(ty)", "F")
  164. dataloader.AddVariable("abs(qop)", "F")
  165. dataloader.AddVariable("abs(tx)", "F")
  166. dataloader.AddVariable("abs(xParam1Final-xParam1Init)", "F")
  167. dataloader.AddSignalTree(signal_tree, 1.0)
  168. dataloader.AddBackgroundTree(bkg_tree, 1.0)
  169. preselectionCuts = ROOT.TCut(
  170. "redChi2 < 8 && abs((x + ( zMagMatch - 770.0 ) * tx) - (xEndT + ( zMagMatch - 9410.0 ) * txEndT)) < 140 && abs(ySeedMatch - yEndT) < 500 && abs(yParam0Final-yParam0Init) < 140 && abs(yParam1Final-yParam1Init) < 0.055",
  171. )
  172. dataloader.PrepareTrainingAndTestTree(
  173. preselectionCuts,
  174. f"NormMode=NumEvents:SplitMode=random:V:nTrain_Signal={n_train_signal}:nTrain_Background={n_train_bkg}:nTest_Signal={n_test_signal}:nTest_Background={n_test_bkg}",
  175. )
  176. factory.BookMethod(
  177. dataloader,
  178. TMVA.Types.kMLP,
  179. "veloUT_forward_ghost_mlp",
  180. "!H:V:TrainingMethod=BP:NeuronType=ReLU:EstimatorType=CE:VarTransform=Norm:NCycles=550:HiddenLayers=N+2:TestRate=50:Sampling=1.0:SamplingImportance=1.0:LearningRate=0.02:DecayRate=0.005:!UseRegulator",
  181. )
  182. factory.TrainAllMethods()
  183. factory.TestAllMethods()
  184. factory.EvaluateAllMethods()
  185. output.Close()
  186. if __name__ == "__main__":
  187. parser = argparse.ArgumentParser()
  188. parser.add_argument(
  189. "--input-file",
  190. type=str,
  191. help="Path to the input file",
  192. required=False,
  193. )
  194. parser.add_argument(
  195. "--exclude_electrons",
  196. action="store_true",
  197. help="Excludes electrons from training.",
  198. required=False,
  199. )
  200. parser.add_argument(
  201. "--n-train-signal",
  202. type=int,
  203. help="Number of training tracks for signal.",
  204. required=False,
  205. )
  206. parser.add_argument(
  207. "--n-train-bkg",
  208. type=int,
  209. help="Number of training tracks for bkg.",
  210. required=False,
  211. )
  212. parser.add_argument(
  213. "--n-test-signal",
  214. type=int,
  215. help="Number of testing tracks for signal.",
  216. required=False,
  217. )
  218. parser.add_argument(
  219. "--n-test-bkg",
  220. type=int,
  221. help="Number of testing tracks for bkg.",
  222. required=False,
  223. )
  224. parser.add_argument(
  225. "--veloUT",
  226. action="store_true",
  227. help="Toggle whether the NN for upstream tracks input is trained.",
  228. )
  229. parser.add_argument(
  230. "--all",
  231. action="store_true",
  232. help="Toggle whether both NNs are trained, for VELO and VeloUT input.",
  233. )
  234. args = parser.parse_args()
  235. args_dict = {
  236. arg: val
  237. for arg, val in vars(args).items()
  238. if val is not None and arg not in ["veloUT", "all"]
  239. }
  240. if args.all:
  241. train_default_forward_ghost_mlp(**args_dict)
  242. train_veloUT_forward_ghost_mlp(**args_dict)
  243. elif args.veloUT:
  244. train_veloUT_forward_ghost_mlp(**args_dict)
  245. else:
  246. train_default_forward_ghost_mlp(**args_dict)