You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

211 lines
7.1 KiB

10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
  1. # flake8: noqaq
  2. import os
  3. import subprocess
  4. import argparse
  5. from parameterisations.parameterise_magnet_kink import parameterise_magnet_kink
  6. from parameterisations.parameterise_track_model import parameterise_track_model
  7. from parameterisations.parameterise_search_window import parameterise_search_window
  8. from parameterisations.parameterise_field_integral import parameterise_field_integral
  9. from parameterisations.parameterise_hough_histogram import parameterise_hough_histogram
  10. from parameterisations.utils.preselection import preselection
  11. from parameterisations.train_forward_ghost_mlps import (
  12. train_default_forward_ghost_mlp,
  13. train_veloUT_forward_ghost_mlp,
  14. )
  15. from parameterisations.residual_train_matching_ghost_mlps_electron import (
  16. res_train_matching_ghost_mlp,
  17. )
  18. from parameterisations.train_matching_ghost_mlps_electron import (
  19. train_matching_ghost_mlp,
  20. )
  21. from parameterisations.utils.parse_tmva_matrix_to_array_electron import (
  22. parse_tmva_matrix_to_array,
  23. )
  24. parser = argparse.ArgumentParser()
  25. parser.add_argument(
  26. "--field-params",
  27. action="store_true",
  28. help="Enables determination of magnetic field parameterisations.",
  29. )
  30. parser.add_argument(
  31. "--forward-weights",
  32. action="store_true",
  33. help="Enables determination of weights used by neural networks.",
  34. )
  35. parser.add_argument(
  36. "--matching-weights",
  37. action="store_true",
  38. default=True,
  39. help="Enables determination of weights used by neural networks.",
  40. )
  41. parser.add_argument(
  42. "-r",
  43. "--residuals",
  44. action="store_true",
  45. help="Trains neural network with residual tracks.",
  46. )
  47. parser.add_argument(
  48. "-p",
  49. "--prepare",
  50. action="store_true",
  51. default=True,
  52. help="Enables preparation of data for matching.",
  53. )
  54. parser.add_argument(
  55. "--prepare-params-data",
  56. action="store_true",
  57. help="Enables preparation of data for magnetic field parameterisations.",
  58. )
  59. parser.add_argument(
  60. "--prepare-weights-data",
  61. action="store_true",
  62. help="Enables preparation of data for NN weight determination.",
  63. )
  64. args = parser.parse_args()
  65. selected = "nn_electron_training/data/param_data_selected.root"
  66. if args.prepare_params_data:
  67. selection = "chi2_comb < 5 && pt > 10 && p > 1500 && p < 100000 && pid != 11"
  68. print("Run selection cuts =", selection)
  69. selected_md = preselection(
  70. cuts=selection,
  71. input_file="data/param_data_MD.root",
  72. )
  73. selected_mu = preselection(
  74. cuts=selection,
  75. input_file="data/param_data_MU.root",
  76. )
  77. merge_cmd = ["hadd", "-fk", selected, selected_md, selected_mu]
  78. print("Concatenate polarities ...")
  79. subprocess.run(merge_cmd, check=True)
  80. cpp_files = []
  81. if args.field_params:
  82. print("Parameterise magnet kink position ...")
  83. cpp_files.append(parameterise_magnet_kink(input_file=selected))
  84. print("Parameterise track model ...")
  85. cpp_files.append(parameterise_track_model(input_file=selected))
  86. selected_all_p = "nn_electron_training/data/param_data_selected_all_p.root"
  87. if args.prepare_params_data:
  88. selection_all_momenta = "chi2_comb < 5 && pid != 11"
  89. print()
  90. print("Run selection cuts =", selection_all_momenta)
  91. selected_md_all_p = preselection(
  92. cuts=selection_all_momenta,
  93. outfile_postfix="selected_all_p",
  94. input_file="data/param_data_MD.root",
  95. )
  96. selected_mu_all_p = preselection(
  97. cuts=selection_all_momenta,
  98. outfile_postfix="selected_all_p",
  99. input_file="data/param_data_MU.root",
  100. )
  101. merge_cmd = ["hadd", "-fk", selected_all_p, selected_md_all_p, selected_mu_all_p]
  102. print("Concatenate polarities ...")
  103. subprocess.run(merge_cmd, check=True)
  104. if args.field_params:
  105. print("Parameterise search window ...")
  106. cpp_files.append(parameterise_search_window(input_file=selected_all_p))
  107. print("Parameterise magnetic field integral ...")
  108. cpp_files.append(parameterise_field_integral(input_file=selected_all_p))
  109. print("Parameterise Hough histogram binning ...")
  110. cpp_files.append(parameterise_hough_histogram(input_file=selected_all_p))
  111. ###>>>
  112. ghost_data = "data/ghost_data.root"
  113. if args.prepare_weights_data:
  114. merge_cmd = [
  115. "hadd",
  116. "-fk",
  117. ghost_data,
  118. "data/ghost_data_B.root",
  119. "data/ghost_data_D.root",
  120. ]
  121. print("Concatenate decays for neural network training ...")
  122. subprocess.run(merge_cmd, check=True)
  123. ###<<<
  124. if args.forward_weights:
  125. train_default_forward_ghost_mlp(prepare_data=args.prepare_weights_data)
  126. # FIXME: use env variable instead
  127. os.chdir(os.path.dirname(os.path.realpath(__file__)))
  128. train_veloUT_forward_ghost_mlp(prepare_data=args.prepare_weights_data)
  129. # this ensures that the directory is correct
  130. os.chdir(os.path.dirname(os.path.realpath(__file__)))
  131. cpp_files += parse_tmva_matrix_to_array(
  132. [
  133. "nn_electron_training/result/GhostNNDataSet/weights/TMVAClassification_default_forward_ghost_mlp.class.C",
  134. "nn_electron_training/result/GhostNNDataSet/weights/TMVAClassification_veloUT_forward_ghost_mlp.class.C",
  135. ],
  136. )
  137. ###
  138. ###>>>
  139. ###
  140. if args.matching_weights and args.residuals:
  141. os.chdir(os.path.dirname(os.path.realpath(__file__)))
  142. res_train_matching_ghost_mlp(
  143. prepare_data=args.prepare,
  144. input_file="data/ghost_data_B_default_only_e_as_seed.root",
  145. tree_name="PrMatchNN_b60a058d.PrMCDebugMatchToolNN/MVAInputAndOutput", # e6feac0d, B: 3e224c41, B res: 1e13cc7e, D: 8cb154ca
  146. exclude_electrons=False,
  147. only_electrons=True,
  148. residuals="PrMatchNN_1e13cc7e.PrMCDebugMatchToolNN/MVAInputAndOutput",
  149. outdir="nn_electron_training",
  150. n_train_signal=0,
  151. n_train_bkg=20e3,
  152. n_test_signal=1e3,
  153. n_test_bkg=5e3,
  154. )
  155. # this ensures that the directory is correct
  156. os.chdir(os.path.dirname(os.path.realpath(__file__)))
  157. cpp_files += parse_tmva_matrix_to_array(
  158. [
  159. "nn_electron_training/result/MatchNNDataSet/weights/TMVAClassification_matching_mlp.class.C",
  160. ],
  161. simd_type=True,
  162. )
  163. file_name = "seed"
  164. tree_names = {}
  165. tree_names["seed"] = "PrMatchNN_b60a058d.PrMCDebugMatchToolNN/MVAInputAndOutput"
  166. tree_names["def"] = "PrMatchNN.PrMCDebugMatchToolNN/MVAInputAndOutput"
  167. if args.matching_weights and not args.residuals:
  168. os.chdir(os.path.dirname(os.path.realpath(__file__)))
  169. train_matching_ghost_mlp(
  170. prepare_data=args.prepare,
  171. input_file="data/ghost_data_B_vars_thesis.root",
  172. tree_name=tree_names[file_name],
  173. exclude_electrons=False,
  174. only_electrons=True,
  175. filter_seeds=True,
  176. outdir="nn_electron_training",
  177. n_train_signal=100e3,
  178. n_train_bkg=100e3,
  179. n_test_signal=10e3,
  180. n_test_bkg=10e3,
  181. )
  182. # this ensures that the directory is correct
  183. os.chdir(os.path.dirname(os.path.realpath(__file__)))
  184. cpp_files += parse_tmva_matrix_to_array(
  185. [
  186. "nn_electron_training/result/MatchNNDataSet/weights/TMVAClassification_matching_mlp.class.C",
  187. ],
  188. simd_type=True,
  189. )
  190. ###
  191. ###<<<
  192. ###
  193. for file in cpp_files:
  194. subprocess.run(
  195. [
  196. "clang-format",
  197. "-i",
  198. f"{file}",
  199. ],
  200. )