training of nn

This commit is contained in:
cetin 2024-01-15 16:16:12 +01:00
parent c0c20a8af0
commit fad2d95a1c
4 changed files with 104 additions and 29 deletions

View File

@ -134,7 +134,7 @@ if args.matching_weights:
os.chdir(os.path.dirname(os.path.realpath(__file__)))
train_matching_ghost_mlp(
prepare_data=args.prepare,
input_file="data/ghost_data_D_default_phi_eta.root",
input_file="data/ghost_data_B_default_phi_eta.root",
tree_name="PrMatchNN_3e224c41.PrMCDebugMatchToolNN/MVAInputAndOutput",
outdir="neural_net_training",
exclude_electrons=False,

View File

@ -1,28 +1,73 @@
const auto fMin = std::array<simd::float_v, 8>{
{1.4334048501e-05, 1.63528045505e-06, 9.53674316406e-06, 3.0517578125e-05,
7.06594437361e-06, 1.16415321827e-09, -3.14159274101, 1.99012887478}};
{2.32376150961e-05, 2.85693909063e-06, 3.0517578125e-05, 0.0001220703125,
3.00072133541e-05, 1.86264514923e-08, -3.14159274101, 1.99001383781}};
const auto fMax = std::array<simd::float_v, 8>{
{14.9999303818, 0.150984346867, 249.944519043, 249.72227478, 1.2982006073,
0.14879232645, 3.14159274101, 5.00999403}};
{14.9998626709, 0.123768046498, 249.997253418, 249.93522644, 1.33115208149,
0.14322412014, 3.14159274101, 5.00999116898}};
const auto fWeightMatrix0to1 = std::array<std::array<simd::float_v, 9>, 10>{
{{-nan, -nan, -nan, -nan, -nan, -nan, -nan, -nan, -nan},
{-nan, -nan, -nan, -nan, -nan, -nan, -nan, -nan, -nan},
{-nan, -nan, -nan, -nan, -nan, -nan, -nan, -nan, -nan},
{-nan, -nan, -nan, -nan, -nan, -nan, -nan, -nan, -nan},
{-nan, -nan, -nan, -nan, -nan, -nan, -nan, -nan, -nan},
{-nan, -nan, -nan, -nan, -nan, -nan, -nan, -nan, -nan},
{-nan, -nan, -nan, -nan, -nan, -nan, -nan, -nan, -nan},
{-nan, -nan, -nan, -nan, -nan, -nan, -nan, -nan, -nan},
{-nan, -nan, -nan, -nan, -nan, -nan, -nan, -nan, -nan},
{-nan, -nan, -nan, -nan, -nan, -nan, -nan, -nan, -nan}}};
{{-0.995728805317666, 1.76111236349049, -6.59156372267116,
-1.97552437720502, 6.1099278217908, 1.53100155205014, 0.0108712803445021,
-0.377003374986967, -0.463996851734649},
{-0.995611128642738, 1.85732454978616, -0.103489755042309,
2.14670091457823, 5.05298718406462, 1.57276959085594, 0.0353247464108457,
-1.86211461574573, 4.27461653614615},
{-0.796326797853357, -7.68130747901943, 7.19311634432287, 9.03180964447068,
-7.63418437895213, 11.9297597348633, 0.114073700050958, -1.25510020672501,
14.8177969049142},
{-0.419906375775032, 3.66809284591753, -7.92355570463663,
-0.455284242503828, 2.78493233327911, -9.42984921706468,
0.0868463943846557, -0.486035562764322, -9.20358391830381},
{-1.03715582203046, 0.851759431928431, -18.2552237371827,
-3.97361448505785, -4.14831123989833, -9.73503865704962,
-0.0137791640986444, -0.587438903895606, -31.113236445281},
{-0.488661196605187, 8.92859176662732, 5.04121839035159, 17.3619847721455,
-7.94413188687239, -22.844424361919, -0.00665023249378502,
1.19839018235788, -0.210841298805848},
{-1.3631374257876, -0.299817770295189, 3.41377518635713, -2.22621654559539,
5.40436710160442, -4.39793227969093, 0.0436642901709855,
-1.78027544600405, -1.10214619769467},
{-0.0940766492071435, 4.62033414623526, 2.94753966098872, 9.4146058812013,
-3.66240254715736, -13.4981502764483, -0.00724238895954879,
-0.248783768193111, -1.04616678170011},
{-0.446474256123752, 3.8555840716226, -10.2650719311117, 4.16775627158457,
7.06133039504113, -1.41399271367562, -0.0111817383783991, 4.878521470496,
-1.92193663004063},
{-1.04962229675345, 1.37668509254858, 0.149634602145268, 1.10915750814357,
1.450404662274, 0.687692166842801, -0.839072756753414, 0.509999788254877,
-1.72055130600754}}};
const auto fWeightMatrix1to2 = std::array<std::array<simd::float_v, 11>, 8>{
{{nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan},
{nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan},
{nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan},
{nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan},
{nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan},
{nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan},
{nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan},
{nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan}}};
{{-0.807455788900782, -1.57847934072609, -0.602141395950992,
-0.468764746603926, 2.91083590379932, -0.781143636397729,
-1.49818825923089, 2.11496957400425, -3.29593795617159, 0.494130959098686,
-7.49347636093411},
{-1.1167253617035, 1.30657031447796, -0.465160141914853, 0.366559047688212,
-1.51633166727346, -0.420851938396777, 1.6984775167355, 0.100867296974804,
0.564699459284778, 0.586191261445691, -5.10400132978458},
{1.1578780065767, -1.67054825221276, -0.251246911016819, 0.184109546639294,
0.856695883199377, -0.325541573671961, -1.00207099644341,
1.25292993337302, -0.725805776795552, 1.00354188426928,
-1.43912318186564},
{-0.97759053555381, -1.67108034935821, -0.580419539930693,
-0.027624560413911, -0.647308762730501, -0.371248743500777,
-1.31187492597551, 0.27102013221006, -0.369172590209503,
-0.584243632853031, 2.65602337330211},
{-1.52119942430823, 0.820721655386528, -1.69306556038436,
0.558359347445987, -0.361087325918055, 0.826311131811265,
-0.395511032622557, -1.82330661847839, -0.588310358929765,
-0.604402018658415, 1.67613902034978},
{0.848509064699364, -0.678582018529978, 0.510495919033791,
0.874882532797076, 0.324685080231291, -1.15172790485086,
-0.548752447634189, 2.19649837886661, 0.3942971074473, 1.05647229315095,
-5.33629435835524},
{1.07709697981573, -2.75079530040449, 0.696321551488037, -1.36071135645975,
-0.161195931952194, 3.23873299066678, -1.39598213528866,
-4.89556931832426, -0.104678879473331, -0.920278878842116,
2.65380015363929},
{-1.63390282125409, 1.18596357349284, -0.00654547032391154,
0.162215051970754, -0.440363717136096, -2.18941620663639,
1.19657877233458, 3.74372200392778, 0.640836552714409, -0.568495189559283,
1.85355484777308}}};
const auto fWeightMatrix2to3 = std::array<simd::float_v, 9>{
{-nan, -nan, -nan, -nan, -nan, -nan, -nan, -nan, -nan}};
{0.636134552505469, 0.426261082203447, 0.434907836775301, 0.472241383644929,
-0.607293621918973, -0.337373090814836, -0.806765829161037,
-0.691462964748509, 1.81337120089119}};

