You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

189 lines
6.3 KiB

# flake8: noqaq
import os
import argparse
import ROOT
from ROOT import TMVA, TList, TTree, TMath
def train_matching_ghost_mlp(
input_file: str = "data/tracking_losses_ntuple_B.root",
tree_name: str = "PrDebugTrackingLosses.PrDebugTrackingTool/Tuple",
b_input_file: str = "data/ghost_data_B.root",
b_tree_name: str = "PrMatchNN_3e224c41.PrMCDebugMatchToolNN/MVAInputAndOutput",
only_electrons: bool = True,
n_train_signal: int = 2e3, # 50e3
n_train_bkg: int = 5e3, # 500e3
n_test_signal: int = 1e3,
n_test_bkg: int = 2e3,
prepare_data: bool = True,
outdir: str = "nn_trackinglosses_training",
):
"""Trains an MLP to classify the match between Velo and Seed track.
Args:
input_file (str, optional): Defaults to "data/ghost_data.root".
tree_name (str, optional): Defaults to "PrMatchNN.PrMCDebugMatchToolNN/Tuple".
exclude_electrons (bool, optional): Defaults to False.
only_electrons (bool, optional): Signal only of electrons, but bkg of all particles. Defaults to True.
residuals (bool, optional): Signal only of mlp<0.215. Defaults to False.
n_train_signal (int, optional): Number of true matches to train on. Defaults to 200e3.
n_train_bkg (int, optional): Number of fake matches to train on. Defaults to 200e3.
n_test_signal (int, optional): Number of true matches to test on. Defaults to 75e3.
n_test_bkg (int, optional): Number of fake matches to test on. Defaults to 75e3.
prepare_data (bool, optional): Split data into signal and background file. Defaults to False.
"""
# vec = ROOT.std.vector("string")(13)
colList = [
"mc_chi2",
"mc_teta2",
"mc_distX",
"mc_distY",
"mc_dSlope",
"mc_dSlopeY",
"mc_quality",
"mc_end_velo_qop",
"mc_end_velo_tx",
"mc_end_velo_ty",
"mc_end_t_qop",
"mc_end_t_tx",
"mc_end_t_ty",
]
# for i in range(13):
# vec[i] = colList[i]
if prepare_data:
rdf = ROOT.RDataFrame(tree_name, input_file)
rdf_b = ROOT.RDataFrame(b_tree_name, b_input_file)
if only_electrons:
rdf_signal = rdf.Filter(
"mc_quality == -1 && lost == 0 && fromSignal == 1", # electron that is true match but mlp said no match
"Signal is defined as one label (only electrons)",
)
else:
rdf_signal = rdf.Filter(
"lost == 0 && fromSignal == 1",
"Signal is defined as non-zero label",
)
rdf_bkg = rdf_b.Filter(
"mc_quality == 0",
"Ghosts are defined as zero label",
)
rdf_signal.Snapshot(
"Signal",
outdir + "/" + input_file.strip(".root") + "_matching_signal.root",
colList,
)
rdf_bkg.Snapshot(
"Bkg",
outdir + "/" + input_file.strip(".root") + "_matching_bkg.root",
)
signal_file = ROOT.TFile.Open(
outdir + "/" + input_file.strip(".root") + "_matching_signal.root",
"READ",
)
signal_tree = signal_file.Get("Signal")
bkg_file = ROOT.TFile.Open(
outdir + "/" + input_file.strip(".root") + "_matching_bkg.root"
)
bkg_tree = bkg_file.Get("Bkg")
os.chdir(outdir + "/result")
output = ROOT.TFile(
"matching_ghost_mlp_training.root",
"RECREATE",
)
factory = TMVA.Factory(
"TMVAClassification",
output,
"V:!Silent:Color:DrawProgressBar:AnalysisType=Classification",
)
factory.SetVerbose(True)
dataloader = TMVA.DataLoader("MatchNNDataSet")
dataloader.AddVariable("mc_chi2", "F")
dataloader.AddVariable("mc_teta2", "F")
dataloader.AddVariable("mc_distX", "F")
dataloader.AddVariable("mc_distY", "F")
dataloader.AddVariable("mc_dSlope", "F")
dataloader.AddVariable("mc_dSlopeY", "F")
dataloader.AddSignalTree(signal_tree, 1.0)
dataloader.AddBackgroundTree(bkg_tree, 1.0)
# these cuts are also applied in the algorithm
preselectionCuts = ROOT.TCut(
"!TMath::IsNaN(mc_chi2) && !TMath::IsNaN(mc_distX) && !TMath::IsNaN(mc_distY) && !TMath::IsNaN(mc_dSlope) && !TMath::IsNaN(mc_dSlopeY) && !TMath::IsNaN(mc_teta2)",
# "mc_chi2<30 && mc_distX<500 && mc_distY<500 && mc_dSlope<2.0 && mc_dSlopeY<0.15", #### ganz raus für elektronen
)
dataloader.PrepareTrainingAndTestTree(
preselectionCuts,
f"SplitMode=random:V:nTrain_Signal={n_train_signal}:nTrain_Background={n_train_bkg}:nTest_Signal={n_test_signal}:nTest_Background={n_test_bkg}",
# normmode default is EqualNumEvents
)
factory.BookMethod(
dataloader,
TMVA.Types.kMLP,
"matching_mlp",
"!H:V:TrainingMethod=BP:NeuronType=ReLU:EstimatorType=CE:VarTransform=Norm:NCycles=700:HiddenLayers=N+2,N:TestRate=50:Sampling=1.0:SamplingImportance=1.0:LearningRate=0.02:DecayRate=0.01:UseRegulator",
)
factory.TrainAllMethods()
factory.TestAllMethods()
factory.EvaluateAllMethods()
output.Close()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--input-file",
type=str,
help="Path to the input file",
required=False,
)
parser.add_argument(
"--exclude_electrons",
action="store_true",
help="Excludes electrons from training.",
required=False,
)
parser.add_argument(
"--only_electrons",
action="store_true",
help="Only electrons for signal training.",
required=False,
)
parser.add_argument(
"--n-train-signal",
type=int,
help="Number of training tracks for signal.",
required=False,
)
parser.add_argument(
"--n-train-bkg",
type=int,
help="Number of training tracks for bkg.",
required=False,
)
parser.add_argument(
"--n-test-signal",
type=int,
help="Number of testing tracks for signal.",
required=False,
)
parser.add_argument(
"--n-test-bkg",
type=int,
help="Number of testing tracks for bkg.",
required=False,
)
args = parser.parse_args()
args_dict = {arg: val for arg, val in vars(args).items() if val is not None}
train_matching_ghost_mlp(**args_dict)