tracking-parametrisation-tuner/electron_main.py

103 lines
3.0 KiB
Python
Raw Normal View History

2023-12-19 13:00:59 +01:00
# flake8: noqaq
# ruff: noqa
2023-12-19 13:00:59 +01:00
import os
import subprocess
import argparse
from parameterisations.parameterise_magnet_kink import parameterise_magnet_kink
2024-02-23 16:00:50 +01:00
from parameterisations.parameterise_track_model_electron import parameterise_track_model
2023-12-19 13:00:59 +01:00
from parameterisations.parameterise_search_window import parameterise_search_window
from parameterisations.parameterise_field_integral import parameterise_field_integral
from parameterisations.parameterise_hough_histogram import parameterise_hough_histogram
from parameterisations.utils.preselection import preselection
from parameterisations.train_forward_ghost_mlps import (
train_default_forward_ghost_mlp,
train_veloUT_forward_ghost_mlp,
)
from parameterisations.train_matching_ghost_mlps_electron import (
train_matching_ghost_mlp,
)
from parameterisations.utils.parse_tmva_matrix_to_array_electron import (
parse_tmva_matrix_to_array,
)
parser = argparse.ArgumentParser()
parser.add_argument(
"--matching-weights",
action="store_true",
default=True,
2023-12-19 13:00:59 +01:00
help="Enables determination of weights used by neural networks.",
)
parser.add_argument(
"-p",
"--prepare",
action="store_true",
default=True,
help="Enables preparation of data for matching.",
)
parser.add_argument(
"--prepare-weights-data",
action="store_true",
help="Enables preparation of data for NN weight determination.",
)
args = parser.parse_args()
cpp_files = []
2024-02-29 15:54:19 +01:00
ghost_data = "data/ghost_data_B_BJpsi.root"
2023-12-19 13:00:59 +01:00
if args.prepare_weights_data:
merge_cmd = [
"hadd",
"-fk",
ghost_data,
"data/ghost_data_B_NewParamsM.root",
"data/ghost_data_BJpsi_NewParamsM.root",
2023-12-19 13:00:59 +01:00
]
print("Concatenate decays for neural network training ...")
subprocess.run(merge_cmd, check=True)
2024-02-23 16:00:50 +01:00
2023-12-19 13:00:59 +01:00
2024-03-27 09:23:35 +01:00
file_name = "new"
2024-02-08 17:42:15 +01:00
tree_names = {}
2024-03-27 09:23:35 +01:00
tree_names["true"] = "PrMatchNN_b9ce4699.PrMCDebugMatchToolNN/MVAInputAndOutput"
tree_names["new"] = "PrMatchNN_410ce396.PrMCDebugMatchToolNN/MVAInputAndOutput"
tree_names["loose"] = "PrMatchNN_40474434.PrMCDebugMatchToolNN/MVAInputAndOutput"
tree_names["base"] = "PrMatchNN_c0bf8e8b.PrMCDebugMatchToolNN/MVAInputAndOutput"
2024-02-08 17:42:15 +01:00
2024-02-23 16:00:50 +01:00
if args.matching_weights:
2023-12-19 13:00:59 +01:00
os.chdir(os.path.dirname(os.path.realpath(__file__)))
train_matching_ghost_mlp(
2024-03-27 09:23:35 +01:00
input_file="data/ghost_data_B_EDef.root",
2024-02-15 16:45:55 +01:00
tree_name=tree_names[file_name],
exclude_electrons=False,
only_electrons=True,
filter_velos=False,
filter_seeds=False,
n_train_signal=115e3, # 115e3,
n_train_bkg=115e3, # 115e3,
n_test_signal=8e3,
n_test_bkg=8e3,
prepare_data=True,
outdir="nn_electron_training",
2023-12-19 13:00:59 +01:00
)
# this ensures that the directory is correct
os.chdir(os.path.dirname(os.path.realpath(__file__)))
cpp_files += parse_tmva_matrix_to_array(
[
"nn_electron_training/result/MatchNNDataSet/weights/TMVAClassification_matching_mlp.class.C",
],
simd_type=True,
)
2024-02-23 16:00:50 +01:00
2023-12-19 13:00:59 +01:00
for file in cpp_files:
subprocess.run(
[
"clang-format",
"-i",
f"{file}",
],
)