View File

@ -110,7 +110,7 @@ def train_matching_ghost_mlp(
dataloader,
TMVA.Types.kMLP,
"matching_mlp",
"!H:V:TrainingMethod=BP:NeuronType=ReLU:EstimatorType=CE:VarTransform=Norm:NCycles=700:HiddenLayers=N+2,N:TestRate=50:Sampling=1.0:SamplingImportance=1.0:LearningRate=0.02:DecayRate=0.01:UseRegulator",
"!H:V:TrainingMethod=BP:NeuronType=ReLU:EstimatorType=CE:VarTransform=Norm:NCycles=700:HiddenLayers=N+2,N:TestRate=50:Sampling=1.0:SamplingImportance=1.0:LearningRate=0.02:DecayRate=0.01:!UseRegulator",
)
factory.TrainAllMethods()
factory.TestAllMethods()

View File

@ -220,10 +220,40 @@ def get_eff(eff, hist, tf, histoName, label, var):
eff[lab] = teff.CreateGraph()
eff[lab].SetName(lab)
eff[lab].SetTitle(lab + " not e^{-}")
if histoName.find("strange") != -1:
eff[lab].SetTitle(lab + " from stranges")
# if histoName.find("strange") != -1:
# eff[lab].SetTitle(lab + " from stranges")
# if histoName.find("electron") != -1:
# eff[lab].SetTitle(lab + " e^{-}")
if histoName.find("Forward") != -1:
if histoName.find("electron") != -1:
eff[lab].SetTitle(lab + " e^{-}")
eff[lab].SetTitle(lab + " Forward, e^{-}")
else:
eff[lab].SetTitle(lab + " Forward")
if histoName.find("Merged") != -1:
if histoName.find("electron") != -1:
eff[lab].SetTitle(lab + " MergedMatch, e^{-}")
else:
eff[lab].SetTitle(lab + " MergedMatch")
elif histoName.find("DefaultMatch") != -1:
if histoName.find("electron") != -1:
eff[lab].SetTitle(lab + " DefaultMatch, e^{-}")
else:
eff[lab].SetTitle(lab + " DefaultMatch")
elif histoName.find("Match") != -1:
if histoName.find("electron") != -1:
eff[lab].SetTitle(lab + " Match, e^{-}")
else:
eff[lab].SetTitle(lab + " Match")
if histoName.find("Seed") != -1:
if histoName.find("electron") != -1:
eff[lab].SetTitle(lab + " Seed, e^{-}")
else:
eff[lab].SetTitle(lab + " Seed")
if histoName.find("BestLong") != -1:
if histoName.find("electron") != -1:
eff[lab].SetTitle(lab + " BestLong, e^{-}")
else:
eff[lab].SetTitle(lab + " BestLong")
hist[lab] = denominator.Clone()
hist[lab].SetName("h_numerator_notElectrons")