Machine Learning Kurs im Rahmen der Studierendentage im SS 2023
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

165 lines
3.6 KiB

2 years ago
  1. {
  2. "cells": [
  3. {
  4. "cell_type": "markdown",
  5. "id": "042acd49",
  6. "metadata": {},
  7. "source": [
  8. "# Test a minimizer"
  9. ]
  10. },
  11. {
  12. "cell_type": "code",
  13. "execution_count": null,
  14. "id": "cb51a492",
  15. "metadata": {},
  16. "outputs": [],
  17. "source": [
  18. "import numpy as np\n",
  19. "import matplotlib.pyplot as plt\n",
  20. "from scipy.optimize import minimize\n",
  21. "plt.style.use(\"ggplot\")\n",
  22. "from matplotlib import colors, cm"
  23. ]
  24. },
  25. {
  26. "cell_type": "markdown",
  27. "id": "2ac3651a",
  28. "metadata": {},
  29. "source": [
  30. "plt.rcParams controls the appearance of your plots globally,\n",
  31. "affecting all subsequent plots created in your session."
  32. ]
  33. },
  34. {
  35. "cell_type": "code",
  36. "execution_count": null,
  37. "id": "97ef9933",
  38. "metadata": {},
  39. "outputs": [],
  40. "source": [
  41. "plt.rcParams[\"axes.grid\"] = False\n",
  42. "plt.rcParams.update({'font.size': 20})\n",
  43. "plt.rcParams.update({'figure.figsize': (12,9)})\n",
  44. "plt.rcParams['lines.markersize'] = 8"
  45. ]
  46. },
  47. {
  48. "cell_type": "code",
  49. "execution_count": null,
  50. "id": "f15200f9",
  51. "metadata": {},
  52. "outputs": [],
  53. "source": [
  54. "# Generate data points with gaussian smearing\n",
  55. "data = np.random.uniform(size=100)\n",
  56. "labels = 5.*data*data*data + 1 + np.random.normal(loc=0.0, scale=0.1, size=100)"
  57. ]
  58. },
  59. {
  60. "cell_type": "code",
  61. "execution_count": null,
  62. "id": "7237f5ed",
  63. "metadata": {},
  64. "outputs": [],
  65. "source": [
  66. "# show plot\n",
  67. "plt.scatter(data, labels, label=\"data\")\n",
  68. "plt.legend()\n",
  69. "plt.show()"
  70. ]
  71. },
  72. {
  73. "cell_type": "code",
  74. "execution_count": null,
  75. "id": "0d6e104c",
  76. "metadata": {},
  77. "outputs": [],
  78. "source": [
  79. "# define chi2 like cost function\n",
  80. "def cost(params):\n",
  81. " W, b = params\n",
  82. " return np.mean((labels - (W*data*data*data + b))**2)"
  83. ]
  84. },
  85. {
  86. "cell_type": "markdown",
  87. "id": "8e00e16a",
  88. "metadata": {},
  89. "source": [
  90. "call minimizer\n",
  91. "provides a collection of optimization algorithms for finding the minimum or maximum of a given function. "
  92. ]
  93. },
  94. {
  95. "cell_type": "code",
  96. "execution_count": null,
  97. "id": "433975c3",
  98. "metadata": {},
  99. "outputs": [],
  100. "source": [
  101. "res = minimize(cost, [1., 1.])\n",
  102. "# returns an OptimizeResult object\n",
  103. "# x :the solution (minimum) of the optimization problem, represented as an\n",
  104. "# array.\n",
  105. "# Results of the minimization\n",
  106. "W, b = res.x\n",
  107. "print ('function value at the minimum and fitted parameters',res.fun,' ',W,' ',b)"
  108. ]
  109. },
  110. {
  111. "cell_type": "code",
  112. "execution_count": null,
  113. "id": "1e1f4e81",
  114. "metadata": {},
  115. "outputs": [],
  116. "source": [
  117. "points = np.linspace(0, 1, 100)\n",
  118. "prediction = W*points*points*points + b"
  119. ]
  120. },
  121. {
  122. "cell_type": "code",
  123. "execution_count": null,
  124. "id": "d8de971e",
  125. "metadata": {},
  126. "outputs": [],
  127. "source": [
  128. "# plot fit model\n",
  129. "plt.scatter(data, labels, label=\"data\")\n",
  130. "plt.plot(points, prediction, label=\"model\", color=\"green\")\n",
  131. "plt.legend()\n",
  132. "plt.show()"
  133. ]
  134. },
  135. {
  136. "cell_type": "code",
  137. "execution_count": null,
  138. "id": "4a7d62c2",
  139. "metadata": {},
  140. "outputs": [],
  141. "source": []
  142. }
  143. ],
  144. "metadata": {
  145. "kernelspec": {
  146. "display_name": "Python 3 (ipykernel)",
  147. "language": "python",
  148. "name": "python3"
  149. },
  150. "language_info": {
  151. "codemirror_mode": {
  152. "name": "ipython",
  153. "version": 3
  154. },
  155. "file_extension": ".py",
  156. "mimetype": "text/x-python",
  157. "name": "python",
  158. "nbconvert_exporter": "python",
  159. "pygments_lexer": "ipython3",
  160. "version": "3.8.16"
  161. }
  162. },
  163. "nbformat": 4,
  164. "nbformat_minor": 5
  165. }