|
|
{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Logistic regression with scikit-learn: heart disease data set" ] }, { "cell_type": "code", "execution_count": 81, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "import matplotlib.pyplot as plt" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Read data " ] }, { "cell_type": "code", "execution_count": 82, "metadata": {}, "outputs": [ { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>age</th>\n", " <th>sex</th>\n", " <th>cp</th>\n", " <th>trestbps</th>\n", " <th>chol</th>\n", " <th>fbs</th>\n", " <th>restecg</th>\n", " <th>thalach</th>\n", " <th>exang</th>\n", " <th>oldpeak</th>\n", " <th>slope</th>\n", " <th>ca</th>\n", " <th>thal</th>\n", " <th>target</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>0</th>\n", " <td>63</td>\n", " <td>1</td>\n", " <td>3</td>\n", " <td>145</td>\n", " <td>233</td>\n", " <td>1</td>\n", " <td>0</td>\n", " <td>150</td>\n", " <td>0</td>\n", " <td>2.3</td>\n", " <td>0</td>\n", " <td>0</td>\n", " <td>1</td>\n", " <td>1</td>\n", " </tr>\n", " <tr>\n", " <th>1</th>\n", " <td>37</td>\n", " <td>1</td>\n", " <td>2</td>\n", " <td>130</td>\n", " <td>250</td>\n", " <td>0</td>\n", " <td>1</td>\n", " <td>187</td>\n", " <td>0</td>\n", " <td>3.5</td>\n", " <td>0</td>\n", " <td>0</td>\n", " <td>2</td>\n", " <td>1</td>\n", " </tr>\n", " <tr>\n", " <th>2</th>\n", " <td>41</td>\n", " <td>0</td>\n", " <td>1</td>\n", " <td>130</td>\n", " <td>204</td>\n", " <td>0</td>\n", " <td>0</td>\n", " <td>172</td>\n", " <td>0</td>\n", " <td>1.4</td>\n", " <td>2</td>\n", " <td>0</td>\n", " <td>2</td>\n", " <td>1</td>\n", " </tr>\n", " <tr>\n", " <th>3</th>\n", " <td>56</td>\n", " <td>1</td>\n", " <td>1</td>\n", " <td>120</td>\n", " <td>236</td>\n", " <td>0</td>\n", " <td>1</td>\n", " <td>178</td>\n", " <td>0</td>\n", " <td>0.8</td>\n", " <td>2</td>\n", " <td>0</td>\n", " <td>2</td>\n", " <td>1</td>\n", " </tr>\n", " <tr>\n", " <th>4</th>\n", " <td>57</td>\n", " <td>0</td>\n", " <td>0</td>\n", " <td>120</td>\n", " <td>354</td>\n", " <td>0</td>\n", " <td>1</td>\n", " <td>163</td>\n", " <td>1</td>\n", " <td>0.6</td>\n", " <td>2</td>\n", " <td>0</td>\n", " <td>2</td>\n", " <td>1</td>\n", " </tr>\n", " <tr>\n", " <th>...</th>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " </tr>\n", " <tr>\n", " <th>298</th>\n", " <td>57</td>\n", " <td>0</td>\n", " <td>0</td>\n", " <td>140</td>\n", " <td>241</td>\n", " <td>0</td>\n", " <td>1</td>\n", " <td>123</td>\n", " <td>1</td>\n", " <td>0.2</td>\n", " <td>1</td>\n", " <td>0</td>\n", " <td>3</td>\n", " <td>0</td>\n", " </tr>\n", " <tr>\n", " <th>299</th>\n", " <td>45</td>\n", " <td>1</td>\n", " <td>3</td>\n", " <td>110</td>\n", " <td>264</td>\n", " <td>0</td>\n", " <td>1</td>\n", " <td>132</td>\n", " <td>0</td>\n", " <td>1.2</td>\n", " <td>1</td>\n", " <td>0</td>\n", " <td>3</td>\n", " <td>0</td>\n", " </tr>\n", " <tr>\n", " <th>300</th>\n", " <td>68</td>\n", " <td>1</td>\n", " <td>0</td>\n", " <td>144</td>\n", " <td>193</td>\n", " <td>1</td>\n", " <td>1</td>\n", " <td>141</td>\n", " <td>0</td>\n", " <td>3.4</td>\n", " <td>1</td>\n", " <td>2</td>\n", " <td>3</td>\n", " <td>0</td>\n", " </tr>\n", " <tr>\n", " <th>301</th>\n", " <td>57</td>\n", " <td>1</td>\n", " <td>0</td>\n", " <td>130</td>\n", " <td>131</td>\n", " <td>0</td>\n", " <td>1</td>\n", " <td>115</td>\n", " <td>1</td>\n", " <td>1.2</td>\n", " <td>1</td>\n", " <td>1</td>\n", " <td>3</td>\n", " <td>0</td>\n", " </tr>\n", " <tr>\n", " <th>302</th>\n", " <td>57</td>\n", " <td>0</td>\n", " <td>1</td>\n", " <td>130</td>\n", " <td>236</td>\n", " <td>0</td>\n", " <td>0</td>\n", " <td>174</td>\n", " <td>0</td>\n", " <td>0.0</td>\n", " <td>1</td>\n", " <td>1</td>\n", " <td>2</td>\n", " <td>0</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "<p>303 rows × 14 columns</p>\n", "</div>" ], "text/plain": [ " age sex cp trestbps chol fbs restecg thalach exang oldpeak \\\n", "0 63 1 3 145 233 1 0 150 0 2.3 \n", "1 37 1 2 130 250 0 1 187 0 3.5 \n", "2 41 0 1 130 204 0 0 172 0 1.4 \n", "3 56 1 1 120 236 0 1 178 0 0.8 \n", "4 57 0 0 120 354 0 1 163 1 0.6 \n", ".. ... ... .. ... ... ... ... ... ... ... \n", "298 57 0 0 140 241 0 1 123 1 0.2 \n", "299 45 1 3 110 264 0 1 132 0 1.2 \n", "300 68 1 0 144 193 1 1 141 0 3.4 \n", "301 57 1 0 130 131 0 1 115 1 1.2 \n", "302 57 0 1 130 236 0 0 174 0 0.0 \n", "\n", " slope ca thal target \n", "0 0 0 1 1 \n", "1 0 0 2 1 \n", "2 2 0 2 1 \n", "3 2 0 2 1 \n", "4 2 0 2 1 \n", ".. ... .. ... ... \n", "298 1 0 3 0 \n", "299 1 0 3 0 \n", "300 1 2 3 0 \n", "301 1 1 3 0 \n", "302 1 1 2 0 \n", "\n", "[303 rows x 14 columns]" ] }, "execution_count": 82, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# filename = \"heart.csv\"\n", "filename = \"https://www.physi.uni-heidelberg.de/~reygers/lectures/2021/ml/data/heart.csv\"\n", "df = pd.read_csv(filename)\n", "df" ] }, { "cell_type": "code", "execution_count": 83, "metadata": {}, "outputs": [], "source": [ "y = df['target'].values\n", "X = df[[col for col in df.columns if col!=\"target\"]]" ] }, { "cell_type": "code", "execution_count": 92, "metadata": {}, "outputs": [], "source": [ "from sklearn.model_selection import train_test_split\n", "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, shuffle=True, random_state=42)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Fit the model" ] }, { "cell_type": "code", "execution_count": 85, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 427 ms, sys: 14.1 ms, total: 441 ms\n", "Wall time: 587 ms\n" ] }, { "data": { "text/plain": [ "LogisticRegression(max_iter=5000, penalty='none', tol=1e-05)" ] }, "execution_count": 85, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from sklearn.linear_model import LogisticRegression\n", "lr = LogisticRegression(penalty='none', fit_intercept=True, max_iter=5000, tol=1E-5)\n", "%time lr.fit(X_train, y_train)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Test predictions on test data set" ] }, { "cell_type": "code", "execution_count": 86, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " precision recall f1-score support\n", "\n", " 0 0.83 0.75 0.79 69\n", " 1 0.81 0.87 0.84 83\n", "\n", " accuracy 0.82 152\n", " macro avg 0.82 0.81 0.81 152\n", "weighted avg 0.82 0.82 0.81 152\n", "\n" ] } ], "source": [ "from sklearn.metrics import classification_report\n", "y_pred_lr = lr.predict(X_test)\n", "print(classification_report(y_test, y_pred_lr))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Compare two classifiers using the ROC curve" ] }, { "cell_type": "code", "execution_count": 87, "metadata": {}, "outputs": [], "source": [ "from sklearn.ensemble import RandomForestClassifier\n", "rf = RandomForestClassifier(max_depth=3)\n", "rf.fit(X_train, y_train)\n", "y_pred_rf = rf.predict(X_test)" ] }, { "cell_type": "code", "execution_count": 88, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYoAAAEOCAYAAACXX1DeAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy86wFpkAAAACXBIWXMAAAsTAAALEwEAmpwYAAAruElEQVR4nO3deXRUVfrv//fDmEREJIwGNMhgA9KiRMDhpyAiijY4IIptK+hVHFeriDi1DH7FAVG+9mpF/Iq0MygK9DUttop69YIQrzYCKqZpZJAZBW1meH5/nEpRqYRKpZJUZfi81qplzjm7znlOEfPU3vvsvc3dEREROZRaqQ5AREQqNyUKERGJSYlCRERiUqIQEZGYlChERCSmOqkOoLw1adLEs7OzUx2GiEiV8sUXX2x296bFHat2iSI7O5u8vLxUhyEiUqWY2Q+HOqamJxERiUmJQkREYlKiEBGRmJQoREQkJiUKERGJKWWJwsymmtlGM1tyiONmZk+ZWb6ZLTazk5Ido4iIpLZGMQ04N8bx84D2odf1wDNJiElERKKkbByFu39iZtkxigwEXvRgHvQFZtbIzFq6+7qKiGfB09dx+M/fVMSpq51f219Ej0tHpDoMEUmSytxHkQWsjtheE9pXhJldb2Z5Zpa3adOmpARXU7Xe8y8afP92qsMQkSSqFiOz3X0KMAUgJycnoZWYet70XLnGVF0tHX96qkMQkSSrzIliLdA6YrtVaJ+kWOs9/zpkwmjSoD7ND08r+0W6DIKcYWU/j4iUWWVuepoDXBV6+qknsK2i+ickfr+2v4jV9doWe2zHnv1s/nV32S+y/mv4+s2yn0dEykXKahRm9hrQC2hiZmuA0UBdAHefDOQC/YF8YAegr5eVQNCJXXxH9mXPzgdg+rBTynaRF84v2/tFpFyl8qmnISUcd+DmJIUjlc36r5OTMNTEJVKiytxHITVVl0HJuc76r4P/KlGIxKREIZVPzrDk/PFWE5dIXJQopFwtW7c93FdRWQzsmsUVPY4u/mCymriqMjXP1XhKFFJuBnYtdjxkSi1btx2g+ESRrCauqkzNc4IShZSjK3ocfehv7ikSs3aTrCauqky1LUGJQmqAytYcFrMpTKQSUqKQaq2yNYfFbAoTqaSUKKRaq2zNYZWpZiMSLyUKkSSrbE1hUMonw/QUVI2jRCGSRJWtKQxK+WSYnoKqkSyYKaP6yMnJ8by8vFSHIVJlXPbsfJat206nlg1LLPvAlpFk713ByrrHhveVasZg1UYqLTP7wt1zijumGoVIDVeaWs5n6b0LbRfMGBxXolBtpMpSohCp4UrX4V94ZuBSzRisMRlVlhKFiJRJrM55jRmpHpQoRCRhsZqtNGak+lCiEJGExWq2qmyPAEvilChEpMJENks9sGVb8MRUaL31Ep+W0hNSlYYShYhUiOhmqcgnpkp8WkpPSFUqShQiUiGKNksdfDKqxKel9IRUpaJEISIpEetpqRKbqdQslVRKFCKSdCUN8ovZTKVmqaRTohCRpCt5kF+MZio1SyWdEoWIVD2x1jpXs1S5U6IQkaol1lrnapaqEEoUIlK1xFrrXM1SFUKJQkQqvegnpEq10FI0NU2VmhKFiFRq0U9IlWqhpWhqmkqIEoWIVGrRT0jFnEMqVrMUqGkqQUoUIlLllGbd8SLNVJFNU2qGiosShYhUKaVZka9IM1Vk05SaoeKmNbNFpNqKtR74A1tG0v7ASuplnXBwZw2uYcRaM7tWsoOJZGbnmtl3ZpZvZncXc/xoM5tnZl+a2WIz65+KOEWkahrYNavYJAHw5p5T+L5W9sEd67+Gr99MTmBVTMpqFGZWG1gO9AXWAIuAIe6+LKLMFOBLd3/GzDoBue6eHeu8qlGISDzCU4MMj5oaZNg7KYootWLVKFLZR9EdyHf3FQBm9jowEFgWUcaBgq8DRwA/JjVCEanWohdWan9gJfU0NUgRqWx6ygJWR2yvCe2LNAa40szWALnArcWdyMyuN7M8M8vbtGlTRcQqItVMdLNUkaaoSDW8WaqyP/U0BJjm7hPN7BTgJTM73t0PRBZy9ynAFAianlIQp4hUMUXHZ8A4Li5+MaUaPv4ilTWKtUDriO1WoX2RrgVmALj7fCANaJKU6EREBEhtolgEtDezNmZWD7gcmBNVZhXQB8DMOhIkCrUtiYgkUcoShbvvA24B5gLfADPcfamZjTOzAaFiI4DrzOyfwGvAUK9uAz9ERCq5lPZRuHsuQSd15L4HIn5eBpyW7LhEpGYq1Sy1NUhl78wWEUmKUs1SW8MoUYiIUMpZamsYJQoRkUMoaIp6YMs2AMbV0GYpJQoRkWLEmqW2pjVLKVGIiBSjUFPUC0cAhAfj1bRmKSUKEZF4RCx49MCWbezYs5+l42uHD//a/iJ6XDoiVdFVqJROMy4iUiV0GQQtuoQ3mzSoT0a9g0mi9Z5/0eD7t1MRWVKoRiEiUpKotbibh14Flo4/PekhJZMShYhIOcjeu6LarsWtRCEiUkafpfcGoDNUy7W41UchIlJGH2T0Z1zmhGB1vIi+jOpCNQoRkXIQa3BetKo2WE+JQkSkjGINzotWFQfrKVGIiJRRrMF50ariYD31UYiISExKFCIiEpMShYiIxFSmPgozywAyAYs+5u6rynJuERGpHEqdKMysFnAXcCvQIkbR2jGOiYhIFZFIjeIR4E5gKTAT2FKuEYmIVHURM80CVX5Kj0QSxZXAu+7ev7yDERGp8roMKrxdDab0SCRRHAnMLu9ARESqhaiZZgvVLKqoRJ56+hpoWd6BiIhI5ZRIohgL3GBmrcs7GBERqXwSaXrqBvwALDOzt4F/A/ujyri7P1jW4EREJPUSSRRjIn6+8hBlHFCiEBGBIuttB+tXFD8XVGWUSKJoU+5RiIhUV1FPQWXvXZGiQBJX6kTh7j9URCAiItVS1FNQK6vg+tplncIjk4M1jH+7uwbfiYhUMwlNCmhmJ5jZx8BG4PPQa6OZfWRmvy3PAEVEJLVKnSjM7HjgU+BUgoF340Ov2cBpwP8xs85xnutcM/vOzPLN7O5DlBlsZsvMbKmZvVraeEVEpGwSaXoaB+wFTnP3xZEHQknkk1CZS2KdxMxqA38B+gJrgEVmNsfdl0WUaQ/cE7rWT2bWLIF4RUSkDBJpejoD+Et0kgBw9yXA08CZcZynO5Dv7ivcfQ/wOjAwqsx1oWv9FDr/xgTiFRGRMkgkURwGrI9xfF2oTEmygNUR22tC+yJ1ADqY2WdmtsDMzi3uRGZ2vZnlmVnepk2b4ri0iIjEK5FEsQK4IMbxC0JlykMdoD3QCxgCPGdmjaILufsUd89x95ymTZuW06VFRAQSSxQvAv3M7FUz62xmtUOv483sFeAcYFoc51kLRM4X1Sq0L9IaYI6773X3fwPLCRKHiIgkSSKJ4nHgDeByYDGwK/T6J8G3/jeAiXGcZxHQ3szamFm90PnmRJWZRVCbwMyaEDRFVb1hjSIiVVgiI7P3A5eZ2f8AF3JwwN0KYJa7vx/nefaZ2S3AXIJlU6e6+1IzGwfkufuc0LFzzGwZwcSDIzWoT0QkuRIeme3u/wD+UZaLu3sukBu174GInx24I/QSEZEUSGhktoiI1Bwl1ijM7AGCacMfcvcDoe2SaD0KEZFqIp6mpzEEieJRYA+F16M4FK1HISJSTcSTKNoAhEZPh7dFRCQx2XtXhBcyosugQtOQV0YlJoro9Se0HoWISOKC1e2gMwQr30GlTxTl1pltZk1Ck/iJiMghfJDRn3GZE2DYO9CiS6rDiUsi04xfZWZTovY9DGwAvg3Ny3R4eQUoIiKplUiNYjgRTVZmlgOMAv4P8BzBrLAa9yAiUk0kMuCuHcE0HQUuBbYC57j7HjNzYDAwthziExGRFEukRnEEsC1iuw/wfsRTUXnA0WUNTEREKodEEsV6QjO4mllToCtBs1OBBgTzMomISDWQSNPTh8DNZrYV6E0wuO6diOPHUXS6cBERqaISSRQPAKcCj4W2/8vdVwKYWR2CtbJnlkt0IiLV3fqvDw6+g0o5AC+RacbXmFlnoBOwzd1XRRzOAK4 "text/plain": [ "<Figure size 432x288 with 1 Axes>" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "from sklearn.metrics import roc_curve\n", "\n", "y_pred_prob_lr = lr.predict_proba(X_test) # predicted probabilities\n", "fpr_lr, tpr_lr, _ = roc_curve(y_test, y_pred_prob_lr[:,1])\n", "\n", "y_pred_prob_rf = rf.predict_proba(X_test) # predicted probabilities\n", "fpr_rf, tpr_rf, _ = roc_curve(y_test, y_pred_prob_rf[:,1])\n", "\n", "plt.plot(tpr_lr, 1-fpr_lr, label=\"log. regression\")\n", "plt.plot(tpr_rf, 1-fpr_rf, label=\"random forest\")\n", "\n", "plt.xlabel('Recall', fontsize=18)\n", "plt.ylabel('Precision', fontsize=18);\n", "plt.legend(fontsize=15)\n", "\n", "plt.savefig(\"03_ml_basics_log_regr_heart_disease.pdf\")" ] }, { "cell_type": "code", "execution_count": 89, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Area under Curve (AUC) scores: 0.81, 0.81\n" ] } ], "source": [ "from sklearn.metrics import roc_auc_score\n", "auc_lr = roc_auc_score(y_test,y_pred_lr)\n", "auc_rf = roc_auc_score(y_test,y_pred_rf)\n", "print(f\"Area under Curve (AUC) scores: {auc_lr:.2f}, {auc_rf:.2f}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Check wether data preprocessing makes a difference in this case" ] }, { "cell_type": "code", "execution_count": 93, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 22.4 ms, sys: 2.89 ms, total: 25.3 ms\n", "Wall time: 27.6 ms\n", " precision recall f1-score support\n", "\n", " 0 0.80 0.81 0.81 70\n", " 1 0.84 0.83 0.83 82\n", "\n", " accuracy 0.82 152\n", " macro avg 0.82 0.82 0.82 152\n", "weighted avg 0.82 0.82 0.82 152\n", "\n" ] } ], "source": [ "from sklearn.pipeline import make_pipeline\n", "from sklearn.preprocessing import StandardScaler\n", "pipe = make_pipeline(StandardScaler(), LogisticRegression(penalty='none', fit_intercept=False, max_iter=5000, tol=1E-5))\n", "%time pipe.fit(X_train, y_train)\n", "y_pred_pipe = pipe.predict(X_test)\n", "print(classification_report(y_test, y_pred_pipe))" ] }, { "cell_type": "code", "execution_count": 91, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " precision recall f1-score support\n", "\n", " 0 0.83 0.75 0.79 69\n", " 1 0.81 0.87 0.84 83\n", "\n", " accuracy 0.82 152\n", " macro avg 0.82 0.81 0.81 152\n", "weighted avg 0.82 0.82 0.81 152\n", "\n" ] } ], "source": [ "print(classification_report(y_test, y_pred_lr))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.5" } }, "nbformat": 4, "nbformat_minor": 4 }
|