You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

163 lines
5.8 KiB

10 months ago
9 months ago
9 months ago
10 months ago
9 months ago
10 months ago
9 months ago
10 months ago
9 months ago
10 months ago
9 months ago
10 months ago
9 months ago
9 months ago
10 months ago
9 months ago
10 months ago
9 months ago
10 months ago
9 months ago
9 months ago
10 months ago
9 months ago
10 months ago
9 months ago
10 months ago
  1. # flake8: noqaq
  2. import os
  3. import argparse
  4. import ROOT
  5. from ROOT import TMVA
  6. def train_matching_ghost_mlp(
  7. input_file: str = "data/ghost_data_B.root",
  8. tree_name: str = "PrMatchNN.PrMCDebugMatchToolNN/MVAInputAndOutput",
  9. exclude_electrons: bool = True,
  10. only_electrons: bool = False,
  11. n_train_signal: int = 165e3,
  12. n_train_bkg: int = 165e3,
  13. n_test_signal: int = 5e3,
  14. n_test_bkg: int = 5e3,
  15. prepare_data: bool = True,
  16. outdir: str = "neural_net_training",
  17. ):
  18. """Trains an MLP to classify the match between Velo and Seed track.
  19. Args:
  20. input_file (str, optional): Defaults to "data/ghost_data_B.root".
  21. tree_name (str, optional): Defaults to "PrMatchNN.PrMCDebugMatchToolNN/Tuple".
  22. exclude_electrons (bool, optional): Defaults to True.
  23. n_train_signal (int, optional): Number of true matches to train on. Defaults to 200e3.
  24. n_train_bkg (int, optional): Number of fake matches to train on. Defaults to 200e3.
  25. n_test_signal (int, optional): Number of true matches to test on. Defaults to 75e3.
  26. n_test_bkg (int, optional): Number of fake matches to test on. Defaults to 75e3.
  27. prepare_data (bool, optional): Split data into signal and background file. Defaults to False.
  28. """
  29. if prepare_data:
  30. print("Preparing Data for MVA Training: ")
  31. rdf = ROOT.RDataFrame(tree_name, input_file)
  32. if exclude_electrons:
  33. rdf_signal = rdf.Filter(
  34. "mc_quality == 1", # -1 elec, 0 ghost, 1 all part wo elec
  35. "Signal is defined as one label (excluding electrons)",
  36. )
  37. rdf_bkg = rdf.Filter("mc_quality == 0", "Ghosts are defined as zero label")
  38. elif only_electrons:
  39. rdf_signal = rdf.Filter(
  40. "mc_quality == -1", # -1 elec, 0 ghost, 1 all part wo elec
  41. "Signal is defined as one label (excluding electrons)",
  42. )
  43. rdf_bkg = rdf.Filter("mc_quality == 0", "Ghosts are defined as zero label")
  44. else:
  45. rdf_signal = rdf.Filter(
  46. "abs(mc_quality) > 0",
  47. "Signal is defined as non-zero label",
  48. )
  49. rdf_bkg = rdf.Filter("mc_quality == 0", "Ghosts are defined as zero label")
  50. print("Create Signal data file. ")
  51. rdf_signal.Snapshot(
  52. "Signal",
  53. outdir + "/" + input_file.strip(".root") + "_matching_signal.root",
  54. )
  55. print("Create Background data file. ")
  56. rdf_bkg.Snapshot(
  57. "Bkg", outdir + "/" + input_file.strip(".root") + "_matching_bkg.root"
  58. )
  59. print("Data preparation terminated successfully.\n")
  60. signal_file = ROOT.TFile.Open(
  61. outdir + "/" + input_file.strip(".root") + "_matching_signal.root",
  62. "READ",
  63. )
  64. signal_tree = signal_file.Get("Signal")
  65. bkg_file = ROOT.TFile.Open(
  66. outdir + "/" + input_file.strip(".root") + "_matching_bkg.root"
  67. )
  68. bkg_tree = bkg_file.Get("Bkg")
  69. os.chdir(outdir + "/result")
  70. output = ROOT.TFile(
  71. "matching_ghost_mlp_training.root",
  72. "RECREATE",
  73. )
  74. factory = TMVA.Factory(
  75. "TMVAClassification",
  76. output,
  77. "V:!Silent:Color:DrawProgressBar:AnalysisType=Classification",
  78. )
  79. factory.SetVerbose(True)
  80. dataloader = TMVA.DataLoader("MatchNNDataSet")
  81. dataloader.AddVariable("mc_chi2", "F")
  82. dataloader.AddVariable("mc_teta2", "F")
  83. dataloader.AddVariable("mc_distX", "F")
  84. dataloader.AddVariable("mc_distY", "F")
  85. dataloader.AddVariable("mc_dSlope", "F")
  86. dataloader.AddVariable("mc_dSlopeY", "F")
  87. dataloader.AddVariable("mc_end_velo_phi", "F")
  88. dataloader.AddVariable("mc_end_velo_eta", "F")
  89. dataloader.AddSignalTree(signal_tree, 1.0)
  90. dataloader.AddBackgroundTree(bkg_tree, 1.0)
  91. # these cuts are also applied in the algorithm
  92. preselectionCuts = ROOT.TCut(
  93. "mc_chi2<15 && mc_distX<250 && mc_distY<250 && mc_dSlope<1.5 && mc_dSlopeY<0.15 && std::abs(mc_end_velo_phi)<3.142 && mc_end_velo_eta>1.99 && mc_end_velo_eta<5.01", #### ganz raus für elektronen
  94. )
  95. dataloader.PrepareTrainingAndTestTree(
  96. preselectionCuts,
  97. f"SplitMode=random:V:nTrain_Signal={n_train_signal}:nTrain_Background={n_train_bkg}:nTest_Signal={n_test_signal}:nTest_Background={n_test_bkg}",
  98. )
  99. factory.BookMethod(
  100. dataloader,
  101. TMVA.Types.kMLP,
  102. "matching_mlp",
  103. "!H:V:TrainingMethod=BP:NeuronType=ReLU:EstimatorType=CE:VarTransform=Norm:NCycles=700:HiddenLayers=N+2,N:TestRate=50:Sampling=1.0:SamplingImportance=1.0:LearningRate=0.02:DecayRate=0.01:!UseRegulator",
  104. )
  105. factory.TrainAllMethods()
  106. factory.TestAllMethods()
  107. factory.EvaluateAllMethods()
  108. output.Close()
  109. if __name__ == "__main__":
  110. parser = argparse.ArgumentParser()
  111. parser.add_argument(
  112. "--input-file",
  113. type=str,
  114. help="Path to the input file",
  115. required=False,
  116. )
  117. parser.add_argument(
  118. "--exclude_electrons",
  119. action="store_true",
  120. help="Excludes electrons from training.",
  121. required=False,
  122. )
  123. parser.add_argument(
  124. "--n-train-signal",
  125. type=int,
  126. help="Number of training tracks for signal.",
  127. required=False,
  128. )
  129. parser.add_argument(
  130. "--n-train-bkg",
  131. type=int,
  132. help="Number of training tracks for bkg.",
  133. required=False,
  134. )
  135. parser.add_argument(
  136. "--n-test-signal",
  137. type=int,
  138. help="Number of testing tracks for signal.",
  139. required=False,
  140. )
  141. parser.add_argument(
  142. "--n-test-bkg",
  143. type=int,
  144. help="Number of testing tracks for bkg.",
  145. required=False,
  146. )
  147. args = parser.parse_args()
  148. args_dict = {arg: val for arg, val in vars(args).items() if val is not None}
  149. train_matching_ghost_mlp(**args_dict)