# flake8: noqaq import os import argparse import ROOT from ROOT import TMVA, TList, TTree, TMath def train_matching_ghost_mlp( input_file: str = "data/tracking_losses_ntuple_B.root", tree_name: str = "PrDebugTrackingLosses.PrDebugTrackingTool/Tuple", b_input_file: str = "data/ghost_data_B.root", b_tree_name: str = "PrMatchNN_3e224c41.PrMCDebugMatchToolNN/MVAInputAndOutput", only_electrons: bool = True, n_train_signal: int = 2e3, # 50e3 n_train_bkg: int = 5e3, # 500e3 n_test_signal: int = 1e3, n_test_bkg: int = 2e3, prepare_data: bool = True, outdir: str = "nn_trackinglosses_training", ): """Trains an MLP to classify the match between Velo and Seed track. Args: input_file (str, optional): Defaults to "data/ghost_data.root". tree_name (str, optional): Defaults to "PrMatchNN.PrMCDebugMatchToolNN/Tuple". exclude_electrons (bool, optional): Defaults to False. only_electrons (bool, optional): Signal only of electrons, but bkg of all particles. Defaults to True. residuals (bool, optional): Signal only of mlp<0.215. Defaults to False. n_train_signal (int, optional): Number of true matches to train on. Defaults to 200e3. n_train_bkg (int, optional): Number of fake matches to train on. Defaults to 200e3. n_test_signal (int, optional): Number of true matches to test on. Defaults to 75e3. n_test_bkg (int, optional): Number of fake matches to test on. Defaults to 75e3. prepare_data (bool, optional): Split data into signal and background file. Defaults to False. """ # vec = ROOT.std.vector("string")(13) colList = [ "mc_chi2", "mc_teta2", "mc_distX", "mc_distY", "mc_dSlope", "mc_dSlopeY", "mc_quality", "mc_end_velo_qop", "mc_end_velo_tx", "mc_end_velo_ty", "mc_end_t_qop", "mc_end_t_tx", "mc_end_t_ty", ] # for i in range(13): # vec[i] = colList[i] if prepare_data: rdf = ROOT.RDataFrame(tree_name, input_file) rdf_b = ROOT.RDataFrame(b_tree_name, b_input_file) if only_electrons: rdf_signal = rdf.Filter( "mc_quality == -1 && lost == 0 && fromSignal == 1", # electron that is true match but mlp said no match "Signal is defined as one label (only electrons)", ) else: rdf_signal = rdf.Filter( "lost == 0 && fromSignal == 1", "Signal is defined as non-zero label", ) rdf_bkg = rdf_b.Filter( "mc_quality == 0", "Ghosts are defined as zero label", ) rdf_signal.Snapshot( "Signal", outdir + "/" + input_file.strip(".root") + "_matching_signal.root", colList, ) rdf_bkg.Snapshot( "Bkg", outdir + "/" + input_file.strip(".root") + "_matching_bkg.root", ) signal_file = ROOT.TFile.Open( outdir + "/" + input_file.strip(".root") + "_matching_signal.root", "READ", ) signal_tree = signal_file.Get("Signal") bkg_file = ROOT.TFile.Open( outdir + "/" + input_file.strip(".root") + "_matching_bkg.root" ) bkg_tree = bkg_file.Get("Bkg") os.chdir(outdir + "/result") output = ROOT.TFile( "matching_ghost_mlp_training.root", "RECREATE", ) factory = TMVA.Factory( "TMVAClassification", output, "V:!Silent:Color:DrawProgressBar:AnalysisType=Classification", ) factory.SetVerbose(True) dataloader = TMVA.DataLoader("MatchNNDataSet") dataloader.AddVariable("mc_chi2", "F") dataloader.AddVariable("mc_teta2", "F") dataloader.AddVariable("mc_distX", "F") dataloader.AddVariable("mc_distY", "F") dataloader.AddVariable("mc_dSlope", "F") dataloader.AddVariable("mc_dSlopeY", "F") dataloader.AddSignalTree(signal_tree, 1.0) dataloader.AddBackgroundTree(bkg_tree, 1.0) # these cuts are also applied in the algorithm preselectionCuts = ROOT.TCut( "!TMath::IsNaN(mc_chi2) && !TMath::IsNaN(mc_distX) && !TMath::IsNaN(mc_distY) && !TMath::IsNaN(mc_dSlope) && !TMath::IsNaN(mc_dSlopeY) && !TMath::IsNaN(mc_teta2)", # "mc_chi2<30 && mc_distX<500 && mc_distY<500 && mc_dSlope<2.0 && mc_dSlopeY<0.15", #### ganz raus für elektronen ) dataloader.PrepareTrainingAndTestTree( preselectionCuts, f"SplitMode=random:V:nTrain_Signal={n_train_signal}:nTrain_Background={n_train_bkg}:nTest_Signal={n_test_signal}:nTest_Background={n_test_bkg}", # normmode default is EqualNumEvents ) factory.BookMethod( dataloader, TMVA.Types.kMLP, "matching_mlp", "!H:V:TrainingMethod=BP:NeuronType=ReLU:EstimatorType=CE:VarTransform=Norm:NCycles=700:HiddenLayers=N+2,N:TestRate=50:Sampling=1.0:SamplingImportance=1.0:LearningRate=0.02:DecayRate=0.01:UseRegulator", ) factory.TrainAllMethods() factory.TestAllMethods() factory.EvaluateAllMethods() output.Close() if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "--input-file", type=str, help="Path to the input file", required=False, ) parser.add_argument( "--exclude_electrons", action="store_true", help="Excludes electrons from training.", required=False, ) parser.add_argument( "--only_electrons", action="store_true", help="Only electrons for signal training.", required=False, ) parser.add_argument( "--n-train-signal", type=int, help="Number of training tracks for signal.", required=False, ) parser.add_argument( "--n-train-bkg", type=int, help="Number of training tracks for bkg.", required=False, ) parser.add_argument( "--n-test-signal", type=int, help="Number of testing tracks for signal.", required=False, ) parser.add_argument( "--n-test-bkg", type=int, help="Number of testing tracks for bkg.", required=False, ) args = parser.parse_args() args_dict = {arg: val for arg, val in vars(args).items() if val is not None} train_matching_ghost_mlp(**args_dict)