368 lines
107 KiB
Plaintext
368 lines
107 KiB
Plaintext
|
{
|
||
|
"cells": [
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"# Simple classification example: the iris dataset"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 2,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"import matplotlib.pyplot as plt\n",
|
||
|
"import pandas as pd\n",
|
||
|
"from sklearn import datasets\n",
|
||
|
"from sklearn.model_selection import train_test_split\n",
|
||
|
"from sklearn.metrics import classification_report\n",
|
||
|
"from sklearn.metrics import accuracy_score\n",
|
||
|
"from sklearn.metrics import confusion_matrix"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 3,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# import some data to play with\n",
|
||
|
"# columns: Sepal Length, Sepal Width, Petal Length and Petal Width\n",
|
||
|
"iris = datasets.load_iris()\n",
|
||
|
"X = iris.data\n",
|
||
|
"y = iris.target"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 4,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/html": [
|
||
|
"<div>\n",
|
||
|
"<style scoped>\n",
|
||
|
" .dataframe tbody tr th:only-of-type {\n",
|
||
|
" vertical-align: middle;\n",
|
||
|
" }\n",
|
||
|
"\n",
|
||
|
" .dataframe tbody tr th {\n",
|
||
|
" vertical-align: top;\n",
|
||
|
" }\n",
|
||
|
"\n",
|
||
|
" .dataframe thead th {\n",
|
||
|
" text-align: right;\n",
|
||
|
" }\n",
|
||
|
"</style>\n",
|
||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
||
|
" <thead>\n",
|
||
|
" <tr style=\"text-align: right;\">\n",
|
||
|
" <th></th>\n",
|
||
|
" <th>Sepal Length (cm)</th>\n",
|
||
|
" <th>Sepal Width (cm)</th>\n",
|
||
|
" <th>Petal Length (cm)</th>\n",
|
||
|
" <th>Petal Width (cm)</th>\n",
|
||
|
" <th>category</th>\n",
|
||
|
" </tr>\n",
|
||
|
" </thead>\n",
|
||
|
" <tbody>\n",
|
||
|
" <tr>\n",
|
||
|
" <th>0</th>\n",
|
||
|
" <td>5.1</td>\n",
|
||
|
" <td>3.5</td>\n",
|
||
|
" <td>1.4</td>\n",
|
||
|
" <td>0.2</td>\n",
|
||
|
" <td>0</td>\n",
|
||
|
" </tr>\n",
|
||
|
" <tr>\n",
|
||
|
" <th>1</th>\n",
|
||
|
" <td>4.9</td>\n",
|
||
|
" <td>3.0</td>\n",
|
||
|
" <td>1.4</td>\n",
|
||
|
" <td>0.2</td>\n",
|
||
|
" <td>0</td>\n",
|
||
|
" </tr>\n",
|
||
|
" <tr>\n",
|
||
|
" <th>2</th>\n",
|
||
|
" <td>4.7</td>\n",
|
||
|
" <td>3.2</td>\n",
|
||
|
" <td>1.3</td>\n",
|
||
|
" <td>0.2</td>\n",
|
||
|
" <td>0</td>\n",
|
||
|
" </tr>\n",
|
||
|
" <tr>\n",
|
||
|
" <th>3</th>\n",
|
||
|
" <td>4.6</td>\n",
|
||
|
" <td>3.1</td>\n",
|
||
|
" <td>1.5</td>\n",
|
||
|
" <td>0.2</td>\n",
|
||
|
" <td>0</td>\n",
|
||
|
" </tr>\n",
|
||
|
" <tr>\n",
|
||
|
" <th>4</th>\n",
|
||
|
" <td>5.0</td>\n",
|
||
|
" <td>3.6</td>\n",
|
||
|
" <td>1.4</td>\n",
|
||
|
" <td>0.2</td>\n",
|
||
|
" <td>0</td>\n",
|
||
|
" </tr>\n",
|
||
|
" </tbody>\n",
|
||
|
"</table>\n",
|
||
|
"</div>"
|
||
|
],
|
||
|
"text/plain": [
|
||
|
" Sepal Length (cm) Sepal Width (cm) Petal Length (cm) Petal Width (cm) \\\n",
|
||
|
"0 5.1 3.5 1.4 0.2 \n",
|
||
|
"1 4.9 3.0 1.4 0.2 \n",
|
||
|
"2 4.7 3.2 1.3 0.2 \n",
|
||
|
"3 4.6 3.1 1.5 0.2 \n",
|
||
|
"4 5.0 3.6 1.4 0.2 \n",
|
||
|
"\n",
|
||
|
" category \n",
|
||
|
"0 0 \n",
|
||
|
"1 0 \n",
|
||
|
"2 0 \n",
|
||
|
"3 0 \n",
|
||
|
"4 0 "
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 4,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"# just to create a nice table\n",
|
||
|
"df = pd.DataFrame({\"Sepal Length (cm)\": X[:,0], \"Sepal Width (cm)\": X[:,1], \n",
|
||
|
" 'Petal Length (cm)': X[:,2], 'Petal Width (cm)': X[:,3], \n",
|
||
|
" 'category': y})\n",
|
||
|
"df.head()"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 5,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"['setosa', 'versicolor', 'virginica']"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 5,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"list(iris.target_names)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 6,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"# split data into training and test data sets\n",
|
||
|
"x_train, x_test, y_train, y_test = train_test_split(X, y, test_size=0.5, random_state=42)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 7,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"data": {
|
||
|
"text/plain": [
|
||
|
"Text(0, 0.5, 'Petal width')"
|
||
|
]
|
||
|
},
|
||
|
"execution_count": 7,
|
||
|
"metadata": {},
|
||
|
"output_type": "execute_result"
|
||
|
},
|
||
|
{
|
||
|
"data": {
|
||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA04AAAHCCAYAAADYTZkLAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy89olMNAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdd3QUVfvA8e9sS++9kVBD7yBFqiBVQVF6FWygYn/F3rGL72tFpUlTEVABC0Wa9Cq9QwIkJKRs+ia7e39/RKL5pUuSDfh8ztlzZO6de58Z4c48uzP3akophRBCCCGEEEKIEukcHYAQQgghhBBC1HSSOAkhhBBCCCFEGSRxEkIIIYQQQogySOIkhBBCCCGEEGWQxEkIIYQQQgghyiCJkxBCCCGEEEKUQRInIYQQQgghhCiDJE5CCCGEEEIIUQZJnIQQQgghhBCiDJI4CSGEEEIIIUQZDI4O4Io33niDadOmMXXqVGbMmFFsnTlz5jBhwoRC25ycnMjJySl3P3a7nYsXL+Lh4YGmaVcTshBCiApQSpGenk5oaCg6nXxv93dybRJCCMeoyLWpRiROO3fu5LPPPqN58+Zl1vX09OTYsWMFf67oBebixYtERERUOEYhhBCVIzY2lvDwcEeHUaPItUkIIRyrPNcmhydOGRkZjBo1is8//5xXX321zPqaphEcHPyP+/Pw8ADyT46np+c/bkcIIUTFpKWlERERUTAOi7/ItUkIIRyjItcmhydOU6ZMYcCAAfTq1atciVNGRgaRkZHY7XZat27N66+/TpMmTUqsb7FYsFgsBX9OT08H8n+5kouTEEJUP3kUragr50SuTUII4RjluTY59CHzxYsXs2fPHqZPn16u+tHR0cyaNYvvv/+e+fPnY7fb6dSpE+fPny9xn+nTp+Pl5VXwkUchhBBCCCGEEBXlsMQpNjaWqVOnsmDBApydncu1T8eOHRk7diwtW7akW7duLF26lICAAD777LMS95k2bRpms7ngExsbW1mHIIQQQgghhPiXcNijert37yYhIYHWrVsXbLPZbGzcuJEPP/wQi8WCXq8vtQ2j0UirVq04efJkiXWcnJxwcnKqtLiFEEIIIYQQ/z4OS5xuuukmDhw4UGjbhAkTaNiwIf/5z3/KTJogP9E6cOAA/fv3r6owhRBCCCGEEMJxiZOHhwdNmzYttM3NzQ0/P7+C7WPHjiUsLKzgHaiXX36ZDh06UK9ePVJTU3n77bc5d+4ckyZNqvb4hRBCCCGEEP8eDp9VrzQxMTGFFqJKSUnh7rvvJj4+Hh8fH9q0acOWLVto3LixA6MUQgghhBBCXO80pZRydBDVKS0tDS8vL8xms0z5KoQQ1UjG35LJuRFCCMeoyPjr0OnIhRBCCCGEEOJaIImTEEIIIYQQQpShRr/jJIQjxcfHM2fOHI4dO4aHhwd33HEHXbp0KdfK0kIIIYQQVSUnJ4dvv/2WjRs3opSia9euDB06tNxroyql2LhxI9999x3p6elER0czfvx4AObOncvRo0dxd3fnjjvuoGvXrnLv8yd5x0mIYnz88cdMnToV7BqeOm9ysZBhTaNb124s/3453t7ejg5RiGuOjL8lk3MjhCiv3bt3c+utA7h48RItm7qhabD3QCYhIYH88MNK2rZtW+r+qamp3Hbbraxfv4moWs6EBOrZdygHi0Wh04HRqNGyiTNxCTbOxuTQrduNLFv2Az4+PtV0hNVL3nES4ip8//33TJkyhWBrJJ3t/Wht68YN1t60pDPbft/OnXcOdXSIQogqMn36dNq1a4eHhweBgYEMHjyYY8eOlbrPnDlz0DSt0Ke83/oKIURFXLp0iT59ehEelM7hTZHsXh3Krl9DObI5ksjQTPr27U18fHypbQwbegf7923jh69CObktnM0/hvL5u/7Y7XbuHu1B7J5INv+YX/bj/FAO/LGdoXcOqaYjrNkkcRLi/3nl5Vfw0wURTUuMmgkATdPw10JoYGvJmjWr2b17t4OjFEJUhQ0bNjBlyhS2bdvG6tWrycvL4+abbyYzM7PU/Tw9PYmLiyv4nDt3rpoiFkL8m8ycOZPs7Ax+/CqY6Hqmgu0N6pr4YV4QOTkZzJw5s8T9d+7cya+r1zLzHX8G9HIreATvo1lmunVy4X+vB+DjrQfy73363+TGzHf8WbP2N3bu3Fm1B3cNkMRJiL+Ji4tj957dhNijin2eN4BQnA0uLF++vPqDE0JUuZ9//pnx48fTpEkTWrRowZw5c4iJiSnzyxJN0wgODi74BAUFlVrfYrGQlpZW6COEEGVZvnwJt/Vzwd9PX6TMz1fP7f1dWLbs2xL3//777wkMcGJQX7eCbQmXrWzdlcPdo72Kvfe5tY8bQQFOLFu2rHIO4homiZMQf5OVlQWACVOx5TpNh0lzIjs7uzrDEkI4iNlsBsDX17fUehkZGURGRhIREcGgQYM4dOhQqfWnT5+Ol5dXwSciIqLSYhZCXL+ysjLx9y2aNF3h76snOzurlP2z8PHSo9f/lSBlZas/9y0+LdDrNXy89XLvgyROQhQSFhaGp4cnSSQUW56lMkjPM9O0adNqjkwIUd3sdjsPP/wwnTt3LvXffHR0NLNmzeL7779n/vz52O12OnXqxPnz50vcZ9q0aZjN5oJPbGxsVRyCEOI606xZK1ZvtFDc3G5KKVZvyKVp0xal7N+M46eyOBOTV7AtNMiAj7eO1RuKT7jOxuZx7GQWzZo1u/oDuMZJ4iTE3zg7OzNx0kTi9GdJV6mFyuzKzkndAby8vBg6VCaIEOJ6N2XKFA4ePMjixYtLrdexY0fGjh1Ly5Yt6datG0uXLiUgIIDPPvusxH2cnJzw9PQs9BFCiLLcf/9kDh/L5pM55iJln84zc/BoFvfdN7nE/YcOHYqXlyePPp9EXl5+8mUyadw1wpOZ88zsO2gpVD8vT/HYC0l4eXkybNiwyj2Ya5Cs4yTE//Piiy+yds1a9hzeSJAtAh8CsJBNvOEcWWSwbP4yXF1dHR2mEKIKPfDAA6xYsYKNGzcSHh5eoX2NRiOtWrXi5MmTVRSdEOLfqnv37kydOpUHn/6An9Zlc+ct+dORf/tjJitXZ/Dggw9y0003lbi/m5sb8+bNZ8iQ22nT+wJ3j3YnOEhPqtlOTi50HhjLhBGedOvkQvwlG18syODYqTyWLPkONze3Etv9t5BfnIT4fzw9Pdm0eRP/mfYkOX5pHGAbJ3UH6DmgO7///jsDBw50dIhCiCqilOKBBx5g2bJlrFu3jtq1a1e4DZvNxoEDBwgJCamCCIUQ/2aapvH+++8zd+5c4pLqMGHqJcY/dIkLibWZM2cOH3zwQZmL1d5yyy1s2rSZ+o168+gLSQy/J54Va0088sgTPDT1CZb9bGL4PfE8+kISdaN7sXHjJm699dZqOsKaTRbAFaIUdrsds9mMi4uLrMsixFW6FsbfyZMns3DhQr7//nuio6MLtnt5eeHi4gLA2LFjCQsLY/r06QC8/PLLdOjQgXr16pGamsrbb7/N8uXL2b17N40bNy5Xv9fCuRFC1Dzp6ekopf7xuJGTk0N2djZeXl7odPm/p/zb7n0qMv7Ko3pClEKn0123K2ULIYr65JNPgPzHYf5u9uzZjB8/HoCYmJiCGwyAlJQU7r77buLj4/Hx8aFNmzZs2bKl3EmTEEL8Ux4eHle1v7Ozc5HkSO59Sia/OAkhhKgWMv6WTM6NEEI4RkXGX3nHSQghhBBCCCHKIImTEEIIIYQQFWS328nMzCx2TaWq6CshIQGr1VrlfYmSSeIkhBBCCCFEOZ0+fZp7770XT0933N3d8ff34fHHHychIaHS+zpy5Ajt2rXDyclAUFAQTk5GoqOjWbduXaX3Jcomk0MIIYQQQghRDgcOHKB79y44GXN49F436tfxZN9BC19+8V++++4bNm/eSlhYWKX0tWfPHjp3ugGDwcYDd3nRurkTp87m8cmcU/Tp04v58xfJorTVTCaHEEIIUS1k/C2ZnBshaj6lFK1aNQfrKdYuCcHHW19QFnM+jy6D4mjbvg/Lln1fKf1FhIeRa7nE1lURREUYC7anmm3
|
||
|
"text/plain": [
|
||
|
"<Figure size 1000x500 with 2 Axes>"
|
||
|
]
|
||
|
},
|
||
|
"metadata": {},
|
||
|
"output_type": "display_data"
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"# plot with color code\n",
|
||
|
"plt.subplots(1, 2, figsize=(10, 5))\n",
|
||
|
"\n",
|
||
|
"plt.subplot(1, 2, 1)\n",
|
||
|
"plt.scatter(X[:, 0], X[:, 1], c=y, edgecolor='k')\n",
|
||
|
"plt.xlabel('Sepal length')\n",
|
||
|
"plt.ylabel('Sepal width')\n",
|
||
|
"\n",
|
||
|
"plt.subplot(1, 2, 2)\n",
|
||
|
"plt.scatter(X[:, 2], X[:, 3], c=y, edgecolor='k')\n",
|
||
|
"plt.xlabel('Petal length')\n",
|
||
|
"plt.ylabel('Petal width')"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"## Softmax regression"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 8,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"from sklearn.linear_model import LogisticRegression\n",
|
||
|
"log_reg = LogisticRegression(multi_class='multinomial', penalty='none')\n",
|
||
|
"log_reg.fit(x_train, y_train);"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"## k-nearest neighbor"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 9,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"from sklearn.neighbors import KNeighborsClassifier\n",
|
||
|
"kn_neigh = KNeighborsClassifier(n_neighbors=5)\n",
|
||
|
"kn_neigh.fit(x_train, y_train);"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"## Fisher linear discriminant"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 10,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"from sklearn.discriminant_analysis import LinearDiscriminantAnalysis\n",
|
||
|
"fisher_ld = LinearDiscriminantAnalysis()\n",
|
||
|
"fisher_ld.fit(x_train, y_train);"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"## Classification accuracy"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 11,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
"LogisticRegression\n",
|
||
|
"accuracy: 0.96\n",
|
||
|
"[[29 0 0]\n",
|
||
|
" [ 0 23 0]\n",
|
||
|
" [ 0 3 20]] \n",
|
||
|
"\n",
|
||
|
"KNeighborsClassifier\n",
|
||
|
"accuracy: 0.95\n",
|
||
|
"[[29 0 0]\n",
|
||
|
" [ 0 23 0]\n",
|
||
|
" [ 0 4 19]] \n",
|
||
|
"\n",
|
||
|
"LinearDiscriminantAnalysis\n",
|
||
|
"accuracy: 0.99\n",
|
||
|
"[[29 0 0]\n",
|
||
|
" [ 0 23 0]\n",
|
||
|
" [ 0 1 22]] \n",
|
||
|
"\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"for clf in [log_reg, kn_neigh, fisher_ld]:\n",
|
||
|
" y_pred = clf.predict(x_test)\n",
|
||
|
" acc = accuracy_score(y_test, y_pred)\n",
|
||
|
" print(type(clf).__name__)\n",
|
||
|
" print(f\"accuracy: {acc:0.2f}\")\n",
|
||
|
" \n",
|
||
|
" # confusion matrix: columns: true class, row: predicted class\n",
|
||
|
" print(confusion_matrix(y_test, y_pred),\"\\n\")"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": 12,
|
||
|
"metadata": {},
|
||
|
"outputs": [
|
||
|
{
|
||
|
"name": "stdout",
|
||
|
"output_type": "stream",
|
||
|
"text": [
|
||
|
" precision recall f1-score support\n",
|
||
|
"\n",
|
||
|
" 0 1.00 1.00 1.00 29\n",
|
||
|
" 1 0.88 1.00 0.94 23\n",
|
||
|
" 2 1.00 0.87 0.93 23\n",
|
||
|
"\n",
|
||
|
" accuracy 0.96 75\n",
|
||
|
" macro avg 0.96 0.96 0.96 75\n",
|
||
|
"weighted avg 0.96 0.96 0.96 75\n",
|
||
|
"\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"source": [
|
||
|
"y_pred = log_reg.predict(x_test)\n",
|
||
|
"print(classification_report(y_test, y_pred))"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"metadata": {
|
||
|
"kernelspec": {
|
||
|
"display_name": "Python 3",
|
||
|
"language": "python",
|
||
|
"name": "python3"
|
||
|
},
|
||
|
"language_info": {
|
||
|
"codemirror_mode": {
|
||
|
"name": "ipython",
|
||
|
"version": 3
|
||
|
},
|
||
|
"file_extension": ".py",
|
||
|
"mimetype": "text/x-python",
|
||
|
"name": "python",
|
||
|
"nbconvert_exporter": "python",
|
||
|
"pygments_lexer": "ipython3",
|
||
|
"version": "3.10.9"
|
||
|
}
|
||
|
},
|
||
|
"nbformat": 4,
|
||
|
"nbformat_minor": 4
|
||
|
}
|