ML-Kurs-SS2023/notebooks/ml_basics_iris_softmax_regression.ipynb

368 lines
107 KiB
Plaintext
Raw Normal View History

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Simple classification example: the iris dataset"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"import pandas as pd\n",
"from sklearn import datasets\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.metrics import classification_report\n",
"from sklearn.metrics import accuracy_score\n",
"from sklearn.metrics import confusion_matrix"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# import some data to play with\n",
"# columns: Sepal Length, Sepal Width, Petal Length and Petal Width\n",
"iris = datasets.load_iris()\n",
"X = iris.data\n",
"y = iris.target"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Sepal Length (cm)</th>\n",
" <th>Sepal Width (cm)</th>\n",
" <th>Petal Length (cm)</th>\n",
" <th>Petal Width (cm)</th>\n",
" <th>category</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>5.1</td>\n",
" <td>3.5</td>\n",
" <td>1.4</td>\n",
" <td>0.2</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>4.9</td>\n",
" <td>3.0</td>\n",
" <td>1.4</td>\n",
" <td>0.2</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>4.7</td>\n",
" <td>3.2</td>\n",
" <td>1.3</td>\n",
" <td>0.2</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>4.6</td>\n",
" <td>3.1</td>\n",
" <td>1.5</td>\n",
" <td>0.2</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>5.0</td>\n",
" <td>3.6</td>\n",
" <td>1.4</td>\n",
" <td>0.2</td>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Sepal Length (cm) Sepal Width (cm) Petal Length (cm) Petal Width (cm) \\\n",
"0 5.1 3.5 1.4 0.2 \n",
"1 4.9 3.0 1.4 0.2 \n",
"2 4.7 3.2 1.3 0.2 \n",
"3 4.6 3.1 1.5 0.2 \n",
"4 5.0 3.6 1.4 0.2 \n",
"\n",
" category \n",
"0 0 \n",
"1 0 \n",
"2 0 \n",
"3 0 \n",
"4 0 "
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# just to create a nice table\n",
"df = pd.DataFrame({\"Sepal Length (cm)\": X[:,0], \"Sepal Width (cm)\": X[:,1], \n",
" 'Petal Length (cm)': X[:,2], 'Petal Width (cm)': X[:,3], \n",
" 'category': y})\n",
"df.head()"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['setosa', 'versicolor', 'virginica']"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"list(iris.target_names)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"# split data into training and test data sets\n",
"x_train, x_test, y_train, y_test = train_test_split(X, y, test_size=0.5, random_state=42)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Text(0, 0.5, 'Petal width')"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA04AAAHCCAYAAADYTZkLAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy89olMNAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdd3QUVfvA8e9sS++9kVBD7yBFqiBVQVF6FWygYn/F3rGL72tFpUlTEVABC0Wa9Cq9QwIkJKRs+ia7e39/RKL5pUuSDfh8ztlzZO6de58Z4c48uzP3akophRBCCCGEEEKIEukcHYAQQgghhBBC1HSSOAkhhBBCCCFEGSRxEkIIIYQQQogySOIkhBBCCCGEEGWQxEkIIYQQQgghyiCJkxBCCCGEEEKUQRInIYQQQgghhCiDJE5CCCGEEEIIUQZJnIQQQgghhBCiDJI4CSGEEEIIIUQZDI4O4Io33niDadOmMXXqVGbMmFFsnTlz5jBhwoRC25ycnMjJySl3P3a7nYsXL+Lh4YGmaVcTshBCiApQSpGenk5oaCg6nXxv93dybRJCCMeoyLWpRiROO3fu5LPPPqN58+Zl1vX09OTYsWMFf67oBebixYtERERUOEYhhBCVIzY2lvDwcEeHUaPItUkIIRyrPNcmhydOGRkZjBo1is8//5xXX321zPqaphEcHPyP+/Pw8ADyT46np+c/bkcIIUTFpKWlERERUTAOi7/ItUkIIRyjItcmhydOU6ZMYcCAAfTq1atciVNGRgaRkZHY7XZat27N66+/TpMmTUqsb7FYsFgsBX9OT08H8n+5kouTEEJUP3kUragr50SuTUII4RjluTY59CHzxYsXs2fPHqZPn16u+tHR0cyaNYvvv/+e+fPnY7fb6dSpE+fPny9xn+nTp+Pl5VXwkUchhBBCCCGEEBXlsMQpNjaWqVOnsmDBApydncu1T8eOHRk7diwtW7akW7duLF26lICAAD777LMS95k2bRpms7ngExsbW1mHIIQQQgghhPiXcNijert37yYhIYHWrVsXbLPZbGzcuJEPP/wQi8WCXq8vtQ2j0UirVq04efJkiXWcnJxwcnKqtLiFEEIIIYQQ/z4OS5xuuukmDhw4UGjbhAkTaNiwIf/5z3/KTJogP9E6cOAA/fv3r6owhRBCCCGEEMJxiZOHhwdNmzYttM3NzQ0/P7+C7WPHjiUsLKzgHaiXX36ZDh06UK9ePVJTU3n77bc5d+4ckyZNqvb4hRBCCCGEEP8eDp9VrzQxMTGFFqJKSUnh7rvvJj4+Hh8fH9q0acOWLVto3LixA6MUQgghhBBCXO80pZRydBDVKS0tDS8vL8xms0z5KoQQ1UjG35LJuRFCCMeoyPjr0OnIhRBCCCGEEOJaIImTEEIIIYQQQpShRr/jJIQjxcfHM2fOHI4dO4aHhwd33HEHXbp0KdfK0kIIIYQQVSUnJ4dvv/2WjRs3opSia9euDB06tNxroyql2LhxI9999x3p6elER0czfvx4AObOncvRo0dxd3fnjjvuoGvXrnLv8yd5x0mIYnz88cdMnToV7BqeOm9ysZBhTaNb124s/3453t7ejg5RiGuOjL8lk3MjhCiv3bt3c+utA7h48RItm7qhabD3QCYhIYH88MNK2rZtW+r+qamp3Hbbraxfv4moWs6EBOrZdygHi0Wh04HRqNGyiTNxCTbOxuTQrduNLFv2Az4+PtV0hNVL3nES4ip8//33TJkyhWBrJJ3t/Wht68YN1t60pDPbft/OnXcOdXSIQogqMn36dNq1a4eHhweBgYEMHjyYY8eOlbrPnDlz0DSt0Ke83/oKIURFXLp0iT59ehEelM7hTZHsXh3Krl9DObI5ksjQTPr27U18fHypbQwbegf7923jh69CObktnM0/hvL5u/7Y7XbuHu1B7J5INv+YX/bj/FAO/LGdoXcOqaYjrNkkcRLi/3nl5Vfw0wURTUuMmgkATdPw10JoYGvJmjWr2b17t4OjFEJUhQ0bNjBlyhS2bdvG6tWrycvL4+abbyYzM7PU/Tw9PYmLiyv4nDt3rpoiFkL8m8ycOZPs7Ax+/CqY6Hqmgu0N6pr4YV4QOTkZzJw5s8T9d+7cya+r1zLzHX8G9HIreATvo1lmunVy4X+vB+DjrQfy73363+TGzHf8WbP2N3bu3Fm1B3cNkMRJiL+Ji4tj957dhNijin2eN4BQnA0uLF++vPqDE0JUuZ9//pnx48fTpEkTWrRowZw5c4iJiSnzyxJN0wgODi74BAUFlVrfYrGQlpZW6COEEGVZvnwJt/Vzwd9PX6TMz1fP7f1dWLbs2xL3//777wkMcGJQX7eCbQmXrWzdlcPdo72Kvfe5tY8bQQFOLFu2rHIO4homiZMQf5OVlQWACVOx5TpNh0lzIjs7uzrDEkI4iNlsBsDX17fUehkZGURGRhIREcGgQYM4dOhQqfWnT5+Ol5dXwSciIqLSYhZCXL+ysjLx9y2aNF3h76snOzurlP2z8PHSo9f/lSBlZas/9y0+LdDrNXy89XLvgyROQhQSFhaGp4cnSSQUW56lMkjPM9O0adNqjkwIUd3sdjsPP/wwnTt3LvXffHR0NLNmzeL7779n/vz52O12OnXqxPnz50vcZ9q0aZjN5oJPbGxsVRyCEOI606xZK1ZvtFDc3G5KKVZvyKVp0xal7N+M46eyOBOTV7AtNMiAj7eO1RuKT7jOxuZx7GQWzZo1u/oDuMZJ4iTE3zg7OzNx0kTi9GdJV6mFyuzKzkndAby8vBg6VCaIEOJ6N2XKFA4ePMjixYtLrdexY0fGjh1Ly5Yt6datG0uXLiUgIIDPPvusxH2cnJzw9PQs9BFCiLLcf/9kDh/L5pM55iJln84zc/BoFvfdN7nE/YcOHYqXlyePPp9EXl5+8mUyadw1wpOZ88zsO2gpVD8vT/HYC0l4eXkybNiwyj2Ya5Cs4yTE//Piiy+yds1a9hzeSJAtAh8CsJBNvOEcWWSwbP4yXF1dHR2mEKIKPfDAA6xYsYKNGzcSHh5eoX2NRiOtWrXi5MmTVRSdEOLfqnv37kydOpUHn/6An9Zlc+ct+dORf/tjJitXZ/Dggw9y0003lbi/m5sb8+bNZ8iQ22nT+wJ3j3YnOEhPqtlOTi50HhjLhBGedOvkQvwlG18syODYqTyWLPkONze3Etv9t5BfnIT4fzw9Pdm0eRP/mfYkOX5pHGAbJ3UH6DmgO7///jsDBw50dIhCiCqilOKBBx5g2bJlrFu3jtq1a1e4DZvNxoEDBwgJCamCCIUQ/2aapvH+++8zd+5c4pLqMGHqJcY/dIkLibWZM2cOH3zwQZmL1d5yyy1s2rSZ+o168+gLSQy/J54Va0088sgTPDT1CZb9bGL4PfE8+kISdaN7sXHjJm699dZqOsKaTRbAFaIUdrsds9mMi4uLrMsixFW6FsbfyZMns3DhQr7//nuio6MLtnt5eeHi4gLA2LFjCQsLY/r06QC8/PLLdOjQgXr16pGamsrbb7/N8uXL2b17N40bNy5Xv9fCuRFC1Dzp6ekopf7xuJGTk0N2djZeXl7odPm/p/zb7n0qMv7Ko3pClEKn0123K2ULIYr65JNPgPzHYf5u9uzZjB8/HoCYmJiCGwyAlJQU7r77buLj4/Hx8aFNmzZs2bKl3EmTEEL8Ux4eHle1v7Ozc5HkSO59Sia/OAkhhKgWMv6WTM6NEEI4RkXGX3nHSQghhBBCCCHKIImTEEIIIYQQFWS328nMzCx2TaWq6CshIQGr1VrlfYmSSeIkhBBCCCFEOZ0+fZp7770XT0933N3d8ff34fHHHychIaHS+zpy5Ajt2rXDyclAUFAQTk5GoqOjWbduXaX3Jcomk0MIIYQQQghRDgcOHKB79y44GXN49F436tfxZN9BC19+8V++++4bNm/eSlhYWKX0tWfPHjp3ugGDwcYDd3nRurkTp87m8cmcU/Tp04v58xfJorTVTCaHEEIIUS1k/C2ZnBshaj6lFK1aNQfrKdYuCcHHW19QFnM+jy6D4mjbvg/Lln1fKf1FhIeRa7nE1lURREUYC7anmm3
"text/plain": [
"<Figure size 1000x500 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# plot with color code\n",
"plt.subplots(1, 2, figsize=(10, 5))\n",
"\n",
"plt.subplot(1, 2, 1)\n",
"plt.scatter(X[:, 0], X[:, 1], c=y, edgecolor='k')\n",
"plt.xlabel('Sepal length')\n",
"plt.ylabel('Sepal width')\n",
"\n",
"plt.subplot(1, 2, 2)\n",
"plt.scatter(X[:, 2], X[:, 3], c=y, edgecolor='k')\n",
"plt.xlabel('Petal length')\n",
"plt.ylabel('Petal width')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Softmax regression"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.linear_model import LogisticRegression\n",
"log_reg = LogisticRegression(multi_class='multinomial', penalty='none')\n",
"log_reg.fit(x_train, y_train);"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## k-nearest neighbor"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.neighbors import KNeighborsClassifier\n",
"kn_neigh = KNeighborsClassifier(n_neighbors=5)\n",
"kn_neigh.fit(x_train, y_train);"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Fisher linear discriminant"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.discriminant_analysis import LinearDiscriminantAnalysis\n",
"fisher_ld = LinearDiscriminantAnalysis()\n",
"fisher_ld.fit(x_train, y_train);"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Classification accuracy"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"LogisticRegression\n",
"accuracy: 0.96\n",
"[[29 0 0]\n",
" [ 0 23 0]\n",
" [ 0 3 20]] \n",
"\n",
"KNeighborsClassifier\n",
"accuracy: 0.95\n",
"[[29 0 0]\n",
" [ 0 23 0]\n",
" [ 0 4 19]] \n",
"\n",
"LinearDiscriminantAnalysis\n",
"accuracy: 0.99\n",
"[[29 0 0]\n",
" [ 0 23 0]\n",
" [ 0 1 22]] \n",
"\n"
]
}
],
"source": [
"for clf in [log_reg, kn_neigh, fisher_ld]:\n",
" y_pred = clf.predict(x_test)\n",
" acc = accuracy_score(y_test, y_pred)\n",
" print(type(clf).__name__)\n",
" print(f\"accuracy: {acc:0.2f}\")\n",
" \n",
" # confusion matrix: columns: true class, row: predicted class\n",
" print(confusion_matrix(y_test, y_pred),\"\\n\")"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" precision recall f1-score support\n",
"\n",
" 0 1.00 1.00 1.00 29\n",
" 1 0.88 1.00 0.94 23\n",
" 2 1.00 0.87 0.93 23\n",
"\n",
" accuracy 0.96 75\n",
" macro avg 0.96 0.96 0.96 75\n",
"weighted avg 0.96 0.96 0.96 75\n",
"\n"
]
}
],
"source": [
"y_pred = log_reg.predict(x_test)\n",
"print(classification_report(y_test, y_pred))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.9"
}
},
"nbformat": 4,
"nbformat_minor": 4
}