# flake8: noqaq # ruff: noqa import os import argparse import ROOT from ROOT import TMVA, TList, TTree def train_matching_ghost_mlp( input_file: str = "data/ghost_data_B.root", tree_name: str = "PrMatchNN.PrMCDebugMatchToolNN/MVAInputAndOutput", exclude_electrons: bool = False, only_electrons: bool = True, filter_velos: bool = False, filter_seeds: bool = False, n_train_signal: int = 50e3, n_train_bkg: int = 50e3, n_test_signal: int = 10e3, n_test_bkg: int = 20e3, prepare_data: bool = True, outdir: str = "nn_electron_training", ): """Trains an MLP to classify the match between Velo and Seed track. Args: input_file (str, optional): Defaults to "data/ghost_data.root". tree_name (str, optional): Defaults to "PrMatchNN.PrMCDebugMatchToolNN/Tuple". exclude_electrons (bool, optional): Defaults to False. only_electrons (bool, optional): Signal only of electrons, but bkg of all particles. Defaults to True. filter_velos (bool, optional): Background only electron VELO tracks. Defaults to False. filter_seeds (bool, optional): Background only electron T tracks. Defaults to False. n_train_signal (int, optional): Number of true matches to train on. Defaults to 200e3. n_train_bkg (int, optional): Number of fake matches to train on. Defaults to 200e3. n_test_signal (int, optional): Number of true matches to test on. Defaults to 75e3. n_test_bkg (int, optional): Number of fake matches to test on. Defaults to 75e3. prepare_data (bool, optional): Split data into signal and background file. Defaults to True. outdir (str, optional): specify the output directory path. """ if prepare_data: rdf = ROOT.RDataFrame(tree_name, input_file) if exclude_electrons: print("signal data: exclude electrons.") rdf_signal = rdf.Filter( "quality == 1", # -1 elec, 0 ghost, 1 all part wo elec "Signal is defined as one label (excluding electrons)", ) print("background data: default ghosts: quality == 0") rdf_bkg = rdf.Filter( "quality == 0", "Ghosts are defined as zero label", ) else: if only_electrons: print("signal data: only electrons.") rdf_signal = rdf.Filter( "quality == -1", # electron that is true match "Signal is defined as negative one label (only electrons)", ) else: print("signal data: all particles.") rdf_signal = rdf.Filter( "abs(quality) > 0", "Signal is defined as non-zero label", ) bkg_selection = "quality >= 0" if filter_velos: bkg_selection += " && velo_isElectron == 1" if filter_seeds: bkg_selection += " && scifi_isElectron == 1" print("background data: selection cuts = " + bkg_selection) rdf_bkg = rdf.Filter( bkg_selection, "Ghosts are defined as zero label", ) rdf_signal.Snapshot( "Signal", outdir + "/" + input_file.strip(".root") + "_matching_signal.root", ) rdf_bkg.Snapshot( "Bkg", outdir + "/" + input_file.strip(".root") + "_matching_bkg.root", ) signal_file = ROOT.TFile.Open( outdir + "/" + input_file.strip(".root") + "_matching_signal.root", "READ", ) signal_tree = signal_file.Get("Signal") bkg_file = ROOT.TFile.Open( outdir + "/" + input_file.strip(".root") + "_matching_bkg.root" ) bkg_tree = bkg_file.Get("Bkg") os.chdir(outdir + "/result") output = ROOT.TFile( "matching_ghost_mlp_training.root", "RECREATE", ) factory = TMVA.Factory( "TMVAClassification", output, "V:!Silent:Color:DrawProgressBar:AnalysisType=Classification", ) factory.SetVerbose(True) dataloader = TMVA.DataLoader("MatchNNDataSet") dataloader.AddVariable("chi2", "F") dataloader.AddVariable("teta2", "F") dataloader.AddVariable("distX", "F") dataloader.AddVariable("distY", "F") dataloader.AddVariable("dSlope", "F") dataloader.AddVariable("dSlopeY", "F") # dataloader.AddVariable("zMag_electron", "F") # dataloader.AddVariable("eta", "F") # dataloader.AddVariable("dEta", "F") dataloader.AddSignalTree(signal_tree, 1.0) dataloader.AddBackgroundTree(bkg_tree, 1.0) # these cuts are also applied in the algorithm preselectionCuts = ROOT.TCut( "chi2<15 && distX<250 && distY<250 && dSlope<1.5 && dSlopeY<0.15", # && zMag_electron<6000 && zMag_electron>4500", ) dataloader.PrepareTrainingAndTestTree( preselectionCuts, f"SplitMode=random:V:nTrain_Signal={n_train_signal}:nTrain_Background={n_train_bkg}:nTest_Signal={n_test_signal}:nTest_Background={n_test_bkg}", # normmode default is EqualNumEvents ) factory.BookMethod( dataloader, TMVA.Types.kMLP, "matching_mlp", "!H:V:TrainingMethod=BP:NeuronType=ReLU:EstimatorType=CE:VarTransform=Norm:NCycles=700:HiddenLayers=N+2,N:TestRate=50:Sampling=1.0:SamplingImportance=1.0:LearningRate=0.02:DecayRate=0.01:!UseRegulator", ) factory.TrainAllMethods() factory.TestAllMethods() factory.EvaluateAllMethods() output.Close() if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument( "--input-file", type=str, help="Path to the input file.", required=False, ) parser.add_argument( "--tree-name", type=str, help="Path to the tree.", required=False, ) parser.add_argument( "--exclude-electrons", action="store_true", help="Excludes electrons from training.", required=False, ) parser.add_argument( "--only-electrons", action="store_true", help="Only electrons for signal training.", required=False, ) parser.add_argument( "--filter-velos", action="store_true", help="Only background with electron VELO tracks.", required=False, ) parser.add_argument( "--filter-seeds", action="store_true", help="Only background with electron T tracks.", required=False, ) parser.add_argument( "--n-train-signal", type=int, help="Number of training tracks for signal.", required=False, ) parser.add_argument( "--n-train-bkg", type=int, help="Number of training tracks for bkg.", required=False, ) parser.add_argument( "--n-test-signal", type=int, help="Number of testing tracks for signal.", required=False, ) parser.add_argument( "--n-test-bkg", type=int, help="Number of testing tracks for bkg.", required=False, ) parser.add_argument( "--prepare-data", action="store_true", help="Create signal and background samples.", required=False, ) parser.add_argument( "--outdir", type=str, help="Path to the output directory.", required=False, ) args = parser.parse_args() args_dict = {arg: val for arg, val in vars(args).items() if val is not None} train_matching_ghost_mlp(**args_dict)