ML-Kurs-SS2023/notebooks/04_decision_trees_ex_1_compare_tree_classifiers.ipynb

410 lines
11 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Logistic regression with scikit-learn: heart disease data set"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Read data "
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>age</th>\n",
" <th>sex</th>\n",
" <th>cp</th>\n",
" <th>trestbps</th>\n",
" <th>chol</th>\n",
" <th>fbs</th>\n",
" <th>restecg</th>\n",
" <th>thalach</th>\n",
" <th>exang</th>\n",
" <th>oldpeak</th>\n",
" <th>slope</th>\n",
" <th>ca</th>\n",
" <th>thal</th>\n",
" <th>target</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>63</td>\n",
" <td>1</td>\n",
" <td>3</td>\n",
" <td>145</td>\n",
" <td>233</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>150</td>\n",
" <td>0</td>\n",
" <td>2.3</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>37</td>\n",
" <td>1</td>\n",
" <td>2</td>\n",
" <td>130</td>\n",
" <td>250</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>187</td>\n",
" <td>0</td>\n",
" <td>3.5</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>2</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>41</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>130</td>\n",
" <td>204</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>172</td>\n",
" <td>0</td>\n",
" <td>1.4</td>\n",
" <td>2</td>\n",
" <td>0</td>\n",
" <td>2</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>56</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>120</td>\n",
" <td>236</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>178</td>\n",
" <td>0</td>\n",
" <td>0.8</td>\n",
" <td>2</td>\n",
" <td>0</td>\n",
" <td>2</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>57</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>120</td>\n",
" <td>354</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>163</td>\n",
" <td>1</td>\n",
" <td>0.6</td>\n",
" <td>2</td>\n",
" <td>0</td>\n",
" <td>2</td>\n",
" <td>1</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>298</th>\n",
" <td>57</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>140</td>\n",
" <td>241</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>123</td>\n",
" <td>1</td>\n",
" <td>0.2</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>3</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>299</th>\n",
" <td>45</td>\n",
" <td>1</td>\n",
" <td>3</td>\n",
" <td>110</td>\n",
" <td>264</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>132</td>\n",
" <td>0</td>\n",
" <td>1.2</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>3</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>300</th>\n",
" <td>68</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>144</td>\n",
" <td>193</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>141</td>\n",
" <td>0</td>\n",
" <td>3.4</td>\n",
" <td>1</td>\n",
" <td>2</td>\n",
" <td>3</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>301</th>\n",
" <td>57</td>\n",
" <td>1</td>\n",
" <td>0</td>\n",
" <td>130</td>\n",
" <td>131</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>115</td>\n",
" <td>1</td>\n",
" <td>1.2</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>3</td>\n",
" <td>0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>302</th>\n",
" <td>57</td>\n",
" <td>0</td>\n",
" <td>1</td>\n",
" <td>130</td>\n",
" <td>236</td>\n",
" <td>0</td>\n",
" <td>0</td>\n",
" <td>174</td>\n",
" <td>0</td>\n",
" <td>0.0</td>\n",
" <td>1</td>\n",
" <td>1</td>\n",
" <td>2</td>\n",
" <td>0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>303 rows × 14 columns</p>\n",
"</div>"
],
"text/plain": [
" age sex cp trestbps chol fbs restecg thalach exang oldpeak \\\n",
"0 63 1 3 145 233 1 0 150 0 2.3 \n",
"1 37 1 2 130 250 0 1 187 0 3.5 \n",
"2 41 0 1 130 204 0 0 172 0 1.4 \n",
"3 56 1 1 120 236 0 1 178 0 0.8 \n",
"4 57 0 0 120 354 0 1 163 1 0.6 \n",
".. ... ... .. ... ... ... ... ... ... ... \n",
"298 57 0 0 140 241 0 1 123 1 0.2 \n",
"299 45 1 3 110 264 0 1 132 0 1.2 \n",
"300 68 1 0 144 193 1 1 141 0 3.4 \n",
"301 57 1 0 130 131 0 1 115 1 1.2 \n",
"302 57 0 1 130 236 0 0 174 0 0.0 \n",
"\n",
" slope ca thal target \n",
"0 0 0 1 1 \n",
"1 0 0 2 1 \n",
"2 2 0 2 1 \n",
"3 2 0 2 1 \n",
"4 2 0 2 1 \n",
".. ... .. ... ... \n",
"298 1 0 3 0 \n",
"299 1 0 3 0 \n",
"300 1 2 3 0 \n",
"301 1 1 3 0 \n",
"302 1 1 2 0 \n",
"\n",
"[303 rows x 14 columns]"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"filename = \"https://www.physi.uni-heidelberg.de/~reygers/lectures/2021/ml/data/heart.csv\"\n",
"df = pd.read_csv(filename)\n",
"df"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"y = df['target'].values\n",
"X = df[[col for col in df.columns if col!=\"target\"]]"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.model_selection import train_test_split\n",
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, shuffle=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Fit the model"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.linear_model import LogisticRegression\n",
"from sklearn.ensemble import RandomForestClassifier\n",
"from sklearn.ensemble import AdaBoostClassifier\n",
"from sklearn.ensemble import GradientBoostingClassifier\n",
"\n",
"lr = LogisticRegression(penalty='none', fit_intercept=True, max_iter=5000, tol=1E-5)\n",
"rf = RandomForestClassifier(max_depth=3)\n",
"ab = AdaBoostClassifier()\n",
"gb = GradientBoostingClassifier()\n",
"\n",
"classifiers = [lr, rf, ab, gb]"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"LogisticRegression\n",
"RandomForestClassifier\n",
"AdaBoostClassifier\n",
"GradientBoostingClassifier\n"
]
}
],
"source": [
"for clf in classifiers:\n",
" print(clf.__class__.__name__)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Train models and compare ROC curves"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"### Your code here ###"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
}
},
"nbformat": 4,
"nbformat_minor": 4
}