cetin
7c2194df23
and GEC Filter to eff options files trained network with correct parameterisation sample4 new effs with sample4 NN weights
226 lines
7.5 KiB
Python
226 lines
7.5 KiB
Python
# flake8: noqaq
|
|
|
|
import os
|
|
import argparse
|
|
import ROOT
|
|
from ROOT import TMVA, TList, TTree
|
|
|
|
|
|
def train_matching_ghost_mlp(
|
|
input_file: str = "data/ghost_data_B.root",
|
|
tree_name: str = "PrMatchNN.PrMCDebugMatchToolNN/MVAInputAndOutput",
|
|
exclude_electrons: bool = False,
|
|
only_electrons: bool = True,
|
|
filter_velos: bool = False,
|
|
filter_seeds: bool = False,
|
|
n_train_signal: int = 50e3,
|
|
n_train_bkg: int = 50e3,
|
|
n_test_signal: int = 10e3,
|
|
n_test_bkg: int = 20e3,
|
|
prepare_data: bool = True,
|
|
outdir: str = "nn_electron_training",
|
|
):
|
|
"""Trains an MLP to classify the match between Velo and Seed track.
|
|
|
|
Args:
|
|
input_file (str, optional): Defaults to "data/ghost_data.root".
|
|
tree_name (str, optional): Defaults to "PrMatchNN.PrMCDebugMatchToolNN/Tuple".
|
|
exclude_electrons (bool, optional): Defaults to False.
|
|
only_electrons (bool, optional): Signal only of electrons, but bkg of all particles. Defaults to True.
|
|
filter_velos (bool, optional): Background only electron VELO tracks. Defaults to False.
|
|
filter_seeds (bool, optional): Background only electron T tracks. Defaults to False.
|
|
n_train_signal (int, optional): Number of true matches to train on. Defaults to 200e3.
|
|
n_train_bkg (int, optional): Number of fake matches to train on. Defaults to 200e3.
|
|
n_test_signal (int, optional): Number of true matches to test on. Defaults to 75e3.
|
|
n_test_bkg (int, optional): Number of fake matches to test on. Defaults to 75e3.
|
|
prepare_data (bool, optional): Split data into signal and background file. Defaults to True.
|
|
outdir (str, optional): specify the output directory path.
|
|
"""
|
|
|
|
if prepare_data:
|
|
rdf = ROOT.RDataFrame(tree_name, input_file)
|
|
if exclude_electrons:
|
|
print("signal data: exclude electrons.")
|
|
rdf_signal = rdf.Filter(
|
|
"quality == 1", # -1 elec, 0 ghost, 1 all part wo elec
|
|
"Signal is defined as one label (excluding electrons)",
|
|
)
|
|
print("background data: default ghosts: quality == 0")
|
|
rdf_bkg = rdf.Filter(
|
|
"quality == 0",
|
|
"Ghosts are defined as zero label",
|
|
)
|
|
else:
|
|
if only_electrons:
|
|
print("signal data: only electrons.")
|
|
rdf_signal = rdf.Filter(
|
|
"quality == -1", # electron that is true match
|
|
"Signal is defined as negative one label (only electrons)",
|
|
)
|
|
else:
|
|
print("signal data: all particles.")
|
|
rdf_signal = rdf.Filter(
|
|
"abs(quality) > 0",
|
|
"Signal is defined as non-zero label",
|
|
)
|
|
bkg_selection = "quality == 0"
|
|
if filter_velos:
|
|
bkg_selection += " && velo_isElectron == 1"
|
|
if filter_seeds:
|
|
bkg_selection += " && scifi_isElectron == 1"
|
|
print("background data: selection cuts = " + bkg_selection)
|
|
rdf_bkg = rdf.Filter(
|
|
bkg_selection,
|
|
"Ghosts are defined as zero label",
|
|
)
|
|
|
|
rdf_signal.Snapshot(
|
|
"Signal",
|
|
outdir + "/" + input_file.strip(".root") + "_matching_signal.root",
|
|
)
|
|
rdf_bkg.Snapshot(
|
|
"Bkg",
|
|
outdir + "/" + input_file.strip(".root") + "_matching_bkg.root",
|
|
)
|
|
|
|
signal_file = ROOT.TFile.Open(
|
|
outdir + "/" + input_file.strip(".root") + "_matching_signal.root",
|
|
"READ",
|
|
)
|
|
signal_tree = signal_file.Get("Signal")
|
|
|
|
bkg_file = ROOT.TFile.Open(
|
|
outdir + "/" + input_file.strip(".root") + "_matching_bkg.root"
|
|
)
|
|
bkg_tree = bkg_file.Get("Bkg")
|
|
|
|
os.chdir(outdir + "/result")
|
|
output = ROOT.TFile(
|
|
"matching_ghost_mlp_training.root",
|
|
"RECREATE",
|
|
)
|
|
|
|
factory = TMVA.Factory(
|
|
"TMVAClassification",
|
|
output,
|
|
"V:!Silent:Color:DrawProgressBar:AnalysisType=Classification",
|
|
)
|
|
factory.SetVerbose(True)
|
|
dataloader = TMVA.DataLoader("MatchNNDataSet")
|
|
|
|
dataloader.AddVariable("chi2", "F")
|
|
dataloader.AddVariable("teta2", "F")
|
|
dataloader.AddVariable("distX", "F")
|
|
dataloader.AddVariable("distY", "F")
|
|
dataloader.AddVariable("dSlope", "F")
|
|
dataloader.AddVariable("dSlopeY", "F")
|
|
# dataloader.AddVariable("zmag", "F")
|
|
# dataloader.AddVariable("eta", "F")
|
|
# dataloader.AddVariable("dEta", "F")
|
|
|
|
dataloader.AddSignalTree(signal_tree, 1.0)
|
|
dataloader.AddBackgroundTree(bkg_tree, 1.0)
|
|
|
|
# these cuts are also applied in the algorithm
|
|
preselectionCuts = ROOT.TCut(
|
|
# "chi2<30 && distX<500 && distY<500 && dSlope<2.0 && dSlopeY<0.15", #### ganz raus für elektronen
|
|
# "eta>2 && eta<5 && eta_scifi>2 && eta_scifi<5"
|
|
)
|
|
dataloader.PrepareTrainingAndTestTree(
|
|
preselectionCuts,
|
|
f"SplitMode=random:V:nTrain_Signal={n_train_signal}:nTrain_Background={n_train_bkg}:nTest_Signal={n_test_signal}:nTest_Background={n_test_bkg}",
|
|
# normmode default is EqualNumEvents
|
|
)
|
|
|
|
factory.BookMethod(
|
|
dataloader,
|
|
TMVA.Types.kMLP,
|
|
"matching_mlp",
|
|
"!H:V:TrainingMethod=BP:NeuronType=ReLU:EstimatorType=CE:VarTransform=Norm:NCycles=700:HiddenLayers=N+2,N:TestRate=50:Sampling=1.0:SamplingImportance=1.0:LearningRate=0.02:DecayRate=0.01:!UseRegulator",
|
|
)
|
|
factory.TrainAllMethods()
|
|
factory.TestAllMethods()
|
|
factory.EvaluateAllMethods()
|
|
output.Close()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
"--input-file",
|
|
type=str,
|
|
help="Path to the input file.",
|
|
required=False,
|
|
)
|
|
parser.add_argument(
|
|
"--tree-name",
|
|
type=str,
|
|
help="Path to the tree.",
|
|
required=False,
|
|
)
|
|
parser.add_argument(
|
|
"--exclude-electrons",
|
|
action="store_true",
|
|
help="Excludes electrons from training.",
|
|
required=False,
|
|
)
|
|
parser.add_argument(
|
|
"--only-electrons",
|
|
action="store_true",
|
|
help="Only electrons for signal training.",
|
|
required=False,
|
|
)
|
|
parser.add_argument(
|
|
"--filter-velos",
|
|
action="store_true",
|
|
help="Only background with electron VELO tracks.",
|
|
required=False,
|
|
)
|
|
parser.add_argument(
|
|
"--filter-seeds",
|
|
action="store_true",
|
|
help="Only background with electron T tracks.",
|
|
required=False,
|
|
)
|
|
parser.add_argument(
|
|
"--n-train-signal",
|
|
type=int,
|
|
help="Number of training tracks for signal.",
|
|
required=False,
|
|
)
|
|
parser.add_argument(
|
|
"--n-train-bkg",
|
|
type=int,
|
|
help="Number of training tracks for bkg.",
|
|
required=False,
|
|
)
|
|
parser.add_argument(
|
|
"--n-test-signal",
|
|
type=int,
|
|
help="Number of testing tracks for signal.",
|
|
required=False,
|
|
)
|
|
parser.add_argument(
|
|
"--n-test-bkg",
|
|
type=int,
|
|
help="Number of testing tracks for bkg.",
|
|
required=False,
|
|
)
|
|
parser.add_argument(
|
|
"--prepare-data",
|
|
action="store_true",
|
|
help="Create signal and background samples.",
|
|
required=False,
|
|
)
|
|
parser.add_argument(
|
|
"--outdir",
|
|
type=str,
|
|
help="Path to the output directory.",
|
|
required=False,
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
args_dict = {arg: val for arg, val in vars(args).items() if val is not None}
|
|
|
|
train_matching_ghost_mlp(**args_dict)
|