training of nn
This commit is contained in:
parent
c0c20a8af0
commit
fad2d95a1c
2
main.py
2
main.py
@ -134,7 +134,7 @@ if args.matching_weights:
|
||||
os.chdir(os.path.dirname(os.path.realpath(__file__)))
|
||||
train_matching_ghost_mlp(
|
||||
prepare_data=args.prepare,
|
||||
input_file="data/ghost_data_D_default_phi_eta.root",
|
||||
input_file="data/ghost_data_B_default_phi_eta.root",
|
||||
tree_name="PrMatchNN_3e224c41.PrMCDebugMatchToolNN/MVAInputAndOutput",
|
||||
outdir="neural_net_training",
|
||||
exclude_electrons=False,
|
||||
|
@ -1,28 +1,73 @@
|
||||
const auto fMin = std::array<simd::float_v, 8>{
|
||||
{1.4334048501e-05, 1.63528045505e-06, 9.53674316406e-06, 3.0517578125e-05,
|
||||
7.06594437361e-06, 1.16415321827e-09, -3.14159274101, 1.99012887478}};
|
||||
{2.32376150961e-05, 2.85693909063e-06, 3.0517578125e-05, 0.0001220703125,
|
||||
3.00072133541e-05, 1.86264514923e-08, -3.14159274101, 1.99001383781}};
|
||||
const auto fMax = std::array<simd::float_v, 8>{
|
||||
{14.9999303818, 0.150984346867, 249.944519043, 249.72227478, 1.2982006073,
|
||||
0.14879232645, 3.14159274101, 5.00999403}};
|
||||
{14.9998626709, 0.123768046498, 249.997253418, 249.93522644, 1.33115208149,
|
||||
0.14322412014, 3.14159274101, 5.00999116898}};
|
||||
const auto fWeightMatrix0to1 = std::array<std::array<simd::float_v, 9>, 10>{
|
||||
{{-nan, -nan, -nan, -nan, -nan, -nan, -nan, -nan, -nan},
|
||||
{-nan, -nan, -nan, -nan, -nan, -nan, -nan, -nan, -nan},
|
||||
{-nan, -nan, -nan, -nan, -nan, -nan, -nan, -nan, -nan},
|
||||
{-nan, -nan, -nan, -nan, -nan, -nan, -nan, -nan, -nan},
|
||||
{-nan, -nan, -nan, -nan, -nan, -nan, -nan, -nan, -nan},
|
||||
{-nan, -nan, -nan, -nan, -nan, -nan, -nan, -nan, -nan},
|
||||
{-nan, -nan, -nan, -nan, -nan, -nan, -nan, -nan, -nan},
|
||||
{-nan, -nan, -nan, -nan, -nan, -nan, -nan, -nan, -nan},
|
||||
{-nan, -nan, -nan, -nan, -nan, -nan, -nan, -nan, -nan},
|
||||
{-nan, -nan, -nan, -nan, -nan, -nan, -nan, -nan, -nan}}};
|
||||
{{-0.995728805317666, 1.76111236349049, -6.59156372267116,
|
||||
-1.97552437720502, 6.1099278217908, 1.53100155205014, 0.0108712803445021,
|
||||
-0.377003374986967, -0.463996851734649},
|
||||
{-0.995611128642738, 1.85732454978616, -0.103489755042309,
|
||||
2.14670091457823, 5.05298718406462, 1.57276959085594, 0.0353247464108457,
|
||||
-1.86211461574573, 4.27461653614615},
|
||||
{-0.796326797853357, -7.68130747901943, 7.19311634432287, 9.03180964447068,
|
||||
-7.63418437895213, 11.9297597348633, 0.114073700050958, -1.25510020672501,
|
||||
14.8177969049142},
|
||||
{-0.419906375775032, 3.66809284591753, -7.92355570463663,
|
||||
-0.455284242503828, 2.78493233327911, -9.42984921706468,
|
||||
0.0868463943846557, -0.486035562764322, -9.20358391830381},
|
||||
{-1.03715582203046, 0.851759431928431, -18.2552237371827,
|
||||
-3.97361448505785, -4.14831123989833, -9.73503865704962,
|
||||
-0.0137791640986444, -0.587438903895606, -31.113236445281},
|
||||
{-0.488661196605187, 8.92859176662732, 5.04121839035159, 17.3619847721455,
|
||||
-7.94413188687239, -22.844424361919, -0.00665023249378502,
|
||||
1.19839018235788, -0.210841298805848},
|
||||
{-1.3631374257876, -0.299817770295189, 3.41377518635713, -2.22621654559539,
|
||||
5.40436710160442, -4.39793227969093, 0.0436642901709855,
|
||||
-1.78027544600405, -1.10214619769467},
|
||||
{-0.0940766492071435, 4.62033414623526, 2.94753966098872, 9.4146058812013,
|
||||
-3.66240254715736, -13.4981502764483, -0.00724238895954879,
|
||||
-0.248783768193111, -1.04616678170011},
|
||||
{-0.446474256123752, 3.8555840716226, -10.2650719311117, 4.16775627158457,
|
||||
7.06133039504113, -1.41399271367562, -0.0111817383783991, 4.878521470496,
|
||||
-1.92193663004063},
|
||||
{-1.04962229675345, 1.37668509254858, 0.149634602145268, 1.10915750814357,
|
||||
1.450404662274, 0.687692166842801, -0.839072756753414, 0.509999788254877,
|
||||
-1.72055130600754}}};
|
||||
const auto fWeightMatrix1to2 = std::array<std::array<simd::float_v, 11>, 8>{
|
||||
{{nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan},
|
||||
{nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan},
|
||||
{nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan},
|
||||
{nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan},
|
||||
{nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan},
|
||||
{nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan},
|
||||
{nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan},
|
||||
{nan, nan, nan, nan, nan, nan, nan, nan, nan, nan, nan}}};
|
||||
{{-0.807455788900782, -1.57847934072609, -0.602141395950992,
|
||||
-0.468764746603926, 2.91083590379932, -0.781143636397729,
|
||||
-1.49818825923089, 2.11496957400425, -3.29593795617159, 0.494130959098686,
|
||||
-7.49347636093411},
|
||||
{-1.1167253617035, 1.30657031447796, -0.465160141914853, 0.366559047688212,
|
||||
-1.51633166727346, -0.420851938396777, 1.6984775167355, 0.100867296974804,
|
||||
0.564699459284778, 0.586191261445691, -5.10400132978458},
|
||||
{1.1578780065767, -1.67054825221276, -0.251246911016819, 0.184109546639294,
|
||||
0.856695883199377, -0.325541573671961, -1.00207099644341,
|
||||
1.25292993337302, -0.725805776795552, 1.00354188426928,
|
||||
-1.43912318186564},
|
||||
{-0.97759053555381, -1.67108034935821, -0.580419539930693,
|
||||
-0.027624560413911, -0.647308762730501, -0.371248743500777,
|
||||
-1.31187492597551, 0.27102013221006, -0.369172590209503,
|
||||
-0.584243632853031, 2.65602337330211},
|
||||
{-1.52119942430823, 0.820721655386528, -1.69306556038436,
|
||||
0.558359347445987, -0.361087325918055, 0.826311131811265,
|
||||
-0.395511032622557, -1.82330661847839, -0.588310358929765,
|
||||
-0.604402018658415, 1.67613902034978},
|
||||
{0.848509064699364, -0.678582018529978, 0.510495919033791,
|
||||
0.874882532797076, 0.324685080231291, -1.15172790485086,
|
||||
-0.548752447634189, 2.19649837886661, 0.3942971074473, 1.05647229315095,
|
||||
-5.33629435835524},
|
||||
{1.07709697981573, -2.75079530040449, 0.696321551488037, -1.36071135645975,
|
||||
-0.161195931952194, 3.23873299066678, -1.39598213528866,
|
||||
-4.89556931832426, -0.104678879473331, -0.920278878842116,
|
||||
2.65380015363929},
|
||||
{-1.63390282125409, 1.18596357349284, -0.00654547032391154,
|
||||
0.162215051970754, -0.440363717136096, -2.18941620663639,
|
||||
1.19657877233458, 3.74372200392778, 0.640836552714409, -0.568495189559283,
|
||||
1.85355484777308}}};
|
||||
const auto fWeightMatrix2to3 = std::array<simd::float_v, 9>{
|
||||
{-nan, -nan, -nan, -nan, -nan, -nan, -nan, -nan, -nan}};
|
||||
{0.636134552505469, 0.426261082203447, 0.434907836775301, 0.472241383644929,
|
||||
-0.607293621918973, -0.337373090814836, -0.806765829161037,
|
||||
-0.691462964748509, 1.81337120089119}};
|
||||
|
@ -110,7 +110,7 @@ def train_matching_ghost_mlp(
|
||||
dataloader,
|
||||
TMVA.Types.kMLP,
|
||||
"matching_mlp",
|
||||
"!H:V:TrainingMethod=BP:NeuronType=ReLU:EstimatorType=CE:VarTransform=Norm:NCycles=700:HiddenLayers=N+2,N:TestRate=50:Sampling=1.0:SamplingImportance=1.0:LearningRate=0.02:DecayRate=0.01:UseRegulator",
|
||||
"!H:V:TrainingMethod=BP:NeuronType=ReLU:EstimatorType=CE:VarTransform=Norm:NCycles=700:HiddenLayers=N+2,N:TestRate=50:Sampling=1.0:SamplingImportance=1.0:LearningRate=0.02:DecayRate=0.01:!UseRegulator",
|
||||
)
|
||||
factory.TrainAllMethods()
|
||||
factory.TestAllMethods()
|
||||
|
@ -220,10 +220,40 @@ def get_eff(eff, hist, tf, histoName, label, var):
|
||||
eff[lab] = teff.CreateGraph()
|
||||
eff[lab].SetName(lab)
|
||||
eff[lab].SetTitle(lab + " not e^{-}")
|
||||
if histoName.find("strange") != -1:
|
||||
eff[lab].SetTitle(lab + " from stranges")
|
||||
# if histoName.find("strange") != -1:
|
||||
# eff[lab].SetTitle(lab + " from stranges")
|
||||
# if histoName.find("electron") != -1:
|
||||
# eff[lab].SetTitle(lab + " e^{-}")
|
||||
if histoName.find("Forward") != -1:
|
||||
if histoName.find("electron") != -1:
|
||||
eff[lab].SetTitle(lab + " e^{-}")
|
||||
eff[lab].SetTitle(lab + " Forward, e^{-}")
|
||||
else:
|
||||
eff[lab].SetTitle(lab + " Forward")
|
||||
if histoName.find("Merged") != -1:
|
||||
if histoName.find("electron") != -1:
|
||||
eff[lab].SetTitle(lab + " MergedMatch, e^{-}")
|
||||
else:
|
||||
eff[lab].SetTitle(lab + " MergedMatch")
|
||||
elif histoName.find("DefaultMatch") != -1:
|
||||
if histoName.find("electron") != -1:
|
||||
eff[lab].SetTitle(lab + " DefaultMatch, e^{-}")
|
||||
else:
|
||||
eff[lab].SetTitle(lab + " DefaultMatch")
|
||||
elif histoName.find("Match") != -1:
|
||||
if histoName.find("electron") != -1:
|
||||
eff[lab].SetTitle(lab + " Match, e^{-}")
|
||||
else:
|
||||
eff[lab].SetTitle(lab + " Match")
|
||||
if histoName.find("Seed") != -1:
|
||||
if histoName.find("electron") != -1:
|
||||
eff[lab].SetTitle(lab + " Seed, e^{-}")
|
||||
else:
|
||||
eff[lab].SetTitle(lab + " Seed")
|
||||
if histoName.find("BestLong") != -1:
|
||||
if histoName.find("electron") != -1:
|
||||
eff[lab].SetTitle(lab + " BestLong, e^{-}")
|
||||
else:
|
||||
eff[lab].SetTitle(lab + " BestLong")
|
||||
|
||||
hist[lab] = denominator.Clone()
|
||||
hist[lab].SetName("h_numerator_notElectrons")
|
||||
|
Loading…
Reference in New Issue
Block a user