{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Logistic regression with scikit-learn: heart disease data set" ] }, { "cell_type": "code", "execution_count": 81, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "import matplotlib.pyplot as plt" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Read data " ] }, { "cell_type": "code", "execution_count": 82, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
agesexcptrestbpscholfbsrestecgthalachexangoldpeakslopecathaltarget
063131452331015002.30011
137121302500118703.50021
241011302040017201.42021
356111202360117800.82021
457001203540116310.62021
.............................................
29857001402410112310.21030
29945131102640113201.21030
30068101441931114103.41230
30157101301310111511.21130
30257011302360017400.01120
\n", "

303 rows × 14 columns

\n", "
" ], "text/plain": [ " age sex cp trestbps chol fbs restecg thalach exang oldpeak \\\n", "0 63 1 3 145 233 1 0 150 0 2.3 \n", "1 37 1 2 130 250 0 1 187 0 3.5 \n", "2 41 0 1 130 204 0 0 172 0 1.4 \n", "3 56 1 1 120 236 0 1 178 0 0.8 \n", "4 57 0 0 120 354 0 1 163 1 0.6 \n", ".. ... ... .. ... ... ... ... ... ... ... \n", "298 57 0 0 140 241 0 1 123 1 0.2 \n", "299 45 1 3 110 264 0 1 132 0 1.2 \n", "300 68 1 0 144 193 1 1 141 0 3.4 \n", "301 57 1 0 130 131 0 1 115 1 1.2 \n", "302 57 0 1 130 236 0 0 174 0 0.0 \n", "\n", " slope ca thal target \n", "0 0 0 1 1 \n", "1 0 0 2 1 \n", "2 2 0 2 1 \n", "3 2 0 2 1 \n", "4 2 0 2 1 \n", ".. ... .. ... ... \n", "298 1 0 3 0 \n", "299 1 0 3 0 \n", "300 1 2 3 0 \n", "301 1 1 3 0 \n", "302 1 1 2 0 \n", "\n", "[303 rows x 14 columns]" ] }, "execution_count": 82, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# filename = \"heart.csv\"\n", "filename = \"https://www.physi.uni-heidelberg.de/~reygers/lectures/2021/ml/data/heart.csv\"\n", "df = pd.read_csv(filename)\n", "df" ] }, { "cell_type": "code", "execution_count": 83, "metadata": {}, "outputs": [], "source": [ "y = df['target'].values\n", "X = df[[col for col in df.columns if col!=\"target\"]]" ] }, { "cell_type": "code", "execution_count": 92, "metadata": {}, "outputs": [], "source": [ "from sklearn.model_selection import train_test_split\n", "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, shuffle=True, random_state=42)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Fit the model" ] }, { "cell_type": "code", "execution_count": 85, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 427 ms, sys: 14.1 ms, total: 441 ms\n", "Wall time: 587 ms\n" ] }, { "data": { "text/plain": [ "LogisticRegression(max_iter=5000, penalty='none', tol=1e-05)" ] }, "execution_count": 85, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from sklearn.linear_model import LogisticRegression\n", "lr = LogisticRegression(penalty='none', fit_intercept=True, max_iter=5000, tol=1E-5)\n", "%time lr.fit(X_train, y_train)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Test predictions on test data set" ] }, { "cell_type": "code", "execution_count": 86, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " precision recall f1-score support\n", "\n", " 0 0.83 0.75 0.79 69\n", " 1 0.81 0.87 0.84 83\n", "\n", " accuracy 0.82 152\n", " macro avg 0.82 0.81 0.81 152\n", "weighted avg 0.82 0.82 0.81 152\n", "\n" ] } ], "source": [ "from sklearn.metrics import classification_report\n", "y_pred_lr = lr.predict(X_test)\n", "print(classification_report(y_test, y_pred_lr))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Compare two classifiers using the ROC curve" ] }, { "cell_type": "code", "execution_count": 87, "metadata": {}, "outputs": [], "source": [ "from sklearn.ensemble import RandomForestClassifier\n", "rf = RandomForestClassifier(max_depth=3)\n", "rf.fit(X_train, y_train)\n", "y_pred_rf = rf.predict(X_test)" ] }, { "cell_type": "code", "execution_count": 88, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "from sklearn.metrics import roc_curve\n", "\n", "y_pred_prob_lr = lr.predict_proba(X_test) # predicted probabilities\n", "fpr_lr, tpr_lr, _ = roc_curve(y_test, y_pred_prob_lr[:,1])\n", "\n", "y_pred_prob_rf = rf.predict_proba(X_test) # predicted probabilities\n", "fpr_rf, tpr_rf, _ = roc_curve(y_test, y_pred_prob_rf[:,1])\n", "\n", "plt.plot(tpr_lr, 1-fpr_lr, label=\"log. regression\")\n", "plt.plot(tpr_rf, 1-fpr_rf, label=\"random forest\")\n", "\n", "plt.xlabel('Recall', fontsize=18)\n", "plt.ylabel('Precision', fontsize=18);\n", "plt.legend(fontsize=15)\n", "\n", "plt.savefig(\"03_ml_basics_log_regr_heart_disease.pdf\")" ] }, { "cell_type": "code", "execution_count": 89, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Area under Curve (AUC) scores: 0.81, 0.81\n" ] } ], "source": [ "from sklearn.metrics import roc_auc_score\n", "auc_lr = roc_auc_score(y_test,y_pred_lr)\n", "auc_rf = roc_auc_score(y_test,y_pred_rf)\n", "print(f\"Area under Curve (AUC) scores: {auc_lr:.2f}, {auc_rf:.2f}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Check wether data preprocessing makes a difference in this case" ] }, { "cell_type": "code", "execution_count": 93, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 22.4 ms, sys: 2.89 ms, total: 25.3 ms\n", "Wall time: 27.6 ms\n", " precision recall f1-score support\n", "\n", " 0 0.80 0.81 0.81 70\n", " 1 0.84 0.83 0.83 82\n", "\n", " accuracy 0.82 152\n", " macro avg 0.82 0.82 0.82 152\n", "weighted avg 0.82 0.82 0.82 152\n", "\n" ] } ], "source": [ "from sklearn.pipeline import make_pipeline\n", "from sklearn.preprocessing import StandardScaler\n", "pipe = make_pipeline(StandardScaler(), LogisticRegression(penalty='none', fit_intercept=False, max_iter=5000, tol=1E-5))\n", "%time pipe.fit(X_train, y_train)\n", "y_pred_pipe = pipe.predict(X_test)\n", "print(classification_report(y_test, y_pred_pipe))" ] }, { "cell_type": "code", "execution_count": 91, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " precision recall f1-score support\n", "\n", " 0 0.83 0.75 0.79 69\n", " 1 0.81 0.87 0.84 83\n", "\n", " accuracy 0.82 152\n", " macro avg 0.82 0.81 0.81 152\n", "weighted avg 0.82 0.82 0.81 152\n", "\n" ] } ], "source": [ "print(classification_report(y_test, y_pred_lr))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.5" } }, "nbformat": 4, "nbformat_minor": 4 }