190 lines
6.3 KiB
Python
190 lines
6.3 KiB
Python
|
# flake8: noqaq
|
||
|
|
||
|
import os
|
||
|
import argparse
|
||
|
import ROOT
|
||
|
from ROOT import TMVA, TList, TTree, TMath
|
||
|
|
||
|
|
||
|
def train_matching_ghost_mlp(
|
||
|
input_file: str = "data/tracking_losses_ntuple_B.root",
|
||
|
tree_name: str = "PrDebugTrackingLosses.PrDebugTrackingTool/Tuple",
|
||
|
b_input_file: str = "data/ghost_data_B.root",
|
||
|
b_tree_name: str = "PrMatchNN_3e224c41.PrMCDebugMatchToolNN/MVAInputAndOutput",
|
||
|
only_electrons: bool = True,
|
||
|
n_train_signal: int = 2e3, # 50e3
|
||
|
n_train_bkg: int = 5e3, # 500e3
|
||
|
n_test_signal: int = 1e3,
|
||
|
n_test_bkg: int = 2e3,
|
||
|
prepare_data: bool = True,
|
||
|
outdir: str = "nn_trackinglosses_training",
|
||
|
):
|
||
|
"""Trains an MLP to classify the match between Velo and Seed track.
|
||
|
|
||
|
Args:
|
||
|
input_file (str, optional): Defaults to "data/ghost_data.root".
|
||
|
tree_name (str, optional): Defaults to "PrMatchNN.PrMCDebugMatchToolNN/Tuple".
|
||
|
exclude_electrons (bool, optional): Defaults to False.
|
||
|
only_electrons (bool, optional): Signal only of electrons, but bkg of all particles. Defaults to True.
|
||
|
residuals (bool, optional): Signal only of mlp<0.215. Defaults to False.
|
||
|
n_train_signal (int, optional): Number of true matches to train on. Defaults to 200e3.
|
||
|
n_train_bkg (int, optional): Number of fake matches to train on. Defaults to 200e3.
|
||
|
n_test_signal (int, optional): Number of true matches to test on. Defaults to 75e3.
|
||
|
n_test_bkg (int, optional): Number of fake matches to test on. Defaults to 75e3.
|
||
|
prepare_data (bool, optional): Split data into signal and background file. Defaults to False.
|
||
|
"""
|
||
|
# vec = ROOT.std.vector("string")(13)
|
||
|
colList = [
|
||
|
"mc_chi2",
|
||
|
"mc_teta2",
|
||
|
"mc_distX",
|
||
|
"mc_distY",
|
||
|
"mc_dSlope",
|
||
|
"mc_dSlopeY",
|
||
|
"mc_quality",
|
||
|
"mc_end_velo_qop",
|
||
|
"mc_end_velo_tx",
|
||
|
"mc_end_velo_ty",
|
||
|
"mc_end_t_qop",
|
||
|
"mc_end_t_tx",
|
||
|
"mc_end_t_ty",
|
||
|
]
|
||
|
# for i in range(13):
|
||
|
# vec[i] = colList[i]
|
||
|
|
||
|
if prepare_data:
|
||
|
rdf = ROOT.RDataFrame(tree_name, input_file)
|
||
|
rdf_b = ROOT.RDataFrame(b_tree_name, b_input_file)
|
||
|
if only_electrons:
|
||
|
rdf_signal = rdf.Filter(
|
||
|
"mc_quality == -1 && lost == 0 && fromSignal == 1", # electron that is true match but mlp said no match
|
||
|
"Signal is defined as one label (only electrons)",
|
||
|
)
|
||
|
else:
|
||
|
rdf_signal = rdf.Filter(
|
||
|
"lost == 0 && fromSignal == 1",
|
||
|
"Signal is defined as non-zero label",
|
||
|
)
|
||
|
rdf_bkg = rdf_b.Filter(
|
||
|
"mc_quality == 0",
|
||
|
"Ghosts are defined as zero label",
|
||
|
)
|
||
|
|
||
|
rdf_signal.Snapshot(
|
||
|
"Signal",
|
||
|
outdir + "/" + input_file.strip(".root") + "_matching_signal.root",
|
||
|
colList,
|
||
|
)
|
||
|
rdf_bkg.Snapshot(
|
||
|
"Bkg",
|
||
|
outdir + "/" + input_file.strip(".root") + "_matching_bkg.root",
|
||
|
)
|
||
|
|
||
|
signal_file = ROOT.TFile.Open(
|
||
|
outdir + "/" + input_file.strip(".root") + "_matching_signal.root",
|
||
|
"READ",
|
||
|
)
|
||
|
signal_tree = signal_file.Get("Signal")
|
||
|
|
||
|
bkg_file = ROOT.TFile.Open(
|
||
|
outdir + "/" + input_file.strip(".root") + "_matching_bkg.root"
|
||
|
)
|
||
|
bkg_tree = bkg_file.Get("Bkg")
|
||
|
|
||
|
os.chdir(outdir + "/result")
|
||
|
output = ROOT.TFile(
|
||
|
"matching_ghost_mlp_training.root",
|
||
|
"RECREATE",
|
||
|
)
|
||
|
|
||
|
factory = TMVA.Factory(
|
||
|
"TMVAClassification",
|
||
|
output,
|
||
|
"V:!Silent:Color:DrawProgressBar:AnalysisType=Classification",
|
||
|
)
|
||
|
factory.SetVerbose(True)
|
||
|
dataloader = TMVA.DataLoader("MatchNNDataSet")
|
||
|
|
||
|
dataloader.AddVariable("mc_chi2", "F")
|
||
|
dataloader.AddVariable("mc_teta2", "F")
|
||
|
dataloader.AddVariable("mc_distX", "F")
|
||
|
dataloader.AddVariable("mc_distY", "F")
|
||
|
dataloader.AddVariable("mc_dSlope", "F")
|
||
|
dataloader.AddVariable("mc_dSlopeY", "F")
|
||
|
|
||
|
dataloader.AddSignalTree(signal_tree, 1.0)
|
||
|
dataloader.AddBackgroundTree(bkg_tree, 1.0)
|
||
|
|
||
|
# these cuts are also applied in the algorithm
|
||
|
preselectionCuts = ROOT.TCut(
|
||
|
"!TMath::IsNaN(mc_chi2) && !TMath::IsNaN(mc_distX) && !TMath::IsNaN(mc_distY) && !TMath::IsNaN(mc_dSlope) && !TMath::IsNaN(mc_dSlopeY) && !TMath::IsNaN(mc_teta2)",
|
||
|
# "mc_chi2<30 && mc_distX<500 && mc_distY<500 && mc_dSlope<2.0 && mc_dSlopeY<0.15", #### ganz raus für elektronen
|
||
|
)
|
||
|
dataloader.PrepareTrainingAndTestTree(
|
||
|
preselectionCuts,
|
||
|
f"SplitMode=random:V:nTrain_Signal={n_train_signal}:nTrain_Background={n_train_bkg}:nTest_Signal={n_test_signal}:nTest_Background={n_test_bkg}",
|
||
|
# normmode default is EqualNumEvents
|
||
|
)
|
||
|
|
||
|
factory.BookMethod(
|
||
|
dataloader,
|
||
|
TMVA.Types.kMLP,
|
||
|
"matching_mlp",
|
||
|
"!H:V:TrainingMethod=BP:NeuronType=ReLU:EstimatorType=CE:VarTransform=Norm:NCycles=700:HiddenLayers=N+2,N:TestRate=50:Sampling=1.0:SamplingImportance=1.0:LearningRate=0.02:DecayRate=0.01:UseRegulator",
|
||
|
)
|
||
|
factory.TrainAllMethods()
|
||
|
factory.TestAllMethods()
|
||
|
factory.EvaluateAllMethods()
|
||
|
output.Close()
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
parser = argparse.ArgumentParser()
|
||
|
parser.add_argument(
|
||
|
"--input-file",
|
||
|
type=str,
|
||
|
help="Path to the input file",
|
||
|
required=False,
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"--exclude_electrons",
|
||
|
action="store_true",
|
||
|
help="Excludes electrons from training.",
|
||
|
required=False,
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"--only_electrons",
|
||
|
action="store_true",
|
||
|
help="Only electrons for signal training.",
|
||
|
required=False,
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"--n-train-signal",
|
||
|
type=int,
|
||
|
help="Number of training tracks for signal.",
|
||
|
required=False,
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"--n-train-bkg",
|
||
|
type=int,
|
||
|
help="Number of training tracks for bkg.",
|
||
|
required=False,
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"--n-test-signal",
|
||
|
type=int,
|
||
|
help="Number of testing tracks for signal.",
|
||
|
required=False,
|
||
|
)
|
||
|
parser.add_argument(
|
||
|
"--n-test-bkg",
|
||
|
type=int,
|
||
|
help="Number of testing tracks for bkg.",
|
||
|
required=False,
|
||
|
)
|
||
|
|
||
|
args = parser.parse_args()
|
||
|
args_dict = {arg: val for arg, val in vars(args).items() if val is not None}
|
||
|
|
||
|
train_matching_ghost_mlp(**args_dict)
|