You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

212 lines
21 KiB

10 months ago
  1. {
  2. "cells": [
  3. {
  4. "cell_type": "code",
  5. "execution_count": 1,
  6. "metadata": {},
  7. "outputs": [],
  8. "source": [
  9. "import uproot\n",
  10. "import awkward as ak\n",
  11. "import matplotlib\n",
  12. "import matplotlib.pyplot as plt\n",
  13. "input_tree = uproot.open({\"/work/guenther/reco_tuner/data/param_data_selected_all_p.root\": \"Selected\"})\n",
  14. "array = input_tree.arrays()\n",
  15. "array[\"dSlope_fringe\"] = array[\"tx_ref\"] - array[\"tx\"]\n",
  16. "array[\"poqmag_gev\"] = 1. / ( array[\"signed_rel_current\"] * array[\"qop\"] * 1000. )\n",
  17. "array[\"B_integral\"] = array[\"poqmag_gev\"] * array[\"dSlope_fringe\"]"
  18. ]
  19. },
  20. {
  21. "cell_type": "code",
  22. "execution_count": 2,
  23. "metadata": {},
  24. "outputs": [
  25. {
  26. "data": {
  27. "text/plain": [
  28. "(array([1.00000e+00, 0.00000e+00, 0.00000e+00, 1.00000e+00, 0.00000e+00,\n",
  29. " 0.00000e+00, 0.00000e+00, 0.00000e+00, 0.00000e+00, 0.00000e+00,\n",
  30. " 0.00000e+00, 0.00000e+00, 0.00000e+00, 0.00000e+00, 1.00000e+00,\n",
  31. " 1.00000e+00, 0.00000e+00, 0.00000e+00, 0.00000e+00, 0.00000e+00,\n",
  32. " 0.00000e+00, 0.00000e+00, 1.00000e+00, 0.00000e+00, 0.00000e+00,\n",
  33. " 2.00000e+00, 2.00000e+00, 4.00000e+00, 0.00000e+00, 3.00000e+00,\n",
  34. " 0.00000e+00, 4.00000e+00, 3.00000e+00, 1.00000e+00, 4.00000e+00,\n",
  35. " 4.00000e+00, 1.00000e+00, 1.00000e+01, 8.00000e+00, 8.00000e+00,\n",
  36. " 2.30000e+01, 3.20000e+01, 6.30000e+01, 8.40000e+01, 1.22000e+02,\n",
  37. " 2.03000e+02, 3.10000e+02, 3.98000e+02, 5.90000e+02, 8.65000e+02,\n",
  38. " 1.08700e+03, 1.55300e+03, 2.01000e+03, 2.67400e+03, 3.49500e+03,\n",
  39. " 4.46300e+03, 5.85300e+03, 7.16800e+03, 8.92900e+03, 1.12130e+04,\n",
  40. " 1.37320e+04, 1.69420e+04, 2.09870e+04, 2.59060e+04, 3.24750e+04,\n",
  41. " 3.94560e+04, 4.81480e+04, 6.21400e+04, 8.28540e+04, 1.20767e+05,\n",
  42. " 2.15388e+05, 5.13195e+05, 5.20000e+02, 3.29000e+02, 1.83000e+02,\n",
  43. " 1.32000e+02, 1.02000e+02, 8.30000e+01, 6.40000e+01, 5.00000e+01,\n",
  44. " 2.70000e+01, 2.80000e+01, 2.20000e+01, 1.80000e+01, 2.00000e+01,\n",
  45. " 1.50000e+01, 1.00000e+01, 4.00000e+00, 7.00000e+00, 6.00000e+00,\n",
  46. " 5.00000e+00, 2.00000e+00, 2.00000e+00, 3.00000e+00, 0.00000e+00,\n",
  47. " 2.00000e+00, 3.00000e+00, 0.00000e+00, 0.00000e+00, 2.00000e+00]),\n",
  48. " array([-2.45885773, -2.44145641, -2.4240551 , -2.40665378, -2.38925247,\n",
  49. " -2.37185115, -2.35444984, -2.33704853, -2.31964721, -2.3022459 ,\n",
  50. " -2.28484458, -2.26744327, -2.25004195, -2.23264064, -2.21523933,\n",
  51. " -2.19783801, -2.1804367 , -2.16303538, -2.14563407, -2.12823275,\n",
  52. " -2.11083144, -2.09343013, -2.07602881, -2.0586275 , -2.04122618,\n",
  53. " -2.02382487, -2.00642355, -1.98902224, -1.97162093, -1.95421961,\n",
  54. " -1.9368183 , -1.91941698, -1.90201567, -1.88461436, -1.86721304,\n",
  55. " -1.84981173, -1.83241041, -1.8150091 , -1.79760778, -1.78020647,\n",
  56. " -1.76280516, -1.74540384, -1.72800253, -1.71060121, -1.6931999 ,\n",
  57. " -1.67579858, -1.65839727, -1.64099596, -1.62359464, -1.60619333,\n",
  58. " -1.58879201, -1.5713907 , -1.55398938, -1.53658807, -1.51918676,\n",
  59. " -1.50178544, -1.48438413, -1.46698281, -1.4495815 , -1.43218018,\n",
  60. " -1.41477887, -1.39737756, -1.37997624, -1.36257493, -1.34517361,\n",
  61. " -1.3277723 , -1.31037098, -1.29296967, -1.27556836, -1.25816704,\n",
  62. " -1.24076573, -1.22336441, -1.2059631 , -1.18856178, -1.17116047,\n",
  63. " -1.15375916, -1.13635784, -1.11895653, -1.10155521, -1.0841539 ,\n",
  64. " -1.06675258, -1.04935127, -1.03194996, -1.01454864, -0.99714733,\n",
  65. " -0.97974601, -0.9623447 , -0.94494338, -0.92754207, -0.91014076,\n",
  66. " -0.89273944, -0.87533813, -0.85793681, -0.8405355 , -0.82313418,\n",
  67. " -0.80573287, -0.78833156, -0.77093024, -0.75352893, -0.73612761,\n",
  68. " -0.7187263 ]),\n",
  69. " <BarContainer object of 100 artists>)"
  70. ]
  71. },
  72. "execution_count": 2,
  73. "metadata": {},
  74. "output_type": "execute_result"
  75. },
  76. {
  77. "data": {
  78. "image/png": "iVBORw0KGgoAAAANSUhEUgAAAicAAAGdCAYAAADJ6dNTAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy89olMNAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAj3ElEQVR4nO3df3AU9eH/8deJ5ACbnMaUQCDBVIWKkSCBYqggYWo0yg9FLVYnxg4wpYZaJnWslLEUp04orUinHEjaqdgZZ8w4VTpTGZl0qkIN1hATi6ZW6QQTSEIKhRw/aoLH+/OHX+7rkQC5y13uvbfPx8zNcHu7e+99Z9m88t73+70eY4wRAACAJS5JdAEAAAC+jHACAACsQjgBAABWIZwAAACrEE4AAIBVCCcAAMAqhBMAAGAVwgkAALDKpYkuQKTOnDmjtrY2paamyuPxJLo4AACgH4wxOn78uLKysnTJJRduG3FcOGlra1N2dnaiiwEAAKLQ2tqqsWPHXnAdx4WT1NRUSV8cXFpaWoJLAwAA+iMQCCg7Ozv0e/xCHBNO/H6//H6/gsGgJCktLY1wAgCAw/SnS4bHaQ/+CwQC8vl86urqIpwAAOAQkfz+ZrQOAACwimPCid/v18SJEzVt2rREFwUAAMQRt3UAAEDccVsHAAA4FuEEAABYxTHhhD4nAAC4A31OAABA3NHnBAAAOJZjwgm3dQAAcAdu6wAAgLjjtg4AAHAswgkAALCKY55KDACIr6ueeC3s/f61dyaoJHA7x7Sc0CEWAAB3cEw4KS8vV1NTk+rq6hJdFAAAEEeOCScAAMAdCCcAAMAqhBMAAGAVwgkAALCKY8IJo3UAAHAHx4QTRusAAOAOjgknAADAHQgnAADAKoQTAABgFcIJAACwCuEEAABYxTHhhKHEAAC4g2PCCUOJAQBwB8eEEwAA4A6EEwAAYBXCCQAAsArhBAAAWIVwAgAArEI4AQAAViGcAAAAq1ya6AIAAAbfVU+8lugiAOdFywkAALCKY8IJ09cDAOAOjgknTF8PAIA7OCacAAAAdyCcAAAAqxBOAACAVQgnAADAKoQTAABgFcIJAACwCuEEAABYhXACAACsQjgBAABWSUg4ufTSSzV58mRNnjxZS5YsSUQRAACApRLyVOLLL79cjY2NifhqAHAlnkIMJ+G2DgAAsErE4WTnzp2aN2+esrKy5PF4tG3btl7rbNq0Sbm5uRo2bJgKCgq0a9eusM8DgYAKCgp0880366233oq68AAAIPlEHE5Onjyp/Px8bdy4sc/Pq6urtWLFCq1atUoNDQ2aOXOmSkpK1NLSElpn//79qq+v13PPPaeHHnpIgUAg+iMAAABJJeJwUlJSop///OdauHBhn5+vX79eixcv1pIlS3Tddddpw4YNys7O1ubNm0PrZGVlSZLy8vI0ceJEffzxx+f9vu7ubgUCgbAXAABIXjHtc9LT06P6+noVFxeHLS8uLlZtba0k6ejRo+ru7pYkHThwQE1NTfra17523n1WVlbK5/OFXtnZ2bEsMgAAsExMw8nhw4cVDAaVmZkZtjwzM1MdHR2SpH/+85+aOnWq8vPzNXfuXP36179Wenr6efe5cuVKdXV1hV6tra2xLDIAALBMXIYSezyesPfGmNCyGTNmaO/evf3el9frldfrjWn5AACAvWLacpKRkaEhQ4aEWknO6uzs7NWaEim/36+JEydq2rRpA9oPAACwW0zDSUpKigoKClRTUxO2vKamRjNmzBjQvsvLy9XU1KS6uroB7QcAANgt4ts6J06c0L59+0Lvm5ub1djYqPT0dOXk5KiiokKlpaWaOnWqCgsLVVVVpZaWFi1btiymBQcA9I3ZYOF0EYeTPXv2qKioKPS+oqJCklRWVqatW7dq0aJFOnLkiJ566im1t7crLy9P27dv17hx4wZUUL/fL7/fr2AwOKD9AAAAu3mMMSbRhYhEIBCQz+dTV1eX0tLSEl0cALBOrFpO9q+9Myb7AaTIfn/zbB0AAGAVx4QTRusAAOAOjgknjNYBAMAdHBNOAACAO8RlhlgAwOBh6DCSjWNaTuhzAgCAOzgmnNDnBAAAd3BMOAEAAO5AOAEAAFYhnAAAAKs4JpzQIRYAAHfg2ToA4CCDOWyYZ+sglni2DgAAcCzCCQAAsArhBAAAWIVwAgAArOKYZ+v4/X75/X4Fg8FEFwUABg3PzYEbOablhOnrAQBwB8eEEwAA4A6EEwAAYBXCCQAAsIpjOsQCQLKj8yvwBVpOAACAVRwTTnjwHwAA7uCYcMJQYgAA3MEx4QQAALgD4QQAAFiF0ToAkCCMzgH6RssJAACwCuEEAABYhXACAACsQjgBAABWoUMsAAwCOr8C/eeYlhNmiAUAwB0cE06YIRYAAHdwTDgBAADuQDgBAABWIZwAAACrMFoHAOKA0TlA9Gg5AQAAViGcAAAAqxBOAACAVQgnAADAKnSIBYABovMrEFu0nAAAAKskLJycOnVK48aN02OPPZaoIgAAAAslLJw8/fTTmj59eqK+HgAAWCoh4eSTTz7RRx99pDvuuCMRXw8AACwWcTjZuXOn5s2bp6ysLHk8Hm3btq3XOps2bVJubq6GDRumgoIC7dq1K+zzxx57TJWVlVEXGgAS6aonXgt7AYitiMPJyZMnlZ+fr40bN/b5eXV1tVasWKFVq1apoaFBM2fOVElJiVpaWiRJf/rTnzR+/HiNHz9+YCUHAABJKeKhxCUlJSopKTnv5+vXr9fixYu1ZMkSSdKGDRu0Y8cObd68WZWVlXrnnXf00ksv6eWXX9aJEyd0+vRppaWl6ac//Wmf++vu7lZ3d3fofSAQiLTIAADAQWLa56Snp0f19fUqLi4OW15cXKza2lpJUmVlpVpbW7V//3796le/0tKlS88bTM6u7/P5Qq/s7OxYFhkAAFgmpuHk8OHDCgaDyszMDFuemZmpjo6OqPa5cuVKdXV1hV6tra2xKCoAALBUXGaI9Xg8Ye+NMb2WSdLDDz980X15vV55vd5YFQ0AAFgupuEkIyNDQ4YM6dVK0tnZ2as1JVJ+v19+v1/BYHBA+wGASDAaBxh8Mb2tk5KSooKCAtXU1IQtr6mp0YwZMwa07/LycjU1Namurm5A+wEAAHaLuOXkxIkT2rdvX+h9c3OzGhsblZ6erpycHFVUVKi0tFRTp05VYWGhqqqq1NLSomXLlsW04AAAIDlFHE727NmjoqKi0PuKigpJUllZmbZu3apFixbpyJEjeuqpp9Te3q68vDxt375d48aNG1BBua0DAIA7eIwxJtGFiEQgEJDP51NXV5fS0tISXRwASc7NfU72r70z0UVAEonk93fCHvwHAADQl7gMJY4HbusAGAxubikBbOGYlhNG6wAA4A6OCScAAMAdCCcAAMAqjgknfr9fEydO1LRp0xJdFAAAEEcMJQbgWnR+vTCGEiOWGEoMAAAci3ACAACsQjgBAABWcUw4oUMsAADu4JhwwiRsAAC4g2PCCQAAcAfHPFsHAAaKocOAM9ByAgAArEI4AQAAVnFMOGG0DgAA7uCYcMJoHQAA3MEx4QQAALgDo3UAJCVG5gDORcsJAACwCuEEAABYhXACAACs4phwwlBiAADcwTHhhKHEAAC4g2PCCQAAcAeGEgNICgwdBpIHLScAAMAqhBMAAGAVwgkAALAK4QQAAFiFcAIAAKxCOAEAAFZxzFBiv98vv9+vYDCY6KIASDCGDQPJzTEtJ8wQCwCAOzgmnAAAAHcgnAAAAKsQTgAAgFUIJwAAwCqEEwAAYBXHDCUG4F4MHQbchZYTAABgFcIJAACwCuEEAABYhXACAACsMujh5Pjx45o2bZomT56sG264Qb/97W8HuwgAAMBigz5aZ8SIEXrrrbc0YsQInTp1Snl5eVq4cKGuvPLKwS4KAACw0KC3nAwZMkQjRoyQJH322WcKBoMyxgx2MQAAgKUiDic7d+7UvHnzlJWVJY/Ho23btvVaZ9OmTcrNzdWwYcNUUFCgXbt2hX1+7Ngx5efna+zYsXr88ceVkZER9QEAAIDkEvFtnZMnTyo/P1/f/e53dc8
  79. "text/plain": [
  80. "<Figure size 640x480 with 1 Axes>"
  81. ]
  82. },
  83. "metadata": {},
  84. "output_type": "display_data"
  85. }
  86. ],
  87. "source": [
  88. "plt.hist(array[\"B_integral\"], bins=100, log=True)"
  89. ]
  90. },
  91. {
  92. "cell_type": "code",
  93. "execution_count": 30,
  94. "metadata": {},
  95. "outputs": [
  96. {
  97. "name": "stdout",
  98. "output_type": "stream",
  99. "text": [
  100. "['ty^2' 'tx^2' 'tx tx_ref' 'tx_ref^2' 'ty^4' 'ty^2 tx^2' 'ty^2 tx tx_ref'\n",
  101. " 'ty^2 tx_ref^2' 'tx^4' 'tx^3 tx_ref' 'tx_ref^4']\n",
  102. "intercept= -1.2094486121528516\n",
  103. "coef= {'ty^2': -2.7897043324822492, 'tx^2': -0.35976930628193077, 'tx tx_ref': -0.47138558705675454, 'tx_ref^2': -0.5600847231491961, 'ty^4': 14.009315350693472, 'ty^2 tx^2': -16.162818973243674, 'ty^2 tx tx_ref': -8.807994419844437, 'ty^2 tx_ref^2': -0.8753190393972976, 'tx^4': 2.98254201374128, 'tx^3 tx_ref': 0.9625408279466898, 'tx_ref^4': 0.10200564097830103}\n",
  104. "r2 score= 0.9916826041227943\n",
  105. "RMSE = 0.006014471039836984\n",
  106. "['ty^2', 'tx^2', 'tx tx_ref', 'tx_ref^2', 'ty^4', 'ty^2 tx^2', 'ty^2 tx tx_ref', 'ty^2 tx_ref^2', 'tx^4', 'tx^3 tx_ref', 'tx_ref^4']\n"
  107. ]
  108. }
  109. ],
  110. "source": [
  111. "from sklearn.preprocessing import PolynomialFeatures\n",
  112. "from sklearn.linear_model import LinearRegression, Lasso, Ridge\n",
  113. "from sklearn.model_selection import train_test_split\n",
  114. "from sklearn.pipeline import Pipeline\n",
  115. "from sklearn.metrics import mean_squared_error\n",
  116. "import numpy as np\n",
  117. "\n",
  118. "features = [\n",
  119. " \"ty\", \n",
  120. " \"tx\",\n",
  121. " \"tx_ref\",\n",
  122. "]\n",
  123. "target_feat = \"B_integral\"\n",
  124. "\n",
  125. "data = np.column_stack([ak.to_numpy(array[feat]) for feat in features])\n",
  126. "target = ak.to_numpy(array[target_feat])\n",
  127. "X_train, X_test, y_train, y_test = train_test_split(data, target, test_size=0.2, random_state=42)\n",
  128. "\n",
  129. "poly = PolynomialFeatures(degree=5, include_bias=False)\n",
  130. "X_train_model = poly.fit_transform( X_train )\n",
  131. "X_test_model = poly.fit_transform( X_test )\n",
  132. "poly_features = poly.get_feature_names_out(input_features=features)\n",
  133. "reduce = True\n",
  134. "if reduce:\n",
  135. " keep = [\n",
  136. " #'ty',\n",
  137. " #'tx',\n",
  138. " #'tx_ref',\n",
  139. " #'ty tx',\n",
  140. " #'ty tx_ref',\n",
  141. " 'ty^2',#keep\n",
  142. " 'tx^2',#keep\n",
  143. " 'tx tx_ref',#keep\n",
  144. " 'tx_ref^2',#keep\n",
  145. " 'ty^2 tx tx_ref',#keep\n",
  146. " 'ty^2 tx^2',#keep\n",
  147. " 'ty^2 tx_ref^2', #keep\n",
  148. " 'tx^4',#keep\n",
  149. " 'ty^4',#keep\n",
  150. " 'tx_ref^4',#keep\n",
  151. " #'tx_ref^5',\n",
  152. " 'tx^3 tx_ref', #keep\n",
  153. " #'tx tx_ref^3',\n",
  154. " #'tx^2 tx_ref^2',\n",
  155. " #'ty tx_ref^4',\n",
  156. " #'tx tx_ref^4',\n",
  157. " #'tx_ref^5',\n",
  158. " ]\n",
  159. " remove = [i for i, f in enumerate(poly_features) if (keep and f not in keep )]\n",
  160. " X_train_model = np.delete( X_train_model, remove, axis=1)\n",
  161. " X_test_model = np.delete( X_test_model, remove, axis=1)\n",
  162. " poly_features = np.delete(poly_features, remove )\n",
  163. " print(poly_features)\n",
  164. "if not reduce:\n",
  165. " lin_reg = Lasso( alpha=0.0000001)#Lasso(fit_intercept=False, alpha=0.001)\n",
  166. "else:\n",
  167. " lin_reg = LinearRegression()\n",
  168. "lin_reg.fit( X_train_model, y_train)\n",
  169. "y_pred_test = lin_reg.predict( X_test_model )\n",
  170. "print(\"intercept=\", lin_reg.intercept_)\n",
  171. "print(\"coef=\", dict(zip(poly_features, lin_reg.coef_)))\n",
  172. "print(\"r2 score=\", lin_reg.score(X_test_model, y_test))\n",
  173. "print(\"RMSE =\", mean_squared_error(y_test, y_pred_test, squared=False))\n",
  174. "print([key for key, val in dict(zip(poly_features, lin_reg.coef_)).items() if val != 0.0])"
  175. ]
  176. },
  177. {
  178. "cell_type": "code",
  179. "execution_count": null,
  180. "metadata": {},
  181. "outputs": [],
  182. "source": []
  183. }
  184. ],
  185. "metadata": {
  186. "kernelspec": {
  187. "display_name": "Python 3.10.6 (conda)",
  188. "language": "python",
  189. "name": "python3"
  190. },
  191. "language_info": {
  192. "codemirror_mode": {
  193. "name": "ipython",
  194. "version": 3
  195. },
  196. "file_extension": ".py",
  197. "mimetype": "text/x-python",
  198. "name": "python",
  199. "nbconvert_exporter": "python",
  200. "pygments_lexer": "ipython3",
  201. "version": "3.10.6"
  202. },
  203. "orig_nbformat": 4,
  204. "vscode": {
  205. "interpreter": {
  206. "hash": "a2eff8b4da8b8eebf5ee2e5f811f31a557e0a202b4d2f04f849b065340a6eda6"
  207. }
  208. }
  209. },
  210. "nbformat": 4,
  211. "nbformat_minor": 2
  212. }