You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

177 lines
5.9 KiB

10 months ago
  1. # flake8: noqaq
  2. import os
  3. import argparse
  4. import ROOT
  5. from ROOT import TMVA, TList, TTree
  6. def train_matching_ghost_mlp(
  7. input_file: str = "data/ghost_data_B.root",
  8. tree_name: str = "PrMatchNN.PrMCDebugMatchToolNN/MVAInputAndOutput",
  9. exclude_electrons: bool = False,
  10. only_electrons: bool = True,
  11. n_train_signal: int = 20e3, # 50e3
  12. n_train_bkg: int = 50e3, # 500e3
  13. n_test_signal: int = 10e3,
  14. n_test_bkg: int = 20e3,
  15. prepare_data: bool = True,
  16. outdir: str = "nn_electron_training",
  17. ):
  18. """Trains an MLP to classify the match between Velo and Seed track.
  19. Args:
  20. input_file (str, optional): Defaults to "data/ghost_data.root".
  21. tree_name (str, optional): Defaults to "PrMatchNN.PrMCDebugMatchToolNN/Tuple".
  22. exclude_electrons (bool, optional): Defaults to False.
  23. only_electrons (bool, optional): Signal only of electrons, but bkg of all particles. Defaults to True.
  24. n_train_signal (int, optional): Number of true matches to train on. Defaults to 200e3.
  25. n_train_bkg (int, optional): Number of fake matches to train on. Defaults to 200e3.
  26. n_test_signal (int, optional): Number of true matches to test on. Defaults to 75e3.
  27. n_test_bkg (int, optional): Number of fake matches to test on. Defaults to 75e3.
  28. prepare_data (bool, optional): Split data into signal and background file. Defaults to False.
  29. """
  30. if prepare_data:
  31. rdf = ROOT.RDataFrame(tree_name, input_file)
  32. if exclude_electrons:
  33. rdf_signal = rdf.Filter(
  34. "mc_quality == 1", # -1 elec, 0 ghost, 1 all part wo elec
  35. "Signal is defined as one label (excluding electrons)",
  36. )
  37. rdf_bkg = rdf.Filter(
  38. "mc_quality == 0",
  39. "Ghosts are defined as zero label",
  40. )
  41. else:
  42. if only_electrons:
  43. rdf_signal = rdf.Filter(
  44. "mc_quality == -1", # electron that is true match but mlp said no match
  45. "Signal is defined as negative one label (only electrons)",
  46. )
  47. else:
  48. rdf_signal = rdf.Filter(
  49. "abs(mc_quality) > 0",
  50. "Signal is defined as non-zero label",
  51. )
  52. rdf_bkg = rdf.Filter(
  53. "mc_quality == 0",
  54. "Ghosts are defined as zero label",
  55. )
  56. rdf_signal.Snapshot(
  57. "Signal",
  58. outdir + "/" + input_file.strip(".root") + "_matching_signal.root",
  59. )
  60. rdf_bkg.Snapshot(
  61. "Bkg",
  62. outdir + "/" + input_file.strip(".root") + "_matching_bkg.root",
  63. )
  64. signal_file = ROOT.TFile.Open(
  65. outdir + "/" + input_file.strip(".root") + "_matching_signal.root",
  66. "READ",
  67. )
  68. signal_tree = signal_file.Get("Signal")
  69. bkg_file = ROOT.TFile.Open(
  70. outdir + "/" + input_file.strip(".root") + "_matching_bkg.root"
  71. )
  72. bkg_tree = bkg_file.Get("Bkg")
  73. os.chdir(outdir + "/result")
  74. output = ROOT.TFile(
  75. "matching_ghost_mlp_training.root",
  76. "RECREATE",
  77. )
  78. factory = TMVA.Factory(
  79. "TMVAClassification",
  80. output,
  81. "V:!Silent:Color:DrawProgressBar:AnalysisType=Classification",
  82. )
  83. factory.SetVerbose(True)
  84. dataloader = TMVA.DataLoader("MatchNNDataSet")
  85. dataloader.AddVariable("mc_chi2", "F")
  86. dataloader.AddVariable("mc_teta2", "F")
  87. dataloader.AddVariable("mc_distX", "F")
  88. dataloader.AddVariable("mc_distY", "F")
  89. dataloader.AddVariable("mc_dSlope", "F")
  90. dataloader.AddVariable("mc_dSlopeY", "F")
  91. dataloader.AddVariable("mc_zMag", "F")
  92. dataloader.AddSignalTree(signal_tree, 1.0)
  93. dataloader.AddBackgroundTree(bkg_tree, 1.0)
  94. # these cuts are also applied in the algorithm
  95. preselectionCuts = ROOT.TCut(
  96. # "chi2<30 && distX<500 && distY<500 && dSlope<2.0 && dSlopeY<0.15", #### ganz raus für elektronen
  97. )
  98. dataloader.PrepareTrainingAndTestTree(
  99. preselectionCuts,
  100. f"SplitMode=random:V:nTrain_Signal={n_train_signal}:nTrain_Background={n_train_bkg}:nTest_Signal={n_test_signal}:nTest_Background={n_test_bkg}",
  101. # normmode default is EqualNumEvents
  102. )
  103. factory.BookMethod(
  104. dataloader,
  105. TMVA.Types.kMLP,
  106. "matching_mlp",
  107. "!H:V:TrainingMethod=BP:NeuronType=ReLU:EstimatorType=CE:VarTransform=Norm:NCycles=700:HiddenLayers=N+2,N:TestRate=50:Sampling=1.0:SamplingImportance=1.0:LearningRate=0.02:DecayRate=0.01:!UseRegulator",
  108. )
  109. factory.TrainAllMethods()
  110. factory.TestAllMethods()
  111. factory.EvaluateAllMethods()
  112. output.Close()
  113. if __name__ == "__main__":
  114. parser = argparse.ArgumentParser()
  115. parser.add_argument(
  116. "--input-file",
  117. type=str,
  118. help="Path to the input file",
  119. required=False,
  120. )
  121. parser.add_argument(
  122. "--exclude_electrons",
  123. action="store_true",
  124. help="Excludes electrons from training.",
  125. required=False,
  126. )
  127. parser.add_argument(
  128. "--only_electrons",
  129. action="store_true",
  130. help="Only electrons for signal training.",
  131. required=False,
  132. )
  133. parser.add_argument(
  134. "--n-train-signal",
  135. type=int,
  136. help="Number of training tracks for signal.",
  137. required=False,
  138. )
  139. parser.add_argument(
  140. "--n-train-bkg",
  141. type=int,
  142. help="Number of training tracks for bkg.",
  143. required=False,
  144. )
  145. parser.add_argument(
  146. "--n-test-signal",
  147. type=int,
  148. help="Number of testing tracks for signal.",
  149. required=False,
  150. )
  151. parser.add_argument(
  152. "--n-test-bkg",
  153. type=int,
  154. help="Number of testing tracks for bkg.",
  155. required=False,
  156. )
  157. args = parser.parse_args()
  158. args_dict = {arg: val for arg, val in vars(args).items() if val is not None}
  159. train_matching_ghost_mlp(**args_dict)