You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

159 lines
5.6 KiB

10 months ago
9 months ago
10 months ago
9 months ago
9 months ago
10 months ago
9 months ago
10 months ago
  1. # flake8: noqaq
  2. import os
  3. import subprocess
  4. import argparse
  5. from parameterisations.parameterise_magnet_kink import parameterise_magnet_kink
  6. from parameterisations.parameterise_track_model import parameterise_track_model
  7. from parameterisations.parameterise_search_window import parameterise_search_window
  8. from parameterisations.parameterise_field_integral import parameterise_field_integral
  9. from parameterisations.parameterise_hough_histogram import parameterise_hough_histogram
  10. from parameterisations.utils.preselection import preselection
  11. from parameterisations.train_forward_ghost_mlps import (
  12. train_default_forward_ghost_mlp,
  13. train_veloUT_forward_ghost_mlp,
  14. )
  15. from parameterisations.train_matching_ghost_mlps import train_matching_ghost_mlp
  16. from parameterisations.utils.parse_tmva_matrix_to_array import (
  17. parse_tmva_matrix_to_array,
  18. )
  19. parser = argparse.ArgumentParser()
  20. parser.add_argument(
  21. "--field-params",
  22. action="store_true",
  23. help="Enables determination of magnetic field parameterisations.",
  24. )
  25. parser.add_argument(
  26. "--forward-weights",
  27. action="store_true",
  28. help="Enables determination of weights used by neural networks.",
  29. )
  30. parser.add_argument(
  31. "--matching-weights",
  32. action="store_true",
  33. default=True,
  34. help="Enables determination of weights used by neural networks.",
  35. )
  36. parser.add_argument(
  37. "-p",
  38. "--prepare",
  39. action="store_true",
  40. help="Enables preparation of data for matching.",
  41. )
  42. parser.add_argument(
  43. "--prepare-params-data",
  44. action="store_true",
  45. help="Enables preparation of data for magnetic field parameterisations.",
  46. )
  47. parser.add_argument(
  48. "--prepare-weights-data",
  49. action="store_true",
  50. help="Enables preparation of data for NN weight determination.",
  51. )
  52. args = parser.parse_args()
  53. selected = "neural_net_training/data/param_data_selected.root"
  54. if args.prepare_params_data:
  55. selection = "chi2_comb < 5 && pt > 10 && p > 1500 && p < 100000 && pid != 11"
  56. print("Run selection cuts =", selection)
  57. selected_md = preselection(
  58. cuts=selection,
  59. input_file="data/param_data_MD.root",
  60. )
  61. selected_mu = preselection(
  62. cuts=selection,
  63. input_file="data/param_data_MU.root",
  64. )
  65. merge_cmd = ["hadd", "-fk", selected, selected_md, selected_mu]
  66. print("Concatenate polarities ...")
  67. subprocess.run(merge_cmd, check=True)
  68. cpp_files = []
  69. if args.field_params:
  70. print("Parameterise magnet kink position ...")
  71. cpp_files.append(parameterise_magnet_kink(input_file=selected))
  72. print("Parameterise track model ...")
  73. cpp_files.append(parameterise_track_model(input_file=selected))
  74. selected_all_p = "neural_net_training/data/param_data_selected_all_p.root"
  75. if args.prepare_params_data:
  76. selection_all_momenta = "chi2_comb < 5 && pid != 11"
  77. print()
  78. print("Run selection cuts =", selection_all_momenta)
  79. selected_md_all_p = preselection(
  80. cuts=selection_all_momenta,
  81. outfile_postfix="selected_all_p",
  82. input_file="data/param_data_MD.root",
  83. )
  84. selected_mu_all_p = preselection(
  85. cuts=selection_all_momenta,
  86. outfile_postfix="selected_all_p",
  87. input_file="data/param_data_MU.root",
  88. )
  89. merge_cmd = ["hadd", "-fk", selected_all_p, selected_md_all_p, selected_mu_all_p]
  90. print("Concatenate polarities ...")
  91. subprocess.run(merge_cmd, check=True)
  92. if args.field_params:
  93. print("Parameterise search window ...")
  94. cpp_files.append(parameterise_search_window(input_file=selected_all_p))
  95. print("Parameterise magnetic field integral ...")
  96. cpp_files.append(parameterise_field_integral(input_file=selected_all_p))
  97. print("Parameterise Hough histogram binning ...")
  98. cpp_files.append(parameterise_hough_histogram(input_file=selected_all_p))
  99. ###>>>
  100. ghost_data = "neural_net_training/data/ghost_data.root"
  101. if args.prepare_weights_data:
  102. merge_cmd = [
  103. "hadd",
  104. "-fk",
  105. ghost_data,
  106. "data/ghost_data_MD.root",
  107. "data/ghost_data_MU.root",
  108. ]
  109. print("Concatenate polarities for neural network training ...")
  110. subprocess.run(merge_cmd, check=True)
  111. ###<<<
  112. if args.forward_weights:
  113. train_default_forward_ghost_mlp(prepare_data=args.prepare_weights_data)
  114. # FIXME: use env variable instead
  115. os.chdir(os.path.dirname(os.path.realpath(__file__)))
  116. train_veloUT_forward_ghost_mlp(prepare_data=args.prepare_weights_data)
  117. # this ensures that the directory is correct
  118. os.chdir(os.path.dirname(os.path.realpath(__file__)))
  119. cpp_files += parse_tmva_matrix_to_array(
  120. [
  121. "neural_net_training/result/GhostNNDataSet/weights/TMVAClassification_default_forward_ghost_mlp.class.C",
  122. "neural_net_training/result/GhostNNDataSet/weights/TMVAClassification_veloUT_forward_ghost_mlp.class.C",
  123. ],
  124. )
  125. ###>>>
  126. if args.matching_weights:
  127. os.chdir(os.path.dirname(os.path.realpath(__file__)))
  128. train_matching_ghost_mlp(
  129. prepare_data=args.prepare,
  130. input_file="data/ghost_data_D_default_phi_eta.root",
  131. tree_name="PrMatchNN_3e224c41.PrMCDebugMatchToolNN/MVAInputAndOutput",
  132. outdir="neural_net_training",
  133. exclude_electrons=False,
  134. only_electrons=True,
  135. )
  136. # this ensures that the directory is correct
  137. os.chdir(os.path.dirname(os.path.realpath(__file__)))
  138. cpp_files += parse_tmva_matrix_to_array(
  139. [
  140. "neural_net_training/result/MatchNNDataSet/weights/TMVAClassification_matching_mlp.class.C",
  141. ],
  142. simd_type=True,
  143. )
  144. ###<<<
  145. for file in cpp_files:
  146. subprocess.run(
  147. [
  148. "clang-format",
  149. "-i",
  150. f"{file}",
  151. ],
  152. )