566 lines
30 KiB
Plaintext
566 lines
30 KiB
Plaintext
|
{
|
|||
|
"cells": [
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"# Logistic regression with scikit-learn: heart disease data set"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 81,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"import numpy as np\n",
|
|||
|
"import pandas as pd\n",
|
|||
|
"import matplotlib.pyplot as plt"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"### Read data "
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 82,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>age</th>\n",
|
|||
|
" <th>sex</th>\n",
|
|||
|
" <th>cp</th>\n",
|
|||
|
" <th>trestbps</th>\n",
|
|||
|
" <th>chol</th>\n",
|
|||
|
" <th>fbs</th>\n",
|
|||
|
" <th>restecg</th>\n",
|
|||
|
" <th>thalach</th>\n",
|
|||
|
" <th>exang</th>\n",
|
|||
|
" <th>oldpeak</th>\n",
|
|||
|
" <th>slope</th>\n",
|
|||
|
" <th>ca</th>\n",
|
|||
|
" <th>thal</th>\n",
|
|||
|
" <th>target</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>0</th>\n",
|
|||
|
" <td>63</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>3</td>\n",
|
|||
|
" <td>145</td>\n",
|
|||
|
" <td>233</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>150</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>2.3</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>1</th>\n",
|
|||
|
" <td>37</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>2</td>\n",
|
|||
|
" <td>130</td>\n",
|
|||
|
" <td>250</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>187</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>3.5</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>2</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>2</th>\n",
|
|||
|
" <td>41</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>130</td>\n",
|
|||
|
" <td>204</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>172</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>1.4</td>\n",
|
|||
|
" <td>2</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>2</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>3</th>\n",
|
|||
|
" <td>56</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>120</td>\n",
|
|||
|
" <td>236</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>178</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0.8</td>\n",
|
|||
|
" <td>2</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>2</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>4</th>\n",
|
|||
|
" <td>57</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>120</td>\n",
|
|||
|
" <td>354</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>163</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>0.6</td>\n",
|
|||
|
" <td>2</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>2</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>...</th>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>298</th>\n",
|
|||
|
" <td>57</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>140</td>\n",
|
|||
|
" <td>241</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>123</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>0.2</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>3</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>299</th>\n",
|
|||
|
" <td>45</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>3</td>\n",
|
|||
|
" <td>110</td>\n",
|
|||
|
" <td>264</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>132</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>1.2</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>3</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>300</th>\n",
|
|||
|
" <td>68</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>144</td>\n",
|
|||
|
" <td>193</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>141</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>3.4</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>2</td>\n",
|
|||
|
" <td>3</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>301</th>\n",
|
|||
|
" <td>57</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>130</td>\n",
|
|||
|
" <td>131</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>115</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>1.2</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>3</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>302</th>\n",
|
|||
|
" <td>57</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>130</td>\n",
|
|||
|
" <td>236</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>174</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>2</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"<p>303 rows × 14 columns</p>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" age sex cp trestbps chol fbs restecg thalach exang oldpeak \\\n",
|
|||
|
"0 63 1 3 145 233 1 0 150 0 2.3 \n",
|
|||
|
"1 37 1 2 130 250 0 1 187 0 3.5 \n",
|
|||
|
"2 41 0 1 130 204 0 0 172 0 1.4 \n",
|
|||
|
"3 56 1 1 120 236 0 1 178 0 0.8 \n",
|
|||
|
"4 57 0 0 120 354 0 1 163 1 0.6 \n",
|
|||
|
".. ... ... .. ... ... ... ... ... ... ... \n",
|
|||
|
"298 57 0 0 140 241 0 1 123 1 0.2 \n",
|
|||
|
"299 45 1 3 110 264 0 1 132 0 1.2 \n",
|
|||
|
"300 68 1 0 144 193 1 1 141 0 3.4 \n",
|
|||
|
"301 57 1 0 130 131 0 1 115 1 1.2 \n",
|
|||
|
"302 57 0 1 130 236 0 0 174 0 0.0 \n",
|
|||
|
"\n",
|
|||
|
" slope ca thal target \n",
|
|||
|
"0 0 0 1 1 \n",
|
|||
|
"1 0 0 2 1 \n",
|
|||
|
"2 2 0 2 1 \n",
|
|||
|
"3 2 0 2 1 \n",
|
|||
|
"4 2 0 2 1 \n",
|
|||
|
".. ... .. ... ... \n",
|
|||
|
"298 1 0 3 0 \n",
|
|||
|
"299 1 0 3 0 \n",
|
|||
|
"300 1 2 3 0 \n",
|
|||
|
"301 1 1 3 0 \n",
|
|||
|
"302 1 1 2 0 \n",
|
|||
|
"\n",
|
|||
|
"[303 rows x 14 columns]"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 82,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"# filename = \"heart.csv\"\n",
|
|||
|
"filename = \"https://www.physi.uni-heidelberg.de/~reygers/lectures/2021/ml/data/heart.csv\"\n",
|
|||
|
"df = pd.read_csv(filename)\n",
|
|||
|
"df"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 83,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"y = df['target'].values\n",
|
|||
|
"X = df[[col for col in df.columns if col!=\"target\"]]"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 92,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"from sklearn.model_selection import train_test_split\n",
|
|||
|
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, shuffle=True, random_state=42)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"### Fit the model"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 85,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"CPU times: user 427 ms, sys: 14.1 ms, total: 441 ms\n",
|
|||
|
"Wall time: 587 ms\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/plain": [
|
|||
|
"LogisticRegression(max_iter=5000, penalty='none', tol=1e-05)"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 85,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"from sklearn.linear_model import LogisticRegression\n",
|
|||
|
"lr = LogisticRegression(penalty='none', fit_intercept=True, max_iter=5000, tol=1E-5)\n",
|
|||
|
"%time lr.fit(X_train, y_train)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"### Test predictions on test data set"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 86,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
" precision recall f1-score support\n",
|
|||
|
"\n",
|
|||
|
" 0 0.83 0.75 0.79 69\n",
|
|||
|
" 1 0.81 0.87 0.84 83\n",
|
|||
|
"\n",
|
|||
|
" accuracy 0.82 152\n",
|
|||
|
" macro avg 0.82 0.81 0.81 152\n",
|
|||
|
"weighted avg 0.82 0.82 0.81 152\n",
|
|||
|
"\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"from sklearn.metrics import classification_report\n",
|
|||
|
"y_pred_lr = lr.predict(X_test)\n",
|
|||
|
"print(classification_report(y_test, y_pred_lr))"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"### Compare two classifiers using the ROC curve"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 87,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"from sklearn.ensemble import RandomForestClassifier\n",
|
|||
|
"rf = RandomForestClassifier(max_depth=3)\n",
|
|||
|
"rf.fit(X_train, y_train)\n",
|
|||
|
"y_pred_rf = rf.predict(X_test)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 88,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYoAAAEOCAYAAACXX1DeAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy86wFpkAAAACXBIWXMAAAsTAAALEwEAmpwYAAAruElEQVR4nO3deXRUVfrv//fDmEREJIwGNMhgA9KiRMDhpyAiijY4IIptK+hVHFeriDi1DH7FAVG+9mpF/Iq0MygK9DUttop69YIQrzYCKqZpZJAZBW1meH5/nEpRqYRKpZJUZfi81qplzjm7znlOEfPU3vvsvc3dEREROZRaqQ5AREQqNyUKERGJSYlCRERiUqIQEZGYlChERCSmOqkOoLw1adLEs7OzUx2GiEiV8sUXX2x296bFHat2iSI7O5u8vLxUhyEiUqWY2Q+HOqamJxERiUmJQkREYlKiEBGRmJQoREQkJiUKERGJKWWJwsymmtlGM1tyiONmZk+ZWb6ZLTazk5Ido4iIpLZGMQ04N8bx84D2odf1wDNJiElERKKkbByFu39iZtkxigwEXvRgHvQFZtbIzFq6+7qKiGfB09dx+M/fVMSpq51f219Ej0tHpDoMEUmSytxHkQWsjtheE9pXhJldb2Z5Zpa3adOmpARXU7Xe8y8afP92qsMQkSSqFiOz3X0KMAUgJycnoZWYet70XLnGVF0tHX96qkMQkSSrzIliLdA6YrtVaJ+kWOs9/zpkwmjSoD7ND08r+0W6DIKcYWU/j4iUWWVuepoDXBV6+qknsK2i+ickfr+2v4jV9doWe2zHnv1s/nV32S+y/mv4+s2yn0dEykXKahRm9hrQC2hiZmuA0UBdAHefDOQC/YF8YAegr5eVQNCJXXxH9mXPzgdg+rBTynaRF84v2/tFpFyl8qmnISUcd+DmJIUjlc36r5OTMNTEJVKiytxHITVVl0HJuc76r4P/KlGIxKREIZVPzrDk/PFWE5dIXJQopFwtW7c93FdRWQzsmsUVPY4u/mCymriqMjXP1XhKFFJuBnYtdjxkSi1btx2g+ESRrCauqkzNc4IShZSjK3ocfehv7ikSs3aTrCauqky1LUGJQmqAytYcFrMpTKQSUqKQaq2yNYfFbAoTqaSUKKRaq2zNYZWpZiMSLyUKkSSrbE1hUMonw/QUVI2jRCGSRJWtKQxK+WSYnoKqkSyYKaP6yMnJ8by8vFSHIVJlXPbsfJat206nlg1LLPvAlpFk713ByrrHhveVasZg1UYqLTP7wt1zijumGoVIDVeaWs5n6b0LbRfMGBxXolBtpMpSohCp4UrX4V94ZuBSzRisMRlVlhKFiJRJrM55jRmpHpQoRCRhsZqtNGak+lCiEJGExWq2qmyPAEvilChEpMJENks9sGVb8MRUaL31Ep+W0hNSlYYShYhUiOhmqcgnpkp8WkpPSFUqShQiUiGKNksdfDKqxKel9IRUpaJEISIpEetpqRKbqdQslVRKFCKSdCUN8ovZTKVmqaRTohCRpCt5kF+MZio1SyWdEoWIVD2x1jpXs1S5U6IQkaol1lrnapaqEEoUIlK1xFrrXM1SFUKJQkQqvegnpEq10FI0NU2VmhKFiFRq0U9IlWqhpWhqmkqIEoWIVGrRT0jFnEMqVrMUqGkqQUoUIlLllGbd8SLNVJFNU2qGiosShYhUKaVZka9IM1Vk05SaoeKmNbNFpNqKtR74A1tG0v7ASuplnXBwZw2uYcRaM7tWsoOJZGbnmtl3ZpZvZncXc/xoM5tnZl+a2WIz65+KOEWkahrYNavYJAHw5p5T+L5W9sEd67+Gr99MTmBVTMpqFGZWG1gO9AXWAIuAIe6+LKLMFOBLd3/GzDoBue6eHeu8qlGISDzCU4MMj5oaZNg7KYootWLVKFLZR9EdyHf3FQBm9jowEFgWUcaBgq8DRwA/JjVCEanWohdWan9gJfU0NUgRqWx6ygJWR2yvCe2LNAa40szWALnArcWdyMyuN7M8M8vbtGlTRcQqItVMdLNUkaaoSDW8WaqyP/U0BJjm7hPN7BTgJTM73t0PRBZy9ynAFAianlIQp4hUMUXHZ8A4Li5+MaUaPv4ilTWKtUDriO1WoX2RrgVmALj7fCANaJKU6EREBEhtolgEtDezNmZWD7gcmBNVZhXQB8DMOhIkCrUtiYgkUcoShbvvA24B5gLfADPcfamZjTOzAaFiI4DrzOyfwGvAUK9uAz9ERCq5lPZRuHsuQSd15L4HIn5eBpyW7LhEpGYq1Sy1NUhl78wWEUmKUs1SW8MoUYiIUMpZamsYJQoRkUMoaIp6YMs2AMbV0GYpJQoRkWLEmqW2pjVLKVGIiBSjUFPUC0cAhAfj1bRmKSUKEZF4RCx49MCWbezYs5+l42uHD//a/iJ6XDoiVdFVqJROMy4iUiV0GQQtuoQ3mzSoT0a9g0mi9Z5/0eD7t1MRWVKoRiEiUpKotbibh14Flo4/PekhJZMShYhIOcjeu6LarsWtRCEiUkafpfcGoDNUy7W41UchIlJGH2T0Z1zmhGB1vIi+jOpCNQoRkXIQa3BetKo2WE+JQkSkjGINzotWFQfrKVGIiJRRrMF50ariYD31UYiISExKFCIiEpMShYiIxFSmPgozywAyAYs+5u6rynJuERGpHEqdKMysFnAXcCvQIkbR2jGOiYhIFZFIjeIR4E5gKTAT2FKuEYmIVHURM80CVX5Kj0QSxZXAu+7ev7yDERGp8roMKrxdDab0SCRRHAnMLu9ARESqhaiZZgvVLKqoRJ56+hpoWd6BiIhI5ZRIohgL3GBmrcs7GBERqXwSaXrqBvwALDOzt4F/A/ujyri7P1jW4EREJPUSSRRjIn6+8hBlHFCiEBGBIuttB+tXFD8XVGWUSKJoU+5RiIhUV1FPQWXvXZGiQBJX6kTh7j9URCAiItVS1FNQK6vg+tplncIjk4M1jH+7uwbfiYhUMwlNCmhmJ5jZx8BG4PPQa6OZfWRmvy3PAEVEJLVKnSjM7HjgU+BUgoF340Ov2cBpwP8xs85xnutcM/vOzPLN7O5DlBlsZsvMbKmZvVraeEVEpGwSaXoaB+wFTnP3xZEHQknkk1CZS2KdxMxqA38B+gJrgEVmNsfdl0WUaQ/cE7rWT2bWLIF4RUSkDBJpejoD+Et0kgBw9yXA08CZcZynO5Dv7ivcfQ/wOjAwqsx1oWv9FDr/xgTiFRGRMkgkURwGrI9xfF2oTEmygNUR22tC+yJ1ADqY2WdmtsDMzi3uRGZ2vZnlmVnepk2b4ri0iIjEK5FEsQK4IMbxC0JlykMdoD3QCxgCPGdmjaILufsUd89x95ymTZuW06VFRAQSSxQvAv3M7FUz62xmtUOv483sFeAcYFoc51kLRM4X1Sq0L9IaYI6773X3fwPLCRKHiIgkSSKJ4nHgDeByYDGwK/T6J8G3/jeAiXGcZxHQ3szamFm90PnmRJWZRVCbwMyaEDRFVb1hjSIiVVgiI7P3A5eZ2f8AF3JwwN0KYJa7vx/nefaZ2S3AXIJlU6e6+1IzGwfkufuc0LFzzGwZwcSDIzWoT0QkuRIeme3u/wD+UZaLu3sukBu174GInx24I/QSEZEUSGhktoiI1Bwl1ijM7AGCacMfcvcDoe2SaD0KEZFqIp6mpzEEieJRYA+F16M4FK1HISJSTcSTKNoAhEZPh7dFRCQx2XtXhBcyosugQtOQV0YlJoro9Se0HoWISOKC1e2gMwQr30GlTxTl1pltZk1Ck/iJiMghfJDRn3GZE2DYO9CiS6rDiUsi04xfZWZTovY9DGwAvg3Ny3R4eQUoIiKplUiNYjgRTVZmlgOMAv4P8BzBrLAa9yAiUk0kMuCuHcE0HQUuBbYC57j7HjNzYDAwthziExGRFEukRnEEsC1iuw/wfsRTUXnA0WUNTEREKodEEsV6QjO4mllToCtBs1OBBgTzMomISDWQSNPTh8DNZrYV6E0wuO6diOPHUXS6cBERqaISSRQPAKcCj4W2/8vdVwKYWR2CtbJnlkt0IiLV3fqvDw6+g0o5AC+RacbXmFlnoBOwzd1XRRzOAK4
|
|||
|
"text/plain": [
|
|||
|
"<Figure size 432x288 with 1 Axes>"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {
|
|||
|
"needs_background": "light"
|
|||
|
},
|
|||
|
"output_type": "display_data"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"from sklearn.metrics import roc_curve\n",
|
|||
|
"\n",
|
|||
|
"y_pred_prob_lr = lr.predict_proba(X_test) # predicted probabilities\n",
|
|||
|
"fpr_lr, tpr_lr, _ = roc_curve(y_test, y_pred_prob_lr[:,1])\n",
|
|||
|
"\n",
|
|||
|
"y_pred_prob_rf = rf.predict_proba(X_test) # predicted probabilities\n",
|
|||
|
"fpr_rf, tpr_rf, _ = roc_curve(y_test, y_pred_prob_rf[:,1])\n",
|
|||
|
"\n",
|
|||
|
"plt.plot(tpr_lr, 1-fpr_lr, label=\"log. regression\")\n",
|
|||
|
"plt.plot(tpr_rf, 1-fpr_rf, label=\"random forest\")\n",
|
|||
|
"\n",
|
|||
|
"plt.xlabel('Recall', fontsize=18)\n",
|
|||
|
"plt.ylabel('Precision', fontsize=18);\n",
|
|||
|
"plt.legend(fontsize=15)\n",
|
|||
|
"\n",
|
|||
|
"plt.savefig(\"03_ml_basics_log_regr_heart_disease.pdf\")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 89,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"Area under Curve (AUC) scores: 0.81, 0.81\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"from sklearn.metrics import roc_auc_score\n",
|
|||
|
"auc_lr = roc_auc_score(y_test,y_pred_lr)\n",
|
|||
|
"auc_rf = roc_auc_score(y_test,y_pred_rf)\n",
|
|||
|
"print(f\"Area under Curve (AUC) scores: {auc_lr:.2f}, {auc_rf:.2f}\")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"# Check wether data preprocessing makes a difference in this case"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 93,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"CPU times: user 22.4 ms, sys: 2.89 ms, total: 25.3 ms\n",
|
|||
|
"Wall time: 27.6 ms\n",
|
|||
|
" precision recall f1-score support\n",
|
|||
|
"\n",
|
|||
|
" 0 0.80 0.81 0.81 70\n",
|
|||
|
" 1 0.84 0.83 0.83 82\n",
|
|||
|
"\n",
|
|||
|
" accuracy 0.82 152\n",
|
|||
|
" macro avg 0.82 0.82 0.82 152\n",
|
|||
|
"weighted avg 0.82 0.82 0.82 152\n",
|
|||
|
"\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"from sklearn.pipeline import make_pipeline\n",
|
|||
|
"from sklearn.preprocessing import StandardScaler\n",
|
|||
|
"pipe = make_pipeline(StandardScaler(), LogisticRegression(penalty='none', fit_intercept=False, max_iter=5000, tol=1E-5))\n",
|
|||
|
"%time pipe.fit(X_train, y_train)\n",
|
|||
|
"y_pred_pipe = pipe.predict(X_test)\n",
|
|||
|
"print(classification_report(y_test, y_pred_pipe))"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 91,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
" precision recall f1-score support\n",
|
|||
|
"\n",
|
|||
|
" 0 0.83 0.75 0.79 69\n",
|
|||
|
" 1 0.81 0.87 0.84 83\n",
|
|||
|
"\n",
|
|||
|
" accuracy 0.82 152\n",
|
|||
|
" macro avg 0.82 0.81 0.81 152\n",
|
|||
|
"weighted avg 0.82 0.82 0.81 152\n",
|
|||
|
"\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"print(classification_report(y_test, y_pred_lr))"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": []
|
|||
|
}
|
|||
|
],
|
|||
|
"metadata": {
|
|||
|
"kernelspec": {
|
|||
|
"display_name": "Python 3",
|
|||
|
"language": "python",
|
|||
|
"name": "python3"
|
|||
|
},
|
|||
|
"language_info": {
|
|||
|
"codemirror_mode": {
|
|||
|
"name": "ipython",
|
|||
|
"version": 3
|
|||
|
},
|
|||
|
"file_extension": ".py",
|
|||
|
"mimetype": "text/x-python",
|
|||
|
"name": "python",
|
|||
|
"nbconvert_exporter": "python",
|
|||
|
"pygments_lexer": "ipython3",
|
|||
|
"version": "3.8.5"
|
|||
|
}
|
|||
|
},
|
|||
|
"nbformat": 4,
|
|||
|
"nbformat_minor": 4
|
|||
|
}
|