ML-Kurs-SS2023/notebooks/03_ml_basics_logistic_regression.ipynb

196 lines
22 KiB
Plaintext
Raw Normal View History

2023-04-05 17:35:33 +02:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Simple example of logistic regression with scikit-learn"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import pandas as pd\n",
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Read data \n",
"Data are from the [wikipedia article on logistic regression](https://en.wikipedia.org/wiki/Logistic_regression)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# data: 1. hours studies, 2. passed (0/1) \n",
"filename = \"https://www.physi.uni-heidelberg.de/~reygers/lectures/2021/ml/data/exam.txt\"\n",
"df = pd.read_csv(filename, engine='python', sep='\\s+')"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"x_tmp = df['hours_studied'].values\n",
"x = np.reshape(x_tmp, (-1, 1))\n",
"y = df['passed'].values"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Fit the model"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"from sklearn.linear_model import LogisticRegression\n",
"clf = LogisticRegression(penalty='none', fit_intercept=True)\n",
"clf.fit(x, y);"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Calculate predictions"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"hours_studied_tmp = np.linspace(0., 6., 1000)\n",
"hours_studied = np.reshape(hours_studied_tmp, (-1, 1))\n",
"y_pred = clf.predict_proba(hours_studied)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Plot result"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYYAAAEKCAYAAAAW8vJGAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy86wFpkAAAACXBIWXMAAAsTAAALEwEAmpwYAAAzgUlEQVR4nO3dd3gVZfbA8e9JBxJ6qKELSBEpodjAumJZy2LvBVF/Yllsa9nVZYuCruu66+7asFJEwd57YQUJvfcWaoAQEkLqPb8/ZqI3MQmTcGtyPs9zn9w7d2beM5eQc+etoqoYY4wxZWLCHYAxxpjIYonBGGNMOZYYjDHGlGOJwRhjTDmWGIwxxpQTF+4ADlfLli21c+fO4Q7DGGOiyrx583arampl70V9YujcuTMZGRnhDsMYY6KKiGyq6j2rSjLGGFOOJQZjjDHlWGIwxhhTjiUGY4wx5VhiMMYYU07IEoOITBKRXSKytIr3RUSeEpG1IrJYRAaGKjYTHHvyClm0ZR978gqjuoxQlROqa6mp2sQVimNCFVdNReq/Y02EsrvqS8C/gFeqeP8MoLv7GAr8x/1potA7C7dy74zFxMfEUOzzMXFUP87p3z7qyghVOaG6llDEFYpjQhVXTUXqv2NNheyOQVW/BfZWs8u5wCvqmA00FZG2oYnOBNKevELunbGYgmIfuYUlFBT7uGfG4oB+gwpFGaEqJ1TXEoq4QnFMqOKqqUj9d6yNSGpjaA9s8Xud6W77BREZIyIZIpKRlZUVkuCMd5nZB4mPKf+rFR8TQ2b2wagqI1TlhOpaaqo2cYXimFDFVVOhKENVKSwpZV9+Edv2HSS3oDhg5/YXlSOfVfVZ4FmA9PR0W2kowqQ1a0Cxz1duW7HPR1qzBlFVRqjKCdW11FRt4grFMaGKq6YOVYaqkl9USnZ+Efvyi8nOLyLnYDF5BSXkFZaQ6/786XVhCXkFxT9tyy0sIb+olFLfz3/y/nJ+Xy4f2ilg11AmkhLDVqCD3+s0d5uJMi2SE5k4qh/3VKhrbZGcGFVlhKqcUF1LKOIKxTGhisuLwpJSsnIL2ZVbyK79hfy6XztmLsgkBqFUlbSmDbjsuTk/JYOiUl+152uUEEtyUhzJiXEkJ8WTkhhHq5Skn7Y1TIilUWIcDeJjaZgQS3rn5ocVf1UklEt7ikhn4H1V7VvJe2cBY4EzcRqdn1LVIYc6Z3p6utpcSZFpT14hmdkHSWvWIGh/5EJRRqjKCdW11FRt4grFMcEuo9Sn7NhfQObefLbuO0hm9kEys/PZtq+AXbkF7MotZF/+L6tyYgRSkuJp3iiBlskJNGvoPJo2inefx9PU3dakQTwpSXEkJ8XRKCGO2BjxdB2BICLzVDW90ve8JgYROR84CWhFhbYJVb3Iw/FTgROBlsBO4CEg3j3+vyIiOL2WRgL5wLWqesi/+JYYjDG1VVLqIzP7IOt357E+6wDrsg6wcfcBMvfls31fASW+8n8fW6Uk0rZpA1qnJNKqcSKtUpJoVeF5i+TEkP6Br63qEoOnqiQR+RtwKzAL5496aU2DUNVLD/G+ArfU9LzGGHMoqkpm9kGWbdvP8u37Wb0jl3VZeWzak1+ueqdpw3i6tGzEgA7N+HW/BqQ1a0haswakNWtAu6YNSIqPDeNVhI7XNoargQtV9Z1gBmOMMYfL51PWZuWxcPM+lm3LYcX2XFZs309uYQkAItClRSO6piZz8pGt6JaaTNdU53XzRglhjj4yeE0M+cDKYAZijDG1kZNfzIIt2czfvI8Fm7NZuHnfT0mgYUIsvdo25twB7ejdtgm92zWmZ+sUGiTUj2/+teU1MTwK3CMiN6pqSTADMsaY6hwoLOHHjXv539rdzFq7hxU79qPqNPr2aJ3Cr/u3Y0CHpgzo2IyuLRsREwX1/ZHGa2J4Dvg1sFVEVgPlmuJV9eRAB2aMMeBUDS3dlsMXK3Yxa+1uFm7ZR4lPSYiNYWCnptxxSg8Gd25Gvw5NSU6MpB740cvrp/hf4HjgY5zGZxtUZowJmoLiUmat3c3nK3bxxYqd7MotJEbgqLSm3DC8K8d1a8mgTs2sSihIvCaGi4HzVfWzYAZjjKm/CopL+XrVLt5dtI0vV+6ioNhHo4RYRvRM5ZQjW3PSka2scThEvCaGLGwUsjEmwEpKfcxat4d3F27j02U7yC0soWVyAhcMSuNXvdswtGtzEuPsriDUvCaGh4DxInKNquYFMyBjTN23eU8+r2ds5o2MTHblFpKSFMfIvm04p387junagrjYSJrfs/7xmhjuBjoDO0VkM79sfO4X4LiMMXVMUYmPT5fvYNqPW/h+7W5iBE7q2YoL0ztwYs/UejN4LBp4TQxvBjUKY0ydlX2giCk/bubl/21kV24h7Zs2YNxpPbgwPY22TcI7g6ypnKfEoKp/DHYgxpi6ZX1WHpNmbeDNeZkUFPs4oXtLJozqx/AeqVExl1B9Zp1+jTEBtXpnLk99sYYPlmwnPiaG8wa04/rju9KzTUq4QzMeeZ1ELwF4ALgU6Ig7K2oZVbXKQWPquVU7cnnqyzV8uGQ7DeNjuWlEN647rgupKZEzhbjxxusdw59wxjI8AvydnxujLwF+H5TIjDFRYcvefB77ZBXvLd5Gw/hYbh7RjdEndLUxB1HMa2K4CLhJVT8WkceBd1R1nYisAE4DnglahMaYiJSTX8zTX6/lpVkbiYmBm0d044YTutLMEkLU85oYWgPL3ed5QFP3+cfAhADHZIyJYMWlPl79YRNPfbmGnIPFXDAwjTt/1ZM2TZLCHZoJEK+JYTPQzv25FjgdmAccAxwMTmjGmEjz44a9PPj2ElbvzOP4I1py/5m96N2ucbjDMgHmNTG8BZwCzAb+AUwVkRuA9sBjQYrNGBMh9uQV8shHK3lzXibtmzbguavSObVXK5wVeU1d43Ucw31+z98UkS3AccBqVX0/WMEZY8JLVZkxfyt//mA5eQUl3DSiG7edcgQNE6yne11Wq39dVZ0DzAEQkXhVLT7EIcaYKLNrfwH3zVzCFyt3MbhzM/5y/lH0aG1jEeoDr+MY3gWuVdU9Fbb3AqYAA4IQmzEmDFSVdxdt4w/vLKOguJQHz+rFdcd1sZXQ6hGvUxg2A5aIyK/KNojIWJwG6MXBCMwYE3o5B4sZO2UBt09bSJeWjfjw9hMYfUJXSwr1jNeqpBHAg8B7IvIs0A2nR9L1qjo1WMEZY0JnweZsbp26gB05Bdx9ek9uHN7Vpr+up7w2Pvtw1mOIxRnpXAIMV9XZwQzOGBN8Pp/y/PfrmfjxKto0SWL6TccwsGOzcIdlwshrG0Mi8DhwA/AwcALO3cNoVX0neOEZY4JpX34Rv319IV+tyuKMvm14dFQ/mjSIP/SBpk7zWpU0z933eFXNABCRO4FpIvKaqt4QrACNMcGxakcuY17NYPu+Av50bh+uGNbJxiUYwHtimA3cpqr5ZRtU9W8i8jnwWlAiM8YEzcdLtzNu+iKSE+OYOmYYgzpZ1ZH5mdc2htFVbF8kIumBDckYEyw+n/L3z1fzzy/X0r9DU565chCtG9scR6Y8z10OROQMEXlfRJaLSAd322jg+KBFZ4wJmILiUm6ZMp9/frmWi9LTeP3GYZYUTKU8JQYRuRyYDqwBuvDzQj2xwD3BCc0YEyh7DxRx+fNz+GjpDh44sxcTRvUjMc7W1zKV83rHcA9wg6r+FqerapnZQP9AB2WMCZxNew4w6j//Y8nWHP59+UBuGN7VGplNtbwmhu7AD5VszwM8z7krIiNFZJWIrBWR31XyfkcR+UpEFojIYhE50+u5jTG/tHDLPn7z7/+RnV/ElNFDOfOotuEOyUQBr4lhG9Cjku3DgXVeTuAOjnsaOAPoDVwqIr0r7PYgMF1VB+AsG/pvj/EZYyr437rdXPbcbBomxjLz5mNJ79w83CGZKOE1MTwLPCUix7mvO4jI1cBE4D8ezzEEWKuq61W1CJgGnFthH+XnO5AmOAnJGFNDX67cybUvziWtWQNm3HQsXVOTwx2SiSJeu6tOFJEmwGdAEvAVUAg8rqpPeyyrPbDF73UmMLTCPg8Dn4r
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"df.plot.scatter(x='hours_studied', y='passed')\n",
"plt.plot(hours_studied, y_pred[:,1])\n",
"plt.xlabel(\"preparation time in hours\", fontsize=14)\n",
"plt.ylabel(\"probability of passing exam\", fontsize=14)\n",
"plt.savefig(\"03_ml_basics_logistic_regression.pdf\")"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'C': 1.0,\n",
" 'class_weight': None,\n",
" 'dual': False,\n",
" 'fit_intercept': True,\n",
" 'intercept_scaling': 1,\n",
" 'l1_ratio': None,\n",
" 'max_iter': 100,\n",
" 'multi_class': 'auto',\n",
" 'n_jobs': None,\n",
" 'penalty': 'none',\n",
" 'random_state': None,\n",
" 'solver': 'lbfgs',\n",
" 'tol': 0.0001,\n",
" 'verbose': 0,\n",
" 'warm_start': False}"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"clf.get_params()"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Coefficient: [[1.50464522]]\n",
"Intercept: [-4.07771764]\n"
]
}
],
"source": [
"print('Coefficient: ', clf.coef_)\n",
"print('Intercept: ', clf.intercept_)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
}
},
"nbformat": 4,
"nbformat_minor": 4
}