Machine Learning Kurs im Rahmen der Studierendentage im SS 2023
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

166 lines
3.7 KiB

2 years ago
  1. {
  2. "cells": [
  3. {
  4. "cell_type": "markdown",
  5. "id": "042acd49",
  6. "metadata": {},
  7. "source": [
  8. "# Test a minimizer"
  9. ]
  10. },
  11. {
  12. "cell_type": "code",
  13. "execution_count": null,
  14. "id": "cb51a492",
  15. "metadata": {},
  16. "outputs": [],
  17. "source": [
  18. "import tensorflow as tf\n",
  19. "import numpy as np\n",
  20. "import matplotlib.pyplot as plt\n",
  21. "from scipy.optimize import minimize\n",
  22. "plt.style.use(\"ggplot\")\n",
  23. "from matplotlib import colors, cm"
  24. ]
  25. },
  26. {
  27. "cell_type": "markdown",
  28. "id": "2ac3651a",
  29. "metadata": {},
  30. "source": [
  31. "plt.rcParams controls the appearance of your plots globally,\n",
  32. "affecting all subsequent plots created in your session."
  33. ]
  34. },
  35. {
  36. "cell_type": "code",
  37. "execution_count": null,
  38. "id": "97ef9933",
  39. "metadata": {},
  40. "outputs": [],
  41. "source": [
  42. "plt.rcParams[\"axes.grid\"] = False\n",
  43. "plt.rcParams.update({'font.size': 20})\n",
  44. "plt.rcParams.update({'figure.figsize': (12,9)})\n",
  45. "plt.rcParams['lines.markersize'] = 8"
  46. ]
  47. },
  48. {
  49. "cell_type": "code",
  50. "execution_count": null,
  51. "id": "f15200f9",
  52. "metadata": {},
  53. "outputs": [],
  54. "source": [
  55. "# Generate data points with gaussian smearing\n",
  56. "data = np.random.uniform(size=100)\n",
  57. "labels = 5.*data*data*data + 1 + np.random.normal(loc=0.0, scale=0.1, size=100)"
  58. ]
  59. },
  60. {
  61. "cell_type": "code",
  62. "execution_count": null,
  63. "id": "7237f5ed",
  64. "metadata": {},
  65. "outputs": [],
  66. "source": [
  67. "# show plot\n",
  68. "plt.scatter(data, labels, label=\"data\")\n",
  69. "plt.legend()\n",
  70. "plt.show()"
  71. ]
  72. },
  73. {
  74. "cell_type": "code",
  75. "execution_count": null,
  76. "id": "0d6e104c",
  77. "metadata": {},
  78. "outputs": [],
  79. "source": [
  80. "# define chi2 like cost function\n",
  81. "def cost(params):\n",
  82. " W, b = params\n",
  83. " return np.mean((labels - (W*data*data*data + b))**2)"
  84. ]
  85. },
  86. {
  87. "cell_type": "markdown",
  88. "id": "8e00e16a",
  89. "metadata": {},
  90. "source": [
  91. "call minimizer\n",
  92. "provides a collection of optimization algorithms for finding the minimum or maximum of a given function. "
  93. ]
  94. },
  95. {
  96. "cell_type": "code",
  97. "execution_count": null,
  98. "id": "433975c3",
  99. "metadata": {},
  100. "outputs": [],
  101. "source": [
  102. "res = minimize(cost, [1., 1.])\n",
  103. "# returns an OptimizeResult object\n",
  104. "# x :the solution (minimum) of the optimization problem, represented as an\n",
  105. "# array.\n",
  106. "# Results of the minimization\n",
  107. "W, b = res.x\n",
  108. "print ('function value at the minimum and fitted parameters',res.fun,' ',W,' ',b)"
  109. ]
  110. },
  111. {
  112. "cell_type": "code",
  113. "execution_count": null,
  114. "id": "1e1f4e81",
  115. "metadata": {},
  116. "outputs": [],
  117. "source": [
  118. "points = np.linspace(0, 1, 100)\n",
  119. "prediction = W*points*points*points + b"
  120. ]
  121. },
  122. {
  123. "cell_type": "code",
  124. "execution_count": null,
  125. "id": "d8de971e",
  126. "metadata": {},
  127. "outputs": [],
  128. "source": [
  129. "# plot fit model\n",
  130. "plt.scatter(data, labels, label=\"data\")\n",
  131. "plt.plot(points, prediction, label=\"model\", color=\"green\")\n",
  132. "plt.legend()\n",
  133. "plt.show()"
  134. ]
  135. },
  136. {
  137. "cell_type": "code",
  138. "execution_count": null,
  139. "id": "4a7d62c2",
  140. "metadata": {},
  141. "outputs": [],
  142. "source": []
  143. }
  144. ],
  145. "metadata": {
  146. "kernelspec": {
  147. "display_name": "Python 3 (ipykernel)",
  148. "language": "python",
  149. "name": "python3"
  150. },
  151. "language_info": {
  152. "codemirror_mode": {
  153. "name": "ipython",
  154. "version": 3
  155. },
  156. "file_extension": ".py",
  157. "mimetype": "text/x-python",
  158. "name": "python",
  159. "nbconvert_exporter": "python",
  160. "pygments_lexer": "ipython3",
  161. "version": "3.8.16"
  162. }
  163. },
  164. "nbformat": 4,
  165. "nbformat_minor": 5
  166. }