tracking-parametrisation-tuner/parameterisations/train_forward_ghost_mlps.py

262 lines
10 KiB
Python
Raw Normal View History

2023-12-19 13:00:59 +01:00
import os
import argparse
import ROOT
from ROOT import TMVA
def train_default_forward_ghost_mlp(
input_file: str = "data/ghost_data.root",
tree_name: str = "PrForwardTrackingVelo.PrMCDebugForwardTool/MVAInput",
exclude_electrons: bool = False,
n_train_signal: int = 300e3,
n_train_bkg: int = 300e3,
n_test_signal: int = 50e3,
n_test_bkg: int = 50e3,
prepare_data: bool = False,
):
"""Trains an MLP to classify track candidates from PrForwardTrackingVelo as ghost or Long Track.
Args:
input_file (str, optional): Defaults to "data/ghost_data.root".
tree_name (str, optional): Defaults to "PrForwardTrackingVelo.PrMCDebugForwardTool/MVAInput".
exclude_electrons (bool, optional): Defaults to False.
n_train_signal (int, optional): Number of true tracks to train on. Defaults to 750e3.
n_train_bkg (int, optional): Number of fake tracks to train on. Defaults to 750e3.
n_test_signal (int, optional): Number of true tracks to test on. Defaults to 50e3.
n_test_bkg (int, optional): umber of fake tracks to test on. Defaults to 50e3.
prepare_data (bool, optional): Split data into signal and background file. Defaults to False.
"""
if prepare_data:
rdf = ROOT.RDataFrame(tree_name, input_file)
if exclude_electrons:
rdf_signal = rdf.Filter(
"label == 1",
"Signal is defined as one label (excluding electrons)",
)
rdf_bkg = rdf.Filter("label == 0", "Ghosts are defined as zero label")
else:
rdf_signal = rdf.Filter("label > 0", "Signal is defined as non-zero label")
rdf_bkg = rdf.Filter("label == 0", "Ghosts are defined as zero label")
rdf_signal.Snapshot(
"Signal",
input_file.strip(".root") + "_forward_signal.root",
)
rdf_bkg.Snapshot("Bkg", input_file.strip(".root") + "_forward_bkg.root")
signal_file = ROOT.TFile.Open(
input_file.strip(".root") + "_forward_signal.root",
"READ",
)
signal_tree = signal_file.Get("Signal")
bkg_file = ROOT.TFile.Open(input_file.strip(".root") + "_forward_bkg.root")
bkg_tree = bkg_file.Get("Bkg")
os.chdir("neural_net_training/result")
output = ROOT.TFile(
"default_forward_ghost_mlp_training.root",
"RECREATE",
)
factory = TMVA.Factory(
"TMVAClassification",
output,
"V:!Silent:Color:DrawProgressBar:AnalysisType=Classification",
)
factory.SetVerbose(True)
dataloader = TMVA.DataLoader("GhostNNDataSet")
dataloader.AddVariable("redChi2", "F")
dataloader.AddVariable(
"distXMatch := abs((x + ( zMagMatch - 770.0 ) * tx) - (xEndT + ( zMagMatch - 9410.0 ) * txEndT))",
"F",
)
dataloader.AddVariable(
"distY := abs(ySeedMatch - yEndT)",
"F",
)
dataloader.AddVariable("abs(yParam0Final-yParam0Init)", "F")
dataloader.AddVariable("abs(yParam1Final-yParam1Init)", "F")
dataloader.AddVariable("abs(ty)", "F")
dataloader.AddVariable("abs(qop)", "F")
dataloader.AddVariable("abs(tx)", "F")
dataloader.AddVariable("abs(xParam1Final-xParam1Init)", "F")
dataloader.AddSignalTree(signal_tree, 1.0)
dataloader.AddBackgroundTree(bkg_tree, 1.0)
preselectionCuts = ROOT.TCut(
"redChi2 < 8 && abs((x + ( zMagMatch - 770.0 ) * tx) - (xEndT + ( zMagMatch - 9410.0 ) * txEndT)) < 140 && abs(ySeedMatch - yEndT) < 500 && abs(yParam0Final-yParam0Init) < 140 && abs(yParam1Final-yParam1Init) < 0.055",
)
dataloader.PrepareTrainingAndTestTree(
preselectionCuts,
f"NormMode=NumEvents:SplitMode=random:V:nTrain_Signal={n_train_signal}:nTrain_Background={n_train_bkg}:nTest_Signal={n_test_signal}:nTest_Background={n_test_bkg}",
)
factory.BookMethod(
dataloader,
TMVA.Types.kMLP,
"default_forward_ghost_mlp",
"!H:V:TrainingMethod=BP:NeuronType=ReLU:EstimatorType=CE:VarTransform=Norm:NCycles=750:HiddenLayers=N+4,N+2:TestRate=50:Sampling=1.0:SamplingImportance=1.0:LearningRate=0.02:DecayRate=0.005:!UseRegulator",
)
factory.TrainAllMethods()
factory.TestAllMethods()
factory.EvaluateAllMethods()
output.Close()
def train_veloUT_forward_ghost_mlp(
input_file: str = "data/ghost_data.root",
tree_name: str = "PrForwardTracking.PrMCDebugForwardTool/MVAInput",
exclude_electrons: bool = False,
n_train_signal: int = 300e3,
n_train_bkg: int = 300e3,
n_test_signal: int = 50e3,
n_test_bkg: int = 50e3,
prepare_data: bool = False,
):
"""Trains an MLP to classify track candidates from PrForwardTracking as ghost or Long Track.
Args:
input_file (str, optional): Defaults to "data/ghost_data.root".
tree_name (str, optional): Defaults to "PrForwardTracking.PrMCDebugForwardTool/MVAInput".
exclude_electrons (bool, optional): Defaults to False.
n_train_signal (int, optional): Number of true tracks to train on. Defaults to 750e3.
n_train_bkg (int, optional): Number of fake tracks to train on. Defaults to 750e3.
n_test_signal (int, optional): Number of true tracks to test on. Defaults to 50e3.
n_test_bkg (int, optional): umber of fake tracks to test on. Defaults to 50e3.
prepare_data (bool, optional): Split data into signal and background file. Defaults to False.
"""
if prepare_data:
rdf = ROOT.RDataFrame(tree_name, input_file)
if exclude_electrons:
rdf_signal = rdf.Filter(
"label == 1",
"Signal is defined as one label (excluding electrons)",
)
rdf_bkg = rdf.Filter("label == 0", "Ghosts are defined as zero label")
else:
rdf_signal = rdf.Filter("label > 0", "Signal is defined as non-zero label")
rdf_bkg = rdf.Filter("label == 0", "Ghosts are defined as zero label")
rdf_signal.Snapshot(
"Signal",
input_file.strip(".root") + "_forward_velout_signal.root",
)
rdf_bkg.Snapshot("Bkg", input_file.strip(".root") + "_forward_velout_bkg.root")
signal_file = ROOT.TFile.Open(
input_file.strip(".root") + "_forward_velout_signal.root",
"READ",
)
signal_tree = signal_file.Get("Signal")
bkg_file = ROOT.TFile.Open(input_file.strip(".root") + "_forward_velout_bkg.root")
bkg_tree = bkg_file.Get("Bkg")
os.chdir("neural_net_training/result")
output = ROOT.TFile(
"veloUT_forward_ghost_mlp_training.root",
"RECREATE",
)
factory = TMVA.Factory(
"TMVAClassification",
output,
"V:!Silent:Color:DrawProgressBar:AnalysisType=Classification",
)
factory.SetVerbose(True)
dataloader = TMVA.DataLoader("GhostNNDataSet")
dataloader.AddVariable("dMom := log(abs((1.0/qop) - (1.0/qopUT) ))", "F")
dataloader.AddVariable("redChi2", "F")
dataloader.AddVariable(
"distXMatch := abs((x + ( zMagMatch - 770.0 ) * tx) - (xEndT + ( zMagMatch - 9410.0 ) * txEndT))",
"F",
)
dataloader.AddVariable(
"distY := abs(ySeedMatch - yEndT)",
"F",
)
dataloader.AddVariable("abs(yParam0Final-yParam0Init)", "F")
dataloader.AddVariable("abs(yParam1Final-yParam1Init)", "F")
dataloader.AddVariable("abs(ty)", "F")
dataloader.AddVariable("abs(qop)", "F")
dataloader.AddVariable("abs(tx)", "F")
dataloader.AddVariable("abs(xParam1Final-xParam1Init)", "F")
dataloader.AddSignalTree(signal_tree, 1.0)
dataloader.AddBackgroundTree(bkg_tree, 1.0)
preselectionCuts = ROOT.TCut(
"redChi2 < 8 && abs((x + ( zMagMatch - 770.0 ) * tx) - (xEndT + ( zMagMatch - 9410.0 ) * txEndT)) < 140 && abs(ySeedMatch - yEndT) < 500 && abs(yParam0Final-yParam0Init) < 140 && abs(yParam1Final-yParam1Init) < 0.055",
)
dataloader.PrepareTrainingAndTestTree(
preselectionCuts,
f"NormMode=NumEvents:SplitMode=random:V:nTrain_Signal={n_train_signal}:nTrain_Background={n_train_bkg}:nTest_Signal={n_test_signal}:nTest_Background={n_test_bkg}",
)
factory.BookMethod(
dataloader,
TMVA.Types.kMLP,
"veloUT_forward_ghost_mlp",
"!H:V:TrainingMethod=BP:NeuronType=ReLU:EstimatorType=CE:VarTransform=Norm:NCycles=550:HiddenLayers=N+2:TestRate=50:Sampling=1.0:SamplingImportance=1.0:LearningRate=0.02:DecayRate=0.005:!UseRegulator",
)
factory.TrainAllMethods()
factory.TestAllMethods()
factory.EvaluateAllMethods()
output.Close()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--input-file",
type=str,
help="Path to the input file",
required=False,
)
parser.add_argument(
"--exclude_electrons",
action="store_true",
help="Excludes electrons from training.",
required=False,
)
parser.add_argument(
"--n-train-signal",
type=int,
help="Number of training tracks for signal.",
required=False,
)
parser.add_argument(
"--n-train-bkg",
type=int,
help="Number of training tracks for bkg.",
required=False,
)
parser.add_argument(
"--n-test-signal",
type=int,
help="Number of testing tracks for signal.",
required=False,
)
parser.add_argument(
"--n-test-bkg",
type=int,
help="Number of testing tracks for bkg.",
required=False,
)
parser.add_argument(
"--veloUT",
action="store_true",
help="Toggle whether the NN for upstream tracks input is trained.",
)
parser.add_argument(
"--all",
action="store_true",
help="Toggle whether both NNs are trained, for VELO and VeloUT input.",
)
args = parser.parse_args()
args_dict = {
arg: val
for arg, val in vars(args).items()
if val is not None and arg not in ["veloUT", "all"]
}
if args.all:
train_default_forward_ghost_mlp(**args_dict)
train_veloUT_forward_ghost_mlp(**args_dict)
elif args.veloUT:
train_veloUT_forward_ghost_mlp(**args_dict)
else:
train_default_forward_ghost_mlp(**args_dict)