Machine Learning Kurs im Rahmen der Studierendentage im SS 2023
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

195 lines
22 KiB

2 years ago
  1. {
  2. "cells": [
  3. {
  4. "cell_type": "markdown",
  5. "metadata": {},
  6. "source": [
  7. "# Simple example of logistic regression with scikit-learn"
  8. ]
  9. },
  10. {
  11. "cell_type": "code",
  12. "execution_count": 2,
  13. "metadata": {},
  14. "outputs": [],
  15. "source": [
  16. "import numpy as np\n",
  17. "import pandas as pd\n",
  18. "import matplotlib.pyplot as plt"
  19. ]
  20. },
  21. {
  22. "cell_type": "markdown",
  23. "metadata": {},
  24. "source": [
  25. "### Read data \n",
  26. "Data are from the [wikipedia article on logistic regression](https://en.wikipedia.org/wiki/Logistic_regression)"
  27. ]
  28. },
  29. {
  30. "cell_type": "code",
  31. "execution_count": 3,
  32. "metadata": {},
  33. "outputs": [],
  34. "source": [
  35. "# data: 1. hours studies, 2. passed (0/1) \n",
  36. "filename = \"https://www.physi.uni-heidelberg.de/~reygers/lectures/2021/ml/data/exam.txt\"\n",
  37. "df = pd.read_csv(filename, engine='python', sep='\\s+')"
  38. ]
  39. },
  40. {
  41. "cell_type": "code",
  42. "execution_count": 4,
  43. "metadata": {},
  44. "outputs": [],
  45. "source": [
  46. "x_tmp = df['hours_studied'].values\n",
  47. "x = np.reshape(x_tmp, (-1, 1))\n",
  48. "y = df['passed'].values"
  49. ]
  50. },
  51. {
  52. "cell_type": "markdown",
  53. "metadata": {},
  54. "source": [
  55. "### Fit the model"
  56. ]
  57. },
  58. {
  59. "cell_type": "code",
  60. "execution_count": 5,
  61. "metadata": {},
  62. "outputs": [],
  63. "source": [
  64. "from sklearn.linear_model import LogisticRegression\n",
  65. "clf = LogisticRegression(penalty='none', fit_intercept=True)\n",
  66. "clf.fit(x, y);"
  67. ]
  68. },
  69. {
  70. "cell_type": "markdown",
  71. "metadata": {},
  72. "source": [
  73. "### Calculate predictions"
  74. ]
  75. },
  76. {
  77. "cell_type": "code",
  78. "execution_count": 6,
  79. "metadata": {},
  80. "outputs": [],
  81. "source": [
  82. "hours_studied_tmp = np.linspace(0., 6., 1000)\n",
  83. "hours_studied = np.reshape(hours_studied_tmp, (-1, 1))\n",
  84. "y_pred = clf.predict_proba(hours_studied)"
  85. ]
  86. },
  87. {
  88. "cell_type": "markdown",
  89. "metadata": {},
  90. "source": [
  91. "### Plot result"
  92. ]
  93. },
  94. {
  95. "cell_type": "code",
  96. "execution_count": 7,
  97. "metadata": {},
  98. "outputs": [
  99. {
  100. "data": {
  101. "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYYAAAEKCAYAAAAW8vJGAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy86wFpkAAAACXBIWXMAAAsTAAALEwEAmpwYAAAzgUlEQVR4nO3dd3gVZfbA8e9JBxJ6qKELSBEpodjAumJZy2LvBVF/Yllsa9nVZYuCruu66+7asFJEwd57YQUJvfcWaoAQEkLqPb8/ZqI3MQmTcGtyPs9zn9w7d2beM5eQc+etoqoYY4wxZWLCHYAxxpjIYonBGGNMOZYYjDHGlGOJwRhjTDmWGIwxxpQTF+4ADlfLli21c+fO4Q7DGGOiyrx583arampl70V9YujcuTMZGRnhDsMYY6KKiGyq6j2rSjLGGFOOJQZjjDHlWGIwxhhTjiUGY4wx5VhiMMYYU07IEoOITBKRXSKytIr3RUSeEpG1IrJYRAaGKjYTHHvyClm0ZR978gqjuoxQlROqa6mp2sQVimNCFVdNReq/Y02EsrvqS8C/gFeqeP8MoLv7GAr8x/1potA7C7dy74zFxMfEUOzzMXFUP87p3z7qyghVOaG6llDEFYpjQhVXTUXqv2NNheyOQVW/BfZWs8u5wCvqmA00FZG2oYnOBNKevELunbGYgmIfuYUlFBT7uGfG4oB+gwpFGaEqJ1TXEoq4QnFMqOKqqUj9d6yNSGpjaA9s8Xud6W77BREZIyIZIpKRlZUVkuCMd5nZB4mPKf+rFR8TQ2b2wagqI1TlhOpaaqo2cYXimFDFVVOhKENVKSwpZV9+Edv2HSS3oDhg5/YXlSOfVfVZ4FmA9PR0W2kowqQ1a0Cxz1duW7HPR1qzBlFVRqjKCdW11FRt4grFMaGKq6YOVYaqkl9USnZ+Efvyi8nOLyLnYDF5BSXkFZaQ6/786XVhCXkFxT9tyy0sIb+olFLfz3/y/nJ+Xy4f2ilg11AmkhLDVqCD3+s0d5uJMi2SE5k4qh/3VKhrbZGcGFVlhKqcUF1LKOIKxTGhisuLwpJSsnIL2ZVbyK79hfy6XztmLsgkBqFUlbSmDbjsuTk/JYOiUl+152uUEEtyUhzJiXEkJ8WTkhhHq5Skn7Y1TIilUWIcDeJjaZgQS3rn5ocVf1UklEt7ikhn4H1V7VvJe2cBY4EzcRqdn1LVIYc6Z3p6utpcSZFpT14hmdkHSWvWIGh/5EJRRqjKCdW11FRt4grFMcEuo9Sn7NhfQObefLbuO0hm9kEys/PZtq+AXbkF7MotZF/+L6tyYgRSkuJp3iiBlskJNGvoPJo2inefx9PU3dakQTwpSXEkJ8XRKCGO2BjxdB2BICLzVDW90ve8JgYROR84CWhFhbYJVb3Iw/FTgROBlsBO4CEg3j3+vyIiOL2WRgL5wLWqesi/+JYYjDG1VVLqIzP7IOt357E+6wDrsg6wcfcBMvfls31fASW+8n8fW6Uk0rZpA1qnJNKqcSKtUpJoVeF5i+TEkP6Br63qEoOnqiQR+RtwKzAL5496aU2DUNVLD/G+ArfU9LzGGHMoqkpm9kGWbdvP8u37Wb0jl3VZeWzak1+ueqdpw3i6tGzEgA7N+HW/BqQ1a0haswakNWtAu6YNSIqPDeNVhI7XNoargQtV9Z1gBmOMMYfL51PWZuWxcPM+lm3LYcX2XFZs309uYQkAItClRSO6piZz8pGt6JaaTNdU53XzRglhjj4yeE0M+cDKYAZijDG1kZNfzIIt2czfvI8Fm7NZuHnfT0mgYUIsvdo25twB7ejdtgm92zWmZ+sUGiTUj2/+teU1MTwK3CMiN6pqSTADMsaY6hwoLOHHjXv539rdzFq7hxU79qPqNPr2aJ3Cr/u3Y0CHpgzo2IyuLRsREwX1/ZHGa2J4Dvg1sFVEVgPlmuJV9eRAB2aMMeBUDS3dlsMXK3Yxa+1uFm7ZR4lPSYiNYWCnptxxSg8Gd25Gvw5NSU6MpB740cvrp/hf4HjgY5zGZxtUZowJmoLiUmat3c3nK3bxxYqd7MotJEbgqLSm3DC8K8d1a8mgTs2sSihIvCaGi4HzVfWzYAZjjKm/CopL+XrVLt5dtI0vV+6ioNhHo4RYRvRM5ZQjW3PSka2scThEvCaGLGwUsjEmwEpKfcxat4d3F27j02U7yC0soWVyAhcMSuNXvdswtGtzEuPsriDUvCaGh4DxInKNquYFMyBjTN23eU8+r2ds5o2MTHblFpKSFMfIvm04p387junagrjYSJrfs/7xmhjuBjoDO0VkM79sfO4X4LiMMXVMUYmPT5fvYNqPW/h+7W5iBE7q2YoL0ztwYs/UejN4LBp4TQxvBjUKY0ydlX2giCk/bubl/21kV24h7Zs2YNxpPbgwPY22TcI7g6ypnKfEoKp/DHYgxpi6ZX1WHpNmbeDNeZkUFPs4oXtLJozqx/AeqVExl1B9Zp1+jTEBtXpnLk99sYYPlmwnPiaG8wa04/rju9KzTUq4QzMeeZ1ELwF4ALgU6Ig7K2oZVbXKQWPquVU7cnnqyzV8uGQ7DeNjuWlEN647rgupKZEzhbjxxusdw59wxjI8AvydnxujLwF+H5TIjDFRYcvefB77ZBXvLd5Gw/hYbh7RjdEndLUxB1HMa2K4CLhJVT8WkceBd1R1nYisAE4DnglahMaYiJSTX8zTX6/lpVkbiYmBm0d044YTutLMEkLU85oYWgPL3ed5QFP3+cfAhADHZIyJYMWlPl79YRNPfbmGnIPFXDAwjTt/1ZM2TZLCHZoJEK+JYTPQzv25FjgdmAccAxwMTmjGmEjz44a9PPj2ElbvzOP4I1py/5m96N2ucbjDMgHmNTG8BZwCzAb+AUwVkRuA9sBjQYrNGBMh9uQV8shHK3lzXibtmzbguavSObVXK5wVeU1d43Ucw31+z98UkS3AccBqVX0/WMEZY8JLVZkxfyt//mA5eQUl3DSiG7edcgQNE6yne11Wq39dVZ0DzAEQkXhVLT7EIcaYKLNrfwH3zVzCFyt3MbhzM/5y/lH0aG1jEeoDr+MY3gWuVdU9Fbb3AqYAA4IQmzEmDFSVdxdt4w/vLKOguJQHz+rFdcd1sZXQ6hGvUxg2A5aIyK/KNojIWJwG6MXBCMwYE3o5B4sZO2UBt09bSJeWjfjw9hMYfUJXSwr1jNeqpBHAg8B7IvIs0A2nR9L1qjo1WMEZY0JnweZsbp26gB05Bdx9ek9uHN7Vpr+up7w2Pvtw1mOIxRnpXAIMV9XZwQzOGBN8Pp/y/PfrmfjxKto0SWL6TccwsGOzcIdlwshrG0Mi8DhwA/AwcALO3cNoVX0neOEZY4JpX34Rv319IV+tyuKMvm14dFQ/mjSIP/SBpk7zWpU0z933eFXNABCRO4FpIvKaqt4QrACNMcGxakcuY17NYPu+Av50bh+uGNbJxiUYwHtimA3cpqr5ZRtU9W8i8jnwWlAiM8YEzcdLtzNu+iKSE+OYOmYYgzpZ1ZH5mdc2htFVbF8kIumBDckYEyw+n/L3z1fzzy/X0r9DU565chCtG9scR6Y8z10OROQMEXlfRJaLSAd322jg+KBFZ4wJmILiUm6ZMp9/frmWi9LTeP3GYZYUTKU8JQYRuRyYDqwBuvDzQj2xwD3BCc0YEyh7DxRx+fNz+GjpDh44sxcTRvUjMc7W1zKV83rHcA9wg6r+FqerapnZQP9AB2WMCZxNew4w6j//Y8nWHP59+UBuGN7VGplNtbwmhu7AD5VszwM8z7krIiNFZJWIrBWR31XyfkcR+UpEFojIYhE50+u5jTG/tHDLPn7z7/+RnV/ElNFDOfOotuEOyUQBr4lhG9Cjku3DgXVeTuAOjnsaOAPoDVwqIr0r7PYgMF1VB+AsG/pvj/EZYyr437rdXPbcbBomxjLz5mNJ79w83CGZKOE1MTwLPCUix7mvO4jI1cBE4D8ezzEEWKuq61W1CJgGnFthH+XnO5AmOAnJGFNDX67cybUvziWtWQNm3HQsXVOTwx2SiSJeu6tOFJEmwGdAEvAVUAg8rqpPeyyrPbDF73UmMLTCPg8Dn4r
  102. "text/plain": [
  103. "<Figure size 432x288 with 1 Axes>"
  104. ]
  105. },
  106. "metadata": {
  107. "needs_background": "light"
  108. },
  109. "output_type": "display_data"
  110. }
  111. ],
  112. "source": [
  113. "df.plot.scatter(x='hours_studied', y='passed')\n",
  114. "plt.plot(hours_studied, y_pred[:,1])\n",
  115. "plt.xlabel(\"preparation time in hours\", fontsize=14)\n",
  116. "plt.ylabel(\"probability of passing exam\", fontsize=14)\n",
  117. "plt.savefig(\"03_ml_basics_logistic_regression.pdf\")"
  118. ]
  119. },
  120. {
  121. "cell_type": "code",
  122. "execution_count": 8,
  123. "metadata": {},
  124. "outputs": [
  125. {
  126. "data": {
  127. "text/plain": [
  128. "{'C': 1.0,\n",
  129. " 'class_weight': None,\n",
  130. " 'dual': False,\n",
  131. " 'fit_intercept': True,\n",
  132. " 'intercept_scaling': 1,\n",
  133. " 'l1_ratio': None,\n",
  134. " 'max_iter': 100,\n",
  135. " 'multi_class': 'auto',\n",
  136. " 'n_jobs': None,\n",
  137. " 'penalty': 'none',\n",
  138. " 'random_state': None,\n",
  139. " 'solver': 'lbfgs',\n",
  140. " 'tol': 0.0001,\n",
  141. " 'verbose': 0,\n",
  142. " 'warm_start': False}"
  143. ]
  144. },
  145. "execution_count": 8,
  146. "metadata": {},
  147. "output_type": "execute_result"
  148. }
  149. ],
  150. "source": [
  151. "clf.get_params()"
  152. ]
  153. },
  154. {
  155. "cell_type": "code",
  156. "execution_count": 9,
  157. "metadata": {},
  158. "outputs": [
  159. {
  160. "name": "stdout",
  161. "output_type": "stream",
  162. "text": [
  163. "Coefficient: [[1.50464522]]\n",
  164. "Intercept: [-4.07771764]\n"
  165. ]
  166. }
  167. ],
  168. "source": [
  169. "print('Coefficient: ', clf.coef_)\n",
  170. "print('Intercept: ', clf.intercept_)"
  171. ]
  172. }
  173. ],
  174. "metadata": {
  175. "kernelspec": {
  176. "display_name": "Python 3",
  177. "language": "python",
  178. "name": "python3"
  179. },
  180. "language_info": {
  181. "codemirror_mode": {
  182. "name": "ipython",
  183. "version": 3
  184. },
  185. "file_extension": ".py",
  186. "mimetype": "text/x-python",
  187. "name": "python",
  188. "nbconvert_exporter": "python",
  189. "pygments_lexer": "ipython3",
  190. "version": "3.8.5"
  191. }
  192. },
  193. "nbformat": 4,
  194. "nbformat_minor": 4
  195. }