tracking-parametrisation-tuner/parameterisations/train_matching_ghost_mlps_electron.py

190 lines
6.2 KiB
Python
Raw Normal View History

2023-12-19 13:00:59 +01:00
# flake8: noqaq
import os
import argparse
import ROOT
from ROOT import TMVA, TList, TTree
def train_matching_ghost_mlp(
input_file: str = "data/ghost_data_B.root",
tree_name: str = "PrMatchNN.PrMCDebugMatchToolNN/MVAInputAndOutput",
exclude_electrons: bool = False,
only_electrons: bool = True,
2024-02-15 16:45:55 +01:00
filter_seeds: bool = False,
2024-02-08 17:42:15 +01:00
n_train_signal: int = 50e3, # 50e3
2023-12-19 13:00:59 +01:00
n_train_bkg: int = 50e3, # 500e3
n_test_signal: int = 10e3,
n_test_bkg: int = 20e3,
prepare_data: bool = True,
outdir: str = "nn_electron_training",
):
"""Trains an MLP to classify the match between Velo and Seed track.
Args:
input_file (str, optional): Defaults to "data/ghost_data.root".
tree_name (str, optional): Defaults to "PrMatchNN.PrMCDebugMatchToolNN/Tuple".
exclude_electrons (bool, optional): Defaults to False.
only_electrons (bool, optional): Signal only of electrons, but bkg of all particles. Defaults to True.
n_train_signal (int, optional): Number of true matches to train on. Defaults to 200e3.
n_train_bkg (int, optional): Number of fake matches to train on. Defaults to 200e3.
n_test_signal (int, optional): Number of true matches to test on. Defaults to 75e3.
n_test_bkg (int, optional): Number of fake matches to test on. Defaults to 75e3.
prepare_data (bool, optional): Split data into signal and background file. Defaults to False.
"""
if prepare_data:
rdf = ROOT.RDataFrame(tree_name, input_file)
if exclude_electrons:
rdf_signal = rdf.Filter(
2024-02-08 17:42:15 +01:00
"quality == 1", # -1 elec, 0 ghost, 1 all part wo elec
2023-12-19 13:00:59 +01:00
"Signal is defined as one label (excluding electrons)",
)
rdf_bkg = rdf.Filter(
2024-02-08 17:42:15 +01:00
"quality == 0",
2023-12-19 13:00:59 +01:00
"Ghosts are defined as zero label",
)
else:
if only_electrons:
rdf_signal = rdf.Filter(
2024-02-08 17:42:15 +01:00
"quality == -1", # electron that is true match
2023-12-19 13:00:59 +01:00
"Signal is defined as negative one label (only electrons)",
)
else:
rdf_signal = rdf.Filter(
2024-02-08 17:42:15 +01:00
"abs(quality) > 0",
2023-12-19 13:00:59 +01:00
"Signal is defined as non-zero label",
)
2024-02-15 16:45:55 +01:00
if filter_seeds:
rdf_bkg = rdf.Filter(
"quality == 0 && scifi_isElectron == 1",
"Ghosts are defined as zero label",
)
else:
rdf_bkg = rdf.Filter(
"quality == 0",
"Ghosts are defined as zero label",
)
2023-12-19 13:00:59 +01:00
rdf_signal.Snapshot(
"Signal",
outdir + "/" + input_file.strip(".root") + "_matching_signal.root",
)
rdf_bkg.Snapshot(
"Bkg",
outdir + "/" + input_file.strip(".root") + "_matching_bkg.root",
)
signal_file = ROOT.TFile.Open(
outdir + "/" + input_file.strip(".root") + "_matching_signal.root",
"READ",
)
signal_tree = signal_file.Get("Signal")
bkg_file = ROOT.TFile.Open(
outdir + "/" + input_file.strip(".root") + "_matching_bkg.root"
)
bkg_tree = bkg_file.Get("Bkg")
os.chdir(outdir + "/result")
output = ROOT.TFile(
"matching_ghost_mlp_training.root",
"RECREATE",
)
factory = TMVA.Factory(
"TMVAClassification",
output,
"V:!Silent:Color:DrawProgressBar:AnalysisType=Classification",
)
factory.SetVerbose(True)
dataloader = TMVA.DataLoader("MatchNNDataSet")
2024-02-08 17:42:15 +01:00
dataloader.AddVariable("chi2", "F")
dataloader.AddVariable("teta2", "F")
dataloader.AddVariable("distX", "F")
dataloader.AddVariable("distY", "F")
dataloader.AddVariable("dSlope", "F")
2024-02-15 16:45:55 +01:00
# dataloader.AddVariable("dSlopeY", "F")
# dataloader.AddVariable("zmag", "F")
dataloader.AddVariable("eta", "F")
# dataloader.AddVariable("dEta", "F")
# dataloader.AddVariable("eta_scifi", "F")
2023-12-19 13:00:59 +01:00
dataloader.AddSignalTree(signal_tree, 1.0)
dataloader.AddBackgroundTree(bkg_tree, 1.0)
# these cuts are also applied in the algorithm
preselectionCuts = ROOT.TCut(
# "chi2<30 && distX<500 && distY<500 && dSlope<2.0 && dSlopeY<0.15", #### ganz raus für elektronen
)
dataloader.PrepareTrainingAndTestTree(
preselectionCuts,
f"SplitMode=random:V:nTrain_Signal={n_train_signal}:nTrain_Background={n_train_bkg}:nTest_Signal={n_test_signal}:nTest_Background={n_test_bkg}",
# normmode default is EqualNumEvents
)
factory.BookMethod(
dataloader,
TMVA.Types.kMLP,
"matching_mlp",
"!H:V:TrainingMethod=BP:NeuronType=ReLU:EstimatorType=CE:VarTransform=Norm:NCycles=700:HiddenLayers=N+2,N:TestRate=50:Sampling=1.0:SamplingImportance=1.0:LearningRate=0.02:DecayRate=0.01:!UseRegulator",
)
factory.TrainAllMethods()
factory.TestAllMethods()
factory.EvaluateAllMethods()
output.Close()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--input-file",
type=str,
help="Path to the input file",
required=False,
)
parser.add_argument(
"--exclude_electrons",
action="store_true",
help="Excludes electrons from training.",
required=False,
)
parser.add_argument(
"--only_electrons",
action="store_true",
help="Only electrons for signal training.",
required=False,
)
parser.add_argument(
"--n-train-signal",
type=int,
help="Number of training tracks for signal.",
required=False,
)
parser.add_argument(
"--n-train-bkg",
type=int,
help="Number of training tracks for bkg.",
required=False,
)
parser.add_argument(
"--n-test-signal",
type=int,
help="Number of testing tracks for signal.",
required=False,
)
parser.add_argument(
"--n-test-bkg",
type=int,
help="Number of testing tracks for bkg.",
required=False,
)
args = parser.parse_args()
args_dict = {arg: val for arg, val in vars(args).items() if val is not None}
train_matching_ghost_mlp(**args_dict)