Machine Learning Kurs im Rahmen der Studierendentage im SS 2023
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

147 lines
3.1 KiB

2 years ago
  1. {
  2. "cells": [
  3. {
  4. "cell_type": "markdown",
  5. "metadata": {},
  6. "source": [
  7. "Fit 3rd order Polynomial to graph data using scikit-learn, more infos\n",
  8. "https://www.datatechnotes.com/2018/06/polynomial-regression-fitting-in-python.html"
  9. ]
  10. },
  11. {
  12. "cell_type": "code",
  13. "execution_count": null,
  14. "metadata": {},
  15. "outputs": [],
  16. "source": [
  17. "from matplotlib import pyplot as plt\n",
  18. "plt.rcParams[\"font.size\"] = 20\n",
  19. "\n",
  20. "import numpy as np\n",
  21. "\n",
  22. "from sklearn.linear_model import LinearRegression\n",
  23. "from sklearn.preprocessing import PolynomialFeatures"
  24. ]
  25. },
  26. {
  27. "cell_type": "markdown",
  28. "metadata": {},
  29. "source": [
  30. "data"
  31. ]
  32. },
  33. {
  34. "cell_type": "code",
  35. "execution_count": null,
  36. "metadata": {},
  37. "outputs": [],
  38. "source": [
  39. "x = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype='d')\n",
  40. "dx = np.array([0.1,0.1,0.5,0.1,0.5,0.1,0.5,0.1,0.5,0.1], dtype='d')\n",
  41. "y = np.array([1.1 ,2.3 ,2.7 ,3.2 ,3.1 ,2.4 ,1.7 ,1.5 ,1.5 ,1.7 ], dtype='d')\n",
  42. "dy = np.array([0.15,0.22,0.29,0.39,0.31,0.21,0.13,0.15,0.19,0.13], dtype='d')"
  43. ]
  44. },
  45. {
  46. "cell_type": "markdown",
  47. "metadata": {},
  48. "source": [
  49. " building polynomial model"
  50. ]
  51. },
  52. {
  53. "cell_type": "code",
  54. "execution_count": null,
  55. "metadata": {},
  56. "outputs": [],
  57. "source": [
  58. "polyModel = PolynomialFeatures(degree = 3)\n",
  59. "xpol = polyModel.fit_transform(x.reshape(-1, 1))\n",
  60. "preg = polyModel.fit(xpol,y)"
  61. ]
  62. },
  63. {
  64. "cell_type": "markdown",
  65. "metadata": {},
  66. "source": [
  67. "Building linear model"
  68. ]
  69. },
  70. {
  71. "cell_type": "code",
  72. "execution_count": null,
  73. "metadata": {},
  74. "outputs": [],
  75. "source": [
  76. "linearModel = LinearRegression(fit_intercept = True)\n",
  77. "linearModel.fit(xpol, y[:, np.newaxis])"
  78. ]
  79. },
  80. {
  81. "cell_type": "markdown",
  82. "metadata": {},
  83. "source": [
  84. "Plotting\n"
  85. ]
  86. },
  87. {
  88. "cell_type": "code",
  89. "execution_count": null,
  90. "metadata": {},
  91. "outputs": [],
  92. "source": [
  93. "x_plot = np.linspace(0.1,10.1,200)\n",
  94. "polyfit = linearModel.predict(preg.fit_transform(x_plot.reshape(-1, 1)))"
  95. ]
  96. },
  97. {
  98. "cell_type": "markdown",
  99. "metadata": {},
  100. "source": [
  101. " plot data"
  102. ]
  103. },
  104. {
  105. "cell_type": "code",
  106. "execution_count": null,
  107. "metadata": {},
  108. "outputs": [],
  109. "source": [
  110. "plt.figure()\n",
  111. "plt.errorbar(x, y, dy , dx, fmt=\"o\")\n",
  112. "plt.plot(x_plot, polyfit )\n",
  113. "plt.title(\"scikit-learn Fit Test\")\n",
  114. "plt.xlim(-0.1, 10.1)\n",
  115. "plt.show()"
  116. ]
  117. },
  118. {
  119. "cell_type": "code",
  120. "execution_count": null,
  121. "metadata": {},
  122. "outputs": [],
  123. "source": []
  124. }
  125. ],
  126. "metadata": {
  127. "kernelspec": {
  128. "display_name": "Python 3 (ipykernel)",
  129. "language": "python",
  130. "name": "python3"
  131. },
  132. "language_info": {
  133. "codemirror_mode": {
  134. "name": "ipython",
  135. "version": 3
  136. },
  137. "file_extension": ".py",
  138. "mimetype": "text/x-python",
  139. "name": "python",
  140. "nbconvert_exporter": "python",
  141. "pygments_lexer": "ipython3",
  142. "version": "3.8.16"
  143. }
  144. },
  145. "nbformat": 4,
  146. "nbformat_minor": 4
  147. }