You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

177 lines
5.8 KiB

10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
  1. # flake8: noqaq
  2. import os
  3. import argparse
  4. import ROOT
  5. from ROOT import TMVA, TList, TTree
  6. def train_matching_ghost_mlp(
  7. input_file: str = "data/ghost_data_B.root",
  8. tree_name: str = "PrMatchNN.PrMCDebugMatchToolNN/MVAInputAndOutput",
  9. exclude_electrons: bool = False,
  10. only_electrons: bool = True,
  11. n_train_signal: int = 50e3, # 50e3
  12. n_train_bkg: int = 50e3, # 500e3
  13. n_test_signal: int = 10e3,
  14. n_test_bkg: int = 20e3,
  15. prepare_data: bool = True,
  16. outdir: str = "nn_electron_training",
  17. ):
  18. """Trains an MLP to classify the match between Velo and Seed track.
  19. Args:
  20. input_file (str, optional): Defaults to "data/ghost_data.root".
  21. tree_name (str, optional): Defaults to "PrMatchNN.PrMCDebugMatchToolNN/Tuple".
  22. exclude_electrons (bool, optional): Defaults to False.
  23. only_electrons (bool, optional): Signal only of electrons, but bkg of all particles. Defaults to True.
  24. n_train_signal (int, optional): Number of true matches to train on. Defaults to 200e3.
  25. n_train_bkg (int, optional): Number of fake matches to train on. Defaults to 200e3.
  26. n_test_signal (int, optional): Number of true matches to test on. Defaults to 75e3.
  27. n_test_bkg (int, optional): Number of fake matches to test on. Defaults to 75e3.
  28. prepare_data (bool, optional): Split data into signal and background file. Defaults to False.
  29. """
  30. if prepare_data:
  31. rdf = ROOT.RDataFrame(tree_name, input_file)
  32. if exclude_electrons:
  33. rdf_signal = rdf.Filter(
  34. "quality == 1", # -1 elec, 0 ghost, 1 all part wo elec
  35. "Signal is defined as one label (excluding electrons)",
  36. )
  37. rdf_bkg = rdf.Filter(
  38. "quality == 0",
  39. "Ghosts are defined as zero label",
  40. )
  41. else:
  42. if only_electrons:
  43. rdf_signal = rdf.Filter(
  44. "quality == -1", # electron that is true match
  45. "Signal is defined as negative one label (only electrons)",
  46. )
  47. else:
  48. rdf_signal = rdf.Filter(
  49. "abs(quality) > 0",
  50. "Signal is defined as non-zero label",
  51. )
  52. rdf_bkg = rdf.Filter(
  53. "quality == 0",
  54. "Ghosts are defined as zero label",
  55. )
  56. rdf_signal.Snapshot(
  57. "Signal",
  58. outdir + "/" + input_file.strip(".root") + "_matching_signal.root",
  59. )
  60. rdf_bkg.Snapshot(
  61. "Bkg",
  62. outdir + "/" + input_file.strip(".root") + "_matching_bkg.root",
  63. )
  64. signal_file = ROOT.TFile.Open(
  65. outdir + "/" + input_file.strip(".root") + "_matching_signal.root",
  66. "READ",
  67. )
  68. signal_tree = signal_file.Get("Signal")
  69. bkg_file = ROOT.TFile.Open(
  70. outdir + "/" + input_file.strip(".root") + "_matching_bkg.root"
  71. )
  72. bkg_tree = bkg_file.Get("Bkg")
  73. os.chdir(outdir + "/result")
  74. output = ROOT.TFile(
  75. "matching_ghost_mlp_training.root",
  76. "RECREATE",
  77. )
  78. factory = TMVA.Factory(
  79. "TMVAClassification",
  80. output,
  81. "V:!Silent:Color:DrawProgressBar:AnalysisType=Classification",
  82. )
  83. factory.SetVerbose(True)
  84. dataloader = TMVA.DataLoader("MatchNNDataSet")
  85. dataloader.AddVariable("chi2", "F")
  86. dataloader.AddVariable("teta2", "F")
  87. dataloader.AddVariable("distX", "F")
  88. dataloader.AddVariable("distY", "F")
  89. dataloader.AddVariable("dSlope", "F")
  90. dataloader.AddVariable("dSlopeY", "F")
  91. # dataloader.AddVariable("zMag", "F")
  92. dataloader.AddSignalTree(signal_tree, 1.0)
  93. dataloader.AddBackgroundTree(bkg_tree, 1.0)
  94. # these cuts are also applied in the algorithm
  95. preselectionCuts = ROOT.TCut(
  96. # "chi2<30 && distX<500 && distY<500 && dSlope<2.0 && dSlopeY<0.15", #### ganz raus für elektronen
  97. )
  98. dataloader.PrepareTrainingAndTestTree(
  99. preselectionCuts,
  100. f"SplitMode=random:V:nTrain_Signal={n_train_signal}:nTrain_Background={n_train_bkg}:nTest_Signal={n_test_signal}:nTest_Background={n_test_bkg}",
  101. # normmode default is EqualNumEvents
  102. )
  103. factory.BookMethod(
  104. dataloader,
  105. TMVA.Types.kMLP,
  106. "matching_mlp",
  107. "!H:V:TrainingMethod=BP:NeuronType=ReLU:EstimatorType=CE:VarTransform=Norm:NCycles=700:HiddenLayers=N+2,N:TestRate=50:Sampling=1.0:SamplingImportance=1.0:LearningRate=0.02:DecayRate=0.01:!UseRegulator",
  108. )
  109. factory.TrainAllMethods()
  110. factory.TestAllMethods()
  111. factory.EvaluateAllMethods()
  112. output.Close()
  113. if __name__ == "__main__":
  114. parser = argparse.ArgumentParser()
  115. parser.add_argument(
  116. "--input-file",
  117. type=str,
  118. help="Path to the input file",
  119. required=False,
  120. )
  121. parser.add_argument(
  122. "--exclude_electrons",
  123. action="store_true",
  124. help="Excludes electrons from training.",
  125. required=False,
  126. )
  127. parser.add_argument(
  128. "--only_electrons",
  129. action="store_true",
  130. help="Only electrons for signal training.",
  131. required=False,
  132. )
  133. parser.add_argument(
  134. "--n-train-signal",
  135. type=int,
  136. help="Number of training tracks for signal.",
  137. required=False,
  138. )
  139. parser.add_argument(
  140. "--n-train-bkg",
  141. type=int,
  142. help="Number of training tracks for bkg.",
  143. required=False,
  144. )
  145. parser.add_argument(
  146. "--n-test-signal",
  147. type=int,
  148. help="Number of testing tracks for signal.",
  149. required=False,
  150. )
  151. parser.add_argument(
  152. "--n-test-bkg",
  153. type=int,
  154. help="Number of testing tracks for bkg.",
  155. required=False,
  156. )
  157. args = parser.parse_args()
  158. args_dict = {arg: val for arg, val in vars(args).items() if val is not None}
  159. train_matching_ghost_mlp(**args_dict)