410 lines
11 KiB
Plaintext
410 lines
11 KiB
Plaintext
|
{
|
|||
|
"cells": [
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"# Logistic regression with scikit-learn: heart disease data set"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 2,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"import numpy as np\n",
|
|||
|
"import pandas as pd\n",
|
|||
|
"import matplotlib.pyplot as plt"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"### Read data "
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 3,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>age</th>\n",
|
|||
|
" <th>sex</th>\n",
|
|||
|
" <th>cp</th>\n",
|
|||
|
" <th>trestbps</th>\n",
|
|||
|
" <th>chol</th>\n",
|
|||
|
" <th>fbs</th>\n",
|
|||
|
" <th>restecg</th>\n",
|
|||
|
" <th>thalach</th>\n",
|
|||
|
" <th>exang</th>\n",
|
|||
|
" <th>oldpeak</th>\n",
|
|||
|
" <th>slope</th>\n",
|
|||
|
" <th>ca</th>\n",
|
|||
|
" <th>thal</th>\n",
|
|||
|
" <th>target</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>0</th>\n",
|
|||
|
" <td>63</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>3</td>\n",
|
|||
|
" <td>145</td>\n",
|
|||
|
" <td>233</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>150</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>2.3</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>1</th>\n",
|
|||
|
" <td>37</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>2</td>\n",
|
|||
|
" <td>130</td>\n",
|
|||
|
" <td>250</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>187</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>3.5</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>2</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>2</th>\n",
|
|||
|
" <td>41</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>130</td>\n",
|
|||
|
" <td>204</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>172</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>1.4</td>\n",
|
|||
|
" <td>2</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>2</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>3</th>\n",
|
|||
|
" <td>56</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>120</td>\n",
|
|||
|
" <td>236</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>178</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0.8</td>\n",
|
|||
|
" <td>2</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>2</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>4</th>\n",
|
|||
|
" <td>57</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>120</td>\n",
|
|||
|
" <td>354</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>163</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>0.6</td>\n",
|
|||
|
" <td>2</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>2</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>...</th>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>298</th>\n",
|
|||
|
" <td>57</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>140</td>\n",
|
|||
|
" <td>241</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>123</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>0.2</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>3</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>299</th>\n",
|
|||
|
" <td>45</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>3</td>\n",
|
|||
|
" <td>110</td>\n",
|
|||
|
" <td>264</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>132</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>1.2</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>3</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>300</th>\n",
|
|||
|
" <td>68</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>144</td>\n",
|
|||
|
" <td>193</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>141</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>3.4</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>2</td>\n",
|
|||
|
" <td>3</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>301</th>\n",
|
|||
|
" <td>57</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>130</td>\n",
|
|||
|
" <td>131</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>115</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>1.2</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>3</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>302</th>\n",
|
|||
|
" <td>57</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>130</td>\n",
|
|||
|
" <td>236</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>174</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" <td>0.0</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>2</td>\n",
|
|||
|
" <td>0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"<p>303 rows × 14 columns</p>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" age sex cp trestbps chol fbs restecg thalach exang oldpeak \\\n",
|
|||
|
"0 63 1 3 145 233 1 0 150 0 2.3 \n",
|
|||
|
"1 37 1 2 130 250 0 1 187 0 3.5 \n",
|
|||
|
"2 41 0 1 130 204 0 0 172 0 1.4 \n",
|
|||
|
"3 56 1 1 120 236 0 1 178 0 0.8 \n",
|
|||
|
"4 57 0 0 120 354 0 1 163 1 0.6 \n",
|
|||
|
".. ... ... .. ... ... ... ... ... ... ... \n",
|
|||
|
"298 57 0 0 140 241 0 1 123 1 0.2 \n",
|
|||
|
"299 45 1 3 110 264 0 1 132 0 1.2 \n",
|
|||
|
"300 68 1 0 144 193 1 1 141 0 3.4 \n",
|
|||
|
"301 57 1 0 130 131 0 1 115 1 1.2 \n",
|
|||
|
"302 57 0 1 130 236 0 0 174 0 0.0 \n",
|
|||
|
"\n",
|
|||
|
" slope ca thal target \n",
|
|||
|
"0 0 0 1 1 \n",
|
|||
|
"1 0 0 2 1 \n",
|
|||
|
"2 2 0 2 1 \n",
|
|||
|
"3 2 0 2 1 \n",
|
|||
|
"4 2 0 2 1 \n",
|
|||
|
".. ... .. ... ... \n",
|
|||
|
"298 1 0 3 0 \n",
|
|||
|
"299 1 0 3 0 \n",
|
|||
|
"300 1 2 3 0 \n",
|
|||
|
"301 1 1 3 0 \n",
|
|||
|
"302 1 1 2 0 \n",
|
|||
|
"\n",
|
|||
|
"[303 rows x 14 columns]"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 3,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"filename = \"https://www.physi.uni-heidelberg.de/~reygers/lectures/2021/ml/data/heart.csv\"\n",
|
|||
|
"df = pd.read_csv(filename)\n",
|
|||
|
"df"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 4,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"y = df['target'].values\n",
|
|||
|
"X = df[[col for col in df.columns if col!=\"target\"]]"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 5,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"from sklearn.model_selection import train_test_split\n",
|
|||
|
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, shuffle=True)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"### Fit the model"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 6,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"from sklearn.linear_model import LogisticRegression\n",
|
|||
|
"from sklearn.ensemble import RandomForestClassifier\n",
|
|||
|
"from sklearn.ensemble import AdaBoostClassifier\n",
|
|||
|
"from sklearn.ensemble import GradientBoostingClassifier\n",
|
|||
|
"\n",
|
|||
|
"lr = LogisticRegression(penalty='none', fit_intercept=True, max_iter=5000, tol=1E-5)\n",
|
|||
|
"rf = RandomForestClassifier(max_depth=3)\n",
|
|||
|
"ab = AdaBoostClassifier()\n",
|
|||
|
"gb = GradientBoostingClassifier()\n",
|
|||
|
"\n",
|
|||
|
"classifiers = [lr, rf, ab, gb]"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 7,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"LogisticRegression\n",
|
|||
|
"RandomForestClassifier\n",
|
|||
|
"AdaBoostClassifier\n",
|
|||
|
"GradientBoostingClassifier\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"for clf in classifiers:\n",
|
|||
|
" print(clf.__class__.__name__)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"### Train models and compare ROC curves"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 10,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"### Your code here ###"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"metadata": {
|
|||
|
"kernelspec": {
|
|||
|
"display_name": "Python 3",
|
|||
|
"language": "python",
|
|||
|
"name": "python3"
|
|||
|
},
|
|||
|
"language_info": {
|
|||
|
"codemirror_mode": {
|
|||
|
"name": "ipython",
|
|||
|
"version": 3
|
|||
|
},
|
|||
|
"file_extension": ".py",
|
|||
|
"mimetype": "text/x-python",
|
|||
|
"name": "python",
|
|||
|
"nbconvert_exporter": "python",
|
|||
|
"pygments_lexer": "ipython3",
|
|||
|
"version": "3.8.5"
|
|||
|
}
|
|||
|
},
|
|||
|
"nbformat": 4,
|
|||
|
"nbformat_minor": 4
|
|||
|
}
|