You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

171 lines
5.8 KiB

10 months ago
  1. # flake8: noqaq
  2. import os
  3. import subprocess
  4. import argparse
  5. from parameterisations.parameterise_magnet_kink import parameterise_magnet_kink
  6. from parameterisations.parameterise_track_model import parameterise_track_model
  7. from parameterisations.parameterise_search_window import parameterise_search_window
  8. from parameterisations.parameterise_field_integral import parameterise_field_integral
  9. from parameterisations.parameterise_hough_histogram import parameterise_hough_histogram
  10. from parameterisations.utils.preselection import preselection
  11. from parameterisations.train_forward_ghost_mlps import (
  12. train_default_forward_ghost_mlp,
  13. train_veloUT_forward_ghost_mlp,
  14. )
  15. from parameterisations.losses_train_matching_ghost_mlps import (
  16. train_matching_ghost_mlp,
  17. )
  18. from parameterisations.utils.parse_tmva_matrix_to_array_TrLo import (
  19. parse_tmva_matrix_to_array,
  20. )
  21. parser = argparse.ArgumentParser()
  22. parser.add_argument(
  23. "--field-params",
  24. action="store_true",
  25. help="Enables determination of magnetic field parameterisations.",
  26. )
  27. parser.add_argument(
  28. "--forward-weights",
  29. action="store_true",
  30. help="Enables determination of weights used by neural networks.",
  31. )
  32. parser.add_argument(
  33. "--matching-weights",
  34. action="store_true",
  35. default=True,
  36. help="Enables determination of weights used by neural networks.",
  37. )
  38. parser.add_argument(
  39. "-p",
  40. "--prepare",
  41. action="store_true",
  42. # default=True,
  43. help="Enables preparation of data for matching.",
  44. )
  45. parser.add_argument(
  46. "--prepare-params-data",
  47. action="store_true",
  48. help="Enables preparation of data for magnetic field parameterisations.",
  49. )
  50. parser.add_argument(
  51. "--prepare-weights-data",
  52. action="store_true",
  53. help="Enables preparation of data for NN weight determination.",
  54. )
  55. args = parser.parse_args()
  56. selected = "nn_electron_training/data/param_data_selected.root"
  57. if args.prepare_params_data:
  58. selection = "chi2_comb < 5 && pt > 10 && p > 1500 && p < 100000 && pid != 11"
  59. print("Run selection cuts =", selection)
  60. selected_md = preselection(
  61. cuts=selection,
  62. input_file="data/param_data_MD.root",
  63. )
  64. selected_mu = preselection(
  65. cuts=selection,
  66. input_file="data/param_data_MU.root",
  67. )
  68. merge_cmd = ["hadd", "-fk", selected, selected_md, selected_mu]
  69. print("Concatenate polarities ...")
  70. subprocess.run(merge_cmd, check=True)
  71. cpp_files = []
  72. if args.field_params:
  73. print("Parameterise magnet kink position ...")
  74. cpp_files.append(parameterise_magnet_kink(input_file=selected))
  75. print("Parameterise track model ...")
  76. cpp_files.append(parameterise_track_model(input_file=selected))
  77. selected_all_p = "nn_electron_training/data/param_data_selected_all_p.root"
  78. if args.prepare_params_data:
  79. selection_all_momenta = "chi2_comb < 5 && pid != 11"
  80. print()
  81. print("Run selection cuts =", selection_all_momenta)
  82. selected_md_all_p = preselection(
  83. cuts=selection_all_momenta,
  84. outfile_postfix="selected_all_p",
  85. input_file="data/param_data_MD.root",
  86. )
  87. selected_mu_all_p = preselection(
  88. cuts=selection_all_momenta,
  89. outfile_postfix="selected_all_p",
  90. input_file="data/param_data_MU.root",
  91. )
  92. merge_cmd = ["hadd", "-fk", selected_all_p, selected_md_all_p, selected_mu_all_p]
  93. print("Concatenate polarities ...")
  94. subprocess.run(merge_cmd, check=True)
  95. if args.field_params:
  96. print("Parameterise search window ...")
  97. cpp_files.append(parameterise_search_window(input_file=selected_all_p))
  98. print("Parameterise magnetic field integral ...")
  99. cpp_files.append(parameterise_field_integral(input_file=selected_all_p))
  100. print("Parameterise Hough histogram binning ...")
  101. cpp_files.append(parameterise_hough_histogram(input_file=selected_all_p))
  102. ###>>>
  103. ghost_data = "data/ghost_data.root"
  104. if args.prepare_weights_data:
  105. merge_cmd = [
  106. "hadd",
  107. "-fk",
  108. ghost_data,
  109. "data/ghost_data_B.root",
  110. "data/ghost_data_D.root",
  111. ]
  112. print("Concatenate decays for neural network training ...")
  113. subprocess.run(merge_cmd, check=True)
  114. ###<<<
  115. if args.forward_weights:
  116. train_default_forward_ghost_mlp(prepare_data=args.prepare_weights_data)
  117. # FIXME: use env variable instead
  118. os.chdir(os.path.dirname(os.path.realpath(__file__)))
  119. train_veloUT_forward_ghost_mlp(prepare_data=args.prepare_weights_data)
  120. # this ensures that the directory is correct
  121. os.chdir(os.path.dirname(os.path.realpath(__file__)))
  122. cpp_files += parse_tmva_matrix_to_array(
  123. [
  124. "nn_trackinglosses_training/result/GhostNNDataSet/weights/TMVAClassification_default_forward_ghost_mlp.class.C",
  125. "nn_trackinglosses_training/result/GhostNNDataSet/weights/TMVAClassification_veloUT_forward_ghost_mlp.class.C",
  126. ],
  127. )
  128. ###
  129. ###>>>
  130. ###
  131. if args.matching_weights:
  132. os.chdir(os.path.dirname(os.path.realpath(__file__)))
  133. train_matching_ghost_mlp(
  134. prepare_data=args.prepare,
  135. input_file="data/tracking_losses_ntuple_B.root",
  136. tree_name="PrDebugTrackingLosses.PrDebugTrackingTool/Tuple", # e6feac0d, B: 3e224c41, B res: 1e13cc7e, D: 8cb154ca
  137. b_input_file="data/ghost_data_B.root",
  138. b_tree_name="PrMatchNN_3e224c41.PrMCDebugMatchToolNN/MVAInputAndOutput",
  139. only_electrons=True,
  140. outdir="nn_trackinglosses_training",
  141. n_train_signal=20e3,
  142. n_train_bkg=40e3,
  143. n_test_signal=5e3,
  144. n_test_bkg=10e3,
  145. )
  146. # this ensures that the directory is correct
  147. os.chdir(os.path.dirname(os.path.realpath(__file__)))
  148. cpp_files += parse_tmva_matrix_to_array(
  149. [
  150. "nn_trackinglosses_training/result/MatchNNDataSet/weights/TMVAClassification_matching_mlp.class.C",
  151. ],
  152. simd_type=True,
  153. )
  154. ###
  155. ###<<<
  156. ###
  157. for file in cpp_files:
  158. subprocess.run(
  159. [
  160. "clang-format",
  161. "-i",
  162. f"{file}",
  163. ],
  164. )