You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

98 lines
2.8 KiB

10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
10 months ago
  1. # flake8: noqaq
  2. import os
  3. import subprocess
  4. import argparse
  5. from parameterisations.parameterise_magnet_kink import parameterise_magnet_kink
  6. from parameterisations.parameterise_track_model_electron import parameterise_track_model
  7. from parameterisations.parameterise_search_window import parameterise_search_window
  8. from parameterisations.parameterise_field_integral import parameterise_field_integral
  9. from parameterisations.parameterise_hough_histogram import parameterise_hough_histogram
  10. from parameterisations.utils.preselection import preselection
  11. from parameterisations.train_forward_ghost_mlps import (
  12. train_default_forward_ghost_mlp,
  13. train_veloUT_forward_ghost_mlp,
  14. )
  15. from parameterisations.train_matching_ghost_mlps_electron import (
  16. train_matching_ghost_mlp,
  17. )
  18. from parameterisations.utils.parse_tmva_matrix_to_array_electron import (
  19. parse_tmva_matrix_to_array,
  20. )
  21. parser = argparse.ArgumentParser()
  22. parser.add_argument(
  23. "--matching-weights",
  24. action="store_true",
  25. default=True,
  26. help="Enables determination of weights used by neural networks.",
  27. )
  28. parser.add_argument(
  29. "-p",
  30. "--prepare",
  31. action="store_true",
  32. default=True,
  33. help="Enables preparation of data for matching.",
  34. )
  35. parser.add_argument(
  36. "--prepare-weights-data",
  37. action="store_true",
  38. help="Enables preparation of data for NN weight determination.",
  39. )
  40. args = parser.parse_args()
  41. cpp_files = []
  42. ghost_data = "data/ghost_data.root"
  43. if args.prepare_weights_data:
  44. merge_cmd = [
  45. "hadd",
  46. "-fk",
  47. ghost_data,
  48. "data/ghost_data_B.root",
  49. "data/ghost_data_D.root",
  50. ]
  51. print("Concatenate decays for neural network training ...")
  52. subprocess.run(merge_cmd, check=True)
  53. file_name = "newpars"
  54. tree_names = {}
  55. tree_names["seed"] = "PrMatchNN_b60a058d.PrMCDebugMatchToolNN/MVAInputAndOutput"
  56. tree_names["newpars"] = "PrMatchNN_b826666c.PrMCDebugMatchToolNN/MVAInputAndOutput"
  57. if args.matching_weights:
  58. os.chdir(os.path.dirname(os.path.realpath(__file__)))
  59. train_matching_ghost_mlp(
  60. input_file="data/ghost_data_B_sample4.root",
  61. tree_name=tree_names[file_name],
  62. exclude_electrons=False,
  63. only_electrons=True,
  64. filter_velos=False,
  65. filter_seeds=True,
  66. n_train_signal=150e3,
  67. n_train_bkg=150e3,
  68. n_test_signal=10e3,
  69. n_test_bkg=10e3,
  70. prepare_data=True,
  71. outdir="nn_electron_training",
  72. )
  73. # this ensures that the directory is correct
  74. os.chdir(os.path.dirname(os.path.realpath(__file__)))
  75. cpp_files += parse_tmva_matrix_to_array(
  76. [
  77. "nn_electron_training/result/MatchNNDataSet/weights/TMVAClassification_matching_mlp.class.C",
  78. ],
  79. simd_type=True,
  80. )
  81. for file in cpp_files:
  82. subprocess.run(
  83. [
  84. "clang-format",
  85. "-i",
  86. f"{file}",
  87. ],
  88. )