You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 

161 lines
5.6 KiB

# flake8: noqaq
import os
import argparse
import ROOT
from ROOT import TMVA
def train_matching_ghost_mlp(
input_file: str = "data/ghost_data_B.root",
tree_name: str = "PrMatchNN.PrMCDebugMatchToolNN/MVAInputAndOutput",
exclude_electrons: bool = True,
only_electrons: bool = False,
n_train_signal: int = 100e3,
n_train_bkg: int = 100e3,
n_test_signal: int = 50e3,
n_test_bkg: int = 50e3,
prepare_data: bool = True,
outdir: str = "neural_net_training",
):
"""Trains an MLP to classify the match between Velo and Seed track.
Args:
input_file (str, optional): Defaults to "data/ghost_data_B.root".
tree_name (str, optional): Defaults to "PrMatchNN.PrMCDebugMatchToolNN/Tuple".
exclude_electrons (bool, optional): Defaults to True.
n_train_signal (int, optional): Number of true matches to train on. Defaults to 200e3.
n_train_bkg (int, optional): Number of fake matches to train on. Defaults to 200e3.
n_test_signal (int, optional): Number of true matches to test on. Defaults to 75e3.
n_test_bkg (int, optional): Number of fake matches to test on. Defaults to 75e3.
prepare_data (bool, optional): Split data into signal and background file. Defaults to False.
"""
if prepare_data:
rdf = ROOT.RDataFrame(tree_name, input_file)
if exclude_electrons:
rdf_signal = rdf.Filter(
"mc_quality == 1", # -1 elec, 0 ghost, 1 all part wo elec
"Signal is defined as one label (excluding electrons)",
)
rdf_bkg = rdf.Filter("mc_quality == 0", "Ghosts are defined as zero label")
elif only_electrons:
rdf_signal = rdf.Filter(
"mc_quality == -1", # -1 elec, 0 ghost, 1 all part wo elec
"Signal is defined as one label (excluding electrons)",
)
rdf_bkg = rdf.Filter("mc_quality == 0", "Ghosts are defined as zero label")
else:
rdf_signal = rdf.Filter(
"abs(mc_quality) > 0",
"Signal is defined as non-zero label",
)
rdf_bkg = rdf.Filter("mc_quality == 0", "Ghosts are defined as zero label")
rdf_signal.Snapshot(
"Signal",
outdir + "/" + input_file.strip(".root") + "_matching_signal.root",
)
rdf_bkg.Snapshot(
"Bkg", outdir + "/" + input_file.strip(".root") + "_matching_bkg.root"
)
signal_file = ROOT.TFile.Open(
outdir + "/" + input_file.strip(".root") + "_matching_signal.root",
"READ",
)
signal_tree = signal_file.Get("Signal")
bkg_file = ROOT.TFile.Open(
outdir + "/" + input_file.strip(".root") + "_matching_bkg.root"
)
bkg_tree = bkg_file.Get("Bkg")
os.chdir(outdir + "/result")
output = ROOT.TFile(
"matching_ghost_mlp_training.root",
"RECREATE",
)
factory = TMVA.Factory(
"TMVAClassification",
output,
"V:!Silent:Color:DrawProgressBar:AnalysisType=Classification",
)
factory.SetVerbose(True)
dataloader = TMVA.DataLoader("MatchNNDataSet")
dataloader.AddVariable("mc_chi2", "F")
dataloader.AddVariable("mc_teta2", "F")
dataloader.AddVariable("mc_distX", "F")
dataloader.AddVariable("mc_distY", "F")
dataloader.AddVariable("mc_dSlope", "F")
dataloader.AddVariable("mc_dSlopeY", "F")
# dataloader.AddVariable("tx", "F")
# dataloader.AddVariable("ty", "F")
# dataloader.AddVariable("tx_scifi", "F")
# dataloader.AddVariable("ty_scifi", "F")
dataloader.AddSignalTree(signal_tree, 1.0)
dataloader.AddBackgroundTree(bkg_tree, 1.0)
# these cuts are also applied in the algorithm
preselectionCuts = ROOT.TCut(
"mc_chi2<30 && mc_distX<500 && mc_distY<500 && mc_dSlope<3.0 && mc_dSlopeY<0.3", #### ganz raus für elektronen
)
dataloader.PrepareTrainingAndTestTree(
preselectionCuts,
f"SplitMode=random:V:nTrain_Signal={n_train_signal}:nTrain_Background={n_train_bkg}:nTest_Signal={n_test_signal}:nTest_Background={n_test_bkg}",
)
factory.BookMethod(
dataloader,
TMVA.Types.kMLP,
"matching_mlp",
"!H:V:TrainingMethod=BP:NeuronType=ReLU:EstimatorType=CE:VarTransform=Norm:NCycles=700:HiddenLayers=N+2,N:TestRate=50:Sampling=1.0:SamplingImportance=1.0:LearningRate=0.02:DecayRate=0.01:!UseRegulator",
)
factory.TrainAllMethods()
factory.TestAllMethods()
factory.EvaluateAllMethods()
output.Close()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--input-file",
type=str,
help="Path to the input file",
required=False,
)
parser.add_argument(
"--exclude_electrons",
action="store_true",
help="Excludes electrons from training.",
required=False,
)
parser.add_argument(
"--n-train-signal",
type=int,
help="Number of training tracks for signal.",
required=False,
)
parser.add_argument(
"--n-train-bkg",
type=int,
help="Number of training tracks for bkg.",
required=False,
)
parser.add_argument(
"--n-test-signal",
type=int,
help="Number of testing tracks for signal.",
required=False,
)
parser.add_argument(
"--n-test-bkg",
type=int,
help="Number of testing tracks for bkg.",
required=False,
)
args = parser.parse_args()
args_dict = {arg: val for arg, val in vars(args).items() if val is not None}
train_matching_ghost_mlp(**args_dict)