384 lines
107 KiB
Plaintext
384 lines
107 KiB
Plaintext
|
{
|
||
|
"cells": [
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"# Simple classification example: the iris dataset"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 1,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"import matplotlib.pyplot as plt\n",
|
||
|
"import pandas as pd\n",
|
||
|
"from sklearn import datasets\n",
|
||
|
"from sklearn.model_selection import train_test_split\n",
|
||
|
"from sklearn.metrics import classification_report\n",
|
||
|
"from sklearn.metrics import accuracy_score\n",
|
||
|
"from sklearn.metrics import confusion_matrix"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 2,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# import some data to play with\n",
|
||
|
"# columns: Sepal Length, Sepal Width, Petal Length and Petal Width\n",
|
||
|
"iris = datasets.load_iris()\n",
|
||
|
"X = iris.data\n",
|
||
|
"y = iris.target"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 3,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/html": [
|
||
|
"<div>\n",
|
||
|
"<style scoped>\n",
|
||
|
" .dataframe tbody tr th:only-of-type {\n",
|
||
|
" vertical-align: middle;\n",
|
||
|
" }\n",
|
||
|
"\n",
|
||
|
" .dataframe tbody tr th {\n",
|
||
|
" vertical-align: top;\n",
|
||
|
" }\n",
|
||
|
"\n",
|
||
|
" .dataframe thead th {\n",
|
||
|
" text-align: right;\n",
|
||
|
" }\n",
|
||
|
"</style>\n",
|
||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
||
|
" <thead>\n",
|
||
|
" <tr style=\"text-align: right;\">\n",
|
||
|
" <th></th>\n",
|
||
|
" <th>Sepal Length (cm)</th>\n",
|
||
|
" <th>Sepal Width (cm)</th>\n",
|
||
|
" <th>Petal Length (cm)</th>\n",
|
||
|
" <th>Petal Width (cm)</th>\n",
|
||
|
" <th>category</th>\n",
|
||
|
" </tr>\n",
|
||
|
" </thead>\n",
|
||
|
" <tbody>\n",
|
||
|
" <tr>\n",
|
||
|
" <th>0</th>\n",
|
||
|
" <td>5.1</td>\n",
|
||
|
" <td>3.5</td>\n",
|
||
|
" <td>1.4</td>\n",
|
||
|
" <td>0.2</td>\n",
|
||
|
" <td>0</td>\n",
|
||
|
" </tr>\n",
|
||
|
" <tr>\n",
|
||
|
" <th>1</th>\n",
|
||
|
" <td>4.9</td>\n",
|
||
|
" <td>3.0</td>\n",
|
||
|
" <td>1.4</td>\n",
|
||
|
" <td>0.2</td>\n",
|
||
|
" <td>0</td>\n",
|
||
|
" </tr>\n",
|
||
|
" <tr>\n",
|
||
|
" <th>2</th>\n",
|
||
|
" <td>4.7</td>\n",
|
||
|
" <td>3.2</td>\n",
|
||
|
" <td>1.3</td>\n",
|
||
|
" <td>0.2</td>\n",
|
||
|
" <td>0</td>\n",
|
||
|
" </tr>\n",
|
||
|
" <tr>\n",
|
||
|
" <th>3</th>\n",
|
||
|
" <td>4.6</td>\n",
|
||
|
" <td>3.1</td>\n",
|
||
|
" <td>1.5</td>\n",
|
||
|
" <td>0.2</td>\n",
|
||
|
" <td>0</td>\n",
|
||
|
" </tr>\n",
|
||
|
" <tr>\n",
|
||
|
" <th>4</th>\n",
|
||
|
" <td>5.0</td>\n",
|
||
|
" <td>3.6</td>\n",
|
||
|
" <td>1.4</td>\n",
|
||
|
" <td>0.2</td>\n",
|
||
|
" <td>0</td>\n",
|
||
|
" </tr>\n",
|
||
|
" </tbody>\n",
|
||
|
"</table>\n",
|
||
|
"</div>"
|
||
|
],
|
||
|
"text/plain": [
|
||
|
" Sepal Length (cm) Sepal Width (cm) Petal Length (cm) Petal Width (cm) \\\n",
|
||
|
"0 5.1 3.5 1.4 0.2 \n",
|
||
|
"1 4.9 3.0 1.4 0.2 \n",
|
||
|
"2 4.7 3.2 1.3 0.2 \n",
|
||
|
"3 4.6 3.1 1.5 0.2 \n",
|
||
|
"4 5.0 3.6 1.4 0.2 \n",
|
||
|
"\n",
|
||
|
" category \n",
|
||
|
"0 0 \n",
|
||
|
"1 0 \n",
|
||
|
"2 0 \n",
|
||
|
"3 0 \n",
|
||
|
"4 0 "
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 3,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"# just to create a nice table\n",
|
||
|
"df = pd.DataFrame({\"Sepal Length (cm)\": X[:,0], \"Sepal Width (cm)\": X[:,1], \n",
|
||
|
" 'Petal Length (cm)': X[:,2], 'Petal Width (cm)': X[:,3], \n",
|
||
|
" 'category': y})\n",
|
||
|
"df.head()"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 4,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"['setosa', 'versicolor', 'virginica']"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 4,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"list(iris.target_names)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 5,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# split data into training and test data sets\n",
|
||
|
"x_train, x_test, y_train, y_test = train_test_split(X, y, test_size=0.5, random_state=42)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 6,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"Text(0, 0.5, 'Petal width')"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 6,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
},
|
||
|
{
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA04AAAHCCAYAAADYTZkLAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/P9b71AAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzddXgUVxcH4N+sJrtJNu6KhRAcgltwLS2uxSmUKh8tpQZVKrSFChQpbi1aXEuCu0MgWCBKPBvdZHfP90dK2u1ujMgmcN7n2afN3Dv3nhmSO3N2Zu4IRERgjDHGGGOMMVYokbkDYIwxxhhjjLGqjhMnxhhjjDHGGCsGJ06MMcYYY4wxVgxOnBhjjDHGGGOsGJw4McYYY4wxxlgxOHFijDHGGGOMsWJw4sQYY4wxxhhjxeDEiTHGGGOMMcaKwYkTY4wxxhhjjBWDEyfGGGOMMcYYK4bE3AE8MXfuXLz//vt48803MX/+fJN1QkJCEBwcbLQ8LCwMdevWLVE/er0eMTExsLa2hiAIZQmZMcZYKRER0tPT4e7uDpGIv7t7go9NjDFmHqU5LlWJxOncuXNYsmQJGjZsWKL6t2/fho2NTcHPTk5OJe4rJiYGXl5epY6RMcZY+YmMjISnp6e5w6gy+NjEGGPmVZLjktkTp4yMDIwcORJLly7F559/XqJ1nJ2dYWtr+1T9WVtbA8jfOf9OvhhjjFU8tVoNLy+vgrGY5eNjE2OMmUdpjktmT5ymTZuGPn36oGvXriVOnJo0aYKcnBzUq1cPH374ocnb957QaDTQaDQFP6enpwMAbGxs+ODEGGNmwrejGXqyP/jYxBhj5lGS45JZE6eNGzfi4sWLOHfuXInqu7m5YcmSJWjWrBk0Gg3WrFmDLl26ICQkBB06dDC5zty5c/HJJ5+UZ9iMMcYYY4yx54zZEqfIyEi8+eabOHDgACwsLEq0jr+/P/z9/Qt+bt26NSIjIzFv3rxCE6dZs2Zh+vTpBT8/uRzHGGOMMcYYYyVltimNLly4gPj4eDRr1gwSiQQSiQShoaH48ccfIZFIoNPpStROq1atcOfOnULL5XJ5wa0PfAsEY4wxxhhj7GmY7YpTly5dcO3aNYNl48aNQ926dTFz5kyIxeIStXPp0iW4ublVRIiMMcYYY4wxBsCMiZO1tTXq169vsEypVMLBwaFg+axZsxAdHY3Vq1cDAObPnw9fX18EBgYiNzcXa9euxZYtW7Bly5ZKj58xxhhjjDH2/DD7rHpFiY2NxaNHjwp+zs3NxYwZMxAdHQ1LS0sEBgZi9+7d6N27txmjZIwxxhhjjD3rBCIicwdRmdRqNVQqFdLS0vh5J8YYq2Q8BpvG+4UxxsyjNOOv2SaHYIwxxhhjjLHqghMnxhhjjDHGGCtGlX7GiTFzysrKwsaNG3Hy5EmIRCJ07twZAwYMgEwmM3dojDHGGHuOERGOHDmCrVu3IjMzE/Xq1cOYMWPg7Oxc4jbCw8OxatUqREdHw8XFBaNGjULNmjXx+++/48SJExAEoeDcRy6XV+DWVB/8jBNjJpw4cQL9X+iP5ORk2EocQNAjVZsMD3cP7N23Fw0aNDB3iIxVSzwGm8b7hTFWUomJiXjxxX44ceI0avhYwsVJhIvXckAkwqJFv2L8+PFFrq/X6/G///0P8+fPh72dDHVryXA3IhfxCblQKOTIytKgaUMldDrClRtZ8PJyx86de9CoUaNK2sLKxc84MVYGjx49Qs8ePUGpYrRGDzTTdUJzXWe0QjdkPM5Gl85dkZKSYu4wGWMVYO7cuQgKCoK1tTWcnZ3x4osv4vbt20WuExISAkEQjD63bt2qpKgZY88LIsKAAf0Rfvsi9m10R/gpDxzf6Y7Iiz4YPcgSEydOxIEDB4ps4+uvv8aCBfPx7WxHRF70wrEdbjix0w0KSwGN6gm4fdIH5/a74+IhD1wL8YajbSp69OiKpKSkStrKqosTJ8b+45dffkFejhYN9a2hEKwKllsJKjTUtUZSUhJWrFhhxggZYxUlNDQU06ZNw+nTp3Hw4EFotVp0794dmZmZxa57+/ZtxMbGFnxq165dCREzxp4nx48fx7FjJ7FiviO6dVRCEAQAgIO9GL9+64zWzRX46qsvCl0/Ozsb8+Z9jWnjVZg+xQ4WFvmpwG/r1JBKBexe545afv88klDPX46da1yQmprC5z7gxIkxI5s3bYGjzh0SQWpUJhcs4QhXbN7ML11m7Fm0b98+jB07FoGBgWjUqBFWrFiBR48e4cKFC8Wu6+zsDFdX14KPWCwutK5Go4FarTb4MMZYcbZv3w4vDwv0CFYYlYlEAsaPUOLIkaNITU01uf6JEyeQnJyGSaNUBsv/3JeJIS9YQWVjPG65uUjQt5sltm7dVC7bUJ1x4sTYf2RlZUGGwieAkJIMWZlZlRgRY8xc0tLSAAD29vbF1m3SpAnc3NzQpUsXHDlypMi6c+fOhUqlKvh4eXmVS7yMsWdbVlYWHOzEEIkEk+WO9vmJT05OTqHr/7tewfJsvdGy/7abnc3nPpw4MfYfjRo1RKo40WQZESFNkoRGjRtWclSMscpGRJg+fTratWuH+vXrF1rPzc0NS5YswZYtW7B161b4+/ujS5cuOHr0aKHrzJo1C2lpaQWfyMjIitgExtgzpkGDBrgWloWYOK3J8oOhWXB2doCjo6PJ8sDAQADAgRDDJKh+gBwHQ00nRno94dBRDRo0aPz0gT8jOHFi7D9enfYqUnSJiKGHRmUPEY4MrRqvvvqqGSJjjFWm1157DVevXsWGDRuKrOfv749JkyahadOmaN26NRYuXIg+ffpg3rx5ha4jl8thY2Nj8GGMseKMHDkSlpaW+N+cRGi1hhNjn7+cgxUbMzFp0hRIJKbfOFSzZk10794Vn32fhrj4f5KvV15W4fwVDVZsTDNaZ/6SVNyLyMGUKVPLd2OqIX6PE2P/0a9fP0yYMAG//fYbEhEDJ/IAQY94URQSKQ7vv/8+WrZsae4wGWMV6PXXX8eOHTtw9OhReHp6lnr9Vq1aYe3atRUQGWPseaZSqbBixSoMGzYUt+/GYMIIJZydxPjrWBbWbM5Eo0ZNMWvWrCLbWLRoMdq3b4MmXaMxaaQVGgbKcCs8F0qFCBPfjseOfVkY2FcJnR74fXsm9h/JwMyZM9GmTZtK2sqqixMnxv5DEAQsWbIErVq1wvfffY8bt84CAJo2aoqf3v0Bw4YNM3OEjLGKQkR4/fXXsW3bNoSEhMDPz++p2rl06RLc3NzKOTrGGAMGDRqE0NCj+OqrL/Hmh3tBRHB3d8HMmdMxY8YMKJXKItevUaMGzpw5j7lz52LBspXIyEiGhYUMw4a9jDp16mDDhrUY8/pNAEDz5k2wfv07fO7zN34BLmNFICKkp6dDJBLBysqq+BUYY0Wq6mPwq6++ivXr1+PPP/+Ev79/wXKVSgVLS0sA+c8nRUdHY/Xq1QCA+fPnw9fXF4GBgcjNzcXatWvx1VdfYcuWLRgwYECJ+q3q+4UxVjXl5OQgOzsbKpUKIlHpn8DRarVQq9WwtraGVJo/m/CTcx9BEGBtbV3eIVc5pRl/+YoTY0UQBIFPYhh7jixatAgA0KlTJ4PlK1aswNixYwEAsbGxePToUUFZbm4uZsyYgejoaFhaWiIwMBC7d+9G7969KytsxthzysLCAhYWFk+9vkQiMZo1lM99CsdXnBhjjFUaHoNN4/3CGGPmUZrxl2fVY4wxxhhjjLFicOLEGGOMMcZYKRERMjMzodPpKrwvvV6P5OTkghfYMvPgxIkxxhhjjLESysjIwCeffAIvL3dYWVnB0tICw4cPw5UrV8q9r6ysLAwYMABWSgs4ODhAqVTCxcUR8+fPL/e+WPF4cgjGGGOMMcZKQK1Wo3Pnjrh58xpGD1KiXUsXxMRpsWzdn2jVaht27dqDLl26lEtfOTk5qFnTD/Hx8RjUzwq9OttDna7Hb+vT8Pbbb+P69etYtmxZufTFSoYTJ8YYY4wxxkrg448/xp3wGzix0wONAuUFy1+fYIsXxz7GiBFD8ehRNORyeRGtlMzo0aMRHx+PHavd0av
|
||
|
"text/plain": [
|
||
|
"<Figure size 1000x500 with 2 Axes>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"# plot with color code\n",
|
||
|
"plt.subplots(1, 2, figsize=(10, 5))\n",
|
||
|
"\n",
|
||
|
"plt.subplot(1, 2, 1)\n",
|
||
|
"plt.scatter(X[:, 0], X[:, 1], c=y, edgecolor='k')\n",
|
||
|
"plt.xlabel('Sepal length')\n",
|
||
|
"plt.ylabel('Sepal width')\n",
|
||
|
"\n",
|
||
|
"plt.subplot(1, 2, 2)\n",
|
||
|
"plt.scatter(X[:, 2], X[:, 3], c=y, edgecolor='k')\n",
|
||
|
"plt.xlabel('Petal length')\n",
|
||
|
"plt.ylabel('Petal width')"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"## Softmax regression"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 7,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stderr",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"/local/home/marks/anaconda3/envs/myML/lib/python3.8/site-packages/sklearn/linear_model/_logistic.py:1173: FutureWarning: `penalty='none'`has been deprecated in 1.2 and will be removed in 1.4. To keep the past behaviour, set `penalty=None`.\n",
|
||
|
" warnings.warn(\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"from sklearn.linear_model import LogisticRegression\n",
|
||
|
"log_reg = LogisticRegression(multi_class='multinomial', penalty='none')\n",
|
||
|
"log_reg.fit(x_train, y_train);"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"## k-nearest neighbor"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 8,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"from sklearn.neighbors import KNeighborsClassifier\n",
|
||
|
"kn_neigh = KNeighborsClassifier(n_neighbors=5)\n",
|
||
|
"kn_neigh.fit(x_train, y_train);"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"## Fisher linear discriminant"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 9,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"from sklearn.discriminant_analysis import LinearDiscriminantAnalysis\n",
|
||
|
"fisher_ld = LinearDiscriminantAnalysis()\n",
|
||
|
"fisher_ld.fit(x_train, y_train);"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"## Classification accuracy"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 10,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"LogisticRegression\n",
|
||
|
"accuracy: 0.96\n",
|
||
|
"[[29 0 0]\n",
|
||
|
" [ 0 23 0]\n",
|
||
|
" [ 0 3 20]] \n",
|
||
|
"\n",
|
||
|
"KNeighborsClassifier\n",
|
||
|
"accuracy: 0.95\n",
|
||
|
"[[29 0 0]\n",
|
||
|
" [ 0 23 0]\n",
|
||
|
" [ 0 4 19]] \n",
|
||
|
"\n",
|
||
|
"LinearDiscriminantAnalysis\n",
|
||
|
"accuracy: 0.99\n",
|
||
|
"[[29 0 0]\n",
|
||
|
" [ 0 23 0]\n",
|
||
|
" [ 0 1 22]] \n",
|
||
|
"\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"for clf in [log_reg, kn_neigh, fisher_ld]:\n",
|
||
|
" y_pred = clf.predict(x_test)\n",
|
||
|
" acc = accuracy_score(y_test, y_pred)\n",
|
||
|
" print(type(clf).__name__)\n",
|
||
|
" print(f\"accuracy: {acc:0.2f}\")\n",
|
||
|
" \n",
|
||
|
" # confusion matrix: columns: true class, row: predicted class\n",
|
||
|
" print(confusion_matrix(y_test, y_pred),\"\\n\")"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 11,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
" precision recall f1-score support\n",
|
||
|
"\n",
|
||
|
" 0 1.00 1.00 1.00 29\n",
|
||
|
" 1 0.88 1.00 0.94 23\n",
|
||
|
" 2 1.00 0.87 0.93 23\n",
|
||
|
"\n",
|
||
|
" accuracy 0.96 75\n",
|
||
|
" macro avg 0.96 0.96 0.96 75\n",
|
||
|
"weighted avg 0.96 0.96 0.96 75\n",
|
||
|
"\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"y_pred = log_reg.predict(x_test)\n",
|
||
|
"print(classification_report(y_test, y_pred))"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": []
|
||
|
}
|
||
|
],
|
||
|
"metadata": {
|
||
|
"kernelspec": {
|
||
|
"display_name": "Python 3 (ipykernel)",
|
||
|
"language": "python",
|
||
|
"name": "python3"
|
||
|
},
|
||
|
"language_info": {
|
||
|
"codemirror_mode": {
|
||
|
"name": "ipython",
|
||
|
"version": 3
|
||
|
},
|
||
|
"file_extension": ".py",
|
||
|
"mimetype": "text/x-python",
|
||
|
"name": "python",
|
||
|
"nbconvert_exporter": "python",
|
||
|
"pygments_lexer": "ipython3",
|
||
|
"version": "3.8.16"
|
||
|
}
|
||
|
},
|
||
|
"nbformat": 4,
|
||
|
"nbformat_minor": 4
|
||
|
}
|