You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

102 lines
3.0 KiB

10 months ago
10 months ago
10 months ago
10 months ago
7 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
  1. # flake8: noqaq
  2. # ruff: noqa
  3. import os
  4. import subprocess
  5. import argparse
  6. from parameterisations.parameterise_magnet_kink import parameterise_magnet_kink
  7. from parameterisations.parameterise_track_model_electron import parameterise_track_model
  8. from parameterisations.parameterise_search_window import parameterise_search_window
  9. from parameterisations.parameterise_field_integral import parameterise_field_integral
  10. from parameterisations.parameterise_hough_histogram import parameterise_hough_histogram
  11. from parameterisations.utils.preselection import preselection
  12. from parameterisations.train_forward_ghost_mlps import (
  13. train_default_forward_ghost_mlp,
  14. train_veloUT_forward_ghost_mlp,
  15. )
  16. from parameterisations.train_matching_ghost_mlps_electron import (
  17. train_matching_ghost_mlp,
  18. )
  19. from parameterisations.utils.parse_tmva_matrix_to_array_electron import (
  20. parse_tmva_matrix_to_array,
  21. )
  22. parser = argparse.ArgumentParser()
  23. parser.add_argument(
  24. "--matching-weights",
  25. action="store_true",
  26. default=True,
  27. help="Enables determination of weights used by neural networks.",
  28. )
  29. parser.add_argument(
  30. "-p",
  31. "--prepare",
  32. action="store_true",
  33. default=True,
  34. help="Enables preparation of data for matching.",
  35. )
  36. parser.add_argument(
  37. "--prepare-weights-data",
  38. action="store_true",
  39. help="Enables preparation of data for NN weight determination.",
  40. )
  41. args = parser.parse_args()
  42. cpp_files = []
  43. ghost_data = "data/ghost_data_B_BJpsi.root"
  44. if args.prepare_weights_data:
  45. merge_cmd = [
  46. "hadd",
  47. "-fk",
  48. ghost_data,
  49. "data/ghost_data_B_NewParamsM.root",
  50. "data/ghost_data_BJpsi_NewParamsM.root",
  51. ]
  52. print("Concatenate decays for neural network training ...")
  53. subprocess.run(merge_cmd, check=True)
  54. file_name = "new"
  55. tree_names = {}
  56. tree_names["true"] = "PrMatchNN_b9ce4699.PrMCDebugMatchToolNN/MVAInputAndOutput"
  57. tree_names["new"] = "PrMatchNN_410ce396.PrMCDebugMatchToolNN/MVAInputAndOutput"
  58. tree_names["loose"] = "PrMatchNN_40474434.PrMCDebugMatchToolNN/MVAInputAndOutput"
  59. tree_names["base"] = "PrMatchNN_c0bf8e8b.PrMCDebugMatchToolNN/MVAInputAndOutput"
  60. if args.matching_weights:
  61. os.chdir(os.path.dirname(os.path.realpath(__file__)))
  62. train_matching_ghost_mlp(
  63. input_file="data/ghost_data_B_EDef.root",
  64. tree_name=tree_names[file_name],
  65. exclude_electrons=False,
  66. only_electrons=True,
  67. filter_velos=False,
  68. filter_seeds=False,
  69. n_train_signal=115e3, # 115e3,
  70. n_train_bkg=115e3, # 115e3,
  71. n_test_signal=8e3,
  72. n_test_bkg=8e3,
  73. prepare_data=True,
  74. outdir="nn_electron_training",
  75. )
  76. # this ensures that the directory is correct
  77. os.chdir(os.path.dirname(os.path.realpath(__file__)))
  78. cpp_files += parse_tmva_matrix_to_array(
  79. [
  80. "nn_electron_training/result/MatchNNDataSet/weights/TMVAClassification_matching_mlp.class.C",
  81. ],
  82. simd_type=True,
  83. )
  84. for file in cpp_files:
  85. subprocess.run(
  86. [
  87. "clang-format",
  88. "-i",
  89. f"{file}",
  90. ],
  91. )