tracking-parametrisation-tuner/parameterisations/train_matching_ghost_mlps_electron.py

228 lines
7.8 KiB
Python
Raw Normal View History

2023-12-19 13:00:59 +01:00
# flake8: noqaq
# ruff: noqa
2023-12-19 13:00:59 +01:00
import os
import argparse
import ROOT
from ROOT import TMVA, TList, TTree
def train_matching_ghost_mlp(
input_file: str = "data/ghost_data_B.root",
tree_name: str = "PrMatchNN.PrMCDebugMatchToolNN/MVAInputAndOutput",
exclude_electrons: bool = False,
only_electrons: bool = True,
filter_velos: bool = False,
2024-02-15 16:45:55 +01:00
filter_seeds: bool = False,
n_train_signal: int = 50e3,
n_train_bkg: int = 50e3,
2023-12-19 13:00:59 +01:00
n_test_signal: int = 10e3,
n_test_bkg: int = 20e3,
prepare_data: bool = True,
outdir: str = "nn_electron_training",
):
"""Trains an MLP to classify the match between Velo and Seed track.
Args:
input_file (str, optional): Defaults to "data/ghost_data.root".
tree_name (str, optional): Defaults to "PrMatchNN.PrMCDebugMatchToolNN/Tuple".
exclude_electrons (bool, optional): Defaults to False.
only_electrons (bool, optional): Signal only of electrons, but bkg of all particles. Defaults to True.
filter_velos (bool, optional): Background only electron VELO tracks. Defaults to False.
filter_seeds (bool, optional): Background only electron T tracks. Defaults to False.
2023-12-19 13:00:59 +01:00
n_train_signal (int, optional): Number of true matches to train on. Defaults to 200e3.
n_train_bkg (int, optional): Number of fake matches to train on. Defaults to 200e3.
n_test_signal (int, optional): Number of true matches to test on. Defaults to 75e3.
n_test_bkg (int, optional): Number of fake matches to test on. Defaults to 75e3.
prepare_data (bool, optional): Split data into signal and background file. Defaults to True.
outdir (str, optional): specify the output directory path.
2023-12-19 13:00:59 +01:00
"""
if prepare_data:
rdf = ROOT.RDataFrame(tree_name, input_file)
if exclude_electrons:
print("signal data: exclude electrons.")
2023-12-19 13:00:59 +01:00
rdf_signal = rdf.Filter(
2024-02-08 17:42:15 +01:00
"quality == 1", # -1 elec, 0 ghost, 1 all part wo elec
2023-12-19 13:00:59 +01:00
"Signal is defined as one label (excluding electrons)",
)
print("background data: default ghosts: quality == 0")
2023-12-19 13:00:59 +01:00
rdf_bkg = rdf.Filter(
2024-02-08 17:42:15 +01:00
"quality == 0",
2023-12-19 13:00:59 +01:00
"Ghosts are defined as zero label",
)
else:
if only_electrons:
print("signal data: only electrons.")
2023-12-19 13:00:59 +01:00
rdf_signal = rdf.Filter(
2024-03-27 09:23:35 +01:00
"quality == -1", # && zMag_electron<5800 && zMag_electron>5000", # electron that is true match
2023-12-19 13:00:59 +01:00
"Signal is defined as negative one label (only electrons)",
)
else:
print("signal data: all particles.")
2023-12-19 13:00:59 +01:00
rdf_signal = rdf.Filter(
2024-02-08 17:42:15 +01:00
"abs(quality) > 0",
2023-12-19 13:00:59 +01:00
"Signal is defined as non-zero label",
)
2024-03-27 09:23:35 +01:00
bkg_selection = "(quality == 0) || (quality == 1 && chi2 > 1)" # && zMag_electron<5800 && zMag_electron>5000"
if filter_velos:
bkg_selection += " && velo_isElectron == 1"
2024-02-15 16:45:55 +01:00
if filter_seeds:
bkg_selection += " && scifi_isElectron == 1"
print("background data: selection cuts = " + bkg_selection)
rdf_bkg = rdf.Filter(
bkg_selection,
"Ghosts are defined as zero label",
)
2023-12-19 13:00:59 +01:00
rdf_signal.Snapshot(
"Signal",
outdir + "/" + input_file.strip(".root") + "_matching_signal.root",
)
rdf_bkg.Snapshot(
"Bkg",
outdir + "/" + input_file.strip(".root") + "_matching_bkg.root",
)
signal_file = ROOT.TFile.Open(
outdir + "/" + input_file.strip(".root") + "_matching_signal.root",
"READ",
)
signal_tree = signal_file.Get("Signal")
bkg_file = ROOT.TFile.Open(
outdir + "/" + input_file.strip(".root") + "_matching_bkg.root"
)
bkg_tree = bkg_file.Get("Bkg")
os.chdir(outdir + "/result")
output = ROOT.TFile(
"matching_ghost_mlp_training.root",
"RECREATE",
)
factory = TMVA.Factory(
"TMVAClassification",
output,
"V:!Silent:Color:DrawProgressBar:AnalysisType=Classification",
)
factory.SetVerbose(True)
dataloader = TMVA.DataLoader("MatchNNDataSet")
2024-02-08 17:42:15 +01:00
dataloader.AddVariable("chi2", "F")
dataloader.AddVariable("teta2", "F")
dataloader.AddVariable("distX", "F")
dataloader.AddVariable("distY", "F")
dataloader.AddVariable("dSlope", "F")
2024-02-19 15:41:09 +01:00
dataloader.AddVariable("dSlopeY", "F")
# dataloader.AddVariable("zMag_electron", "F")
2024-03-27 09:23:35 +01:00
# dataloader.AddVariable("yCorr_electron", "F")
# dataloader.AddVariable("std::abs(zMag_electron - zMag_default)", "F")
2024-02-19 15:41:09 +01:00
# dataloader.AddVariable("eta", "F")
2024-02-15 16:45:55 +01:00
# dataloader.AddVariable("dEta", "F")
2023-12-19 13:00:59 +01:00
dataloader.AddSignalTree(signal_tree, 1.0)
dataloader.AddBackgroundTree(bkg_tree, 1.0)
# these cuts are also applied in the algorithm
preselectionCuts = ROOT.TCut(
2024-03-27 09:23:35 +01:00
"chi2<15 && distX<300 && distY<300 && dSlope<2.0 && dSlopeY<0.15",
# "chi2<15 && distX<250 && distY<250 && dSlope<1.5 && dSlopeY<0.15",
2023-12-19 13:00:59 +01:00
)
dataloader.PrepareTrainingAndTestTree(
preselectionCuts,
f"SplitMode=random:V:nTrain_Signal={n_train_signal}:nTrain_Background={n_train_bkg}:nTest_Signal={n_test_signal}:nTest_Background={n_test_bkg}",
# normmode default is EqualNumEvents
)
factory.BookMethod(
dataloader,
TMVA.Types.kMLP,
"matching_mlp",
2024-03-27 09:23:35 +01:00
"!H:V:TrainingMethod=BP:NeuronType=ReLU:EstimatorType=CE:VarTransform=Norm:NCycles=700:HiddenLayers=N+2,N:TestRate=50:Sampling=1.0:SamplingImportance=1.0:LearningRate=0.02:DecayRate=0.01:UseRegulator:!CreateMVAPdfs",
2023-12-19 13:00:59 +01:00
)
factory.TrainAllMethods()
factory.TestAllMethods()
factory.EvaluateAllMethods()
output.Close()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--input-file",
type=str,
help="Path to the input file.",
2023-12-19 13:00:59 +01:00
required=False,
)
parser.add_argument(
"--tree-name",
type=str,
help="Path to the tree.",
required=False,
)
parser.add_argument(
"--exclude-electrons",
2023-12-19 13:00:59 +01:00
action="store_true",
help="Excludes electrons from training.",
required=False,
)
parser.add_argument(
"--only-electrons",
2023-12-19 13:00:59 +01:00
action="store_true",
help="Only electrons for signal training.",
required=False,
)
parser.add_argument(
"--filter-velos",
action="store_true",
help="Only background with electron VELO tracks.",
required=False,
)
parser.add_argument(
"--filter-seeds",
action="store_true",
help="Only background with electron T tracks.",
required=False,
)
2023-12-19 13:00:59 +01:00
parser.add_argument(
"--n-train-signal",
type=int,
help="Number of training tracks for signal.",
required=False,
)
parser.add_argument(
"--n-train-bkg",
type=int,
help="Number of training tracks for bkg.",
required=False,
)
parser.add_argument(
"--n-test-signal",
type=int,
help="Number of testing tracks for signal.",
required=False,
)
parser.add_argument(
"--n-test-bkg",
type=int,
help="Number of testing tracks for bkg.",
required=False,
)
parser.add_argument(
"--prepare-data",
action="store_true",
help="Create signal and background samples.",
required=False,
)
parser.add_argument(
"--outdir",
type=str,
help="Path to the output directory.",
required=False,
)
2023-12-19 13:00:59 +01:00
args = parser.parse_args()
args_dict = {arg: val for arg, val in vars(args).items() if val is not None}
train_matching_ghost_mlp(**args_dict)