You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

99 lines
2.9 KiB

10 months ago
10 months ago
10 months ago
10 months ago
7 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
  1. # flake8: noqaq
  2. # ruff: noqa
  3. import os
  4. import subprocess
  5. import argparse
  6. from parameterisations.parameterise_magnet_kink import parameterise_magnet_kink
  7. from parameterisations.parameterise_track_model_electron import parameterise_track_model
  8. from parameterisations.parameterise_search_window import parameterise_search_window
  9. from parameterisations.parameterise_field_integral import parameterise_field_integral
  10. from parameterisations.parameterise_hough_histogram import parameterise_hough_histogram
  11. from parameterisations.utils.preselection import preselection
  12. from parameterisations.train_forward_ghost_mlps import (
  13. train_default_forward_ghost_mlp,
  14. train_veloUT_forward_ghost_mlp,
  15. )
  16. from parameterisations.train_matching_ghost_mlps_electron import (
  17. train_matching_ghost_mlp,
  18. )
  19. from parameterisations.utils.parse_tmva_matrix_to_array_electron import (
  20. parse_tmva_matrix_to_array,
  21. )
  22. parser = argparse.ArgumentParser()
  23. parser.add_argument(
  24. "--matching-weights",
  25. action="store_true",
  26. default=True,
  27. help="Enables determination of weights used by neural networks.",
  28. )
  29. parser.add_argument(
  30. "-p",
  31. "--prepare",
  32. action="store_true",
  33. default=True,
  34. help="Enables preparation of data for matching.",
  35. )
  36. parser.add_argument(
  37. "--prepare-weights-data",
  38. action="store_true",
  39. help="Enables preparation of data for NN weight determination.",
  40. )
  41. args = parser.parse_args()
  42. cpp_files = []
  43. ghost_data = "data/ghost_data_B_BJpsi.root"
  44. if args.prepare_weights_data:
  45. merge_cmd = [
  46. "hadd",
  47. "-fk",
  48. ghost_data,
  49. "data/ghost_data_B_NewParamsM.root",
  50. "data/ghost_data_BJpsi_NewParamsM.root",
  51. ]
  52. print("Concatenate decays for neural network training ...")
  53. subprocess.run(merge_cmd, check=True)
  54. file_name = "newpars"
  55. tree_names = {}
  56. tree_names["seed"] = "PrMatchNN_b60a058d.PrMCDebugMatchToolNN/MVAInputAndOutput"
  57. tree_names["newpars"] = "PrMatchNN_b826666c.PrMCDebugMatchToolNN/MVAInputAndOutput"
  58. if args.matching_weights:
  59. os.chdir(os.path.dirname(os.path.realpath(__file__)))
  60. train_matching_ghost_mlp(
  61. input_file="data/ghost_data_B_NewParamsM.root",
  62. tree_name=tree_names[file_name],
  63. exclude_electrons=False,
  64. only_electrons=True,
  65. filter_velos=False,
  66. filter_seeds=False,
  67. n_train_signal=115e3, # 115e3,
  68. n_train_bkg=115e3, # 115e3,
  69. n_test_signal=8e3,
  70. n_test_bkg=8e3,
  71. prepare_data=True,
  72. outdir="nn_electron_training",
  73. )
  74. # this ensures that the directory is correct
  75. os.chdir(os.path.dirname(os.path.realpath(__file__)))
  76. cpp_files += parse_tmva_matrix_to_array(
  77. [
  78. "nn_electron_training/result/MatchNNDataSet/weights/TMVAClassification_matching_mlp.class.C",
  79. ],
  80. simd_type=True,
  81. )
  82. for file in cpp_files:
  83. subprocess.run(
  84. [
  85. "clang-format",
  86. "-i",
  87. f"{file}",
  88. ],
  89. )