Machine Learning Kurs im Rahmen der Studierendentage im SS 2023
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

409 lines
11 KiB

  1. {
  2. "cells": [
  3. {
  4. "cell_type": "markdown",
  5. "metadata": {},
  6. "source": [
  7. "# Logistic regression with scikit-learn: heart disease data set"
  8. ]
  9. },
  10. {
  11. "cell_type": "code",
  12. "execution_count": 2,
  13. "metadata": {},
  14. "outputs": [],
  15. "source": [
  16. "import numpy as np\n",
  17. "import pandas as pd\n",
  18. "import matplotlib.pyplot as plt"
  19. ]
  20. },
  21. {
  22. "cell_type": "markdown",
  23. "metadata": {},
  24. "source": [
  25. "### Read data "
  26. ]
  27. },
  28. {
  29. "cell_type": "code",
  30. "execution_count": 3,
  31. "metadata": {},
  32. "outputs": [
  33. {
  34. "data": {
  35. "text/html": [
  36. "<div>\n",
  37. "<style scoped>\n",
  38. " .dataframe tbody tr th:only-of-type {\n",
  39. " vertical-align: middle;\n",
  40. " }\n",
  41. "\n",
  42. " .dataframe tbody tr th {\n",
  43. " vertical-align: top;\n",
  44. " }\n",
  45. "\n",
  46. " .dataframe thead th {\n",
  47. " text-align: right;\n",
  48. " }\n",
  49. "</style>\n",
  50. "<table border=\"1\" class=\"dataframe\">\n",
  51. " <thead>\n",
  52. " <tr style=\"text-align: right;\">\n",
  53. " <th></th>\n",
  54. " <th>age</th>\n",
  55. " <th>sex</th>\n",
  56. " <th>cp</th>\n",
  57. " <th>trestbps</th>\n",
  58. " <th>chol</th>\n",
  59. " <th>fbs</th>\n",
  60. " <th>restecg</th>\n",
  61. " <th>thalach</th>\n",
  62. " <th>exang</th>\n",
  63. " <th>oldpeak</th>\n",
  64. " <th>slope</th>\n",
  65. " <th>ca</th>\n",
  66. " <th>thal</th>\n",
  67. " <th>target</th>\n",
  68. " </tr>\n",
  69. " </thead>\n",
  70. " <tbody>\n",
  71. " <tr>\n",
  72. " <th>0</th>\n",
  73. " <td>63</td>\n",
  74. " <td>1</td>\n",
  75. " <td>3</td>\n",
  76. " <td>145</td>\n",
  77. " <td>233</td>\n",
  78. " <td>1</td>\n",
  79. " <td>0</td>\n",
  80. " <td>150</td>\n",
  81. " <td>0</td>\n",
  82. " <td>2.3</td>\n",
  83. " <td>0</td>\n",
  84. " <td>0</td>\n",
  85. " <td>1</td>\n",
  86. " <td>1</td>\n",
  87. " </tr>\n",
  88. " <tr>\n",
  89. " <th>1</th>\n",
  90. " <td>37</td>\n",
  91. " <td>1</td>\n",
  92. " <td>2</td>\n",
  93. " <td>130</td>\n",
  94. " <td>250</td>\n",
  95. " <td>0</td>\n",
  96. " <td>1</td>\n",
  97. " <td>187</td>\n",
  98. " <td>0</td>\n",
  99. " <td>3.5</td>\n",
  100. " <td>0</td>\n",
  101. " <td>0</td>\n",
  102. " <td>2</td>\n",
  103. " <td>1</td>\n",
  104. " </tr>\n",
  105. " <tr>\n",
  106. " <th>2</th>\n",
  107. " <td>41</td>\n",
  108. " <td>0</td>\n",
  109. " <td>1</td>\n",
  110. " <td>130</td>\n",
  111. " <td>204</td>\n",
  112. " <td>0</td>\n",
  113. " <td>0</td>\n",
  114. " <td>172</td>\n",
  115. " <td>0</td>\n",
  116. " <td>1.4</td>\n",
  117. " <td>2</td>\n",
  118. " <td>0</td>\n",
  119. " <td>2</td>\n",
  120. " <td>1</td>\n",
  121. " </tr>\n",
  122. " <tr>\n",
  123. " <th>3</th>\n",
  124. " <td>56</td>\n",
  125. " <td>1</td>\n",
  126. " <td>1</td>\n",
  127. " <td>120</td>\n",
  128. " <td>236</td>\n",
  129. " <td>0</td>\n",
  130. " <td>1</td>\n",
  131. " <td>178</td>\n",
  132. " <td>0</td>\n",
  133. " <td>0.8</td>\n",
  134. " <td>2</td>\n",
  135. " <td>0</td>\n",
  136. " <td>2</td>\n",
  137. " <td>1</td>\n",
  138. " </tr>\n",
  139. " <tr>\n",
  140. " <th>4</th>\n",
  141. " <td>57</td>\n",
  142. " <td>0</td>\n",
  143. " <td>0</td>\n",
  144. " <td>120</td>\n",
  145. " <td>354</td>\n",
  146. " <td>0</td>\n",
  147. " <td>1</td>\n",
  148. " <td>163</td>\n",
  149. " <td>1</td>\n",
  150. " <td>0.6</td>\n",
  151. " <td>2</td>\n",
  152. " <td>0</td>\n",
  153. " <td>2</td>\n",
  154. " <td>1</td>\n",
  155. " </tr>\n",
  156. " <tr>\n",
  157. " <th>...</th>\n",
  158. " <td>...</td>\n",
  159. " <td>...</td>\n",
  160. " <td>...</td>\n",
  161. " <td>...</td>\n",
  162. " <td>...</td>\n",
  163. " <td>...</td>\n",
  164. " <td>...</td>\n",
  165. " <td>...</td>\n",
  166. " <td>...</td>\n",
  167. " <td>...</td>\n",
  168. " <td>...</td>\n",
  169. " <td>...</td>\n",
  170. " <td>...</td>\n",
  171. " <td>...</td>\n",
  172. " </tr>\n",
  173. " <tr>\n",
  174. " <th>298</th>\n",
  175. " <td>57</td>\n",
  176. " <td>0</td>\n",
  177. " <td>0</td>\n",
  178. " <td>140</td>\n",
  179. " <td>241</td>\n",
  180. " <td>0</td>\n",
  181. " <td>1</td>\n",
  182. " <td>123</td>\n",
  183. " <td>1</td>\n",
  184. " <td>0.2</td>\n",
  185. " <td>1</td>\n",
  186. " <td>0</td>\n",
  187. " <td>3</td>\n",
  188. " <td>0</td>\n",
  189. " </tr>\n",
  190. " <tr>\n",
  191. " <th>299</th>\n",
  192. " <td>45</td>\n",
  193. " <td>1</td>\n",
  194. " <td>3</td>\n",
  195. " <td>110</td>\n",
  196. " <td>264</td>\n",
  197. " <td>0</td>\n",
  198. " <td>1</td>\n",
  199. " <td>132</td>\n",
  200. " <td>0</td>\n",
  201. " <td>1.2</td>\n",
  202. " <td>1</td>\n",
  203. " <td>0</td>\n",
  204. " <td>3</td>\n",
  205. " <td>0</td>\n",
  206. " </tr>\n",
  207. " <tr>\n",
  208. " <th>300</th>\n",
  209. " <td>68</td>\n",
  210. " <td>1</td>\n",
  211. " <td>0</td>\n",
  212. " <td>144</td>\n",
  213. " <td>193</td>\n",
  214. " <td>1</td>\n",
  215. " <td>1</td>\n",
  216. " <td>141</td>\n",
  217. " <td>0</td>\n",
  218. " <td>3.4</td>\n",
  219. " <td>1</td>\n",
  220. " <td>2</td>\n",
  221. " <td>3</td>\n",
  222. " <td>0</td>\n",
  223. " </tr>\n",
  224. " <tr>\n",
  225. " <th>301</th>\n",
  226. " <td>57</td>\n",
  227. " <td>1</td>\n",
  228. " <td>0</td>\n",
  229. " <td>130</td>\n",
  230. " <td>131</td>\n",
  231. " <td>0</td>\n",
  232. " <td>1</td>\n",
  233. " <td>115</td>\n",
  234. " <td>1</td>\n",
  235. " <td>1.2</td>\n",
  236. " <td>1</td>\n",
  237. " <td>1</td>\n",
  238. " <td>3</td>\n",
  239. " <td>0</td>\n",
  240. " </tr>\n",
  241. " <tr>\n",
  242. " <th>302</th>\n",
  243. " <td>57</td>\n",
  244. " <td>0</td>\n",
  245. " <td>1</td>\n",
  246. " <td>130</td>\n",
  247. " <td>236</td>\n",
  248. " <td>0</td>\n",
  249. " <td>0</td>\n",
  250. " <td>174</td>\n",
  251. " <td>0</td>\n",
  252. " <td>0.0</td>\n",
  253. " <td>1</td>\n",
  254. " <td>1</td>\n",
  255. " <td>2</td>\n",
  256. " <td>0</td>\n",
  257. " </tr>\n",
  258. " </tbody>\n",
  259. "</table>\n",
  260. "<p>303 rows × 14 columns</p>\n",
  261. "</div>"
  262. ],
  263. "text/plain": [
  264. " age sex cp trestbps chol fbs restecg thalach exang oldpeak \\\n",
  265. "0 63 1 3 145 233 1 0 150 0 2.3 \n",
  266. "1 37 1 2 130 250 0 1 187 0 3.5 \n",
  267. "2 41 0 1 130 204 0 0 172 0 1.4 \n",
  268. "3 56 1 1 120 236 0 1 178 0 0.8 \n",
  269. "4 57 0 0 120 354 0 1 163 1 0.6 \n",
  270. ".. ... ... .. ... ... ... ... ... ... ... \n",
  271. "298 57 0 0 140 241 0 1 123 1 0.2 \n",
  272. "299 45 1 3 110 264 0 1 132 0 1.2 \n",
  273. "300 68 1 0 144 193 1 1 141 0 3.4 \n",
  274. "301 57 1 0 130 131 0 1 115 1 1.2 \n",
  275. "302 57 0 1 130 236 0 0 174 0 0.0 \n",
  276. "\n",
  277. " slope ca thal target \n",
  278. "0 0 0 1 1 \n",
  279. "1 0 0 2 1 \n",
  280. "2 2 0 2 1 \n",
  281. "3 2 0 2 1 \n",
  282. "4 2 0 2 1 \n",
  283. ".. ... .. ... ... \n",
  284. "298 1 0 3 0 \n",
  285. "299 1 0 3 0 \n",
  286. "300 1 2 3 0 \n",
  287. "301 1 1 3 0 \n",
  288. "302 1 1 2 0 \n",
  289. "\n",
  290. "[303 rows x 14 columns]"
  291. ]
  292. },
  293. "execution_count": 3,
  294. "metadata": {},
  295. "output_type": "execute_result"
  296. }
  297. ],
  298. "source": [
  299. "filename = \"https://www.physi.uni-heidelberg.de/~reygers/lectures/2021/ml/data/heart.csv\"\n",
  300. "df = pd.read_csv(filename)\n",
  301. "df"
  302. ]
  303. },
  304. {
  305. "cell_type": "code",
  306. "execution_count": 4,
  307. "metadata": {},
  308. "outputs": [],
  309. "source": [
  310. "y = df['target'].values\n",
  311. "X = df[[col for col in df.columns if col!=\"target\"]]"
  312. ]
  313. },
  314. {
  315. "cell_type": "code",
  316. "execution_count": 5,
  317. "metadata": {},
  318. "outputs": [],
  319. "source": [
  320. "from sklearn.model_selection import train_test_split\n",
  321. "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, shuffle=True)"
  322. ]
  323. },
  324. {
  325. "cell_type": "markdown",
  326. "metadata": {},
  327. "source": [
  328. "### Fit the model"
  329. ]
  330. },
  331. {
  332. "cell_type": "code",
  333. "execution_count": 6,
  334. "metadata": {},
  335. "outputs": [],
  336. "source": [
  337. "from sklearn.linear_model import LogisticRegression\n",
  338. "from sklearn.ensemble import RandomForestClassifier\n",
  339. "from sklearn.ensemble import AdaBoostClassifier\n",
  340. "from sklearn.ensemble import GradientBoostingClassifier\n",
  341. "\n",
  342. "lr = LogisticRegression(penalty='none', fit_intercept=True, max_iter=5000, tol=1E-5)\n",
  343. "rf = RandomForestClassifier(max_depth=3)\n",
  344. "ab = AdaBoostClassifier()\n",
  345. "gb = GradientBoostingClassifier()\n",
  346. "\n",
  347. "classifiers = [lr, rf, ab, gb]"
  348. ]
  349. },
  350. {
  351. "cell_type": "code",
  352. "execution_count": 7,
  353. "metadata": {},
  354. "outputs": [
  355. {
  356. "name": "stdout",
  357. "output_type": "stream",
  358. "text": [
  359. "LogisticRegression\n",
  360. "RandomForestClassifier\n",
  361. "AdaBoostClassifier\n",
  362. "GradientBoostingClassifier\n"
  363. ]
  364. }
  365. ],
  366. "source": [
  367. "for clf in classifiers:\n",
  368. " print(clf.__class__.__name__)"
  369. ]
  370. },
  371. {
  372. "cell_type": "markdown",
  373. "metadata": {},
  374. "source": [
  375. "### Train models and compare ROC curves"
  376. ]
  377. },
  378. {
  379. "cell_type": "code",
  380. "execution_count": 10,
  381. "metadata": {},
  382. "outputs": [],
  383. "source": [
  384. "### Your code here ###"
  385. ]
  386. }
  387. ],
  388. "metadata": {
  389. "kernelspec": {
  390. "display_name": "Python 3",
  391. "language": "python",
  392. "name": "python3"
  393. },
  394. "language_info": {
  395. "codemirror_mode": {
  396. "name": "ipython",
  397. "version": 3
  398. },
  399. "file_extension": ".py",
  400. "mimetype": "text/x-python",
  401. "name": "python",
  402. "nbconvert_exporter": "python",
  403. "pygments_lexer": "ipython3",
  404. "version": "3.8.5"
  405. }
  406. },
  407. "nbformat": 4,
  408. "nbformat_minor": 4
  409. }