# flake8: noqaq import os import argparse import ROOT from ROOT import TMVA, TList, TTree def train_matching_ghost_mlp( input_file: str = "data/ghost_data_B.root", tree_name: str = "PrMatchNN.PrMCDebugMatchToolNN/MVAInputAndOutput", exclude_electrons: bool = False, only_electrons: bool = True, n_train_signal: int = 20e3, # 50e3 n_train_bkg: int = 50e3, # 500e3 n_test_signal: int = 10e3, n_test_bkg: int = 20e3, prepare_data: bool = True, outdir: str = "nn_electron_training", ): """Trains an MLP to classify the match between Velo and Seed track. Args: input_file (str, optional): Defaults to "data/ghost_data.root". tree_name (str, optional): Defaults to "PrMatchNN.PrMCDebugMatchToolNN/Tuple". exclude_electrons (bool, optional): Defaults to False. only_electrons (bool, optional): Signal only of electrons, but bkg of all particles. Defaults to True. n_train_signal (int, optional): Number of true matches to train on. Defaults to 200e3. n_train_bkg (int, optional): Number of fake matches to train on. Defaults to 200e3. n_test_signal (int, optional): Number of true matches to test on. Defaults to 75e3. n_test_bkg (int, optional): Number of fake matches to test on. Defaults to 75e3. prepare_data (bool, optional): Split data into signal and background file. Defaults to False. """ if prepare_data: rdf = ROOT.RDataFrame(tree_name, input_file) if exclude_electrons: rdf_signal = rdf.Filter( "mc_quality == 1", # -1 elec, 0 ghost, 1 all part wo elec "Signal is defined as one label (excluding electrons)", ) rdf_bkg = rdf.Filter( "mc_quality == 0", "Ghosts are defined as zero label", ) else: if only_electrons: rdf_signal = rdf.Filter( "mc_quality == -1", # electron that is true match but mlp said no match "Signal is defined as negative one label (only electrons)", ) else: rdf_signal = rdf.Filter( "abs(mc_quality) > 0", "Signal is defined as non-zero label", ) rdf_bkg = rdf.Filter( "mc_quality == 0", "Ghosts are defined as zero label", ) rdf_signal.Snapshot( "Signal", outdir + "/" + input_file.strip(".root") + "_matching_signal.root", ) rdf_bkg.Snapshot( "Bkg", outdir + "/" + input_file.strip(".root") + "_matching_bkg.root", ) signal_file = ROOT.TFile.Open( outdir + "/" + input_file.strip(".root") + "_matching_signal.root", "READ", ) signal_tree = signal_file.Get("Signal") bkg_file = ROOT.TFile.Open( outdir + "/" + input_file.strip(".root") + "_matching_bkg.root" ) bkg_tree = bkg_file.Get("Bkg") os.chdir(outdir + "/result") output = ROOT.TFile( "matching_ghost_mlp_training.root", "RECREATE", ) factory = TMVA.Factory( "TMVAClassification", output, "V:!Silent:Color:DrawProgressBar:AnalysisType=Classification", ) factory.SetVerbose(True) dataloader = TMVA.DataLoader("MatchNNDataSet") dataloader.AddVariable("mc_chi2", "F") dataloader.AddVariable("mc_teta2", "F") dataloader.AddVariable("mc_distX", "F") dataloader.AddVariable("mc_distY", "F") dataloader.AddVariable("mc_dSlope", "F") dataloader.AddVariable("mc_dSlopeY", "F") dataloader.AddVariable("mc_zMag", "F") dataloader.AddSignalTree(signal_tree, 1.0) dataloader.AddBackgroundTree(bkg_tree, 1.0) # these cuts are also applied in the algorithm preselectionCuts = ROOT.TCut( # "chi2<30 && distX<500 && distY<500 && dSlope<2.0 && dSlopeY<0.15", #### ganz raus für elektronen ) dataloader.PrepareTrainingAndTestTree( preselectionCuts, f"SplitMode=random:V:nTrain_Signal={n_train_signal}:nTrain_Background={n_train_bkg}:nTest_Signal={n_test_signal}:nTest_Background={n_test_bkg}", # normmode default is EqualNumEvents ) factory.BookMethod( dataloader, TMVA.Types.kMLP, "matching_mlp", "!H:V:TrainingMethod=BP:NeuronType=ReLU:EstimatorType=CE:VarTransform=Norm:NCycles=700:HiddenLayers=N+2,N:TestRate=50:Sampling=1.0:SamplingImportance=1.0:LearningRate=0.02:DecayRate=0.01:!UseRegulator", ) factory.TrainAllMethods() factory.TestAllMethods() factory.EvaluateAllMethods() output.Close() if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "--input-file", type=str, help="Path to the input file", required=False, ) parser.add_argument( "--exclude_electrons", action="store_true", help="Excludes electrons from training.", required=False, ) parser.add_argument( "--only_electrons", action="store_true", help="Only electrons for signal training.", required=False, ) parser.add_argument( "--n-train-signal", type=int, help="Number of training tracks for signal.", required=False, ) parser.add_argument( "--n-train-bkg", type=int, help="Number of training tracks for bkg.", required=False, ) parser.add_argument( "--n-test-signal", type=int, help="Number of testing tracks for signal.", required=False, ) parser.add_argument( "--n-test-bkg", type=int, help="Number of testing tracks for bkg.", required=False, ) args = parser.parse_args() args_dict = {arg: val for arg, val in vars(args).items() if val is not None} train_matching_ghost_mlp(**args_dict)