ML-Kurs-SS2023/notebooks/03_ml_basics_tf_binary_classification_example.ipynb

355 lines
255 KiB
Plaintext
Raw Permalink Normal View History

2023-03-31 17:34:03 +02:00
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
2023-03-31 17:34:03 +02:00
"id": "344b183c",
"metadata": {},
"outputs": [],
"source": [
"#\n",
"# train a simple TensorFlow model to perform binary classification on a generated\n",
"# 2-dimensional dataset \n",
"# 02/2023\n",
"# "
]
},
{
"cell_type": "code",
"execution_count": 2,
2023-03-31 17:34:03 +02:00
"id": "92c9d0a1",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2023-04-11 16:48:11.741526: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: SSE4.1 SSE4.2 AVX AVX2 AVX512F AVX512_VNNI FMA\n",
"To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
"2023-04-11 16:48:11.802317: I tensorflow/core/util/port.cc:104] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n"
]
}
],
2023-03-31 17:34:03 +02:00
"source": [
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import tensorflow as tf"
]
},
{
"cell_type": "code",
"execution_count": 3,
2023-03-31 17:34:03 +02:00
"id": "3814ea1d",
"metadata": {},
"outputs": [],
"source": [
"# Generate toy data\n",
"np.random.seed(4321)\n",
"n_samples = 1000"
]
},
{
"cell_type": "markdown",
"id": "84d8bc5e",
"metadata": {},
"source": [
"machine learning algorithms need data close to 1 , this generated here, it is not needed to normalize data"
]
},
{
"cell_type": "code",
"execution_count": 4,
2023-03-31 17:34:03 +02:00
"id": "a1d437df",
"metadata": {},
"outputs": [],
"source": [
"class1_data = np.random.multivariate_normal([-1., -1.], [[1., 0.], [0., 1.]], n_samples)\n",
"class2_data = np.random.multivariate_normal([1.0, 1.0], [[1., 0.], [0., 1.]], n_samples)"
]
},
{
"cell_type": "code",
"execution_count": 5,
2023-03-31 17:34:03 +02:00
"id": "6bbccf56",
"metadata": {},
"outputs": [],
"source": [
"# the data is merged together and the toy labels are asigned as [1, 0] and [0,1]\n",
"train_data = np.concatenate([class1_data, class2_data])\n",
"toy_labels = np.zeros(train_data.shape)\n",
"toy_labels[:n_samples, 0] = 1\n",
"toy_labels[n_samples:, 1] = 1"
]
},
{
"cell_type": "code",
"execution_count": 6,
2023-03-31 17:34:03 +02:00
"id": "f8dd0511",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAjUAAAHFCAYAAAAKbwgcAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/P9b71AAAACXBIWXMAAA9hAAAPYQGoP6dpAACVI0lEQVR4nO2deXwU5f3HP7MLIXcgm4CQDQQp9Va8KzaSKN4oElBB/CmepYISFKwiGlZRrCgJxapVq6hpwECCVGs9ggnSYi1W8UKtIGeInDaRw+Bunt8fw2z2mOOZa3d2832/Xs8r2dmZ53lmduZ5vvN9vofAGGMgCIIgCIJIcFzx7gBBEARBEIQVkFBDEARBEERSQEINQRAEQRBJAQk1BEEQBEEkBSTUEARBEASRFJBQQxAEQRBEUkBCDUEQBEEQSQEJNQRBEARBJAUk1BAEQRAEkRSQUGOShQsXQhAEfPTRR/HuSpBHHnkEr732mqk6Nm3aBEEQsHDhQt3Hrlu3DrNmzcKmTZtM9cEqZs2aBUEQwrY99dRTsufW1NQEQRCwdOnSGPVOmwkTJqCoqMjQsatXr8asWbPwv//9z9I+WYXcb+MEBEHArFmz4t2NmBJ5ztLYFo/n2K77Vq3eoqIijBgxwtL2zN5HZsZhO8Yyqc6mpibL6rQaEmqSECuEGjOsW7cOPp/PMULNzTffjA8++CBsm5JQ40Tuv/9+LFu2zNCxq1evhs/nc6xQQziXSy+9FB988AH69u0b87btum9j/Tx88MEHuPnmm2PSFiHSLd4dIAi78Xq98Hq98e6GYQYNGhTvLiQMBw4cQHp6ery74Uj0Xpv8/Hzk5+fb2KPk51e/+pXmPgcPHkRqaqojNZaJCGlqbGDChAnIzMzE+vXrcckllyAzMxOFhYW466670N7eHtxPUi0+9thjePjhh9G/f3+kpqbitNNOw4oVK6LqlFuCiFTfC4KA/fv346WXXoIgCBAEASUlJar93b59O6666ipkZWUhJycHV199Nb7//vuo/T766COMHTsWRUVFSEtLQ1FREcaNG4fNmzcH91m4cCGuvPJKAEBpaWmwD5JW5N1338XIkSPh9XqRmpqKX/ziF/jNb36D3bt3q/aRMYY+ffpg0qRJwW2BQAC9evWCy+XCjh07gtvnzZuHbt26Bd/GIq9RUVERvvzyS6xcuTLYv8hr+/PPP+O+++5Dv379kJ2djeHDh+Obb75R7WNoW5988gnKysqQnZ2NnJwcXHvttdi1a1fYvh0dHXjsscdw9NFHo0ePHujduzeuu+46bNu2LWw/ud9eEARMnjwZr7zyCo455hikp6fjpJNOwhtvvBHWl+nTpwMABg4cGDxXSXX83nvvoaSkBB6PB2lpaejfvz9Gjx6NAwcOaJ5nTU0NzjrrLGRmZiIzMxNDhgzBn//857B9XnjhBZx00klITU1Fbm4uRo0aha+++kqzbt7rUlJSguOPPx7vv/8+hg4divT0dNx4440AgLa2NkybNg0DBw5ESkoKCgoKUF5ejv3794fV0dbWhltuuQUejweZmZm46KKL8N///lezjwDw008/4a677sKQIUOQk5OD3NxcnHXWWVi+fLnsOS1YsABDhgxBWloaevbsiV/96lf461//ast1lcagzz//HBdccAGysrJw3nnn6TpnueUn6ZqvWbMGxcXFSE9Px5FHHolHH30UHR0dYcd/+eWXuOCCC5Ceno78/HxMmjQJf/vb3zSXL7TuW977Q2+9Em+99RZOOeUUpKWl4eijj8YLL7wQVdf333+P3/zmN/B6vUhJScHAgQPh8/ng9/vD9lNa0nvnnXdw4403Ij8/H+np6WHzghbr16/HDTfcgMGDByM9PR0FBQW47LLL8Pnnn8vu/9NPP+HOO+/EEUccgbS0NAwbNgyffPJJ1H4fffQRLr/8cuTm5iI1NRUnn3wyamtrNfvz3XffYezYsejXrx969OiBPn364LzzzsPatWu5z8lKSFNjEz///DMuv/xy3HTTTbjrrrvw/vvv46GHHkJOTg4eeOCBsH2ffPJJDBgwAFVVVcEH9uKLL8bKlStx1lln6Wr3gw8+wLnnnovS0lLcf//9AIDs7GzF/Q8ePIjhw4dj+/btmDNnDn75y1/ib3/7G66++uqofTdt2oSjjjoKY8eORW5uLlpaWvD000/j9NNPx7p165CXl4dLL70UjzzyCGbMmIE//vGPOOWUUwB0ahs2bNiAs846CzfffDNycnKwadMmzJs3D7/+9a/x+eefo3v37rL9FAQB5557LhoaGoLbPvroI/zvf/9DWloaVqxYgWuuuQYA0NDQgFNPPRU9e/aUrWvZsmUYM2YMcnJy8NRTTwEAevToEbbPjBkzcPbZZ+P5559HW1sbfve73+Gyyy7DV199BbfbrXg9JUaNGoWrrroKEydOxJdffon7778f69atw4cffhg8x9/+9rd49tlnMXnyZIwYMQKbNm3C/fffj6amJnz88cfIy8tTbeNvf/sb1qxZgwcffBCZmZl47LHHMGrUKHzzzTc48sgjcfPNN2Pv3r1YsGAB6uvrg8sIxx57LDZt2oRLL70UxcXFeOGFF9CzZ080NzfjrbfewqFDh1Tf6B944AE89NBDKCsrw1133YWcnBx88cUXYcLtnDlzMGPGDIwbNw5z5szBnj17MGvWLJx11llYs2YNBg8erFi/nuvS0tKCa6+9FnfffTceeeQRuFwuHDhwAMOGDcO2bdswY8YMnHjiifjyyy/xwAMP4PPPP0dDQwMEQQBjDFdccQVWr16NBx54AKeffjr++c9/4uKLL9b8fQGgvb0de/fuxbRp01BQUIBDhw6hoaEBZWVlePHFF3HdddcF950wYQKqq6tx00034cEHH0RKSgo+/vjjMIHB6ut66NAhXH755fjNb36De+65B36/3/Q5A+KEPn78eNx1112oqKjAsmXLcO+996Jfv37Bc25pacGwYcOQkZGBp59+Gr1798aiRYswefJkzfrV7lvA+HOjVS8AfPrpp7jrrrtwzz33oE+fPnj++edx00034Re/+AXOOeec4PmfccYZcLlceOCBBzBo0CB88MEHmD17NjZt2oQXX3xR8xxvvPFGXHrppXjllVewf/9+xXFPju3bt8Pj8eDRRx9Ffn4+9u7di5deeglnnnkmPvnkExx11FFh+8+YMQOnnHIKnn/+ebS2tmLWrFkoKSnBJ598giOPPBIA0NjYiIsuughnnnkmnnnmGeTk5GDx4sW4+uqrceDAAUyYMEGxP5dccgkCgQAee+wx9O/fH7t378bq1avjt+TNCFO8+OKLDABbs2ZNcNv111/PALDa2tqwfS+55BJ21FFHBT9v3LiRAWD9+vVjBw8eDG5va2tjubm5bPjw4WF1DhgwIKr9iooKFvkzZmRksOuvv56r/08//TQDwJYvXx62/ZZbbmEA2Isvvqh4rN/vZ/v27WMZGRls/vz5we1LlixhAFhjY6Nq2x0dHeznn39mmzdvlu1DJM8//zwDwLZs2cIYY2z27Nns6KOPZpdffjm74YYbGGOMHTp0iGVkZLAZM2YEj5O7RscddxwbNmxYVBuNjY0MALvkkkvCttfW1jIA7IMPPlDto9TW1KlTw7b/5S9/YQBYdXU1Y4yxr776igFgt912W9h+H374IQMQ1n+53x4A69OnD2trawtu+/7775nL5WJz5swJbps7dy4DwDZu3Bh2/NKlSxkAtnbtWtXzieS7775jbrebjR8/XnGfH374gaWlpUVdwy1btrAePXqwa665Jrgt8rfRc12GDRvGALAVK1aE7TtnzhzmcrnCnknGOs/5zTffZIwx9ve//50BCLt3GWPs4YcfZgBYRUWFypWIxu/3s59//pnddNNN7OSTTw5uf//99xkAdt999ykea/V1lcagF154IWxfPecsjW2h9450zT/88MOw44899lh24YUXBj9Pnz6dCYLAvvzyy7D9LrzwQq6xQem+1XN/6KmXMcYGDBjAUlNT2ebNm4PbDh48yHJzc9lvfvOb4Lbf/OY3LDMzM2w/xhh7/PHHGYCwc1a6ptd
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
2023-03-31 17:34:03 +02:00
"source": [
"# Plot the input data with points colored according to their labels\n",
"plt.scatter(class1_data[:, 0], class1_data[:, 1], color='red')\n",
"plt.scatter(class2_data[:, 0], class2_data[:, 1], color='blue')\n",
"plt.title(\"Input data with points colored according to their labels\")\n",
"plt.xlabel(\"Feature 1\")\n",
"plt.ylabel(\"Feature 2\")\n",
"plt.show()"
]
},
{
"cell_type": "markdown",
"id": "75fb1af9",
"metadata": {},
"source": [
"Build a model, the sequential model is a linear stack of pre-made layers. In a Dense layer all neueral network layer is connected with all other layers. Here we have 32 nodes with input_shape 2 with means 2 dimensional data. The activation is 'relu' (rectified linear unit)\n",
"Softmax maps the output of a model to probability distributions of the 2 classes. "
]
},
{
"cell_type": "code",
"execution_count": 7,
2023-03-31 17:34:03 +02:00
"id": "ad6dcef9",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2023-04-11 16:48:13.355862: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: SSE4.1 SSE4.2 AVX AVX2 AVX512F AVX512_VNNI FMA\n",
"To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n"
]
}
],
2023-03-31 17:34:03 +02:00
"source": [
"model = tf.keras.models.Sequential([\n",
" tf.keras.layers.Dense(32, activation='relu', input_shape=(2,)),\n",
" tf.keras.layers.Dense(32, activation='relu'),\n",
" tf.keras.layers.Dense(2, activation='softmax')\n",
"])"
]
},
{
"cell_type": "markdown",
"id": "a9c38493",
"metadata": {},
"source": [
"Adam optimizer is a gradient-based optimization algorithm for updating the weights in \n",
"a neural network. The leanrning rate depends on the first and second moments of the \n",
"gradients of the loss function with respect to the weights\n",
"The loss function is BinaryCrossentropy since we have 2 classes of data\n",
"The accuracy metric measures the percentage of instances where the model \n",
"correctly predicted the class label and it can be computed as the number of correct\n",
"predictions divided by the total number of instances in the test set."
]
},
{
"cell_type": "code",
"execution_count": 8,
2023-03-31 17:34:03 +02:00
"id": "7f6dc4ef",
"metadata": {},
"outputs": [],
"source": [
"model.compile(optimizer='adam',\n",
"# loss=tf.keras.losses.CategoricalCrossentropy(),\n",
" loss=tf.keras.losses.BinaryCrossentropy(),\n",
" metrics=['accuracy'])"
]
},
{
"cell_type": "markdown",
"id": "a6397c0f",
"metadata": {},
"source": [
"The object history contains loss and accuracy from the training process\n",
"The model is trained by dividing the entire training data into smaller batches\n",
"of a specified size, updating the model's parameters after each batch.\n",
"The batch_size parameter determines the number of samples to be used in each batch. "
]
},
{
"cell_type": "code",
"execution_count": 9,
2023-03-31 17:34:03 +02:00
"id": "b98e46ba",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Epoch 1/20\n",
"63/63 [==============================] - 1s 1ms/step - loss: 0.4911 - accuracy: 0.8340\n",
"Epoch 2/20\n",
"63/63 [==============================] - 0s 889us/step - loss: 0.2459 - accuracy: 0.9255\n",
"Epoch 3/20\n",
"63/63 [==============================] - 0s 815us/step - loss: 0.1932 - accuracy: 0.9270\n",
"Epoch 4/20\n",
"63/63 [==============================] - 0s 797us/step - loss: 0.1898 - accuracy: 0.9265\n",
"Epoch 5/20\n",
"63/63 [==============================] - 0s 797us/step - loss: 0.1895 - accuracy: 0.9275\n",
"Epoch 6/20\n",
"63/63 [==============================] - 0s 806us/step - loss: 0.1894 - accuracy: 0.9265\n",
"Epoch 7/20\n",
"63/63 [==============================] - 0s 839us/step - loss: 0.1890 - accuracy: 0.9275\n",
"Epoch 8/20\n",
"63/63 [==============================] - 0s 852us/step - loss: 0.1887 - accuracy: 0.9255\n",
"Epoch 9/20\n",
"63/63 [==============================] - 0s 793us/step - loss: 0.1885 - accuracy: 0.9265\n",
"Epoch 10/20\n",
"63/63 [==============================] - 0s 799us/step - loss: 0.1882 - accuracy: 0.9245\n",
"Epoch 11/20\n",
"63/63 [==============================] - 0s 846us/step - loss: 0.1885 - accuracy: 0.9255\n",
"Epoch 12/20\n",
"63/63 [==============================] - 0s 853us/step - loss: 0.1883 - accuracy: 0.9260\n",
"Epoch 13/20\n",
"63/63 [==============================] - 0s 816us/step - loss: 0.1890 - accuracy: 0.9225\n",
"Epoch 14/20\n",
"63/63 [==============================] - 0s 816us/step - loss: 0.1887 - accuracy: 0.9255\n",
"Epoch 15/20\n",
"63/63 [==============================] - 0s 834us/step - loss: 0.1877 - accuracy: 0.9255\n",
"Epoch 16/20\n",
"63/63 [==============================] - 0s 807us/step - loss: 0.1888 - accuracy: 0.9250\n",
"Epoch 17/20\n",
"63/63 [==============================] - 0s 811us/step - loss: 0.1881 - accuracy: 0.9255\n",
"Epoch 18/20\n",
"63/63 [==============================] - 0s 854us/step - loss: 0.1878 - accuracy: 0.9255\n",
"Epoch 19/20\n",
"63/63 [==============================] - 0s 865us/step - loss: 0.1879 - accuracy: 0.9260\n",
"Epoch 20/20\n",
"63/63 [==============================] - 0s 852us/step - loss: 0.1881 - accuracy: 0.9260\n"
]
}
],
2023-03-31 17:34:03 +02:00
"source": [
"history = model.fit(train_data, toy_labels, epochs=20, batch_size=32, verbose=1)"
]
},
{
"cell_type": "code",
"execution_count": 10,
2023-03-31 17:34:03 +02:00
"id": "806497d2",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAA94AAAGHCAYAAABGc4o9AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/P9b71AAAACXBIWXMAAA9hAAAPYQGoP6dpAABpzUlEQVR4nO3deVyVZf7/8fdZ4LAIiKKIiUiWS5lmWApmm4U5rTNTUU2UM/pNJ1vMZtGxpvQ7M7ZMjt8WLUszG8ec0ibnly1UlpraYtieWWqYgQixI9s59+8POAeOLHLwHM7h8Ho+HucB3Oe6b66LG7343J9rMRmGYQgAAAAAAPiE2d8VAAAAAAAgmBF4AwAAAADgQwTeAAAAAAD4EIE3AAAAAAA+ROANAAAAAIAPEXgDAAAAAOBDBN4AAAAAAPgQgTcAAAAAAD5E4A0AAAAAgA8ReANBZuXKlTKZTProo4/8XRUAAHCURx55RCaTSSNGjPB3VQB0IgJvAAAAoJOsWLFCkvTFF1/o/fff93NtAHQWAm8AAACgE3z00Uf65JNPdMkll0iSli9f7ucatayystLfVQCCDoE30A1t3bpVEydOVFRUlCIiIpSWlqZXXnnFrUxlZaV+97vfKTk5WWFhYerVq5fGjBmjNWvWuMrs3btX1157rfr37y+bzab4+HhNnDhRu3bt6uQWAQAQ+JyB9v3336+0tDQ9//zzzYLcgwcP6uabb1ZiYqJCQ0PVv39/XXXVVTp06JCrTHFxse666y6deOKJstls6tu3r372s5/p66+/liS98847MplMeuedd9yuvX//fplMJq1cudJ1bMqUKerRo4c+++wzpaenKyoqShMnTpQkZWVl6YorrtCAAQMUFhamk046SdOnT1dBQUGztn399de67rrrFB8fL5vNpoEDB+rGG29UdXW19u/fL6vVqoULFzY7b/PmzTKZTHrhhRc69DMFugqrvysAoHO9++67uuiiizRy5EgtX75cNptNS5Ys0WWXXaY1a9YoIyNDkjR79mw999xz+stf/qLRo0eroqJCn3/+uQoLC13X+tnPfia73a4HH3xQAwcOVEFBgbZt26bi4mI/tQ4AgMB05MgRrVmzRmeeeaZGjBih3/zmN5o2bZpeeOEF3XTTTZLqg+4zzzxTtbW1+tOf/qSRI0eqsLBQr7/+uoqKihQfH6+ysjKdffbZ2r9/v/74xz9q7NixKi8v1+bNm5Wbm6thw4Z5XLeamhpdfvnlmj59uubMmaO6ujpJ0nfffafU1FRNmzZNMTEx2r9/vxYtWqSzzz5bn332mUJCQiRJn3zyic4++2zFxcVpwYIFOvnkk5Wbm6sNGzaopqZGgwYN0uWXX64nnnhCf/jDH2SxWFzf+7HHHlP//v3185//3As/ZSCAGQCCyjPPPGNIMj788MMW3x83bpzRt29fo6yszHWsrq7OGDFihDFgwADD4XAYhmEYI0aMMK688spWv09BQYEhyVi8eLF3GwAAQBBatWqVIcl44oknDMMwjLKyMqNHjx7GhAkTXGV+85vfGCEhIcaXX37Z6nUWLFhgSDKysrJaLbNp0yZDkrFp0ya34/v27TMkGc8884zr2E033WRIMlasWNFm/R0Oh1FbW2t8//33hiTj5Zdfdr13wQUXGD179jTy8/OPWaeXXnrJdezgwYOG1Wo15s+f3+b3BoIBQ82BbqSiokLvv/++rrrqKvXo0cN13GKxKDMzUz/88IN2794tSTrrrLP06quvas6cOXrnnXd05MgRt2v16tVLgwcP1kMPPaRFixYpOztbDoejU9sDAEBXsXz5coWHh+vaa6+VJPXo0UNXX321tmzZoj179kiSXn31VZ1//vkaPnx4q9d59dVXNWTIEF144YVerd8vf/nLZsfy8/M1Y8YMJSYmymq1KiQkRElJSZKkr776SlL91LR3331X11xzjfr06dPq9c877zyNGjVKjz/+uOvYE088IZPJpJtvvtmrbQECEYE30I0UFRXJMAwlJCQ0e69///6S5BpK/sgjj+iPf/yj/vOf/+j8889Xr169dOWVV7r+ODCZTHrrrbc0adIkPfjggzrjjDPUp08f3X777SorK+u8RgEAEOC+/fZbbd68WZdccokMw1BxcbGKi4t11VVXSWpc6fzw4cMaMGBAm9dqTxlPRUREKDo62u2Yw+FQenq61q9frz/84Q9666239MEHH2jHjh2S5HogX1RUJLvd3q463X777Xrrrbe0e/du1dbW6qmnntJVV12lfv36ebU9QCAi8Aa6kdjYWJnNZuXm5jZ778cff5QkxcXFSZIiIyM1f/58ff3118rLy9PSpUu1Y8cOXXbZZa5zkpKStHz5cuXl5Wn37t268847tWTJEv3+97/vnAYBANAFrFixQoZh6MUXX1RsbKzr5Vzd/Nlnn5XdblefPn30ww8/tHmt9pQJCwuTJFVXV7sdb2lRNKn+YfrRPv/8c33yySd66KGHdNttt+m8887TmWeeqd69e7uV69WrlywWyzHrJEnXX3+9evfurccff1wvvPCC8vLyNHPmzGOeBwQDAm+gG4mMjNTYsWO1fv16t6HjDodD//znPzVgwAANGTKk2Xnx8fGaMmWKrrvuOu3evbvFbUaGDBmiu+++W6eddpo+/vhjn7YDAICuwm6369lnn9XgwYO1adOmZq+77rpLubm5evXVVzV58mRt2rTJNe2rJZMnT9Y333yjt99+u9UygwYNkiR9+umnbsc3bNjQ7no7g3GbzeZ2/Mknn3T7Ojw8XOeee65eeOGFVgN7p7CwMN1888169tlntWjRIp1++ukaP358u+sEdGWsag4Eqbffflv79+9vdnzhwoW66KKLdP755+t3v/udQkNDtWTJEn3++edas2aNq6MdO3asLr30Uo0cOVKxsbH66quv9Nxzzyk1NVURERH69NNPdeutt+rqq6/WySefrNDQUL399tv69NNPNWfOnE5uLQAAgenVV1/Vjz/+qAceeEDnnXdes/dHjBihxx57TMuXL9djjz2mV199Veecc47+9Kc/6bTTTlNxcbFee+01zZ49W8OGDdOsWbO0du1aXXHFFZozZ47OOussHTlyRO+++64uvfRSnX/++erXr58uvPBCLVy4ULGxsUpKStJbb72l9evXt7vew4YN0+DBgzVnzhwZhqFevXrpv//9r7KyspqVda50PnbsWM2ZM0cnnXSSDh06pA0bNujJJ59UVFSUq+wtt9yiBx98UDt37tTTTz/doZ8p0CX5d203AN7mXNW8tde+ffuMLVu2GBdccIERGRlphIeHG+PGjTP++9//ul1nzpw5xpgxY4zY2FjDZrMZJ554onHnnXcaBQUFhmEYxqFDh4wpU6YYw4YNMyIjI40ePXoYI0eONP7xj38YdXV1/mg6AAAB58orrzRCQ0PbXPH72muvNaxWq5GXl2ccOHDA+M1vfmP069fPCAkJMfr3729cc801xqFDh1zli4qKjDvuuMMYOHCgERISYvTt29e45JJLjK+//tpVJjc317jqqquMXr16GTExMcYNN9xgfPTRRy2uah4ZGdlivb788kvjoosuMqKioozY2Fjj6quvNnJycgxJxr333tus7NVXX2307t3bCA0NNQYOHGhMmTLFqKqqanbd8847z+jVq5dRWVnZzp8i0PWZDMMw/Bb1AwAAAOg28vPzlZSUpNtuu00PPvigv6sDdBqGmgMAAADwqR9++EF79+7VQw89JLPZrDvuuMPfVQI6FYurAQAAAPCpp59+Wuedd56++OILrV69WieccIK/qwR0KoaaAwAAAADgQ2S8AQAAAADwIQJvAAAAAAB8iMAbAAAAAAAfCppVzR0Oh3788UdFRUXJZDL5uzoAgG7OMAyVlZWpf//+Mpt5zu0N9PUAgEDT3v4+aALvH3/8UYmJif6uBgAAbg4cOKABAwb4uxpBgb4eABCojtXfB03gHRUVJam+wdHR0X6uDQCguystLVViYqKrf8Lxo68HAASa9vb3HQq8lyxZooceeki5ubk69dRTtXjxYk2YMKHFsu+8847OP//8Zse
"text/plain": [
"<Figure size 1200x400 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
2023-03-31 17:34:03 +02:00
"source": [
"# Plot loss and accuracy\n",
"plt.figure(figsize=(12, 4))\n",
"\n",
"plt.subplot(1, 2, 1)\n",
"plt.plot(history.history['loss'])\n",
"plt.title('Loss')\n",
"plt.xlabel('Epoch')\n",
"\n",
"plt.subplot(1, 2, 2)\n",
"plt.plot(history.history['accuracy'])\n",
"plt.title('Accuracy')\n",
"plt.xlabel('Epoch')\n",
"\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 11,
2023-03-31 17:34:03 +02:00
"id": "ca6da322",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"7772/7772 [==============================] - 4s 528us/step\n"
]
}
],
2023-03-31 17:34:03 +02:00
"source": [
"# Plot data points and decision boundary\n",
"x_min, x_max = train_data[:, 0].min() - .5, train_data[:, 0].max() + .5\n",
"y_min, y_max = train_data[:, 1].min() - .5, train_data[:, 1].max() + .5\n",
"xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.02), np.arange(y_min, y_max, 0.02))\n",
"# creating a 2D grid from the arrays xx and yy which is the area of our inputs \n",
"grid = np.c_[xx.ravel(), yy.ravel()]\n",
"# get the predicted class probabilities for each data point in the grid.\n",
"# The result Z is an array with shape (n_samples, n_classes) where n_samples\n",
"# is the number of data points in the grid and n_classes is the number of\n",
"# classes in the toy_labels. Z contains the predicted class probabilities\n",
"# for each data point in the grid.\n",
"Z = model.predict(grid)\n",
"# The line Z = np.argmax(Z, axis=1) is used to convert the predicted probabilities\n",
"# into class labels.\n",
"Z = np.argmax(Z, axis=1)\n",
"# reshaped Z variable is used to create the contour plot of the model's predictions \n",
"# on the grid.\n",
"Z = Z.reshape(xx.shape)\n"
]
},
{
"cell_type": "code",
"execution_count": 12,
2023-03-31 17:34:03 +02:00
"id": "54c02602",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAp4AAAKTCAYAAACw6AhNAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/P9b71AAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOz9eWxk+ZUn+n3vjX0lGdwZ3LdMMpnMfatKSdVSS+P2vBLGM7Dwym24B3jWlD1drRoMBLzXgt1At40SZPfDm64neNyT8GsbGIzGemOMp7r91KOskrJKWZXJZDJJZjK5LxEMxr7v+/35jyCjGMlYbpDB/XyAQlUxLiN+JIOME+d3fudwjDEGQgghhBBCDhl/3AsghBBCCCHnAwWehBBCCCHkSFDgSQghhBBCjgQFnoQQQggh5EhQ4EkIIYQQQo4EBZ6EEEIIIeRIUOBJCCGEEEKOhPS4F1COIAiw2WzQ6XTgOO64l0MIIYQQQt7AGEM4HEZHRwd4vnxO80QHnjabDV1dXce9DEIIIYQQUoHFYkFnZ2fZa0504KnT6QAA3/rn/3dIFapjXg0hhBBCCHlTJhnH5/+3/10+bivnRAeeO9vrUoUKUoX6mFdDCCGEEEJKEVMWSYeLCCGEEELIkaDAkxBCCCGEHIlTEXhuBeLHvQRCCCGEEHJApyLwzDIGiz923MsghBBCCCEHcCoCz//+R9+m4JMQQggh5JQ7FYGn/8Mf48GPv4csY8e9FEIIIYQQsk+nIvAEAN/7HwAAZT0JIYQQQk6pUxN4AshnPSn4JIQQQgg5fU5V4Ol7/4N88Gn2UfBJCCGEEHKanKrAE8gFn1N/8R0wUPBJCCGEEHKanLrAEwA+fe9DCj4JIYQQQk6ZUxl4AhR8EkIIIYScNqc28ARywefEH8Qp+CSEEEIIOQVOdeAJAI8+fpgPPk2+6HEvhxBCCCGElHDqA08gF3z+MvUJAFDwSQghhBByQp2JwBMANsyg4JMQQggh5AQ7M4EnkAs+f7aSCz6p5pMQQggh5GQ5U4Hnjgc//h4YaK47IYQQQshJciYDz7qPcnPdacudEEIIIeTkOJOBJ9V7EkIIIYScPGcy8AS+rve8/84YBZ+EEEIIISfAmQ08d7z74CMKPgkh5BAwxsAY1dMTQsSTHvcCjsK7Dz4CfvgTPH40BwnHoatBfdxLIoSQUysUDGJzYx0elxOCIECj1aKzuxftnZ3g+TOfzyCEHMC5+Qvx7oOP8ODH30OWMVj81GqJEEL2w+104PmTL+F2OiAIAgAgGolgaX4Or6Zf5D9GCCHFnJvAEwB873+QDz6pzychhFQnnU7j9ewMABTdYve6XbBubh7xqgghp8m5CjyBXPA59RffAQMFn4QQUg2HdatiRnPLbKK6T0JISecu8ASAT9/7EA9+/L3jXgYhhJwqoVCw4jXxeAxCNnsEqyGEnEbnMvDcQVlPQggRj+d4cBxX8TqODhgRQko4t38dduo9KfgkhBBxGptbKm6jNzQ20sl2QkhJ5/qvA9V7EkKIeE0tLVCp1WWznj19/Ue4IkLIaXOuA08gV+9JwSchhFTG8zyu3rwNhVJZ9PYLo2MwNDUf8aoIIafJuWggX8mn732IiR99F3d+pYLZF0OPgRrME0JIMSq1Gne/8S24nQ64nU4IQhZanR4dnV1QqlTHvTxCyAlHgee2Rx8/zAefJl8UvQbNcS+JEEJOJJ7n0dregdb2juNeCiHklDn3W+27Pfr4IX6Z+gQAaLY7IYQQQkiNUeD5hg0zKPgkhBBCCDkEFHgWsTv4pANHhBBCCCG1QYFnCRtm5E+7E0IIIYSQg6PAs4y1/+ZDALTlTgghhBBSCxR4lkH1noQQQgghtUOBZwUbZuBnKxR8EkIIIYQcFAWeIv1s5RPcf2eMgk9CCCGEkH2iwLMK7z74KB98Wvx02p0QQgghpBoUeFbp3Qcf4cGPv4csYxR8EkIIIYRUgQLPffC9/0E++KQ+n4QQQggh4lDguU++9z/I9/mk4JMQQgghpDIKPA/g0/c+pOCTEEIIIUQkCjwPaHfwSQghhBBCSqPAswYyT/4OAM11J4QQQggphwLPGnj08UPacieEEEIIqYACzxqhek9CCCGEkPIo8KyhT9/7EBN/EKfgkxBCCCGkCAo8a+zRxw/zwSeN1ySEEEII+RoFnodgJ/gEQMEnIYQQQsg2CjwPyaOPH+KXqU8AUPBJCCGEEAJQ4HmoNsyg4JMQQgghZBsFnodswwz8bOUT3H9njIJPQgghhJxrFHgekf9WNXncSyCEEEIIOVYUeB4R8396CIC23AkhhBByflHgeUR2ttwBCj4JIYQQcj5R4HnEqN6TEEIIIecVBZ7H4N0HH+WDT4ufJhwRQggh5HygwPOYvPvgIzz48feQZYyCT0IIIYScCxR4HiPf+x9Q8EkIIYSQc4MCz2O2O/g0+yj4JIQQQsjZdWSB509/+lNwHId/8S/+xVE95Knhe/8DTP3Fd8BAwSchhBBCzq4jCTwnJyfxb/7Nv8H4+PhRPNyp9Ol7H1LwSQghhJAz7dADz0gkgj/8wz/EgwcP0NDQcNgPd6p9+t6HmPiDOBio5pMQQkghxhi8bjc2N9Zh3TQjkYgf95IIqdqhB55//Md/jH/4D/8hfv/3f7/itclkEqFQqOCf8+bRxw9x/50xCOy4V0IIIdVhjCGdSiGTTh/3Us6cgN+Hrz7/LWanJrG6tIil+df46tFvsfDqJQQhe9zLI0Q06WHe+b//9/8eL168wOSkuDnlP/3pT/Hnf/7nh7mkU+Gvvt2KG4/mYPbF0GNQH/dyCCGkLMYYtswmWMwmJOK5LJxOX4ee/n60tLUf8+pOv0g4hJnJZxAEYc9tdusWstkMxq5eP4aVEVK9Q8t4WiwWfPjhh/i3//bfQqlUivqcP/3TP0UwGMz/Y7FYDmt5JxrVexIChIIBWEwmbJlNiEYix72cEyMSDsPlsMPrdiObPf5MF2MMczPTWFlcyAedABAOBTE3M42NtdVjXN3ZYFpbBWOlt8FcDgfC53CHkJxOh5bxnJqagsvlwo0bN/Ify2az+OKLL/Dzn/8cyWQSEomk4HMUCgUUCsVhLelU+fS9DzHxo+/izq9UlPkk50o8FsPczDTCoWDBxw1NTRgdvwq5XH5MKztekXAYi3MvEQp+/X2RSKXo7R9Ad18/OI47lnU57Ta4nY6St2+sLKO5pQVanf4IV3V2ZLNZuBxOAKUDT47j4LTboNPT95icfIeW8fzOd76DV69eYWZmJv/PzZs38Yd/+IeYmZnZE3SSvR59/DB/2Ihmu5PzIJ1K4cXEE4TDe7M3fq8XM5MTRbcbz7pYNIqpp08QChZ+X7KZDNaWl7C2vHRMKwO2zOayt3McB+s53b2qhWwmg3JB5450OnX4iyGkBg4t46nT6TA2NlbwMY1Gg8bGxj0fJ6U9+vghftkD/ED+fZh8UfQaNMe9JEIOjdWyiWQyWfQ2xtj2NrMDbR0dR7yy47WxuoyskEWpAGRzYx2d3T1QqlRHuzAA0Ui47O25nxttA++XVCYDz0vKHiBijEGpol0xcjrQ5KJTYMMM/DL1CQBQ5pOcaXbrVsVrHLbK15wl2UwGLocDKFPjBwAOm/WIVlSI5yvvXkkkh3qO9UzjeR4dnZ0VSynaO4xHtCJCDuZIA89Hjx7hX/2rf3WUD3lmbJiBn61Q8EnOtrSINjyp1PnaUkyn02UPlgC57exSmeLD1tLeVjEoamltO6LVnE09A4OQy+Ulv899g0PHku0mZD8o43nK/GzlE9x/Z4yCT3ImKZXlXzw5joNafb62FKUyWcXAjjF2bAczu3r6Sq+P46BQKtHaTi2VDkKhUODmvbfR1NK65+MXRsfQOzB4TCsjpHq0/3EKvfvgI+CHP8GTz1+jq+F8vQiTs83Y1Y2l+bmStzPG0N7ZdYQrOn5SqRTNrW1wOctvt7ce01arWqPBlRu38Gp6CplMJh+EMsagVKpw9eYtSKT0UnNQCqUSl69dRyqZRCwaBS+RQKfXH1s3A0L2i/4anGLZCttvhJw2bUYj7FZLQcug3ZpbW2F
"text/plain": [
"<Figure size 800x800 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
2023-03-31 17:34:03 +02:00
"source": [
"plt.figure(figsize=(8, 8))\n",
"plt.contourf(xx, yy, Z, cmap=plt.cm.RdBu, alpha=.8)\n",
"plt.scatter(train_data[:, 0], train_data[:, 1], c=np.argmax(toy_labels, axis=1), cmap=plt.cm.RdBu)\n",
"\n",
"plt.show()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.16"
}
},
"nbformat": 4,
"nbformat_minor": 5
}