You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

154 lines
5.2 KiB

10 months ago
  1. # flake8: noqaq
  2. import os
  3. import argparse
  4. import ROOT
  5. from ROOT import TMVA
  6. def train_matching_ghost_mlp(
  7. input_file: str = "data/ghost_data_B.root",
  8. tree_name: str = "PrMatchNN.PrMCDebugMatchToolNN/MVAInputAndOutput",
  9. exclude_electrons: bool = True,
  10. n_train_signal: int = 50e3,
  11. n_train_bkg: int = 500e3,
  12. n_test_signal: int = 10e3,
  13. n_test_bkg: int = 100e3,
  14. prepare_data: bool = True,
  15. outdir: str = "neural_net_training",
  16. ):
  17. """Trains an MLP to classify the match between Velo and Seed track.
  18. Args:
  19. input_file (str, optional): Defaults to "data/ghost_data_B.root".
  20. tree_name (str, optional): Defaults to "PrMatchNN.PrMCDebugMatchToolNN/Tuple".
  21. exclude_electrons (bool, optional): Defaults to True.
  22. n_train_signal (int, optional): Number of true matches to train on. Defaults to 200e3.
  23. n_train_bkg (int, optional): Number of fake matches to train on. Defaults to 200e3.
  24. n_test_signal (int, optional): Number of true matches to test on. Defaults to 75e3.
  25. n_test_bkg (int, optional): Number of fake matches to test on. Defaults to 75e3.
  26. prepare_data (bool, optional): Split data into signal and background file. Defaults to False.
  27. """
  28. if prepare_data:
  29. rdf = ROOT.RDataFrame(tree_name, input_file)
  30. if exclude_electrons:
  31. rdf_signal = rdf.Filter(
  32. "quality == 1", # -1 elec, 0 ghost, 1 all part wo elec
  33. "Signal is defined as one label (excluding electrons)",
  34. )
  35. rdf_bkg = rdf.Filter("quality == 0", "Ghosts are defined as zero label")
  36. else:
  37. rdf_signal = rdf.Filter(
  38. "abs(quality) > 0",
  39. "Signal is defined as non-zero label",
  40. )
  41. rdf_bkg = rdf.Filter("quality == 0", "Ghosts are defined as zero label")
  42. rdf_signal.Snapshot(
  43. "Signal",
  44. outdir + "/" + input_file.strip(".root") + "_matching_signal.root",
  45. )
  46. rdf_bkg.Snapshot(
  47. "Bkg", outdir + "/" + input_file.strip(".root") + "_matching_bkg.root"
  48. )
  49. signal_file = ROOT.TFile.Open(
  50. outdir + "/" + input_file.strip(".root") + "_matching_signal.root",
  51. "READ",
  52. )
  53. signal_tree = signal_file.Get("Signal")
  54. bkg_file = ROOT.TFile.Open(
  55. outdir + "/" + input_file.strip(".root") + "_matching_bkg.root"
  56. )
  57. bkg_tree = bkg_file.Get("Bkg")
  58. os.chdir(outdir + "/result")
  59. output = ROOT.TFile(
  60. "matching_ghost_mlp_training.root",
  61. "RECREATE",
  62. )
  63. factory = TMVA.Factory(
  64. "TMVAClassification",
  65. output,
  66. "V:!Silent:Color:DrawProgressBar:AnalysisType=Classification",
  67. )
  68. factory.SetVerbose(True)
  69. dataloader = TMVA.DataLoader("MatchNNDataSet")
  70. dataloader.AddVariable("chi2", "F")
  71. dataloader.AddVariable("teta2", "F")
  72. dataloader.AddVariable("distX", "F")
  73. dataloader.AddVariable("distY", "F")
  74. dataloader.AddVariable("dSlope", "F")
  75. dataloader.AddVariable("dSlopeY", "F")
  76. # dataloader.AddVariable("tx", "F")
  77. # dataloader.AddVariable("ty", "F")
  78. # dataloader.AddVariable("tx_scifi", "F")
  79. # dataloader.AddVariable("ty_scifi", "F")
  80. dataloader.AddSignalTree(signal_tree, 1.0)
  81. dataloader.AddBackgroundTree(bkg_tree, 1.0)
  82. # these cuts are also applied in the algorithm
  83. preselectionCuts = ROOT.TCut(
  84. "chi2<15 && distX<250 && distY<400 && dSlope<1.5 && dSlopeY<0.15", #### ganz raus für elektronen
  85. )
  86. dataloader.PrepareTrainingAndTestTree(
  87. preselectionCuts,
  88. f"SplitMode=random:V:nTrain_Signal={n_train_signal}:nTrain_Background={n_train_bkg}:nTest_Signal={n_test_signal}:nTest_Background={n_test_bkg}",
  89. )
  90. factory.BookMethod(
  91. dataloader,
  92. TMVA.Types.kMLP,
  93. "matching_mlp",
  94. "!H:V:TrainingMethod=BP:NeuronType=ReLU:EstimatorType=CE:VarTransform=Norm:NCycles=700:HiddenLayers=N+2,N:TestRate=50:Sampling=1.0:SamplingImportance=1.0:LearningRate=0.02:DecayRate=0.01:!UseRegulator",
  95. )
  96. factory.TrainAllMethods()
  97. factory.TestAllMethods()
  98. factory.EvaluateAllMethods()
  99. output.Close()
  100. if __name__ == "__main__":
  101. parser = argparse.ArgumentParser()
  102. parser.add_argument(
  103. "--input-file",
  104. type=str,
  105. help="Path to the input file",
  106. required=False,
  107. )
  108. parser.add_argument(
  109. "--exclude_electrons",
  110. action="store_true",
  111. help="Excludes electrons from training.",
  112. required=False,
  113. )
  114. parser.add_argument(
  115. "--n-train-signal",
  116. type=int,
  117. help="Number of training tracks for signal.",
  118. required=False,
  119. )
  120. parser.add_argument(
  121. "--n-train-bkg",
  122. type=int,
  123. help="Number of training tracks for bkg.",
  124. required=False,
  125. )
  126. parser.add_argument(
  127. "--n-test-signal",
  128. type=int,
  129. help="Number of testing tracks for signal.",
  130. required=False,
  131. )
  132. parser.add_argument(
  133. "--n-test-bkg",
  134. type=int,
  135. help="Number of testing tracks for bkg.",
  136. required=False,
  137. )
  138. args = parser.parse_args()
  139. args_dict = {arg: val for arg, val in vars(args).items() if val is not None}
  140. train_matching_ghost_mlp(**args_dict)