178 lines
5.9 KiB
Python
178 lines
5.9 KiB
Python
|
# flake8: noqaq
|
||
|
|
||
|
import os
|
||
|
import argparse
|
||
|
import ROOT
|
||
|
from ROOT import TMVA, TList, TTree
|
||
|
|
||
|
|
||
|
def train_matching_ghost_mlp(
|
||
|
input_file: str = "data/ghost_data_B.root",
|
||
|
tree_name: str = "PrMatchNN.PrMCDebugMatchToolNN/MVAInputAndOutput",
|
||
|
exclude_electrons: bool = False,
|
||
|
only_electrons: bool = True,
|
||
|
n_train_signal: int = 20e3, # 50e3
|
||
|
n_train_bkg: int = 50e3, # 500e3
|
||
|
n_test_signal: int = 10e3,
|
||
|
n_test_bkg: int = 20e3,
|
||
|
prepare_data: bool = True,
|
||
|
outdir: str = "nn_electron_training",
|
||
|
):
|
||
|
"""Trains an MLP to classify the match between Velo and Seed track.
|
||
|
|
||
|
Args:
|
||
|
input_file (str, optional): Defaults to "data/ghost_data.root".
|
||
|
tree_name (str, optional): Defaults to "PrMatchNN.PrMCDebugMatchToolNN/Tuple".
|
||
|
exclude_electrons (bool, optional): Defaults to False.
|
||
|
only_electrons (bool, optional): Signal only of electrons, but bkg of all particles. Defaults to True.
|
||
|
n_train_signal (int, optional): Number of true matches to train on. Defaults to 200e3.
|
||
|
n_train_bkg (int, optional): Number of fake matches to train on. Defaults to 200e3.
|
||
|
n_test_signal (int, optional): Number of true matches to test on. Defaults to 75e3.
|
||
|
n_test_bkg (int, optional): Number of fake matches to test on. Defaults to 75e3.
|
||
|
prepare_data (bool, optional): Split data into signal and background file. Defaults to False.
|
||
|
"""
|
||
|
|
||
|
if prepare_data:
|
||
|
rdf = ROOT.RDataFrame(tree_name, input_file)
|
||
|
if exclude_electrons:
|
||
|
rdf_signal = rdf.Filter(
|
||
|
"mc_quality == 1", # -1 elec, 0 ghost, 1 all part wo elec
|
||
|
"Signal is defined as one label (excluding electrons)",
|
||
|
)
|
||
|
rdf_bkg = rdf.Filter(
|
||
|
"mc_quality == 0",
|
||
|
"Ghosts are defined as zero label",
|
||
|
)
|
||
|
else:
|
||
|
if only_electrons:
|
||
|
rdf_signal = rdf.Filter(
|
||
|
"mc_quality == -1", # electron that is true match but mlp said no match
|
||
|
"Signal is defined as negative one label (only electrons)",
|
||
|
)
|
||
|
else:
|
||
|
rdf_signal = rdf.Filter(
|
||
|
"abs(mc_quality) > 0",
|
||
|
"Signal is defined as non-zero label",
|
||
|
)
|
||
|
rdf_bkg = rdf.Filter(
|
||
|
"mc_quality == 0",
|
||
|
"Ghosts are defined as zero label",
|
||
|
)
|
||
|
|
||
|
rdf_signal.Snapshot(
|
||
|
"Signal",
|
||
|
outdir + "/" + input_file.strip(".root") + "_matching_signal.root",
|
||
|
)
|
||
|
rdf_bkg.Snapshot(
|
||
|
"Bkg",
|
||
|
outdir + "/" + input_file.strip(".root") + "_matching_bkg.root",
|
||
|
)
|
||
|
|
||
|
signal_file = ROOT.TFile.Open(
|
||
|
outdir + "/" + input_file.strip(".root") + "_matching_signal.root",
|
||
|
"READ",
|
||
|
)
|
||
|
signal_tree = signal_file.Get("Signal")
|
||
|
|
||
|
bkg_file = ROOT.TFile.Open(
|
||
|
outdir + "/" + input_file.strip(".root") + "_matching_bkg.root"
|
||
|
)
|
||
|
bkg_tree = bkg_file.Get("Bkg")
|
||
|
|
||
|
os.chdir(outdir + "/result")
|
||
|
output = ROOT.TFile(
|
||
|
"matching_ghost_mlp_training.root",
|
||
|
"RECREATE",
|
||
|
)
|
||
|
|
||
|
factory = TMVA.Factory(
|
||
|
"TMVAClassification",
|
||
|
output,
|
||
|
"V:!Silent:Color:DrawProgressBar:AnalysisType=Classification",
|
||
|
)
|
||
|
factory.SetVerbose(True)
|
||
|
dataloader = TMVA.DataLoader("MatchNNDataSet")
|
||
|
|
||
|
dataloader.AddVariable("mc_chi2", "F")
|
||
|
dataloader.AddVariable("mc_teta2", "F")
|
||
|
dataloader.AddVariable("mc_distX", "F")
|
||
|
dataloader.AddVariable("mc_distY", "F")
|
||
|
dataloader.AddVariable("mc_dSlope", "F")
|
||
|
dataloader.AddVariable("mc_dSlopeY", "F")
|
||
|
dataloader.AddVariable("mc_zMag", "F")
|
||
|
|
||
|
dataloader.AddSignalTree(signal_tree, 1.0)
|
||
|
dataloader.AddBackgroundTree(bkg_tree, 1.0)
|
||
|
|
||
|
# these cuts are also applied in the algorithm
|
||
|
preselectionCuts = ROOT.TCut(
|
||
|
# "chi2<30 && distX<500 && distY<500 && dSlope<2.0 && dSlopeY<0.15", #### ganz raus für elektronen
|
||
|
)
|
||
|
dataloader.PrepareTrainingAndTestTree(
|
||
|
preselectionCuts,
|
||
|
f"SplitMode=random:V:nTrain_Signal={n_train_signal}:nTrain_Background={n_train_bkg}:nTest_Signal={n_test_signal}:nTest_Background={n_test_bkg}",
|
||
|
# normmode default is EqualNumEvents
|
||
|
)
|
||
|
|
||
|
factory.BookMethod(
|
||
|
dataloader,
|
||
|
TMVA.Types.kMLP,
|
||
|
"matching_mlp",
|
||
|
"!H:V:TrainingMethod=BP:NeuronType=ReLU:EstimatorType=CE:VarTransform=Norm:NCycles=700:HiddenLayers=N+2,N:TestRate=50:Sampling=1.0:SamplingImportance=1.0:LearningRate=0.02:DecayRate=0.01:!UseRegulator",
|
||
|
)
|
||
|
factory.TrainAllMethods()
|
||
|
factory.TestAllMethods()
|
||
|
factory.EvaluateAllMethods()
|
||
|
output.Close()
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
parser = argparse.ArgumentParser()
|
||
|
parser.add_argument(
|
||
|
"--input-file",
|
||
|
type=str,
|
||
|
help="Path to the input file",
|
||
|
required=False,
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"--exclude_electrons",
|
||
|
action="store_true",
|
||
|
help="Excludes electrons from training.",
|
||
|
required=False,
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"--only_electrons",
|
||
|
action="store_true",
|
||
|
help="Only electrons for signal training.",
|
||
|
required=False,
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"--n-train-signal",
|
||
|
type=int,
|
||
|
help="Number of training tracks for signal.",
|
||
|
required=False,
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"--n-train-bkg",
|
||
|
type=int,
|
||
|
help="Number of training tracks for bkg.",
|
||
|
required=False,
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"--n-test-signal",
|
||
|
type=int,
|
||
|
help="Number of testing tracks for signal.",
|
||
|
required=False,
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"--n-test-bkg",
|
||
|
type=int,
|
||
|
help="Number of testing tracks for bkg.",
|
||
|
required=False,
|
||
|
)
|
||
|
|
||
|
args = parser.parse_args()
|
||
|
args_dict = {arg: val for arg, val in vars(args).items() if val is not None}
|
||
|
|
||
|
train_matching_ghost_mlp(**args_dict)
|