You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

189 lines
6.3 KiB

10 months ago
  1. # flake8: noqaq
  2. import os
  3. import argparse
  4. import ROOT
  5. from ROOT import TMVA, TList, TTree, TMath
  6. def train_matching_ghost_mlp(
  7. input_file: str = "data/tracking_losses_ntuple_B.root",
  8. tree_name: str = "PrDebugTrackingLosses.PrDebugTrackingTool/Tuple",
  9. b_input_file: str = "data/ghost_data_B.root",
  10. b_tree_name: str = "PrMatchNN_3e224c41.PrMCDebugMatchToolNN/MVAInputAndOutput",
  11. only_electrons: bool = True,
  12. n_train_signal: int = 2e3, # 50e3
  13. n_train_bkg: int = 5e3, # 500e3
  14. n_test_signal: int = 1e3,
  15. n_test_bkg: int = 2e3,
  16. prepare_data: bool = True,
  17. outdir: str = "nn_trackinglosses_training",
  18. ):
  19. """Trains an MLP to classify the match between Velo and Seed track.
  20. Args:
  21. input_file (str, optional): Defaults to "data/ghost_data.root".
  22. tree_name (str, optional): Defaults to "PrMatchNN.PrMCDebugMatchToolNN/Tuple".
  23. exclude_electrons (bool, optional): Defaults to False.
  24. only_electrons (bool, optional): Signal only of electrons, but bkg of all particles. Defaults to True.
  25. residuals (bool, optional): Signal only of mlp<0.215. Defaults to False.
  26. n_train_signal (int, optional): Number of true matches to train on. Defaults to 200e3.
  27. n_train_bkg (int, optional): Number of fake matches to train on. Defaults to 200e3.
  28. n_test_signal (int, optional): Number of true matches to test on. Defaults to 75e3.
  29. n_test_bkg (int, optional): Number of fake matches to test on. Defaults to 75e3.
  30. prepare_data (bool, optional): Split data into signal and background file. Defaults to False.
  31. """
  32. # vec = ROOT.std.vector("string")(13)
  33. colList = [
  34. "mc_chi2",
  35. "mc_teta2",
  36. "mc_distX",
  37. "mc_distY",
  38. "mc_dSlope",
  39. "mc_dSlopeY",
  40. "mc_quality",
  41. "mc_end_velo_qop",
  42. "mc_end_velo_tx",
  43. "mc_end_velo_ty",
  44. "mc_end_t_qop",
  45. "mc_end_t_tx",
  46. "mc_end_t_ty",
  47. ]
  48. # for i in range(13):
  49. # vec[i] = colList[i]
  50. if prepare_data:
  51. rdf = ROOT.RDataFrame(tree_name, input_file)
  52. rdf_b = ROOT.RDataFrame(b_tree_name, b_input_file)
  53. if only_electrons:
  54. rdf_signal = rdf.Filter(
  55. "mc_quality == -1 && lost == 0 && fromSignal == 1", # electron that is true match but mlp said no match
  56. "Signal is defined as one label (only electrons)",
  57. )
  58. else:
  59. rdf_signal = rdf.Filter(
  60. "lost == 0 && fromSignal == 1",
  61. "Signal is defined as non-zero label",
  62. )
  63. rdf_bkg = rdf_b.Filter(
  64. "mc_quality == 0",
  65. "Ghosts are defined as zero label",
  66. )
  67. rdf_signal.Snapshot(
  68. "Signal",
  69. outdir + "/" + input_file.strip(".root") + "_matching_signal.root",
  70. colList,
  71. )
  72. rdf_bkg.Snapshot(
  73. "Bkg",
  74. outdir + "/" + input_file.strip(".root") + "_matching_bkg.root",
  75. )
  76. signal_file = ROOT.TFile.Open(
  77. outdir + "/" + input_file.strip(".root") + "_matching_signal.root",
  78. "READ",
  79. )
  80. signal_tree = signal_file.Get("Signal")
  81. bkg_file = ROOT.TFile.Open(
  82. outdir + "/" + input_file.strip(".root") + "_matching_bkg.root"
  83. )
  84. bkg_tree = bkg_file.Get("Bkg")
  85. os.chdir(outdir + "/result")
  86. output = ROOT.TFile(
  87. "matching_ghost_mlp_training.root",
  88. "RECREATE",
  89. )
  90. factory = TMVA.Factory(
  91. "TMVAClassification",
  92. output,
  93. "V:!Silent:Color:DrawProgressBar:AnalysisType=Classification",
  94. )
  95. factory.SetVerbose(True)
  96. dataloader = TMVA.DataLoader("MatchNNDataSet")
  97. dataloader.AddVariable("mc_chi2", "F")
  98. dataloader.AddVariable("mc_teta2", "F")
  99. dataloader.AddVariable("mc_distX", "F")
  100. dataloader.AddVariable("mc_distY", "F")
  101. dataloader.AddVariable("mc_dSlope", "F")
  102. dataloader.AddVariable("mc_dSlopeY", "F")
  103. dataloader.AddSignalTree(signal_tree, 1.0)
  104. dataloader.AddBackgroundTree(bkg_tree, 1.0)
  105. # these cuts are also applied in the algorithm
  106. preselectionCuts = ROOT.TCut(
  107. "!TMath::IsNaN(mc_chi2) && !TMath::IsNaN(mc_distX) && !TMath::IsNaN(mc_distY) && !TMath::IsNaN(mc_dSlope) && !TMath::IsNaN(mc_dSlopeY) && !TMath::IsNaN(mc_teta2)",
  108. # "mc_chi2<30 && mc_distX<500 && mc_distY<500 && mc_dSlope<2.0 && mc_dSlopeY<0.15", #### ganz raus für elektronen
  109. )
  110. dataloader.PrepareTrainingAndTestTree(
  111. preselectionCuts,
  112. f"SplitMode=random:V:nTrain_Signal={n_train_signal}:nTrain_Background={n_train_bkg}:nTest_Signal={n_test_signal}:nTest_Background={n_test_bkg}",
  113. # normmode default is EqualNumEvents
  114. )
  115. factory.BookMethod(
  116. dataloader,
  117. TMVA.Types.kMLP,
  118. "matching_mlp",
  119. "!H:V:TrainingMethod=BP:NeuronType=ReLU:EstimatorType=CE:VarTransform=Norm:NCycles=700:HiddenLayers=N+2,N:TestRate=50:Sampling=1.0:SamplingImportance=1.0:LearningRate=0.02:DecayRate=0.01:UseRegulator",
  120. )
  121. factory.TrainAllMethods()
  122. factory.TestAllMethods()
  123. factory.EvaluateAllMethods()
  124. output.Close()
  125. if __name__ == "__main__":
  126. parser = argparse.ArgumentParser()
  127. parser.add_argument(
  128. "--input-file",
  129. type=str,
  130. help="Path to the input file",
  131. required=False,
  132. )
  133. parser.add_argument(
  134. "--exclude_electrons",
  135. action="store_true",
  136. help="Excludes electrons from training.",
  137. required=False,
  138. )
  139. parser.add_argument(
  140. "--only_electrons",
  141. action="store_true",
  142. help="Only electrons for signal training.",
  143. required=False,
  144. )
  145. parser.add_argument(
  146. "--n-train-signal",
  147. type=int,
  148. help="Number of training tracks for signal.",
  149. required=False,
  150. )
  151. parser.add_argument(
  152. "--n-train-bkg",
  153. type=int,
  154. help="Number of training tracks for bkg.",
  155. required=False,
  156. )
  157. parser.add_argument(
  158. "--n-test-signal",
  159. type=int,
  160. help="Number of testing tracks for signal.",
  161. required=False,
  162. )
  163. parser.add_argument(
  164. "--n-test-bkg",
  165. type=int,
  166. help="Number of testing tracks for bkg.",
  167. required=False,
  168. )
  169. args = parser.parse_args()
  170. args_dict = {arg: val for arg, val in vars(args).items() if val is not None}
  171. train_matching_ghost_mlp(**args_dict)