99 lines
2.8 KiB
Python
99 lines
2.8 KiB
Python
# flake8: noqaq
|
|
import os
|
|
import subprocess
|
|
import argparse
|
|
from parameterisations.parameterise_magnet_kink import parameterise_magnet_kink
|
|
from parameterisations.parameterise_track_model_electron import parameterise_track_model
|
|
from parameterisations.parameterise_search_window import parameterise_search_window
|
|
from parameterisations.parameterise_field_integral import parameterise_field_integral
|
|
from parameterisations.parameterise_hough_histogram import parameterise_hough_histogram
|
|
from parameterisations.utils.preselection import preselection
|
|
from parameterisations.train_forward_ghost_mlps import (
|
|
train_default_forward_ghost_mlp,
|
|
train_veloUT_forward_ghost_mlp,
|
|
)
|
|
from parameterisations.train_matching_ghost_mlps_electron import (
|
|
train_matching_ghost_mlp,
|
|
)
|
|
|
|
from parameterisations.utils.parse_tmva_matrix_to_array_electron import (
|
|
parse_tmva_matrix_to_array,
|
|
)
|
|
|
|
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
"--matching-weights",
|
|
action="store_true",
|
|
default=True,
|
|
help="Enables determination of weights used by neural networks.",
|
|
)
|
|
parser.add_argument(
|
|
"-p",
|
|
"--prepare",
|
|
action="store_true",
|
|
default=True,
|
|
help="Enables preparation of data for matching.",
|
|
)
|
|
parser.add_argument(
|
|
"--prepare-weights-data",
|
|
action="store_true",
|
|
help="Enables preparation of data for NN weight determination.",
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
cpp_files = []
|
|
|
|
ghost_data = "data/ghost_data.root"
|
|
if args.prepare_weights_data:
|
|
merge_cmd = [
|
|
"hadd",
|
|
"-fk",
|
|
ghost_data,
|
|
"data/ghost_data_B.root",
|
|
"data/ghost_data_D.root",
|
|
]
|
|
print("Concatenate decays for neural network training ...")
|
|
subprocess.run(merge_cmd, check=True)
|
|
|
|
|
|
file_name = "newpars"
|
|
|
|
tree_names = {}
|
|
tree_names["seed"] = "PrMatchNN_b60a058d.PrMCDebugMatchToolNN/MVAInputAndOutput"
|
|
tree_names["newpars"] = "PrMatchNN_b826666c.PrMCDebugMatchToolNN/MVAInputAndOutput"
|
|
|
|
if args.matching_weights:
|
|
os.chdir(os.path.dirname(os.path.realpath(__file__)))
|
|
train_matching_ghost_mlp(
|
|
input_file="data/ghost_data_B_sample4.root",
|
|
tree_name=tree_names[file_name],
|
|
exclude_electrons=False,
|
|
only_electrons=True,
|
|
filter_velos=False,
|
|
filter_seeds=True,
|
|
n_train_signal=110e3,
|
|
n_train_bkg=110e3,
|
|
n_test_signal=10e3,
|
|
n_test_bkg=10e3,
|
|
prepare_data=True,
|
|
outdir="nn_electron_training",
|
|
)
|
|
# this ensures that the directory is correct
|
|
os.chdir(os.path.dirname(os.path.realpath(__file__)))
|
|
cpp_files += parse_tmva_matrix_to_array(
|
|
[
|
|
"nn_electron_training/result/MatchNNDataSet/weights/TMVAClassification_matching_mlp.class.C",
|
|
],
|
|
simd_type=True,
|
|
)
|
|
|
|
for file in cpp_files:
|
|
subprocess.run(
|
|
[
|
|
"clang-format",
|
|
"-i",
|
|
f"{file}",
|
|
],
|
|
)
|