Machine Learning Kurs im Rahmen der Studierendentage im SS 2023
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

291 lines
5.5 KiB

2 years ago
  1. {
  2. "cells": [
  3. {
  4. "cell_type": "markdown",
  5. "metadata": {},
  6. "source": [
  7. "Fit with the python interface to Minuit 2 called iminuit\n",
  8. "https://iminuit.readthedocs.io/en/stable/"
  9. ]
  10. },
  11. {
  12. "cell_type": "code",
  13. "execution_count": null,
  14. "metadata": {},
  15. "outputs": [],
  16. "source": [
  17. "from matplotlib import pyplot as plt\n",
  18. "plt.rcParams[\"font.size\"] = 20\n",
  19. "import numpy as np"
  20. ]
  21. },
  22. {
  23. "cell_type": "markdown",
  24. "metadata": {},
  25. "source": [
  26. "Data "
  27. ]
  28. },
  29. {
  30. "cell_type": "code",
  31. "execution_count": null,
  32. "metadata": {},
  33. "outputs": [],
  34. "source": [
  35. "x = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype='d')\n",
  36. "dx = np.array([0.1,0.1,0.5,0.1,0.5,0.1,0.5,0.1,0.5,0.1], dtype='d')\n",
  37. "y = np.array([1.1 ,2.3 ,2.7 ,3.2 ,3.1 ,2.4 ,1.7 ,1.5 ,1.5 ,1.7 ], dtype='d')\n",
  38. "dy = np.array([0.15,0.22,0.29,0.39,0.31,0.21,0.13,0.15,0.19,0.13], dtype='d')"
  39. ]
  40. },
  41. {
  42. "cell_type": "markdown",
  43. "metadata": {},
  44. "source": [
  45. "Define fit function"
  46. ]
  47. },
  48. {
  49. "cell_type": "code",
  50. "execution_count": null,
  51. "metadata": {},
  52. "outputs": [],
  53. "source": [
  54. "def pol3(a0, a1, a2, a3):\n",
  55. " return a0 + x*a1 + a2*x**2 + a3*x**3"
  56. ]
  57. },
  58. {
  59. "cell_type": "markdown",
  60. "metadata": {},
  61. "source": [
  62. "least-squares function: sum of data residuals squared"
  63. ]
  64. },
  65. {
  66. "cell_type": "code",
  67. "execution_count": null,
  68. "metadata": {},
  69. "outputs": [],
  70. "source": [
  71. "def LSQ(a0, a1, a2, a3):\n",
  72. " return np.sum((y - pol3(a0, a1, a2, a3)) ** 2 / dy ** 2)"
  73. ]
  74. },
  75. {
  76. "cell_type": "markdown",
  77. "metadata": {},
  78. "source": [
  79. "import Minuit object"
  80. ]
  81. },
  82. {
  83. "cell_type": "code",
  84. "execution_count": null,
  85. "metadata": {},
  86. "outputs": [],
  87. "source": [
  88. "from iminuit import Minuit"
  89. ]
  90. },
  91. {
  92. "cell_type": "markdown",
  93. "metadata": {},
  94. "source": [
  95. "Minuit instance using LSQ function to minimize"
  96. ]
  97. },
  98. {
  99. "cell_type": "code",
  100. "execution_count": null,
  101. "metadata": {},
  102. "outputs": [],
  103. "source": [
  104. "LSQ.errordef = Minuit.LEAST_SQUARES\n",
  105. "#LSQ.errordef = Minuit.LIKELIHOOD\n",
  106. "m = Minuit(LSQ,a0=-1.3, a1=2.6 ,a2=-0.24 ,a3=0.005)\n",
  107. "m.fixed[\"a3\"] = True \n",
  108. "m.params"
  109. ]
  110. },
  111. {
  112. "cell_type": "markdown",
  113. "metadata": {},
  114. "source": [
  115. "run migrad"
  116. ]
  117. },
  118. {
  119. "cell_type": "code",
  120. "execution_count": null,
  121. "metadata": {},
  122. "outputs": [],
  123. "source": [
  124. "m.fixed[\"a3\"] = False\n",
  125. "m.params\n",
  126. "m.migrad()"
  127. ]
  128. },
  129. {
  130. "cell_type": "markdown",
  131. "metadata": {},
  132. "source": [
  133. "Get contour"
  134. ]
  135. },
  136. {
  137. "cell_type": "code",
  138. "execution_count": null,
  139. "metadata": {},
  140. "outputs": [],
  141. "source": [
  142. "m.draw_mncontour(\"a2\", \"a3\", cl=[1, 2, 3])"
  143. ]
  144. },
  145. {
  146. "cell_type": "markdown",
  147. "metadata": {},
  148. "source": [
  149. "Improve the fit"
  150. ]
  151. },
  152. {
  153. "cell_type": "code",
  154. "execution_count": null,
  155. "metadata": {},
  156. "outputs": [],
  157. "source": [
  158. "m.hesse()"
  159. ]
  160. },
  161. {
  162. "cell_type": "code",
  163. "execution_count": null,
  164. "metadata": {},
  165. "outputs": [],
  166. "source": [
  167. "m.minos()"
  168. ]
  169. },
  170. {
  171. "cell_type": "markdown",
  172. "metadata": {},
  173. "source": [
  174. "access fit results"
  175. ]
  176. },
  177. {
  178. "cell_type": "code",
  179. "execution_count": null,
  180. "metadata": {},
  181. "outputs": [],
  182. "source": [
  183. "print(m.values,m.errors)\n",
  184. "a0_fit = m.values[\"a0\"]\n",
  185. "a1_fit = m.values[\"a1\"]\n",
  186. "a2_fit = m.values[\"a2\"]\n",
  187. "a3_fit = m.values[\"a3\"]"
  188. ]
  189. },
  190. {
  191. "cell_type": "code",
  192. "execution_count": null,
  193. "metadata": {},
  194. "outputs": [],
  195. "source": [
  196. "print (m.covariance)"
  197. ]
  198. },
  199. {
  200. "cell_type": "markdown",
  201. "metadata": {},
  202. "source": [
  203. "prepare data to display fitted function "
  204. ]
  205. },
  206. {
  207. "cell_type": "code",
  208. "execution_count": null,
  209. "metadata": {},
  210. "outputs": [],
  211. "source": [
  212. "x_plot = np.linspace( 0.5, 10.5 , 500 )\n",
  213. "y_fit = a0_fit + a1_fit * x_plot + a2_fit * x_plot**2 + a3_fit * x_plot**3"
  214. ]
  215. },
  216. {
  217. "cell_type": "markdown",
  218. "metadata": {},
  219. "source": [
  220. "The Minos algorithm uses the profile likelihood method to compute (generally asymmetric) confidence intervals. This can be plotted"
  221. ]
  222. },
  223. {
  224. "cell_type": "code",
  225. "execution_count": null,
  226. "metadata": {},
  227. "outputs": [],
  228. "source": [
  229. "m.draw_profile(\"a2\")"
  230. ]
  231. },
  232. {
  233. "cell_type": "markdown",
  234. "metadata": {},
  235. "source": [
  236. "Get a 2D contour of the function around the minimum for 2 parameters"
  237. ]
  238. },
  239. {
  240. "cell_type": "code",
  241. "execution_count": null,
  242. "metadata": {},
  243. "outputs": [],
  244. "source": [
  245. "m.draw_mncontour(\"a2\", \"a3\" , cl=[1, 2, 3])"
  246. ]
  247. },
  248. {
  249. "cell_type": "markdown",
  250. "metadata": {},
  251. "source": [
  252. "lotlib"
  253. ]
  254. },
  255. {
  256. "cell_type": "code",
  257. "execution_count": null,
  258. "metadata": {},
  259. "outputs": [],
  260. "source": [
  261. "plt.figure()\n",
  262. "plt.errorbar(x, y, dy , dx, fmt=\"o\")\n",
  263. "plt.plot(x_plot, y_fit)\n",
  264. "plt.title(\"iminuit Fit Test\")\n",
  265. "plt.xlim(-0.1, 10.1)\n",
  266. "plt.show()"
  267. ]
  268. }
  269. ],
  270. "metadata": {
  271. "kernelspec": {
  272. "display_name": "Python 3 (ipykernel)",
  273. "language": "python",
  274. "name": "python3"
  275. },
  276. "language_info": {
  277. "codemirror_mode": {
  278. "name": "ipython",
  279. "version": 3
  280. },
  281. "file_extension": ".py",
  282. "mimetype": "text/x-python",
  283. "name": "python",
  284. "nbconvert_exporter": "python",
  285. "pygments_lexer": "ipython3",
  286. "version": "3.8.16"
  287. }
  288. },
  289. "nbformat": 4,
  290. "nbformat_minor": 4
  291. }