Machine Learning Kurs im Rahmen der Studierendentage im SS 2023
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

565 lines
30 KiB

2 years ago
  1. {
  2. "cells": [
  3. {
  4. "cell_type": "markdown",
  5. "metadata": {},
  6. "source": [
  7. "# Logistic regression with scikit-learn: heart disease data set"
  8. ]
  9. },
  10. {
  11. "cell_type": "code",
  12. "execution_count": 81,
  13. "metadata": {},
  14. "outputs": [],
  15. "source": [
  16. "import numpy as np\n",
  17. "import pandas as pd\n",
  18. "import matplotlib.pyplot as plt"
  19. ]
  20. },
  21. {
  22. "cell_type": "markdown",
  23. "metadata": {},
  24. "source": [
  25. "### Read data "
  26. ]
  27. },
  28. {
  29. "cell_type": "code",
  30. "execution_count": 82,
  31. "metadata": {},
  32. "outputs": [
  33. {
  34. "data": {
  35. "text/html": [
  36. "<div>\n",
  37. "<style scoped>\n",
  38. " .dataframe tbody tr th:only-of-type {\n",
  39. " vertical-align: middle;\n",
  40. " }\n",
  41. "\n",
  42. " .dataframe tbody tr th {\n",
  43. " vertical-align: top;\n",
  44. " }\n",
  45. "\n",
  46. " .dataframe thead th {\n",
  47. " text-align: right;\n",
  48. " }\n",
  49. "</style>\n",
  50. "<table border=\"1\" class=\"dataframe\">\n",
  51. " <thead>\n",
  52. " <tr style=\"text-align: right;\">\n",
  53. " <th></th>\n",
  54. " <th>age</th>\n",
  55. " <th>sex</th>\n",
  56. " <th>cp</th>\n",
  57. " <th>trestbps</th>\n",
  58. " <th>chol</th>\n",
  59. " <th>fbs</th>\n",
  60. " <th>restecg</th>\n",
  61. " <th>thalach</th>\n",
  62. " <th>exang</th>\n",
  63. " <th>oldpeak</th>\n",
  64. " <th>slope</th>\n",
  65. " <th>ca</th>\n",
  66. " <th>thal</th>\n",
  67. " <th>target</th>\n",
  68. " </tr>\n",
  69. " </thead>\n",
  70. " <tbody>\n",
  71. " <tr>\n",
  72. " <th>0</th>\n",
  73. " <td>63</td>\n",
  74. " <td>1</td>\n",
  75. " <td>3</td>\n",
  76. " <td>145</td>\n",
  77. " <td>233</td>\n",
  78. " <td>1</td>\n",
  79. " <td>0</td>\n",
  80. " <td>150</td>\n",
  81. " <td>0</td>\n",
  82. " <td>2.3</td>\n",
  83. " <td>0</td>\n",
  84. " <td>0</td>\n",
  85. " <td>1</td>\n",
  86. " <td>1</td>\n",
  87. " </tr>\n",
  88. " <tr>\n",
  89. " <th>1</th>\n",
  90. " <td>37</td>\n",
  91. " <td>1</td>\n",
  92. " <td>2</td>\n",
  93. " <td>130</td>\n",
  94. " <td>250</td>\n",
  95. " <td>0</td>\n",
  96. " <td>1</td>\n",
  97. " <td>187</td>\n",
  98. " <td>0</td>\n",
  99. " <td>3.5</td>\n",
  100. " <td>0</td>\n",
  101. " <td>0</td>\n",
  102. " <td>2</td>\n",
  103. " <td>1</td>\n",
  104. " </tr>\n",
  105. " <tr>\n",
  106. " <th>2</th>\n",
  107. " <td>41</td>\n",
  108. " <td>0</td>\n",
  109. " <td>1</td>\n",
  110. " <td>130</td>\n",
  111. " <td>204</td>\n",
  112. " <td>0</td>\n",
  113. " <td>0</td>\n",
  114. " <td>172</td>\n",
  115. " <td>0</td>\n",
  116. " <td>1.4</td>\n",
  117. " <td>2</td>\n",
  118. " <td>0</td>\n",
  119. " <td>2</td>\n",
  120. " <td>1</td>\n",
  121. " </tr>\n",
  122. " <tr>\n",
  123. " <th>3</th>\n",
  124. " <td>56</td>\n",
  125. " <td>1</td>\n",
  126. " <td>1</td>\n",
  127. " <td>120</td>\n",
  128. " <td>236</td>\n",
  129. " <td>0</td>\n",
  130. " <td>1</td>\n",
  131. " <td>178</td>\n",
  132. " <td>0</td>\n",
  133. " <td>0.8</td>\n",
  134. " <td>2</td>\n",
  135. " <td>0</td>\n",
  136. " <td>2</td>\n",
  137. " <td>1</td>\n",
  138. " </tr>\n",
  139. " <tr>\n",
  140. " <th>4</th>\n",
  141. " <td>57</td>\n",
  142. " <td>0</td>\n",
  143. " <td>0</td>\n",
  144. " <td>120</td>\n",
  145. " <td>354</td>\n",
  146. " <td>0</td>\n",
  147. " <td>1</td>\n",
  148. " <td>163</td>\n",
  149. " <td>1</td>\n",
  150. " <td>0.6</td>\n",
  151. " <td>2</td>\n",
  152. " <td>0</td>\n",
  153. " <td>2</td>\n",
  154. " <td>1</td>\n",
  155. " </tr>\n",
  156. " <tr>\n",
  157. " <th>...</th>\n",
  158. " <td>...</td>\n",
  159. " <td>...</td>\n",
  160. " <td>...</td>\n",
  161. " <td>...</td>\n",
  162. " <td>...</td>\n",
  163. " <td>...</td>\n",
  164. " <td>...</td>\n",
  165. " <td>...</td>\n",
  166. " <td>...</td>\n",
  167. " <td>...</td>\n",
  168. " <td>...</td>\n",
  169. " <td>...</td>\n",
  170. " <td>...</td>\n",
  171. " <td>...</td>\n",
  172. " </tr>\n",
  173. " <tr>\n",
  174. " <th>298</th>\n",
  175. " <td>57</td>\n",
  176. " <td>0</td>\n",
  177. " <td>0</td>\n",
  178. " <td>140</td>\n",
  179. " <td>241</td>\n",
  180. " <td>0</td>\n",
  181. " <td>1</td>\n",
  182. " <td>123</td>\n",
  183. " <td>1</td>\n",
  184. " <td>0.2</td>\n",
  185. " <td>1</td>\n",
  186. " <td>0</td>\n",
  187. " <td>3</td>\n",
  188. " <td>0</td>\n",
  189. " </tr>\n",
  190. " <tr>\n",
  191. " <th>299</th>\n",
  192. " <td>45</td>\n",
  193. " <td>1</td>\n",
  194. " <td>3</td>\n",
  195. " <td>110</td>\n",
  196. " <td>264</td>\n",
  197. " <td>0</td>\n",
  198. " <td>1</td>\n",
  199. " <td>132</td>\n",
  200. " <td>0</td>\n",
  201. " <td>1.2</td>\n",
  202. " <td>1</td>\n",
  203. " <td>0</td>\n",
  204. " <td>3</td>\n",
  205. " <td>0</td>\n",
  206. " </tr>\n",
  207. " <tr>\n",
  208. " <th>300</th>\n",
  209. " <td>68</td>\n",
  210. " <td>1</td>\n",
  211. " <td>0</td>\n",
  212. " <td>144</td>\n",
  213. " <td>193</td>\n",
  214. " <td>1</td>\n",
  215. " <td>1</td>\n",
  216. " <td>141</td>\n",
  217. " <td>0</td>\n",
  218. " <td>3.4</td>\n",
  219. " <td>1</td>\n",
  220. " <td>2</td>\n",
  221. " <td>3</td>\n",
  222. " <td>0</td>\n",
  223. " </tr>\n",
  224. " <tr>\n",
  225. " <th>301</th>\n",
  226. " <td>57</td>\n",
  227. " <td>1</td>\n",
  228. " <td>0</td>\n",
  229. " <td>130</td>\n",
  230. " <td>131</td>\n",
  231. " <td>0</td>\n",
  232. " <td>1</td>\n",
  233. " <td>115</td>\n",
  234. " <td>1</td>\n",
  235. " <td>1.2</td>\n",
  236. " <td>1</td>\n",
  237. " <td>1</td>\n",
  238. " <td>3</td>\n",
  239. " <td>0</td>\n",
  240. " </tr>\n",
  241. " <tr>\n",
  242. " <th>302</th>\n",
  243. " <td>57</td>\n",
  244. " <td>0</td>\n",
  245. " <td>1</td>\n",
  246. " <td>130</td>\n",
  247. " <td>236</td>\n",
  248. " <td>0</td>\n",
  249. " <td>0</td>\n",
  250. " <td>174</td>\n",
  251. " <td>0</td>\n",
  252. " <td>0.0</td>\n",
  253. " <td>1</td>\n",
  254. " <td>1</td>\n",
  255. " <td>2</td>\n",
  256. " <td>0</td>\n",
  257. " </tr>\n",
  258. " </tbody>\n",
  259. "</table>\n",
  260. "<p>303 rows × 14 columns</p>\n",
  261. "</div>"
  262. ],
  263. "text/plain": [
  264. " age sex cp trestbps chol fbs restecg thalach exang oldpeak \\\n",
  265. "0 63 1 3 145 233 1 0 150 0 2.3 \n",
  266. "1 37 1 2 130 250 0 1 187 0 3.5 \n",
  267. "2 41 0 1 130 204 0 0 172 0 1.4 \n",
  268. "3 56 1 1 120 236 0 1 178 0 0.8 \n",
  269. "4 57 0 0 120 354 0 1 163 1 0.6 \n",
  270. ".. ... ... .. ... ... ... ... ... ... ... \n",
  271. "298 57 0 0 140 241 0 1 123 1 0.2 \n",
  272. "299 45 1 3 110 264 0 1 132 0 1.2 \n",
  273. "300 68 1 0 144 193 1 1 141 0 3.4 \n",
  274. "301 57 1 0 130 131 0 1 115 1 1.2 \n",
  275. "302 57 0 1 130 236 0 0 174 0 0.0 \n",
  276. "\n",
  277. " slope ca thal target \n",
  278. "0 0 0 1 1 \n",
  279. "1 0 0 2 1 \n",
  280. "2 2 0 2 1 \n",
  281. "3 2 0 2 1 \n",
  282. "4 2 0 2 1 \n",
  283. ".. ... .. ... ... \n",
  284. "298 1 0 3 0 \n",
  285. "299 1 0 3 0 \n",
  286. "300 1 2 3 0 \n",
  287. "301 1 1 3 0 \n",
  288. "302 1 1 2 0 \n",
  289. "\n",
  290. "[303 rows x 14 columns]"
  291. ]
  292. },
  293. "execution_count": 82,
  294. "metadata": {},
  295. "output_type": "execute_result"
  296. }
  297. ],
  298. "source": [
  299. "# filename = \"heart.csv\"\n",
  300. "filename = \"https://www.physi.uni-heidelberg.de/~reygers/lectures/2021/ml/data/heart.csv\"\n",
  301. "df = pd.read_csv(filename)\n",
  302. "df"
  303. ]
  304. },
  305. {
  306. "cell_type": "code",
  307. "execution_count": 83,
  308. "metadata": {},
  309. "outputs": [],
  310. "source": [
  311. "y = df['target'].values\n",
  312. "X = df[[col for col in df.columns if col!=\"target\"]]"
  313. ]
  314. },
  315. {
  316. "cell_type": "code",
  317. "execution_count": 92,
  318. "metadata": {},
  319. "outputs": [],
  320. "source": [
  321. "from sklearn.model_selection import train_test_split\n",
  322. "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, shuffle=True, random_state=42)"
  323. ]
  324. },
  325. {
  326. "cell_type": "markdown",
  327. "metadata": {},
  328. "source": [
  329. "### Fit the model"
  330. ]
  331. },
  332. {
  333. "cell_type": "code",
  334. "execution_count": 85,
  335. "metadata": {},
  336. "outputs": [
  337. {
  338. "name": "stdout",
  339. "output_type": "stream",
  340. "text": [
  341. "CPU times: user 427 ms, sys: 14.1 ms, total: 441 ms\n",
  342. "Wall time: 587 ms\n"
  343. ]
  344. },
  345. {
  346. "data": {
  347. "text/plain": [
  348. "LogisticRegression(max_iter=5000, penalty='none', tol=1e-05)"
  349. ]
  350. },
  351. "execution_count": 85,
  352. "metadata": {},
  353. "output_type": "execute_result"
  354. }
  355. ],
  356. "source": [
  357. "from sklearn.linear_model import LogisticRegression\n",
  358. "lr = LogisticRegression(penalty='none', fit_intercept=True, max_iter=5000, tol=1E-5)\n",
  359. "%time lr.fit(X_train, y_train)"
  360. ]
  361. },
  362. {
  363. "cell_type": "markdown",
  364. "metadata": {},
  365. "source": [
  366. "### Test predictions on test data set"
  367. ]
  368. },
  369. {
  370. "cell_type": "code",
  371. "execution_count": 86,
  372. "metadata": {},
  373. "outputs": [
  374. {
  375. "name": "stdout",
  376. "output_type": "stream",
  377. "text": [
  378. " precision recall f1-score support\n",
  379. "\n",
  380. " 0 0.83 0.75 0.79 69\n",
  381. " 1 0.81 0.87 0.84 83\n",
  382. "\n",
  383. " accuracy 0.82 152\n",
  384. " macro avg 0.82 0.81 0.81 152\n",
  385. "weighted avg 0.82 0.82 0.81 152\n",
  386. "\n"
  387. ]
  388. }
  389. ],
  390. "source": [
  391. "from sklearn.metrics import classification_report\n",
  392. "y_pred_lr = lr.predict(X_test)\n",
  393. "print(classification_report(y_test, y_pred_lr))"
  394. ]
  395. },
  396. {
  397. "cell_type": "markdown",
  398. "metadata": {},
  399. "source": [
  400. "### Compare two classifiers using the ROC curve"
  401. ]
  402. },
  403. {
  404. "cell_type": "code",
  405. "execution_count": 87,
  406. "metadata": {},
  407. "outputs": [],
  408. "source": [
  409. "from sklearn.ensemble import RandomForestClassifier\n",
  410. "rf = RandomForestClassifier(max_depth=3)\n",
  411. "rf.fit(X_train, y_train)\n",
  412. "y_pred_rf = rf.predict(X_test)"
  413. ]
  414. },
  415. {
  416. "cell_type": "code",
  417. "execution_count": 88,
  418. "metadata": {},
  419. "outputs": [
  420. {
  421. "data": {
  422. "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYoAAAEOCAYAAACXX1DeAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy86wFpkAAAACXBIWXMAAAsTAAALEwEAmpwYAAAruElEQVR4nO3deXRUVfrv//fDmEREJIwGNMhgA9KiRMDhpyAiijY4IIptK+hVHFeriDi1DH7FAVG+9mpF/Iq0MygK9DUttop69YIQrzYCKqZpZJAZBW1meH5/nEpRqYRKpZJUZfi81qplzjm7znlOEfPU3vvsvc3dEREROZRaqQ5AREQqNyUKERGJSYlCRERiUqIQEZGYlChERCSmOqkOoLw1adLEs7OzUx2GiEiV8sUXX2x296bFHat2iSI7O5u8vLxUhyEiUqWY2Q+HOqamJxERiUmJQkREYlKiEBGRmJQoREQkJiUKERGJKWWJwsymmtlGM1tyiONmZk+ZWb6ZLTazk5Ido4iIpLZGMQ04N8bx84D2odf1wDNJiElERKKkbByFu39iZtkxigwEXvRgHvQFZtbIzFq6+7qKiGfB09dx+M/fVMSpq51f219Ej0tHpDoMEUmSytxHkQWsjtheE9pXhJldb2Z5Zpa3adOmpARXU7Xe8y8afP92qsMQkSSqFiOz3X0KMAUgJycnoZWYet70XLnGVF0tHX96qkMQkSSrzIliLdA6YrtVaJ+kWOs9/zpkwmjSoD7ND08r+0W6DIKcYWU/j4iUWWVuepoDXBV6+qknsK2i+ickfr+2v4jV9doWe2zHnv1s/nV32S+y/mv4+s2yn0dEykXKahRm9hrQC2hiZmuA0UBdAHefDOQC/YF8YAegr5eVQNCJXXxH9mXPzgdg+rBTynaRF84v2/tFpFyl8qmnISUcd+DmJIUjlc36r5OTMNTEJVKiytxHITVVl0HJuc76r4P/KlGIxKREIZVPzrDk/PFWE5dIXJQopFwtW7c93FdRWQzsmsUVPY4u/mCymriqMjXP1XhKFFJuBnYtdjxkSi1btx2g+ESRrCauqkzNc4IShZSjK3ocfehv7ikSs3aTrCauqky1LUGJQmqAytYcFrMpTKQSUqKQaq2yNYfFbAoTqaSUKKRaq2zNYZWpZiMSLyUKkSSrbE1hUMonw/QUVI2jRCGSRJWtKQxK+WSYnoKqkSyYKaP6yMnJ8by8vFSHIVJlXPbsfJat206nlg1LLPvAlpFk713ByrrHhveVasZg1UYqLTP7wt1zijumGoVIDVeaWs5n6b0LbRfMGBxXolBtpMpSohCp4UrX4V94ZuBSzRisMRlVlhKFiJRJrM55jRmpHpQoRCRhsZqtNGak+lCiEJGExWq2qmyPAEvilChEpMJENks9sGVb8MRUaL31Ep+W0hNSlYYShYhUiOhmqcgnpkp8WkpPSFUqShQiUiGKNksdfDKqxKel9IRUpaJEISIpEetpqRKbqdQslVRKFCKSdCUN8ovZTKVmqaRTohCRpCt5kF+MZio1SyWdEoWIVD2x1jpXs1S5U6IQkaol1lrnapaqEEoUIlK1xFrrXM1SFUKJQkQqvegnpEq10FI0NU2VmhKFiFRq0U9IlWqhpWhqmkqIEoWIVGrRT0jFnEMqVrMUqGkqQUoUIlLllGbd8SLNVJFNU2qGiosShYhUKaVZka9IM1Vk05SaoeKmNbNFpNqKtR74A1tG0v7ASuplnXBwZw2uYcRaM7tWsoOJZGbnmtl3ZpZvZncXc/xoM5tnZl+a2WIz65+KOEWkahrYNavYJAHw5p5T+L5W9sEd67+Gr99MTmBVTMpqFGZWG1gO9AXWAIuAIe6+LKLMFOBLd3/GzDoBue6eHeu8qlGISDzCU4MMj5oaZNg7KYootWLVKFLZR9EdyHf3FQBm9jowEFgWUcaBgq8DRwA/JjVCEanWohdWan9gJfU0NUgRqWx6ygJWR2yvCe2LNAa40szWALnArcWdyMyuN7M8M8vbtGlTRcQqItVMdLNUkaaoSDW8WaqyP/U0BJjm7hPN7BTgJTM73t0PRBZy9ynAFAianlIQp4hUMUXHZ8A4Li5+MaUaPv4ilTWKtUDriO1WoX2RrgVmALj7fCANaJKU6EREBEhtolgEtDezNmZWD7gcmBNVZhXQB8DMOhIkCrUtiYgkUcoShbvvA24B5gLfADPcfamZjTOzAaFiI4DrzOyfwGvAUK9uAz9ERCq5lPZRuHsuQSd15L4HIn5eBpyW7LhEpGYq1Sy1NUhl78wWEUmKUs1SW8MoUYiIUMpZamsYJQoRkUMoaIp6YMs2AMbV0GYpJQoRkWLEmqW2pjVLKVGIiBSjUFPUC0cAhAfj1bRmKSUKEZF4RCx49MCWbezYs5+l42uHD//a/iJ6XDoiVdFVqJROMy4iUiV0GQQtuoQ3mzSoT0a9g0mi9Z5/0eD7t1MRWVKoRiEiUpKotbibh14Flo4/PekhJZMShYhIOcjeu6LarsWtRCEiUkafpfcGoDNUy7W41UchIlJGH2T0Z1zmhGB1vIi+jOpCNQoRkXIQa3BetKo2WE+JQkSkjGINzotWFQfrKVGIiJRRrMF50ariYD31UYiISExKFCIiEpMShYiIxFSmPgozywAyAYs+5u6rynJuERGpHEqdKMysFnAXcCvQIkbR2jGOiYhIFZFIjeIR4E5gKTAT2FKuEYmIVHURM80CVX5Kj0QSxZXAu+7ev7yDERGp8roMKrxdDab0SCRRHAnMLu9ARESqhaiZZgvVLKqoRJ56+hpoWd6BiIhI5ZRIohgL3GBmrcs7GBERqXwSaXrqBvwALDOzt4F/A/ujyri7P1jW4EREJPUSSRRjIn6+8hBlHFCiEBGBIuttB+tXFD8XVGWUSKJoU+5RiIhUV1FPQWXvXZGiQBJX6kTh7j9URCAiItVS1FNQK6vg+tplncIjk4M1jH+7uwbfiYhUMwlNCmhmJ5jZx8BG4PPQa6OZfWRmvy3PAEVEJLVKnSjM7HjgU+BUgoF340Ov2cBpwP8xs85xnutcM/vOzPLN7O5DlBlsZsvMbKmZvVraeEVEpGwSaXoaB+wFTnP3xZEHQknkk1CZS2KdxMxqA38B+gJrgEVmNsfdl0WUaQ/cE7rWT2bWLIF4RUSkDBJpejoD+Et0kgBw9yXA08CZcZynO5Dv7ivcfQ/wOjAwqsx1oWv9FDr/xgTiFRGRMkgkURwGrI9xfF2oTEmygNUR22tC+yJ1ADqY2WdmtsDMzi3uRGZ2vZnlmVnepk2b4ri0iIjEK5FEsQK4IMbxC0JlykMdoD3QCxgCPGdmjaILufsUd89x95ymTZuW06VFRAQSSxQvAv3M7FUz62xmtUOv483sFeAcYFoc51kLRM4X1Sq0L9IaYI6773X3fwPLCRKHiIgkSSKJ4nHgDeByYDGwK/T6J8G3/jeAiXGcZxHQ3szamFm90PnmRJWZRVCbwMyaEDRFVb1hjSIiVVgiI7P3A5eZ2f8AF3JwwN0KYJa7vx/nefaZ2S3AXIJlU6e6+1IzGwfkufuc0LFzzGwZwcSDIzWoT0QkuRIeme3u/wD+UZaLu3sukBu174GInx24I/QSEZEUSGhktoiI1Bwl1ijM7AGCacMfcvcDoe2SaD0KEZFqIp6mpzEEieJRYA+F16M4FK1HISJSTcSTKNoAhEZPh7dFRCQx2XtXhBcyosugQtOQV0YlJoro9Se0HoWISOKC1e2gMwQr30GlTxTl1pltZk1Ck/iJiMghfJDRn3GZE2DYO9CiS6rDiUsi04xfZWZTovY9DGwAvg3Ny3R4eQUoIiKplUiNYjgRTVZmlgOMAv4P8BzBrLAa9yAiUk0kMuCuHcE0HQUuBbYC57j7HjNzYDAwthziExGRFEukRnEEsC1iuw/wfsRTUXnA0WUNTEREKodEEsV6QjO4mllToCtBs1OBBgTzMomISDWQSNPTh8DNZrYV6E0wuO6diOPHUXS6cBERqaISSRQPAKcCj4W2/8vdVwKYWR2CtbJnlkt0IiLV3fqvDw6+g0o5AC+RacbXmFlnoBOwzd1XRRzOAK4
  423. "text/plain": [
  424. "<Figure size 432x288 with 1 Axes>"
  425. ]
  426. },
  427. "metadata": {
  428. "needs_background": "light"
  429. },
  430. "output_type": "display_data"
  431. }
  432. ],
  433. "source": [
  434. "from sklearn.metrics import roc_curve\n",
  435. "\n",
  436. "y_pred_prob_lr = lr.predict_proba(X_test) # predicted probabilities\n",
  437. "fpr_lr, tpr_lr, _ = roc_curve(y_test, y_pred_prob_lr[:,1])\n",
  438. "\n",
  439. "y_pred_prob_rf = rf.predict_proba(X_test) # predicted probabilities\n",
  440. "fpr_rf, tpr_rf, _ = roc_curve(y_test, y_pred_prob_rf[:,1])\n",
  441. "\n",
  442. "plt.plot(tpr_lr, 1-fpr_lr, label=\"log. regression\")\n",
  443. "plt.plot(tpr_rf, 1-fpr_rf, label=\"random forest\")\n",
  444. "\n",
  445. "plt.xlabel('Recall', fontsize=18)\n",
  446. "plt.ylabel('Precision', fontsize=18);\n",
  447. "plt.legend(fontsize=15)\n",
  448. "\n",
  449. "plt.savefig(\"03_ml_basics_log_regr_heart_disease.pdf\")"
  450. ]
  451. },
  452. {
  453. "cell_type": "code",
  454. "execution_count": 89,
  455. "metadata": {},
  456. "outputs": [
  457. {
  458. "name": "stdout",
  459. "output_type": "stream",
  460. "text": [
  461. "Area under Curve (AUC) scores: 0.81, 0.81\n"
  462. ]
  463. }
  464. ],
  465. "source": [
  466. "from sklearn.metrics import roc_auc_score\n",
  467. "auc_lr = roc_auc_score(y_test,y_pred_lr)\n",
  468. "auc_rf = roc_auc_score(y_test,y_pred_rf)\n",
  469. "print(f\"Area under Curve (AUC) scores: {auc_lr:.2f}, {auc_rf:.2f}\")"
  470. ]
  471. },
  472. {
  473. "cell_type": "markdown",
  474. "metadata": {},
  475. "source": [
  476. "# Check wether data preprocessing makes a difference in this case"
  477. ]
  478. },
  479. {
  480. "cell_type": "code",
  481. "execution_count": 93,
  482. "metadata": {},
  483. "outputs": [
  484. {
  485. "name": "stdout",
  486. "output_type": "stream",
  487. "text": [
  488. "CPU times: user 22.4 ms, sys: 2.89 ms, total: 25.3 ms\n",
  489. "Wall time: 27.6 ms\n",
  490. " precision recall f1-score support\n",
  491. "\n",
  492. " 0 0.80 0.81 0.81 70\n",
  493. " 1 0.84 0.83 0.83 82\n",
  494. "\n",
  495. " accuracy 0.82 152\n",
  496. " macro avg 0.82 0.82 0.82 152\n",
  497. "weighted avg 0.82 0.82 0.82 152\n",
  498. "\n"
  499. ]
  500. }
  501. ],
  502. "source": [
  503. "from sklearn.pipeline import make_pipeline\n",
  504. "from sklearn.preprocessing import StandardScaler\n",
  505. "pipe = make_pipeline(StandardScaler(), LogisticRegression(penalty='none', fit_intercept=False, max_iter=5000, tol=1E-5))\n",
  506. "%time pipe.fit(X_train, y_train)\n",
  507. "y_pred_pipe = pipe.predict(X_test)\n",
  508. "print(classification_report(y_test, y_pred_pipe))"
  509. ]
  510. },
  511. {
  512. "cell_type": "code",
  513. "execution_count": 91,
  514. "metadata": {},
  515. "outputs": [
  516. {
  517. "name": "stdout",
  518. "output_type": "stream",
  519. "text": [
  520. " precision recall f1-score support\n",
  521. "\n",
  522. " 0 0.83 0.75 0.79 69\n",
  523. " 1 0.81 0.87 0.84 83\n",
  524. "\n",
  525. " accuracy 0.82 152\n",
  526. " macro avg 0.82 0.81 0.81 152\n",
  527. "weighted avg 0.82 0.82 0.81 152\n",
  528. "\n"
  529. ]
  530. }
  531. ],
  532. "source": [
  533. "print(classification_report(y_test, y_pred_lr))"
  534. ]
  535. },
  536. {
  537. "cell_type": "code",
  538. "execution_count": null,
  539. "metadata": {},
  540. "outputs": [],
  541. "source": []
  542. }
  543. ],
  544. "metadata": {
  545. "kernelspec": {
  546. "display_name": "Python 3",
  547. "language": "python",
  548. "name": "python3"
  549. },
  550. "language_info": {
  551. "codemirror_mode": {
  552. "name": "ipython",
  553. "version": 3
  554. },
  555. "file_extension": ".py",
  556. "mimetype": "text/x-python",
  557. "name": "python",
  558. "nbconvert_exporter": "python",
  559. "pygments_lexer": "ipython3",
  560. "version": "3.8.5"
  561. }
  562. },
  563. "nbformat": 4,
  564. "nbformat_minor": 4
  565. }