# flake8: noqaq import os import argparse import ROOT from ROOT import TMVA def train_matching_ghost_mlp( input_file: str = "data/ghost_data_B.root", tree_name: str = "PrMatchNN.PrMCDebugMatchToolNN/MVAInputAndOutput", exclude_electrons: bool = True, only_electrons: bool = False, n_train_signal: int = 165e3, n_train_bkg: int = 165e3, n_test_signal: int = 5e3, n_test_bkg: int = 5e3, prepare_data: bool = True, outdir: str = "neural_net_training", ): """Trains an MLP to classify the match between Velo and Seed track. Args: input_file (str, optional): Defaults to "data/ghost_data_B.root". tree_name (str, optional): Defaults to "PrMatchNN.PrMCDebugMatchToolNN/Tuple". exclude_electrons (bool, optional): Defaults to True. n_train_signal (int, optional): Number of true matches to train on. Defaults to 200e3. n_train_bkg (int, optional): Number of fake matches to train on. Defaults to 200e3. n_test_signal (int, optional): Number of true matches to test on. Defaults to 75e3. n_test_bkg (int, optional): Number of fake matches to test on. Defaults to 75e3. prepare_data (bool, optional): Split data into signal and background file. Defaults to False. """ if prepare_data: print("Preparing Data for MVA Training: ") rdf = ROOT.RDataFrame(tree_name, input_file) if exclude_electrons: rdf_signal = rdf.Filter( "mc_quality == 1", # -1 elec, 0 ghost, 1 all part wo elec "Signal is defined as one label (excluding electrons)", ) rdf_bkg = rdf.Filter("mc_quality == 0", "Ghosts are defined as zero label") elif only_electrons: rdf_signal = rdf.Filter( "mc_quality == -1", # -1 elec, 0 ghost, 1 all part wo elec "Signal is defined as one label (excluding electrons)", ) rdf_bkg = rdf.Filter("mc_quality == 0", "Ghosts are defined as zero label") else: rdf_signal = rdf.Filter( "abs(mc_quality) > 0", "Signal is defined as non-zero label", ) rdf_bkg = rdf.Filter("mc_quality == 0", "Ghosts are defined as zero label") print("Create Signal data file. ") rdf_signal.Snapshot( "Signal", outdir + "/" + input_file.strip(".root") + "_matching_signal.root", ) print("Create Background data file. ") rdf_bkg.Snapshot( "Bkg", outdir + "/" + input_file.strip(".root") + "_matching_bkg.root" ) print("Data preparation terminated successfully.\n") signal_file = ROOT.TFile.Open( outdir + "/" + input_file.strip(".root") + "_matching_signal.root", "READ", ) signal_tree = signal_file.Get("Signal") bkg_file = ROOT.TFile.Open( outdir + "/" + input_file.strip(".root") + "_matching_bkg.root" ) bkg_tree = bkg_file.Get("Bkg") os.chdir(outdir + "/result") output = ROOT.TFile( "matching_ghost_mlp_training.root", "RECREATE", ) factory = TMVA.Factory( "TMVAClassification", output, "V:!Silent:Color:DrawProgressBar:AnalysisType=Classification", ) factory.SetVerbose(True) dataloader = TMVA.DataLoader("MatchNNDataSet") dataloader.AddVariable("mc_chi2", "F") dataloader.AddVariable("mc_teta2", "F") dataloader.AddVariable("mc_distX", "F") dataloader.AddVariable("mc_distY", "F") dataloader.AddVariable("mc_dSlope", "F") dataloader.AddVariable("mc_dSlopeY", "F") dataloader.AddVariable("mc_end_velo_phi", "F") dataloader.AddVariable("mc_end_velo_eta", "F") dataloader.AddSignalTree(signal_tree, 1.0) dataloader.AddBackgroundTree(bkg_tree, 1.0) # these cuts are also applied in the algorithm preselectionCuts = ROOT.TCut( "mc_chi2<15 && mc_distX<250 && mc_distY<250 && mc_dSlope<1.5 && mc_dSlopeY<0.15 && std::abs(mc_end_velo_phi)<3.142 && mc_end_velo_eta>1.99 && mc_end_velo_eta<5.01", #### ganz raus für elektronen ) dataloader.PrepareTrainingAndTestTree( preselectionCuts, f"SplitMode=random:V:nTrain_Signal={n_train_signal}:nTrain_Background={n_train_bkg}:nTest_Signal={n_test_signal}:nTest_Background={n_test_bkg}", ) factory.BookMethod( dataloader, TMVA.Types.kMLP, "matching_mlp", "!H:V:TrainingMethod=BP:NeuronType=ReLU:EstimatorType=CE:VarTransform=Norm:NCycles=700:HiddenLayers=N+2,N:TestRate=50:Sampling=1.0:SamplingImportance=1.0:LearningRate=0.02:DecayRate=0.01:UseRegulator", ) factory.TrainAllMethods() factory.TestAllMethods() factory.EvaluateAllMethods() output.Close() if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "--input-file", type=str, help="Path to the input file", required=False, ) parser.add_argument( "--exclude_electrons", action="store_true", help="Excludes electrons from training.", required=False, ) parser.add_argument( "--n-train-signal", type=int, help="Number of training tracks for signal.", required=False, ) parser.add_argument( "--n-train-bkg", type=int, help="Number of training tracks for bkg.", required=False, ) parser.add_argument( "--n-test-signal", type=int, help="Number of testing tracks for signal.", required=False, ) parser.add_argument( "--n-test-bkg", type=int, help="Number of testing tracks for bkg.", required=False, ) args = parser.parse_args() args_dict = {arg: val for arg, val in vars(args).items() if val is not None} train_matching_ghost_mlp(**args_dict)