You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

224 lines
7.4 KiB

10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
7 months ago
10 months ago
7 months ago
10 months ago
7 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
  1. # flake8: noqaq
  2. import os
  3. import argparse
  4. import ROOT
  5. from ROOT import TMVA, TList, TTree
  6. def train_matching_ghost_mlp(
  7. input_file: str = "data/ghost_data_B.root",
  8. tree_name: str = "PrMatchNN.PrMCDebugMatchToolNN/MVAInputAndOutput",
  9. exclude_electrons: bool = False,
  10. only_electrons: bool = True,
  11. filter_velos: bool = False,
  12. filter_seeds: bool = False,
  13. n_train_signal: int = 50e3,
  14. n_train_bkg: int = 50e3,
  15. n_test_signal: int = 10e3,
  16. n_test_bkg: int = 20e3,
  17. prepare_data: bool = True,
  18. outdir: str = "nn_electron_training",
  19. ):
  20. """Trains an MLP to classify the match between Velo and Seed track.
  21. Args:
  22. input_file (str, optional): Defaults to "data/ghost_data.root".
  23. tree_name (str, optional): Defaults to "PrMatchNN.PrMCDebugMatchToolNN/Tuple".
  24. exclude_electrons (bool, optional): Defaults to False.
  25. only_electrons (bool, optional): Signal only of electrons, but bkg of all particles. Defaults to True.
  26. filter_velos (bool, optional): Background only electron VELO tracks. Defaults to False.
  27. filter_seeds (bool, optional): Background only electron T tracks. Defaults to False.
  28. n_train_signal (int, optional): Number of true matches to train on. Defaults to 200e3.
  29. n_train_bkg (int, optional): Number of fake matches to train on. Defaults to 200e3.
  30. n_test_signal (int, optional): Number of true matches to test on. Defaults to 75e3.
  31. n_test_bkg (int, optional): Number of fake matches to test on. Defaults to 75e3.
  32. prepare_data (bool, optional): Split data into signal and background file. Defaults to True.
  33. outdir (str, optional): specify the output directory path.
  34. """
  35. if prepare_data:
  36. rdf = ROOT.RDataFrame(tree_name, input_file)
  37. if exclude_electrons:
  38. print("signal data: exclude electrons.")
  39. rdf_signal = rdf.Filter(
  40. "quality == 1", # -1 elec, 0 ghost, 1 all part wo elec
  41. "Signal is defined as one label (excluding electrons)",
  42. )
  43. print("background data: default ghosts: quality == 0")
  44. rdf_bkg = rdf.Filter(
  45. "quality == 0",
  46. "Ghosts are defined as zero label",
  47. )
  48. else:
  49. if only_electrons:
  50. print("signal data: only electrons.")
  51. rdf_signal = rdf.Filter(
  52. "quality == -1", # electron that is true match
  53. "Signal is defined as negative one label (only electrons)",
  54. )
  55. else:
  56. print("signal data: all particles.")
  57. rdf_signal = rdf.Filter(
  58. "abs(quality) > 0",
  59. "Signal is defined as non-zero label",
  60. )
  61. bkg_selection = "quality == 0"
  62. if filter_velos:
  63. bkg_selection += " && velo_isElectron == 1"
  64. if filter_seeds:
  65. bkg_selection += " && scifi_isElectron == 1"
  66. print("background data: selection cuts = " + bkg_selection)
  67. rdf_bkg = rdf.Filter(
  68. bkg_selection,
  69. "Ghosts are defined as zero label",
  70. )
  71. rdf_signal.Snapshot(
  72. "Signal",
  73. outdir + "/" + input_file.strip(".root") + "_matching_signal.root",
  74. )
  75. rdf_bkg.Snapshot(
  76. "Bkg",
  77. outdir + "/" + input_file.strip(".root") + "_matching_bkg.root",
  78. )
  79. signal_file = ROOT.TFile.Open(
  80. outdir + "/" + input_file.strip(".root") + "_matching_signal.root",
  81. "READ",
  82. )
  83. signal_tree = signal_file.Get("Signal")
  84. bkg_file = ROOT.TFile.Open(
  85. outdir + "/" + input_file.strip(".root") + "_matching_bkg.root"
  86. )
  87. bkg_tree = bkg_file.Get("Bkg")
  88. os.chdir(outdir + "/result")
  89. output = ROOT.TFile(
  90. "matching_ghost_mlp_training.root",
  91. "RECREATE",
  92. )
  93. factory = TMVA.Factory(
  94. "TMVAClassification",
  95. output,
  96. "V:!Silent:Color:DrawProgressBar:AnalysisType=Classification",
  97. )
  98. factory.SetVerbose(True)
  99. dataloader = TMVA.DataLoader("MatchNNDataSet")
  100. dataloader.AddVariable("chi2", "F")
  101. dataloader.AddVariable("teta2", "F")
  102. dataloader.AddVariable("distX", "F")
  103. dataloader.AddVariable("distY", "F")
  104. dataloader.AddVariable("dSlope", "F")
  105. dataloader.AddVariable("dSlopeY", "F")
  106. dataloader.AddVariable("zMag", "F")
  107. # dataloader.AddVariable("eta", "F")
  108. # dataloader.AddVariable("dEta", "F")
  109. dataloader.AddSignalTree(signal_tree, 1.0)
  110. dataloader.AddBackgroundTree(bkg_tree, 1.0)
  111. # these cuts are also applied in the algorithm
  112. preselectionCuts = ROOT.TCut(
  113. "chi2<15 && distX<250 && distY<250 && dSlope<1.5 && dSlopeY<0.15 && zMag<6000 && zMag>4500",
  114. )
  115. dataloader.PrepareTrainingAndTestTree(
  116. preselectionCuts,
  117. f"SplitMode=random:V:nTrain_Signal={n_train_signal}:nTrain_Background={n_train_bkg}:nTest_Signal={n_test_signal}:nTest_Background={n_test_bkg}",
  118. # normmode default is EqualNumEvents
  119. )
  120. factory.BookMethod(
  121. dataloader,
  122. TMVA.Types.kMLP,
  123. "matching_mlp",
  124. "!H:V:TrainingMethod=BP:NeuronType=ReLU:EstimatorType=CE:VarTransform=Norm:NCycles=700:HiddenLayers=N+4,N+2:TestRate=50:Sampling=1.0:SamplingImportance=1.0:LearningRate=0.02:DecayRate=0.01:!UseRegulator",
  125. )
  126. factory.TrainAllMethods()
  127. factory.TestAllMethods()
  128. factory.EvaluateAllMethods()
  129. output.Close()
  130. if __name__ == "__main__":
  131. parser = argparse.ArgumentParser()
  132. parser.add_argument(
  133. "--input-file",
  134. type=str,
  135. help="Path to the input file.",
  136. required=False,
  137. )
  138. parser.add_argument(
  139. "--tree-name",
  140. type=str,
  141. help="Path to the tree.",
  142. required=False,
  143. )
  144. parser.add_argument(
  145. "--exclude-electrons",
  146. action="store_true",
  147. help="Excludes electrons from training.",
  148. required=False,
  149. )
  150. parser.add_argument(
  151. "--only-electrons",
  152. action="store_true",
  153. help="Only electrons for signal training.",
  154. required=False,
  155. )
  156. parser.add_argument(
  157. "--filter-velos",
  158. action="store_true",
  159. help="Only background with electron VELO tracks.",
  160. required=False,
  161. )
  162. parser.add_argument(
  163. "--filter-seeds",
  164. action="store_true",
  165. help="Only background with electron T tracks.",
  166. required=False,
  167. )
  168. parser.add_argument(
  169. "--n-train-signal",
  170. type=int,
  171. help="Number of training tracks for signal.",
  172. required=False,
  173. )
  174. parser.add_argument(
  175. "--n-train-bkg",
  176. type=int,
  177. help="Number of training tracks for bkg.",
  178. required=False,
  179. )
  180. parser.add_argument(
  181. "--n-test-signal",
  182. type=int,
  183. help="Number of testing tracks for signal.",
  184. required=False,
  185. )
  186. parser.add_argument(
  187. "--n-test-bkg",
  188. type=int,
  189. help="Number of testing tracks for bkg.",
  190. required=False,
  191. )
  192. parser.add_argument(
  193. "--prepare-data",
  194. action="store_true",
  195. help="Create signal and background samples.",
  196. required=False,
  197. )
  198. parser.add_argument(
  199. "--outdir",
  200. type=str,
  201. help="Path to the output directory.",
  202. required=False,
  203. )
  204. args = parser.parse_args()
  205. args_dict = {arg: val for arg, val in vars(args).items() if val is not None}
  206. train_matching_ghost_mlp(**args_dict)