You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

161 lines
5.6 KiB

10 months ago
9 months ago
10 months ago
9 months ago
10 months ago
9 months ago
10 months ago
9 months ago
10 months ago
9 months ago
10 months ago
9 months ago
10 months ago
9 months ago
10 months ago
  1. # flake8: noqaq
  2. import os
  3. import argparse
  4. import ROOT
  5. from ROOT import TMVA
  6. def train_matching_ghost_mlp(
  7. input_file: str = "data/ghost_data_B.root",
  8. tree_name: str = "PrMatchNN.PrMCDebugMatchToolNN/MVAInputAndOutput",
  9. exclude_electrons: bool = True,
  10. only_electrons: bool = False,
  11. n_train_signal: int = 100e3,
  12. n_train_bkg: int = 100e3,
  13. n_test_signal: int = 50e3,
  14. n_test_bkg: int = 50e3,
  15. prepare_data: bool = True,
  16. outdir: str = "neural_net_training",
  17. ):
  18. """Trains an MLP to classify the match between Velo and Seed track.
  19. Args:
  20. input_file (str, optional): Defaults to "data/ghost_data_B.root".
  21. tree_name (str, optional): Defaults to "PrMatchNN.PrMCDebugMatchToolNN/Tuple".
  22. exclude_electrons (bool, optional): Defaults to True.
  23. n_train_signal (int, optional): Number of true matches to train on. Defaults to 200e3.
  24. n_train_bkg (int, optional): Number of fake matches to train on. Defaults to 200e3.
  25. n_test_signal (int, optional): Number of true matches to test on. Defaults to 75e3.
  26. n_test_bkg (int, optional): Number of fake matches to test on. Defaults to 75e3.
  27. prepare_data (bool, optional): Split data into signal and background file. Defaults to False.
  28. """
  29. if prepare_data:
  30. rdf = ROOT.RDataFrame(tree_name, input_file)
  31. if exclude_electrons:
  32. rdf_signal = rdf.Filter(
  33. "mc_quality == 1", # -1 elec, 0 ghost, 1 all part wo elec
  34. "Signal is defined as one label (excluding electrons)",
  35. )
  36. rdf_bkg = rdf.Filter("mc_quality == 0", "Ghosts are defined as zero label")
  37. elif only_electrons:
  38. rdf_signal = rdf.Filter(
  39. "mc_quality == -1", # -1 elec, 0 ghost, 1 all part wo elec
  40. "Signal is defined as one label (excluding electrons)",
  41. )
  42. rdf_bkg = rdf.Filter("mc_quality == 0", "Ghosts are defined as zero label")
  43. else:
  44. rdf_signal = rdf.Filter(
  45. "abs(mc_quality) > 0",
  46. "Signal is defined as non-zero label",
  47. )
  48. rdf_bkg = rdf.Filter("mc_quality == 0", "Ghosts are defined as zero label")
  49. rdf_signal.Snapshot(
  50. "Signal",
  51. outdir + "/" + input_file.strip(".root") + "_matching_signal.root",
  52. )
  53. rdf_bkg.Snapshot(
  54. "Bkg", outdir + "/" + input_file.strip(".root") + "_matching_bkg.root"
  55. )
  56. signal_file = ROOT.TFile.Open(
  57. outdir + "/" + input_file.strip(".root") + "_matching_signal.root",
  58. "READ",
  59. )
  60. signal_tree = signal_file.Get("Signal")
  61. bkg_file = ROOT.TFile.Open(
  62. outdir + "/" + input_file.strip(".root") + "_matching_bkg.root"
  63. )
  64. bkg_tree = bkg_file.Get("Bkg")
  65. os.chdir(outdir + "/result")
  66. output = ROOT.TFile(
  67. "matching_ghost_mlp_training.root",
  68. "RECREATE",
  69. )
  70. factory = TMVA.Factory(
  71. "TMVAClassification",
  72. output,
  73. "V:!Silent:Color:DrawProgressBar:AnalysisType=Classification",
  74. )
  75. factory.SetVerbose(True)
  76. dataloader = TMVA.DataLoader("MatchNNDataSet")
  77. dataloader.AddVariable("mc_chi2", "F")
  78. dataloader.AddVariable("mc_teta2", "F")
  79. dataloader.AddVariable("mc_distX", "F")
  80. dataloader.AddVariable("mc_distY", "F")
  81. dataloader.AddVariable("mc_dSlope", "F")
  82. dataloader.AddVariable("mc_dSlopeY", "F")
  83. # dataloader.AddVariable("tx", "F")
  84. # dataloader.AddVariable("ty", "F")
  85. # dataloader.AddVariable("tx_scifi", "F")
  86. # dataloader.AddVariable("ty_scifi", "F")
  87. dataloader.AddSignalTree(signal_tree, 1.0)
  88. dataloader.AddBackgroundTree(bkg_tree, 1.0)
  89. # these cuts are also applied in the algorithm
  90. preselectionCuts = ROOT.TCut(
  91. "mc_chi2<30 && mc_distX<500 && mc_distY<500 && mc_dSlope<3.0 && mc_dSlopeY<0.3", #### ganz raus für elektronen
  92. )
  93. dataloader.PrepareTrainingAndTestTree(
  94. preselectionCuts,
  95. f"SplitMode=random:V:nTrain_Signal={n_train_signal}:nTrain_Background={n_train_bkg}:nTest_Signal={n_test_signal}:nTest_Background={n_test_bkg}",
  96. )
  97. factory.BookMethod(
  98. dataloader,
  99. TMVA.Types.kMLP,
  100. "matching_mlp",
  101. "!H:V:TrainingMethod=BP:NeuronType=ReLU:EstimatorType=CE:VarTransform=Norm:NCycles=700:HiddenLayers=N+2,N:TestRate=50:Sampling=1.0:SamplingImportance=1.0:LearningRate=0.02:DecayRate=0.01:!UseRegulator",
  102. )
  103. factory.TrainAllMethods()
  104. factory.TestAllMethods()
  105. factory.EvaluateAllMethods()
  106. output.Close()
  107. if __name__ == "__main__":
  108. parser = argparse.ArgumentParser()
  109. parser.add_argument(
  110. "--input-file",
  111. type=str,
  112. help="Path to the input file",
  113. required=False,
  114. )
  115. parser.add_argument(
  116. "--exclude_electrons",
  117. action="store_true",
  118. help="Excludes electrons from training.",
  119. required=False,
  120. )
  121. parser.add_argument(
  122. "--n-train-signal",
  123. type=int,
  124. help="Number of training tracks for signal.",
  125. required=False,
  126. )
  127. parser.add_argument(
  128. "--n-train-bkg",
  129. type=int,
  130. help="Number of training tracks for bkg.",
  131. required=False,
  132. )
  133. parser.add_argument(
  134. "--n-test-signal",
  135. type=int,
  136. help="Number of testing tracks for signal.",
  137. required=False,
  138. )
  139. parser.add_argument(
  140. "--n-test-bkg",
  141. type=int,
  142. help="Number of testing tracks for bkg.",
  143. required=False,
  144. )
  145. args = parser.parse_args()
  146. args_dict = {arg: val for arg, val in vars(args).items() if val is not None}
  147. train_matching_ghost_mlp(**args_dict)