matching alg for electrons
This commit is contained in:
parent
6cca1422d6
commit
4b7aeb35d2
@ -171,10 +171,6 @@ if args.matching_weights and args.residuals:
|
|||||||
|
|
||||||
file_name = "seed"
|
file_name = "seed"
|
||||||
|
|
||||||
input_files = {}
|
|
||||||
input_files["seed"] = "data/ghost_data_B_default_only_e_as_seed.root"
|
|
||||||
input_files["def"] = "data/ghost_data_B_default_thesis.root"
|
|
||||||
|
|
||||||
tree_names = {}
|
tree_names = {}
|
||||||
tree_names["seed"] = "PrMatchNN_b60a058d.PrMCDebugMatchToolNN/MVAInputAndOutput"
|
tree_names["seed"] = "PrMatchNN_b60a058d.PrMCDebugMatchToolNN/MVAInputAndOutput"
|
||||||
tree_names["def"] = "PrMatchNN.PrMCDebugMatchToolNN/MVAInputAndOutput"
|
tree_names["def"] = "PrMatchNN.PrMCDebugMatchToolNN/MVAInputAndOutput"
|
||||||
@ -183,15 +179,16 @@ if args.matching_weights and not args.residuals:
|
|||||||
os.chdir(os.path.dirname(os.path.realpath(__file__)))
|
os.chdir(os.path.dirname(os.path.realpath(__file__)))
|
||||||
train_matching_ghost_mlp(
|
train_matching_ghost_mlp(
|
||||||
prepare_data=args.prepare,
|
prepare_data=args.prepare,
|
||||||
input_file="data/ghost_data_B_default_thesis.root",
|
input_file="data/ghost_data_B_vars_thesis.root",
|
||||||
tree_name=tree_names[file_name], # B: 3e224c41
|
tree_name=tree_names[file_name],
|
||||||
exclude_electrons=True,
|
exclude_electrons=False,
|
||||||
only_electrons=False,
|
only_electrons=True,
|
||||||
|
filter_seeds=True,
|
||||||
outdir="nn_electron_training",
|
outdir="nn_electron_training",
|
||||||
n_train_signal=150e3,
|
n_train_signal=100e3,
|
||||||
n_train_bkg=150e3,
|
n_train_bkg=100e3,
|
||||||
n_test_signal=20e3,
|
n_test_signal=10e3,
|
||||||
n_test_bkg=20e3,
|
n_test_bkg=10e3,
|
||||||
)
|
)
|
||||||
# this ensures that the directory is correct
|
# this ensures that the directory is correct
|
||||||
os.chdir(os.path.dirname(os.path.realpath(__file__)))
|
os.chdir(os.path.dirname(os.path.realpath(__file__)))
|
||||||
|
@ -34,7 +34,7 @@ import glob
|
|||||||
options.evt_max = -1
|
options.evt_max = -1
|
||||||
|
|
||||||
decay = "B" # D, B
|
decay = "B" # D, B
|
||||||
options.ntuple_file = f"data/ghost_data_{decay}_default_thesis.root"
|
options.ntuple_file = f"data/ghost_data_{decay}_vars_thesis.root"
|
||||||
|
|
||||||
options.input_type = "ROOT"
|
options.input_type = "ROOT"
|
||||||
if decay == "B":
|
if decay == "B":
|
||||||
|
@ -38,7 +38,7 @@ tested by mc_matching_example.py
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
decay = "B"
|
decay = "test"
|
||||||
|
|
||||||
options.evt_max = -1
|
options.evt_max = -1
|
||||||
|
|
||||||
@ -61,7 +61,7 @@ options.dddb_tag = "dddb-20210617"
|
|||||||
options.simulation = True
|
options.simulation = True
|
||||||
options.input_type = "ROOT"
|
options.input_type = "ROOT"
|
||||||
|
|
||||||
options.ntuple_file = f"data/tracking_losses_ntuple_{decay}_zmag.root"
|
options.ntuple_file = f"data/tracking_losses_ntuple_{decay}_endVelo2endT.root"
|
||||||
|
|
||||||
|
|
||||||
def run_tracking_losses():
|
def run_tracking_losses():
|
||||||
|
@ -11,6 +11,7 @@ def train_matching_ghost_mlp(
|
|||||||
tree_name: str = "PrMatchNN.PrMCDebugMatchToolNN/MVAInputAndOutput",
|
tree_name: str = "PrMatchNN.PrMCDebugMatchToolNN/MVAInputAndOutput",
|
||||||
exclude_electrons: bool = False,
|
exclude_electrons: bool = False,
|
||||||
only_electrons: bool = True,
|
only_electrons: bool = True,
|
||||||
|
filter_seeds: bool = False,
|
||||||
n_train_signal: int = 50e3, # 50e3
|
n_train_signal: int = 50e3, # 50e3
|
||||||
n_train_bkg: int = 50e3, # 500e3
|
n_train_bkg: int = 50e3, # 500e3
|
||||||
n_test_signal: int = 10e3,
|
n_test_signal: int = 10e3,
|
||||||
@ -54,10 +55,16 @@ def train_matching_ghost_mlp(
|
|||||||
"abs(quality) > 0",
|
"abs(quality) > 0",
|
||||||
"Signal is defined as non-zero label",
|
"Signal is defined as non-zero label",
|
||||||
)
|
)
|
||||||
rdf_bkg = rdf.Filter(
|
if filter_seeds:
|
||||||
"quality == 0",
|
rdf_bkg = rdf.Filter(
|
||||||
"Ghosts are defined as zero label",
|
"quality == 0 && scifi_isElectron == 1",
|
||||||
)
|
"Ghosts are defined as zero label",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
rdf_bkg = rdf.Filter(
|
||||||
|
"quality == 0",
|
||||||
|
"Ghosts are defined as zero label",
|
||||||
|
)
|
||||||
|
|
||||||
rdf_signal.Snapshot(
|
rdf_signal.Snapshot(
|
||||||
"Signal",
|
"Signal",
|
||||||
@ -98,8 +105,13 @@ def train_matching_ghost_mlp(
|
|||||||
dataloader.AddVariable("distX", "F")
|
dataloader.AddVariable("distX", "F")
|
||||||
dataloader.AddVariable("distY", "F")
|
dataloader.AddVariable("distY", "F")
|
||||||
dataloader.AddVariable("dSlope", "F")
|
dataloader.AddVariable("dSlope", "F")
|
||||||
dataloader.AddVariable("dSlopeY", "F")
|
# dataloader.AddVariable("dSlopeY", "F")
|
||||||
# dataloader.AddVariable("zMag", "F")
|
# dataloader.AddVariable("zmag", "F")
|
||||||
|
dataloader.AddVariable("eta", "F")
|
||||||
|
# dataloader.AddVariable("dEta", "F")
|
||||||
|
# dataloader.AddVariable("eta_scifi", "F")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
dataloader.AddSignalTree(signal_tree, 1.0)
|
dataloader.AddSignalTree(signal_tree, 1.0)
|
||||||
dataloader.AddBackgroundTree(bkg_tree, 1.0)
|
dataloader.AddBackgroundTree(bkg_tree, 1.0)
|
||||||
|
Loading…
Reference in New Issue
Block a user