|
|
{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Simple example of logistic regression with scikit-learn" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "import matplotlib.pyplot as plt" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Read data \n", "Data are from the [wikipedia article on logistic regression](https://en.wikipedia.org/wiki/Logistic_regression)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# data: 1. hours studies, 2. passed (0/1) \n", "filename = \"https://www.physi.uni-heidelberg.de/~reygers/lectures/2021/ml/data/exam.txt\"\n", "df = pd.read_csv(filename, engine='python', sep='\\s+')" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "x_tmp = df['hours_studied'].values\n", "x = np.reshape(x_tmp, (-1, 1))\n", "y = df['passed'].values" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Fit the model" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "from sklearn.linear_model import LogisticRegression\n", "clf = LogisticRegression(penalty='none', fit_intercept=True)\n", "clf.fit(x, y);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Calculate predictions" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "hours_studied_tmp = np.linspace(0., 6., 1000)\n", "hours_studied = np.reshape(hours_studied_tmp, (-1, 1))\n", "y_pred = clf.predict_proba(hours_studied)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Plot result" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYYAAAEKCAYAAAAW8vJGAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy86wFpkAAAACXBIWXMAAAsTAAALEwEAmpwYAAAzgUlEQVR4nO3dd3gVZfbA8e9JBxJ6qKELSBEpodjAumJZy2LvBVF/Yllsa9nVZYuCruu66+7asFJEwd57YQUJvfcWaoAQEkLqPb8/ZqI3MQmTcGtyPs9zn9w7d2beM5eQc+etoqoYY4wxZWLCHYAxxpjIYonBGGNMOZYYjDHGlGOJwRhjTDmWGIwxxpQTF+4ADlfLli21c+fO4Q7DGGOiyrx583arampl70V9YujcuTMZGRnhDsMYY6KKiGyq6j2rSjLGGFOOJQZjjDHlWGIwxhhTjiUGY4wx5VhiMMYYU07IEoOITBKRXSKytIr3RUSeEpG1IrJYRAaGKjYTHHvyClm0ZR978gqjuoxQlROqa6mp2sQVimNCFVdNReq/Y02EsrvqS8C/gFeqeP8MoLv7GAr8x/1potA7C7dy74zFxMfEUOzzMXFUP87p3z7qyghVOaG6llDEFYpjQhVXTUXqv2NNheyOQVW/BfZWs8u5wCvqmA00FZG2oYnOBNKevELunbGYgmIfuYUlFBT7uGfG4oB+gwpFGaEqJ1TXEoq4QnFMqOKqqUj9d6yNSGpjaA9s8Xud6W77BREZIyIZIpKRlZUVkuCMd5nZB4mPKf+rFR8TQ2b2wagqI1TlhOpaaqo2cYXimFDFVVOhKENVKSwpZV9+Edv2HSS3oDhg5/YXlSOfVfVZ4FmA9PR0W2kowqQ1a0Cxz1duW7HPR1qzBlFVRqjKCdW11FRt4grFMaGKq6YOVYaqkl9USnZ+Efvyi8nOLyLnYDF5BSXkFZaQ6/786XVhCXkFxT9tyy0sIb+olFLfz3/y/nJ+Xy4f2ilg11AmkhLDVqCD3+s0d5uJMi2SE5k4qh/3VKhrbZGcGFVlhKqcUF1LKOIKxTGhisuLwpJSsnIL2ZVbyK79hfy6XztmLsgkBqFUlbSmDbjsuTk/JYOiUl+152uUEEtyUhzJiXEkJ8WTkhhHq5Skn7Y1TIilUWIcDeJjaZgQS3rn5ocVf1UklEt7ikhn4H1V7VvJe2cBY4EzcRqdn1LVIYc6Z3p6utpcSZFpT14hmdkHSWvWIGh/5EJRRqjKCdW11FRt4grFMcEuo9Sn7NhfQObefLbuO0hm9kEys/PZtq+AXbkF7MotZF/+L6tyYgRSkuJp3iiBlskJNGvoPJo2inefx9PU3dakQTwpSXEkJ8XRKCGO2BjxdB2BICLzVDW90ve8JgYROR84CWhFhbYJVb3Iw/FTgROBlsBO4CEg3j3+vyIiOL2WRgL5wLWqesi/+JYYjDG1VVLqIzP7IOt357E+6wDrsg6wcfcBMvfls31fASW+8n8fW6Uk0rZpA1qnJNKqcSKtUpJoVeF5i+TEkP6Br63qEoOnqiQR+RtwKzAL5496aU2DUNVLD/G+ArfU9LzGGHMoqkpm9kGWbdvP8u37Wb0jl3VZeWzak1+ueqdpw3i6tGzEgA7N+HW/BqQ1a0haswakNWtAu6YNSIqPDeNVhI7XNoargQtV9Z1gBmOMMYfL51PWZuWxcPM+lm3LYcX2XFZs309uYQkAItClRSO6piZz8pGt6JaaTNdU53XzRglhjj4yeE0M+cDKYAZijDG1kZNfzIIt2czfvI8Fm7NZuHnfT0mgYUIsvdo25twB7ejdtgm92zWmZ+sUGiTUj2/+teU1MTwK3CMiN6pqSTADMsaY6hwoLOHHjXv539rdzFq7hxU79qPqNPr2aJ3Cr/u3Y0CHpgzo2IyuLRsREwX1/ZHGa2J4Dvg1sFVEVgPlmuJV9eRAB2aMMeBUDS3dlsMXK3Yxa+1uFm7ZR4lPSYiNYWCnptxxSg8Gd25Gvw5NSU6MpB740cvrp/hf4HjgY5zGZxtUZowJmoLiUmat3c3nK3bxxYqd7MotJEbgqLSm3DC8K8d1a8mgTs2sSihIvCaGi4HzVfWzYAZjjKm/CopL+XrVLt5dtI0vV+6ioNhHo4RYRvRM5ZQjW3PSka2scThEvCaGLGwUsjEmwEpKfcxat4d3F27j02U7yC0soWVyAhcMSuNXvdswtGtzEuPsriDUvCaGh4DxInKNquYFMyBjTN23eU8+r2ds5o2MTHblFpKSFMfIvm04p387junagrjYSJrfs/7xmhjuBjoDO0VkM79sfO4X4LiMMXVMUYmPT5fvYNqPW/h+7W5iBE7q2YoL0ztwYs/UejN4LBp4TQxvBjUKY0ydlX2giCk/bubl/21kV24h7Zs2YNxpPbgwPY22TcI7g6ypnKfEoKp/DHYgxpi6ZX1WHpNmbeDNeZkUFPs4oXtLJozqx/AeqVExl1B9Zp1+jTEBtXpnLk99sYYPlmwnPiaG8wa04/rju9KzTUq4QzMeeZ1ELwF4ALgU6Ig7K2oZVbXKQWPquVU7cnnqyzV8uGQ7DeNjuWlEN647rgupKZEzhbjxxusdw59wxjI8AvydnxujLwF+H5TIjDFRYcvefB77ZBXvLd5Gw/hYbh7RjdEndLUxB1HMa2K4CLhJVT8WkceBd1R1nYisAE4DnglahMaYiJSTX8zTX6/lpVkbiYmBm0d044YTutLMEkLU85oYWgPL3ed5QFP3+cfAhADHZIyJYMWlPl79YRNPfbmGnIPFXDAwjTt/1ZM2TZLCHZoJEK+JYTPQzv25FjgdmAccAxwMTmjGmEjz44a9PPj2ElbvzOP4I1py/5m96N2ucbjDMgHmNTG8BZwCzAb+AUwVkRuA9sBjQYrNGBMh9uQV8shHK3lzXibtmzbguavSObVXK5wVeU1d43Ucw31+z98UkS3AccBqVX0/WMEZY8JLVZkxfyt//mA5eQUl3DSiG7edcgQNE6yne11Wq39dVZ0DzAEQkXhVLT7EIcaYKLNrfwH3zVzCFyt3MbhzM/5y/lH0aG1jEeoDr+MY3gWuVdU9Fbb3AqYAA4IQmzEmDFSVdxdt4w/vLKOguJQHz+rFdcd1sZXQ6hGvUxg2A5aIyK/KNojIWJwG6MXBCMwYE3o5B4sZO2UBt09bSJeWjfjw9hMYfUJXSwr1jNeqpBHAg8B7IvIs0A2nR9L1qjo1WMEZY0JnweZsbp26gB05Bdx9ek9uHN7Vpr+up7w2Pvtw1mOIxRnpXAIMV9XZwQzOGBN8Pp/y/PfrmfjxKto0SWL6TccwsGOzcIdlwshrG0Mi8DhwA/AwcALO3cNoVX0neOEZY4JpX34Rv319IV+tyuKMvm14dFQ/mjSIP/SBpk7zWpU0z933eFXNABCRO4FpIvKaqt4QrACNMcGxakcuY17NYPu+Av50bh+uGNbJxiUYwHtimA3cpqr5ZRtU9W8i8jnwWlAiM8YEzcdLtzNu+iKSE+OYOmYYgzpZ1ZH5mdc2htFVbF8kIumBDckYEyw+n/L3z1fzzy/X0r9DU565chCtG9scR6Y8z10OROQMEXlfRJaLSAd322jg+KBFZ4wJmILiUm6ZMp9/frmWi9LTeP3GYZYUTKU8JQYRuRyYDqwBuvDzQj2xwD3BCc0YEyh7DxRx+fNz+GjpDh44sxcTRvUjMc7W1zKV83rHcA9wg6r+FqerapnZQP9AB2WMCZxNew4w6j//Y8nWHP59+UBuGN7VGplNtbwmhu7AD5VszwM8z7krIiNFZJWIrBWR31XyfkcR+UpEFojIYhE50+u5jTG/tHDLPn7z7/+RnV/ElNFDOfOotuEOyUQBr4lhG9Cjku3DgXVeTuAOjnsaOAPoDVwqIr0r7PYgMF1VB+AsG/pvj/EZYyr437rdXPbcbBomxjLz5mNJ79w83CGZKOE1MTwLPCUix7mvO4jI1cBE4D8ezzEEWKuq61W1CJgGnFthH+XnO5AmOAnJGFNDX67cybUvziWtWQNm3HQsXVOTwx2SiSJeu6tOFJEmwGdAEvAVUAg8rqpPeyyrPbDF73UmMLTCPg8Dn4r "text/plain": [ "<Figure size 432x288 with 1 Axes>" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "df.plot.scatter(x='hours_studied', y='passed')\n", "plt.plot(hours_studied, y_pred[:,1])\n", "plt.xlabel(\"preparation time in hours\", fontsize=14)\n", "plt.ylabel(\"probability of passing exam\", fontsize=14)\n", "plt.savefig(\"03_ml_basics_logistic_regression.pdf\")" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'C': 1.0,\n", " 'class_weight': None,\n", " 'dual': False,\n", " 'fit_intercept': True,\n", " 'intercept_scaling': 1,\n", " 'l1_ratio': None,\n", " 'max_iter': 100,\n", " 'multi_class': 'auto',\n", " 'n_jobs': None,\n", " 'penalty': 'none',\n", " 'random_state': None,\n", " 'solver': 'lbfgs',\n", " 'tol': 0.0001,\n", " 'verbose': 0,\n", " 'warm_start': False}" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "clf.get_params()" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Coefficient: [[1.50464522]]\n", "Intercept: [-4.07771764]\n" ] } ], "source": [ "print('Coefficient: ', clf.coef_)\n", "print('Intercept: ', clf.intercept_)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.5" } }, "nbformat": 4, "nbformat_minor": 4 }
|