Machine Learning Kurs im Rahmen der Studierendentage im SS 2023
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

367 lines
107 KiB

  1. {
  2. "cells": [
  3. {
  4. "cell_type": "markdown",
  5. "metadata": {},
  6. "source": [
  7. "# Simple classification example: the iris dataset"
  8. ]
  9. },
  10. {
  11. "cell_type": "code",
  12. "execution_count": 2,
  13. "metadata": {},
  14. "outputs": [],
  15. "source": [
  16. "import matplotlib.pyplot as plt\n",
  17. "import pandas as pd\n",
  18. "from sklearn import datasets\n",
  19. "from sklearn.model_selection import train_test_split\n",
  20. "from sklearn.metrics import classification_report\n",
  21. "from sklearn.metrics import accuracy_score\n",
  22. "from sklearn.metrics import confusion_matrix"
  23. ]
  24. },
  25. {
  26. "cell_type": "code",
  27. "execution_count": 3,
  28. "metadata": {},
  29. "outputs": [],
  30. "source": [
  31. "# import some data to play with\n",
  32. "# columns: Sepal Length, Sepal Width, Petal Length and Petal Width\n",
  33. "iris = datasets.load_iris()\n",
  34. "X = iris.data\n",
  35. "y = iris.target"
  36. ]
  37. },
  38. {
  39. "cell_type": "code",
  40. "execution_count": 4,
  41. "metadata": {},
  42. "outputs": [
  43. {
  44. "data": {
  45. "text/html": [
  46. "<div>\n",
  47. "<style scoped>\n",
  48. " .dataframe tbody tr th:only-of-type {\n",
  49. " vertical-align: middle;\n",
  50. " }\n",
  51. "\n",
  52. " .dataframe tbody tr th {\n",
  53. " vertical-align: top;\n",
  54. " }\n",
  55. "\n",
  56. " .dataframe thead th {\n",
  57. " text-align: right;\n",
  58. " }\n",
  59. "</style>\n",
  60. "<table border=\"1\" class=\"dataframe\">\n",
  61. " <thead>\n",
  62. " <tr style=\"text-align: right;\">\n",
  63. " <th></th>\n",
  64. " <th>Sepal Length (cm)</th>\n",
  65. " <th>Sepal Width (cm)</th>\n",
  66. " <th>Petal Length (cm)</th>\n",
  67. " <th>Petal Width (cm)</th>\n",
  68. " <th>category</th>\n",
  69. " </tr>\n",
  70. " </thead>\n",
  71. " <tbody>\n",
  72. " <tr>\n",
  73. " <th>0</th>\n",
  74. " <td>5.1</td>\n",
  75. " <td>3.5</td>\n",
  76. " <td>1.4</td>\n",
  77. " <td>0.2</td>\n",
  78. " <td>0</td>\n",
  79. " </tr>\n",
  80. " <tr>\n",
  81. " <th>1</th>\n",
  82. " <td>4.9</td>\n",
  83. " <td>3.0</td>\n",
  84. " <td>1.4</td>\n",
  85. " <td>0.2</td>\n",
  86. " <td>0</td>\n",
  87. " </tr>\n",
  88. " <tr>\n",
  89. " <th>2</th>\n",
  90. " <td>4.7</td>\n",
  91. " <td>3.2</td>\n",
  92. " <td>1.3</td>\n",
  93. " <td>0.2</td>\n",
  94. " <td>0</td>\n",
  95. " </tr>\n",
  96. " <tr>\n",
  97. " <th>3</th>\n",
  98. " <td>4.6</td>\n",
  99. " <td>3.1</td>\n",
  100. " <td>1.5</td>\n",
  101. " <td>0.2</td>\n",
  102. " <td>0</td>\n",
  103. " </tr>\n",
  104. " <tr>\n",
  105. " <th>4</th>\n",
  106. " <td>5.0</td>\n",
  107. " <td>3.6</td>\n",
  108. " <td>1.4</td>\n",
  109. " <td>0.2</td>\n",
  110. " <td>0</td>\n",
  111. " </tr>\n",
  112. " </tbody>\n",
  113. "</table>\n",
  114. "</div>"
  115. ],
  116. "text/plain": [
  117. " Sepal Length (cm) Sepal Width (cm) Petal Length (cm) Petal Width (cm) \\\n",
  118. "0 5.1 3.5 1.4 0.2 \n",
  119. "1 4.9 3.0 1.4 0.2 \n",
  120. "2 4.7 3.2 1.3 0.2 \n",
  121. "3 4.6 3.1 1.5 0.2 \n",
  122. "4 5.0 3.6 1.4 0.2 \n",
  123. "\n",
  124. " category \n",
  125. "0 0 \n",
  126. "1 0 \n",
  127. "2 0 \n",
  128. "3 0 \n",
  129. "4 0 "
  130. ]
  131. },
  132. "execution_count": 4,
  133. "metadata": {},
  134. "output_type": "execute_result"
  135. }
  136. ],
  137. "source": [
  138. "# just to create a nice table\n",
  139. "df = pd.DataFrame({\"Sepal Length (cm)\": X[:,0], \"Sepal Width (cm)\": X[:,1], \n",
  140. " 'Petal Length (cm)': X[:,2], 'Petal Width (cm)': X[:,3], \n",
  141. " 'category': y})\n",
  142. "df.head()"
  143. ]
  144. },
  145. {
  146. "cell_type": "code",
  147. "execution_count": 5,
  148. "metadata": {},
  149. "outputs": [
  150. {
  151. "data": {
  152. "text/plain": [
  153. "['setosa', 'versicolor', 'virginica']"
  154. ]
  155. },
  156. "execution_count": 5,
  157. "metadata": {},
  158. "output_type": "execute_result"
  159. }
  160. ],
  161. "source": [
  162. "list(iris.target_names)"
  163. ]
  164. },
  165. {
  166. "cell_type": "code",
  167. "execution_count": 6,
  168. "metadata": {},
  169. "outputs": [],
  170. "source": [
  171. "# split data into training and test data sets\n",
  172. "x_train, x_test, y_train, y_test = train_test_split(X, y, test_size=0.5, random_state=42)"
  173. ]
  174. },
  175. {
  176. "cell_type": "code",
  177. "execution_count": 7,
  178. "metadata": {},
  179. "outputs": [
  180. {
  181. "data": {
  182. "text/plain": [
  183. "Text(0, 0.5, 'Petal width')"
  184. ]
  185. },
  186. "execution_count": 7,
  187. "metadata": {},
  188. "output_type": "execute_result"
  189. },
  190. {
  191. "data": {
  192. "image/png": "iVBORw0KGgoAAAANSUhEUgAAA04AAAHCCAYAAADYTZkLAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy89olMNAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdd3QUVfvA8e9sS++9kVBD7yBFqiBVQVF6FWygYn/F3rGL72tFpUlTEVABC0Wa9Cq9QwIkJKRs+ia7e39/RKL5pUuSDfh8ztlzZO6de58Z4c48uzP3akophRBCCCGEEEKIEukcHYAQQgghhBBC1HSSOAkhhBBCCCFEGSRxEkIIIYQQQogySOIkhBBCCCGEEGWQxEkIIYQQQgghyiCJkxBCCCGEEEKUQRInIYQQQgghhCiDJE5CCCGEEEIIUQZJnIQQQgghhBCiDJI4CSGEEEIIIUQZDI4O4Io33niDadOmMXXqVGbMmFFsnTlz5jBhwoRC25ycnMjJySl3P3a7nYsXL+Lh4YGmaVcTshBCiApQSpGenk5oaCg6nXxv93dybRJCCMeoyLWpRiROO3fu5LPPPqN58+Zl1vX09OTYsWMFf67oBebixYtERERUOEYhhBCVIzY2lvDwcEeHUaPItUkIIRyrPNcmhydOGRkZjBo1is8//5xXX321zPqaphEcHPyP+/Pw8ADyT46np+c/bkcIIUTFpKWlERERUTAOi7/ItUkIIRyjItcmhydOU6ZMYcCAAfTq1atciVNGRgaRkZHY7XZat27N66+/TpMmTUqsb7FYsFgsBX9OT08H8n+5kouTEEJUP3kUragr50SuTUII4RjluTY59CHzxYsXs2fPHqZPn16u+tHR0cyaNYvvv/+e+fPnY7fb6dSpE+fPny9xn+nTp+Pl5VXwkUchhBBCCCGEEBXlsMQpNjaWqVOnsmDBApydncu1T8eOHRk7diwtW7akW7duLF26lICAAD777LMS95k2bRpms7ngExsbW1mHIIQQQgghhPiXcNijert37yYhIYHWrVsXbLPZbGzcuJEPP/wQi8WCXq8vtQ2j0UirVq04efJkiXWcnJxwcnKqtLiFEEIIIYQQ/z4OS5xuuukmDhw4UGjbhAkTaNiwIf/5z3/KTJogP9E6cOAA/fv3r6owhRBCCCGEEMJxiZOHhwdNmzYttM3NzQ0/P7+C7WPHjiUsLKzgHaiXX36ZDh06UK9ePVJTU3n77bc5d+4ckyZNqvb4hRBCCCGEEP8eDp9VrzQxMTGFFqJKSUnh7rvvJj4+Hh8fH9q0acOWLVto3LixA6MUQgghhBBCXO80pZRydBDVKS0tDS8vL8xms0z5KoQQ1UjG35LJuRFCCMeoyPjr0OnIhRBCCCGEEOJaIImTEEIIIYQQQpShRr/jJIQjxcfHM2fOHI4dO4aHhwd33HEHXbp0KdfK0kIIIYQQVSUnJ4dvv/2WjRs3opSia9euDB06tNxroyql2LhxI9999x3p6elER0czfvx4AObOncvRo0dxd3fnjjvuoGvXrnLv8yd5x0mIYnz88cdMnToV7BqeOm9ysZBhTaNb124s/3453t7ejg5RiGuOjL8lk3MjhCiv3bt3c+utA7h48RItm7qhabD3QCYhIYH88MNK2rZtW+r+qamp3Hbbraxfv4moWs6EBOrZdygHi0Wh04HRqNGyiTNxCTbOxuTQrduNLFv2Az4+PtV0hNVL3nES4ip8//33TJkyhWBrJJ3t/Wht68YN1t60pDPbft/OnXcOdXSIQogqMn36dNq1a4eHhweBgYEMHjyYY8eOlbrPnDlz0DSt0Ke83/oKIURFXLp0iT59ehEelM7hTZHsXh3Krl9DObI5ksjQTPr27U18fHypbQwbegf7923jh69CObktnM0/hvL5u/7Y7XbuHu1B7J5INv+YX/bj/FAO/LGdoXcOqaYjrNkkcRLi/3nl5Vfw0wURTUuMmgkATdPw10JoYGvJmjWr2b17t4OjFEJUhQ0bNjBlyhS2bdvG6tWrycvL4+abbyYzM7PU/Tw9PYmLiyv4nDt3rpoiFkL8m8ycOZPs7Ax+/CqY6Hqmgu0N6pr4YV4QOTkZzJw5s8T9d+7cya+r1zLzHX8G9HIreATvo1lmunVy4X+vB+DjrQfy73363+TGzHf8WbP2N3bu3Fm1B3cNkMRJiL+Ji4tj957dhNijin2eN4BQnA0uLF++vPqDE0JUuZ9//pnx48fTpEkTWrRowZw5c4iJiSnzyxJN0wgODi74BAUFlVrfYrGQlpZW6COEEGVZvnwJt/Vzwd9PX6TMz1fP7f1dWLbs2xL3//777wkMcGJQX7eCbQmXrWzdlcPdo72Kvfe5tY8bQQFOLFu2rHIO4homiZMQf5OVlQWACVOx5TpNh0lzIjs7uzrDEkI4iNlsBsDX17fUehkZGURGRhIREcGgQYM4dOhQqfWnT5+Ol5dXwSciIqLSYhZCXL+ysjLx9y2aNF3h76snOzurlP2z8PHSo9f/lSBlZas/9y0+LdDrNXy89XLvgyROQhQSFhaGp4cnSSQUW56lMkjPM9O0adNqjkwIUd3sdjsPP/wwnTt3LvXffHR0NLNmzeL7779n/vz52O12OnXqxPnz50vcZ9q0aZjN5oJPbGxsVRyCEOI606xZK1ZvtFDc3G5KKVZvyKVp0xal7N+M46eyOBOTV7AtNMiAj7eO1RuKT7jOxuZx7GQWzZo1u/oDuMZJ4iTE3zg7OzNx0kTi9GdJV6mFyuzKzkndAby8vBg6VCaIEOJ6N2XKFA4ePMjixYtLrdexY0fGjh1Ly5Yt6datG0uXLiUgIIDPPvusxH2cnJzw9PQs9BFCiLLcf/9kDh/L5pM55iJln84zc/BoFvfdN7nE/YcOHYqXlyePPp9EXl5+8mUyadw1wpOZ88zsO2gpVD8vT/HYC0l4eXkybNiwyj2Ya5Cs4yTE//Piiy+yds1a9hzeSJAtAh8CsJBNvOEcWWSwbP4yXF1dHR2mEKIKPfDAA6xYsYKNGzcSHh5eoX2NRiOtWrXi5MmTVRSdEOLfqnv37kydOpUHn/6An9Zlc+ct+dORf/tjJitXZ/Dggw9y0003lbi/m5sb8+bNZ8iQ22nT+wJ3j3YnOEhPqtlOTi50HhjLhBGedOvkQvwlG18syODYqTyWLPkONze3Etv9t5BfnIT4fzw9Pdm0eRP/mfYkOX5pHGAbJ3UH6DmgO7///jsDBw50dIhCiCqilOKBBx5g2bJlrFu3jtq1a1e4DZvNxoEDBwgJCamCCIUQ/2aapvH+++8zd+5c4pLqMGHqJcY/dIkLibWZM2cOH3zwQZmL1d5yyy1s2rSZ+o168+gLSQy/J54Va0088sgTPDT1CZb9bGL4PfE8+kISdaN7sXHjJm699dZqOsKaTRbAFaIUdrsds9mMi4uLrMsixFW6FsbfyZMns3DhQr7//nuio6MLtnt5eeHi4gLA2LFjCQsLY/r06QC8/PLLdOjQgXr16pGamsrbb7/N8uXL2b17N40bNy5Xv9fCuRFC1Dzp6ekopf7xuJGTk0N2djZeXl7odPm/p/zb7n0qMv7Ko3pClEKn0123K2ULIYr65JNPgPzHYf5u9uzZjB8/HoCYmJiCGwyAlJQU7r77buLj4/Hx8aFNmzZs2bKl3EmTEEL8Ux4eHle1v7Ozc5HkSO59Sia/OAkhhKgWMv6WTM6NEEI4RkXGX3nHSQghhBBCCCHKIImTEEIIIYQQFWS328nMzCx2TaWq6CshIQGr1VrlfYmSSeIkhBBCCCFEOZ0+fZp7770XT0933N3d8ff34fHHHychIaHS+zpy5Ajt2rXDyclAUFAQTk5GoqOjWbduXaX3Jcomk0MIIYQQQghRDgcOHKB79y44GXN49F436tfxZN9BC19+8V++++4bNm/eSlhYWKX0tWfPHjp3ugGDwcYDd3nRurkTp87m8cmcU/Tp04v58xfJorTVTCaHEEIIUS1k/C2ZnBshaj6lFK1aNQfrKdYuCcHHW19QFnM+jy6D4mjbvg/Lln1fKf1FhIeRa7nE1lURREUYC7anmm3
  193. "text/plain": [
  194. "<Figure size 1000x500 with 2 Axes>"
  195. ]
  196. },
  197. "metadata": {},
  198. "output_type": "display_data"
  199. }
  200. ],
  201. "source": [
  202. "# plot with color code\n",
  203. "plt.subplots(1, 2, figsize=(10, 5))\n",
  204. "\n",
  205. "plt.subplot(1, 2, 1)\n",
  206. "plt.scatter(X[:, 0], X[:, 1], c=y, edgecolor='k')\n",
  207. "plt.xlabel('Sepal length')\n",
  208. "plt.ylabel('Sepal width')\n",
  209. "\n",
  210. "plt.subplot(1, 2, 2)\n",
  211. "plt.scatter(X[:, 2], X[:, 3], c=y, edgecolor='k')\n",
  212. "plt.xlabel('Petal length')\n",
  213. "plt.ylabel('Petal width')"
  214. ]
  215. },
  216. {
  217. "cell_type": "markdown",
  218. "metadata": {},
  219. "source": [
  220. "## Softmax regression"
  221. ]
  222. },
  223. {
  224. "cell_type": "code",
  225. "execution_count": 8,
  226. "metadata": {},
  227. "outputs": [],
  228. "source": [
  229. "from sklearn.linear_model import LogisticRegression\n",
  230. "log_reg = LogisticRegression(multi_class='multinomial', penalty='none')\n",
  231. "log_reg.fit(x_train, y_train);"
  232. ]
  233. },
  234. {
  235. "cell_type": "markdown",
  236. "metadata": {},
  237. "source": [
  238. "## k-nearest neighbor"
  239. ]
  240. },
  241. {
  242. "cell_type": "code",
  243. "execution_count": 9,
  244. "metadata": {},
  245. "outputs": [],
  246. "source": [
  247. "from sklearn.neighbors import KNeighborsClassifier\n",
  248. "kn_neigh = KNeighborsClassifier(n_neighbors=5)\n",
  249. "kn_neigh.fit(x_train, y_train);"
  250. ]
  251. },
  252. {
  253. "cell_type": "markdown",
  254. "metadata": {},
  255. "source": [
  256. "## Fisher linear discriminant"
  257. ]
  258. },
  259. {
  260. "cell_type": "code",
  261. "execution_count": 10,
  262. "metadata": {},
  263. "outputs": [],
  264. "source": [
  265. "from sklearn.discriminant_analysis import LinearDiscriminantAnalysis\n",
  266. "fisher_ld = LinearDiscriminantAnalysis()\n",
  267. "fisher_ld.fit(x_train, y_train);"
  268. ]
  269. },
  270. {
  271. "cell_type": "markdown",
  272. "metadata": {},
  273. "source": [
  274. "## Classification accuracy"
  275. ]
  276. },
  277. {
  278. "cell_type": "code",
  279. "execution_count": 11,
  280. "metadata": {},
  281. "outputs": [
  282. {
  283. "name": "stdout",
  284. "output_type": "stream",
  285. "text": [
  286. "LogisticRegression\n",
  287. "accuracy: 0.96\n",
  288. "[[29 0 0]\n",
  289. " [ 0 23 0]\n",
  290. " [ 0 3 20]] \n",
  291. "\n",
  292. "KNeighborsClassifier\n",
  293. "accuracy: 0.95\n",
  294. "[[29 0 0]\n",
  295. " [ 0 23 0]\n",
  296. " [ 0 4 19]] \n",
  297. "\n",
  298. "LinearDiscriminantAnalysis\n",
  299. "accuracy: 0.99\n",
  300. "[[29 0 0]\n",
  301. " [ 0 23 0]\n",
  302. " [ 0 1 22]] \n",
  303. "\n"
  304. ]
  305. }
  306. ],
  307. "source": [
  308. "for clf in [log_reg, kn_neigh, fisher_ld]:\n",
  309. " y_pred = clf.predict(x_test)\n",
  310. " acc = accuracy_score(y_test, y_pred)\n",
  311. " print(type(clf).__name__)\n",
  312. " print(f\"accuracy: {acc:0.2f}\")\n",
  313. " \n",
  314. " # confusion matrix: columns: true class, row: predicted class\n",
  315. " print(confusion_matrix(y_test, y_pred),\"\\n\")"
  316. ]
  317. },
  318. {
  319. "cell_type": "code",
  320. "execution_count": 12,
  321. "metadata": {},
  322. "outputs": [
  323. {
  324. "name": "stdout",
  325. "output_type": "stream",
  326. "text": [
  327. " precision recall f1-score support\n",
  328. "\n",
  329. " 0 1.00 1.00 1.00 29\n",
  330. " 1 0.88 1.00 0.94 23\n",
  331. " 2 1.00 0.87 0.93 23\n",
  332. "\n",
  333. " accuracy 0.96 75\n",
  334. " macro avg 0.96 0.96 0.96 75\n",
  335. "weighted avg 0.96 0.96 0.96 75\n",
  336. "\n"
  337. ]
  338. }
  339. ],
  340. "source": [
  341. "y_pred = log_reg.predict(x_test)\n",
  342. "print(classification_report(y_test, y_pred))"
  343. ]
  344. }
  345. ],
  346. "metadata": {
  347. "kernelspec": {
  348. "display_name": "Python 3",
  349. "language": "python",
  350. "name": "python3"
  351. },
  352. "language_info": {
  353. "codemirror_mode": {
  354. "name": "ipython",
  355. "version": 3
  356. },
  357. "file_extension": ".py",
  358. "mimetype": "text/x-python",
  359. "name": "python",
  360. "nbconvert_exporter": "python",
  361. "pygments_lexer": "ipython3",
  362. "version": "3.10.9"
  363. }
  364. },
  365. "nbformat": 4,
  366. "nbformat_minor": 4
  367. }