|
|
{ "cells": [ { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# A simple neural network with one hidden layer in pure Python\n", "\n", "## Introduction\n", "We consider a simple feed-forward neural network with one hidden layer:" ] }, { "attachments": { "48b1ed6e-8e2b-4883-82ac-a2bbed6e2885.png": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAASwAAAEsCAYAAAB5fY51AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy89olMNAAAACXBIWXMAAA9hAAAPYQGoP6dpAABEKklEQVR4nO2ddVxU6ffHD62EgYWBhd2K2F2767prrmLHioqBrWv3rh1rdxeIioGN3WAuusbaIqlISc39/P7gN8/XkYY7DHc879drXvta5855zjzc+5mnzjkGAEAMwzAKwFDXDjAMw6QVFiyGYRQDCxbDMIqBBYthGMXAgsUwjGJgwWIYRjGwYDEMoxhYsBiGUQwsWAzDKAYWLIZhFAMLFsMwioEFi2EYxcCCxTCMYmDBYhhGMbBgMQyjGFiwGIZRDCxYDMMoBhYshmEUAwsWwzCKgQWLYRjFwILFMIxiYMFiGEYxsGAxDKMYWLAYhlEMLFgMwygGFiyGYRQDCxbDMIqBBYthGMXAgsUwjGJgwWIYRjGwYDEMoxhYsBiGUQwsWAzDKAYWLIZhFAMLFsMwioEFi2EYxcCCxTCMYmDBYhhGMbBgMQyjGFiwGIZRDCxYDMMoBhYshmEUAwsWwzCKgQWLYRjFwILFMIxiMNa1A0zqxMTEkL+/P8XHx5OVlRUVKFCADAwMdO2WXsN9nj3hEVY25dGjRzRq1CiqWbMmWVpaUsmSJalMmTJUqFAhKly4MLVr14527NhB0dHRunZVb+A+z/4YAICunWD+x7Nnz2jYsGF05swZKliwILVr147s7e2pTJkyZGxsTB8/fqR79+7RlStX6OLFi2RtbU3Tp0+nESNGkKEh//5kBO5zBQEm27B69WrkzJkTpUuXxp49exATE5Pi9U+fPsWgQYNARGjUqBHevn2bRZ7qD9znyoIFK5swbdo0EBGGDh2KiIiIdH32woULsLW1RfHixfHixQsteah/cJ8rDxasbMD69etBRFiwYEGGbbx9+xZ2dnYoW7YswsPDZfROP+E+VyYsWDrmv//+g7m5OQYPHpxpW8+ePYO5uTmGDh0qg2f6C/e5cmHB0jHt2rVDyZIlZfuFXrlyJYgIt2/flsWePsJ9rlxYsHTIs2fPQETYsmWLbDbj4+NRokQJ9OnTRzab+gT3ubLhPVkdsmXLFsqbNy85OjrKZtPIyIicnZ1p//79FBYWJptdfYH7XNmwYOmQq1evUqtWrShnzpyy2m3Xrh3FxMSQj4+PrHb1Ae5zZcOCpSMkSaI7d+5Q7dq1ZbddoUIFMjc354fnG7jPlQ8Llo6IiIigiIgIKlGihOy2jYyMyNbWlj58+CC7bSXDfa58WLB0BP4/IkpbAbWGhoaiDSYB7nPlw4KlIywsLMjU1JQCAgJktw2AAgICyNraWnbbSob7XPmwYOkIY2Njqlq1qlbWPF69ekUfP36kmjVrym5byXCfKx8WLB1St25dOn/+PKlUKlntenl5kYGBgVYWl5UO97myYcHSIX369KE3b97QiRMnZLMJgP7++28qV66c7Fv3+oC2+nzNmjX0448/UqFChWSzyySG82HpEADk4OBAxsbGdPXqVTIyMsq0zePHj1O7du2IiChv3rw0ZswYcnFxoVy5cmXatj6gzT4/duwY/fzzzzJ4ySSLbg7YM2quXLkCAwMDLFq0KNO2Pn36hCJFiqB69eooV64ciAhEhDx58mDWrFkIDQ2VwWPlo40+//HHHyFJkgzeMSnBgpUNGDt2LExMTHD8+PEM24iKikLLli2RO3duvH37FvHx8dizZw8qVKigIVwzZ87Ep0+f5HNeoWijzxntw4KVDYiNjUX79u1hYmKCzZs3p/uX+t27d2jcuDFy5syJixcvarwXHx+PvXv3olKlSkK4cufOjenTp+Pjx49yfg1FIUefN2nSJMk+Z7QHC1Y2ITY2FgMGDAARoW3btvj3339T/Ux0dDTWr18PS0tL5MiRAxcuXEj2WpVKhf3796Ny5cpCuHLlyoVp06YhJCREzq+iGDLT57lz54aNjQ2uXLmSBZ4yaliwshlHjhxBkSJFQERo2bIlVq1ahevXryM4OBifP3/Gixcv4O7ujjFjxiBPnjwgIpiZmYGIMHfu3FTtq1QquLm5oUqVKkK4rKysMGXKFAQHB8v+fVQqFaKiohAbGyu7bblIT59bW1uDiFCqVKnvVuh1CQtWNuTLly/YuXMnGjduDBMTEyEsX7+MjY3FKGnz5s0gIpiYmOD+/ftpakOlUuHAgQOoVq2asGlpaYlJkyYhKCgow76rVCqcOXMGTk5OqFmzpob/NjY2aNu2LRYsWICAgIAMt6EN0tLnNjY26NGjB4gIBgYG8PX11bXb3x18rCGbExMTQw8fPqQ3b96QSqUiKysrCg0Npe7du5ORkRGpVCratm0bHTp0iDw8PKhmzZp08+ZNMjExSZN9SZLIw8ODZs2aRffv3yciIktLSxo+fDiNHTuW8ufPn2Zf3dzcaMqUKfTs2TMqV64cNWrUiGrUqEHW1tYUHx9Pz549Ix8fH7pw4QKpVCrq0aMHLVy4kAoWLJihvtEWSfV5tWrVqEiRIkRE1KlTJzp06BA5OjrS3r17deztd4auFZNJP9HR0TA3Nxe//A0aNMCHDx/EdGXWrFnptqlSqXDo0CHUrFlT2LWwsMCECRMQGBiY4mdDQkLQpUsXEBF++eUXXLp0KcVF7JCQECxZsgT58uVD/vz54e7unm5/dcn9+/fFKOuff/7RtTvfFSxYCuWXX34RDw0R4eHDh9i7d6+YLt69ezdDdiVJgoeHB2rVqiWEy9zcHOPGjUtyGhcQEIAqVarA2toarq6u6WorICAAHTt2BBFh9erVGfJXV3Tu3BlEhK5du+rale8KFiyFsmbNGhAR8uXLByLC8OHDIUmSeJCqVauWalHQlJAkCUePHoW9vb2GcI0dOxb+/v4AEkZ69vb2KFSoEB49epThdkaOHAkigpubW4b9zWoePHggfjAePnyoa3e+G1iwFMrLly9BRDA0NBRnqyIjIxEQEID8+fODiDBt2rRMtyNJEo4dOwYHBwchXDlz5sTo0aPh4uICExMT3LlzJ9NtdOnSBdbW1vjw4UOmfc4qfvvtNxARunTpomtXvhtYsBSM+jBowYIFNSrBuLq6gohgZGQEb29vWdqSJAmenp6oW7euEC4DAwPMmTNHFvtBQUEoWLAgunfvLou9rODhw4diSp7W3Vkmc7BgKZixY8eCiMR6U926dcV7Xbt2BRGhSpUqiI6Olq1NSZJw8uRJFCxYEAUKFMjUtPNbVq5cCSMjI7x79042m9pG3c+dOnXStSvfBSxYCubcuXMgIuTPnx9GRkYgIty7dw/A/0YsRITJkyfL2m5ERAQsLCwwY8YMWe1+/vwZFhYWmD17tqx2tYmvr68YZan7ntEenA9LwTRq1IgsLS0pODiYmjdvTkRE69evJyKi/Pnz09q1a4mIaP78+XT79m3Z2r179y5FRkZShw4dZLNJRJQrVy5q1aoVXbp0SVa72qRSpUrUrVs3IiKaNWuWjr3Rf1iwFIypqSm1atWKiIiKFStGRES7du2iiIgIIko44Ni9e3eSJIn69u1L0dHRsrTr7e1NZmZmVLlyZVnsfY29vT15e3srqpjD9OnTycDAgA4dOkT37t3TtTt6DQuWwmnbti0RET158oTKlClD4eHhtG/fPvH+ypUrqVChQvT48WOaOXOmLG2+ffuWSpQokebT9OmhTJkyFBoaSpGRkbLb1hYVK1ak7t27ExHJ1sdM0rBgKZyffvqJiIhu3LhBPXv2JKL/TQuJiPLlyyf+f9GiRXTjxo1Mt6lSqcjY2DjTdpJCnQFU7pzr2mbatGlkaGhIHh4edOfOHV27o7ewYCmcYsWKUdWqVQkA2djYkKmpKXl7e2s8NO3bt6fevXuTJEnUr18/+vLlS6bazJMnDwUHB2fW9SQJCQkhY2NjMjc314p9bVGhQgUeZWUBLFh6gHpaePXqVerUqRMRaY6yiIhWrFhBhQsXpidPntC0adMy1V716tUpMDCQ/Pz8MmUnKe7c } }, "cell_type": "markdown", "metadata": {}, "source": [ "![nn.png](attachment:48b1ed6e-8e2b-4883-82ac-a2bbed6e2885.png)" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "In this example the input vector of the neural network has two features, i.e., the input is a two-dimensional vector:\n", "\n", "$$\n", "\\mathbf x = (x_0, x_1).\n", "$$\n", "\n", "We consider a set of $n$ vectors as training data. The training data can therefore be written as a $n \\times 2$ matrix where each row represents a feature vector:\n", "\n", "$$ \n", "X = \n", "\\begin{pmatrix}\n", "x_{00} & x_{01} \\\\\n", "x_{10} & x_{11} \\\\\n", "\\vdots & \\vdots \\\\\n", "x_{m-1\\,0} & x_{m-1\\,1} \n", "\\end{pmatrix} $$\n", "\n", "The known labels (1 = 'signal', 0 = 'background') are stored in a $n$-dimensional column vector $\\mathbf y$.\n", "\n", "In the following, $n_1$ denotes the number of neurons in the hidden layer. The weights for the connections from the input layer (layer 0) to the hidden layer (layer 0) are given by the following matrix:\n", "\n", "$$\n", "W^{(1)} = \n", "\\begin{pmatrix}\n", "w_{00}^{(1)} \\dots w_{0 \\, n_1-1}^{(1)} \\\\\n", "w_{10}^{(1)} \\dots w_{1 \\, n_1-1}^{(1)} \n", "\\end{pmatrix}\n", "$$\n", "\n", "Each neuron in the hidden layer is assigned a bias $\\mathbf b^{(1)} = (b^{(1)}_0, \\ldots, b^{(1)}_{n_1-1})$. The neuron in the output layer has the bias $\\mathbf b^{(2)}$. With that, the output values of the network for the matrix $X$ of input feature vectors is given by\n", "\n", "$$\n", "\\begin{align}\n", "Z^{(1)} &= X W^{(1)} + \\mathbf b^{(1)} \\\\\n", "A^{(1)} &= \\sigma(Z^{(1)}) \\\\\n", "Z^{(2)} &= A^{(1)} W^{(2)} + \\mathbf b^{(2)} \\\\\n", "A^{(2)} &= \\sigma(Z^{(2)})\n", "\\end{align}\n", "$$\n", "\n", "The loss function for a given set of weights is given by\n", "\n", "$$ L = \\sum_{i=0}^{n-1} (y_\\mathrm{pred} - y_\\mathrm{true})^2 $$\n", "\n", "We can know calculate the gradient of the loss function w.r.t. the wights. With the definition $\\hat L = (y_\\mathrm{pred} - y_\\mathrm{true})^2$, the gradients for the weights from the output layer to the hidden layer are given by: \n", "\n", "$$ \\frac{\\partial \\tilde L}{\\partial w_i^{(2)}} = \\frac{\\partial \\tilde L}{a_k^{(2)}} \\frac{a_k^{(2)}}{\\partial w_i^{(2)}} = \\frac{\\partial \\tilde L}{a_k^{(2)}} \\frac{a_k^{(2)}}{z_k^{(2)}} \\frac{z_k^{(2)}}{\\partial w_i^{(2)}} = 2 (a_k^{(2)} - y_k) a_k^{(2)} (1 - a_k^{(2)}) a_{k,i}^{(1)}$$\n", "\n", "Applying the chain rule further, we also obtain the gradient for the weights from the input layer to the hidden layer read: \n", "\n", "$$ \\frac{\\partial \\tilde L}{\\partial w_{ij}^{(1)}} = \\frac{\\partial \\tilde L}{\\partial a_k^{(2)}} \\frac{\\partial a_k^{(2)}}{\\partial z_k^{(2)}} \\frac{\\partial z_k^{(2)}}{\\partial a_{k,j}^{(1)}} \\frac{\\partial a_{k,j}^{(1)}}{\\partial z_{k,j}^{(1)}} \\frac{\\partial z_{k,j}^{(1)}}{\\partial w_{ij}^{(1)}} $$" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## A simple neural network class" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# A simple feed-forward neutral network with on hidden layer\n", "# see also https://towardsdatascience.com/how-to-build-your-own-neural-network-from-scratch-in-python-68998a08e4f6\n", "\n", "import numpy as np\n", "\n", "class NeuralNetwork:\n", " def __init__(self, x, y):\n", " n1 = 3 # number of neurons in the hidden layer\n", " self.input = x\n", " self.weights1 = np.random.rand(self.input.shape[1],n1)\n", " self.bias1 = np.random.rand(n1)\n", " self.weights2 = np.random.rand(n1,1)\n", " self.bias2 = np.random.rand(1) \n", " self.y = y\n", " self.output = np.zeros(y.shape)\n", " self.learning_rate = 0.01\n", " self.n_train = 0\n", " self.loss_history = []\n", "\n", " def sigmoid(self, x):\n", " return 1/(1+np.exp(-x))\n", "\n", " def sigmoid_derivative(self, x):\n", " return x * (1 - x)\n", "\n", " def feedforward(self):\n", " self.layer1 = self.sigmoid(self.input @ self.weights1 + self.bias1)\n", " self.output = self.sigmoid(self.layer1 @ self.weights2 + self.bias2)\n", "\n", " def backprop(self):\n", "\n", " # delta1: [m, 1], m = number of training data\n", " delta1 = 2 * (self.y - self.output) * self.sigmoid_derivative(self.output)\n", "\n", " # Gradient w.r.t. weights from hidden to output layer: [n1, 1] matrix, n1 = # neurons in hidden layer\n", " d_weights2 = self.layer1.T @ delta1\n", " d_bias2 = np.sum(delta1) \n", " \n", " # shape of delta2: [m, n1], m = number of training data, n1 = # neurons in hidden layer\n", " delta2 = (delta1 @ self.weights2.T) * self.sigmoid_derivative(self.layer1)\n", " d_weights1 = self.input.T @ delta2\n", " d_bias1 = np.ones(delta2.shape[0]) @ delta2 \n", " \n", " # update weights and biases\n", " self.weights1 += self.learning_rate * d_weights1\n", " self.weights2 += self.learning_rate * d_weights2\n", "\n", " self.bias1 += self.learning_rate * d_bias1\n", " self.bias2 += self.learning_rate * d_bias2\n", "\n", " def train(self, X, y):\n", " self.output = np.zeros(y.shape)\n", " self.input = X\n", " self.y = y\n", " self.feedforward()\n", " self.backprop()\n", " self.n_train += 1\n", " if (self.n_train %1000 == 0):\n", " loss = np.sum((self.y - self.output)**2)\n", " print(\"loss: \", loss)\n", " self.loss_history.append(loss)\n", " \n", " def predict(self, X):\n", " self.output = np.zeros(y.shape)\n", " self.input = X\n", " self.feedforward()\n", " return self.output\n", " \n", " def loss_history(self):\n", " return self.loss_history\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Create toy data\n", "We create three toy data sets\n", "1. two moon-like distributions\n", "2. circles\n", "3. linearly separable data sets" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# https://scikit-learn.org/stable/auto_examples/classification/plot_classifier_comparison.html#sphx-glr-auto-examples-classification-plot-classifier-comparison-py\n", "import numpy as np\n", "from sklearn.datasets import make_moons, make_circles, make_classification\n", "from sklearn.model_selection import train_test_split\n", "\n", "X, y = make_classification(\n", " n_features=2, n_redundant=0, n_informative=2, random_state=1, n_clusters_per_class=1\n", ")\n", "rng = np.random.RandomState(2)\n", "X += 2 * rng.uniform(size=X.shape)\n", "linearly_separable = (X, y)\n", "\n", "datasets = [\n", " make_moons(n_samples=200, noise=0.1, random_state=0),\n", " make_circles(n_samples=200, noise=0.1, factor=0.5, random_state=1),\n", " linearly_separable,\n", "]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Create training and test data set" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# datasets: 0 = moons, 1 = circles, 2 = linearly separable\n", "X, y = datasets[1]\n", "X_train, X_test, y_train, y_test = train_test_split(\n", " X, y, test_size=0.4, random_state=42\n", ")\n", "\n", "x_min, x_max = X[:, 0].min() - 0.5, X[:, 0].max() + 0.5\n", "y_min, y_max = X[:, 1].min() - 0.5, X[:, 1].max() + 0.5\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Train the model" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "loss: 28.591431249971087\n", "loss: 19.174944855091578\n", "loss: 18.300519116661075\n", "loss: 5.44035901972833\n", "loss: 2.2654992441410906\n", "loss: 1.6923656607186892\n", "loss: 1.3715971480249087\n", "loss: 1.1473150221090382\n", "loss: 0.9774346378363713\n", "loss: 0.8457117685917934\n", "loss: 0.7429652120737472\n", "loss: 0.6621808985042399\n", "loss: 0.5977165926831687\n", "loss: 0.545283043346378\n", "loss: 0.5017902977940301\n", "loss: 0.46506515287723293\n", "loss: 0.4335772706016494\n", "loss: 0.40623169342909965\n", "loss: 0.3822273847227754\n", "loss: 0.36096446182458697\n", "loss: 0.3419836665195889\n", "loss: 0.3249263905044797\n", "loss: 0.3095077414631703\n", "loss: 0.29549797484687557\n", "loss: 0.282709394404349\n", "loss: 0.27098690712728085\n", "loss: 0.2602010759266338\n", "loss: 0.2502429170283057\n", "loss: 0.24101994107129043\n", "loss: 0.23245309736167535\n", "loss: 0.2244743850815736\n", "loss: 0.2170249645242441\n", "loss: 0.21005364833790718\n", "loss: 0.2035156851277511\n", "loss: 0.19737177048767093\n", "loss: 0.19158723674048994\n", "loss: 0.18613138439559326\n", "loss: 0.1809769269368725\n", "loss: 0.17609952694233805\n", "loss: 0.17147740633361158\n", "loss: 0.16709101719249247\n", "loss: 0.16292276236867193\n", "loss: 0.15895675725575403\n", "loss: 0.15517862578972155\n", "loss: 0.1515753250400271\n", "loss: 0.1481349938036084\n", "loss: 0.1448468214395504\n", "loss: 0.1417009338445398\n", "loss: 0.13868829400248622\n", "loss: 0.13580061497353096\n", "loss: 0.1330302835389617\n", "loss: 0.1303702930059422\n", "loss: 0.127814183912036\n", "loss: 0.12535599156436183\n", "loss: 0.12299019950967911\n", "loss: 0.1207116981660866\n", "loss: 0.11851574795923153\n", "loss: 0.11639794640004714\n", "loss: 0.11435419862018448\n", "loss: 0.11238069094815246\n", "loss: 0.11047386716576169\n", "loss: 0.10863040713256283\n", "loss: 0.10684720750694056\n", "loss: 0.1051213643275356\n", "loss: 0.10345015724866172\n", "loss: 0.10183103524916717\n", "loss: 0.10026160365640727\n", "loss: 0.09873961234613571\n", "loss: 0.09726294499576621\n", "loss: 0.09582960928281026\n", "loss: 0.09443772793285743\n", "loss: 0.09308553053235864\n", "loss: 0.0917713460310074\n", "loss: 0.09049359586685884\n", "loss: 0.08925078765463015\n", "loss: 0.08804150938406197\n", "loss: 0.08686442408087547\n", "loss: 0.08571826488782217\n", "loss: 0.08460183052776402\n", "loss: 0.08351398111459095\n", "loss: 0.08245363428124586\n", "loss: 0.08141976159718958\n", "loss: 0.08041138525037475\n", "loss: 0.07942757497120295\n", "loss: 0.07846744517812901\n", "loss: 0.07753015232648876\n", "loss: 0.0766148924438704\n", "loss: 0.0757208988368859\n", "loss: 0.07484743995559193\n", "loss: 0.07399381740306453\n", "loss: 0.07315936407873515\n", "loss: 0.07234344244512234\n", "loss: 0.07154544290849069\n", "loss: 0.07076478230479659\n", "loss: 0.07000090248301835\n", "loss: 0.06925326897863726\n", "loss: 0.06852136977065329\n", "loss: 0.06780471411604974\n", "loss: 0.06710283145614956\n", "loss: 0.06641527038972668\n" ] } ], "source": [ "y_train = y_train.reshape(-1, 1)\n", "\n", "nn = NeuralNetwork(X_train, y_train)\n", "\n", "for i in range(100000):\n", " nn.train(X_train, y_train)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Plot the loss vs. the number of epochs" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Text(0, 0.5, 'loss')" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAjIAAAG2CAYAAABlBWwKAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy89olMNAAAACXBIWXMAAA9hAAAPYQGoP6dpAAA36ElEQVR4nO3de3hU1b3/8c9Mkpkk5ka4JCAJIlCughQwRqgHjygi3vHUWk4P9qD81KBiesRi1Za2Gi+nlbZS+9QewT6KtipiQYtFEPASQAMRUUQDKKAkKEgCgVxn/f5IZicDAUJmz96Z8H49zzyZ2XvtPd9sq/l07bXX8hhjjAAAAKKQ1+0CAAAA2oogAwAAohZBBgAARC2CDAAAiFoEGQAAELUIMgAAIGoRZAAAQNQiyAAAgKhFkAEAAFGLIAMAAKKWq0HmiSee0NChQ5WSkqKUlBTl5ubqn//8p7W/qqpKeXl56ty5s5KSkjRp0iSVlZW5WDEAAGhPPG6utbR48WLFxMSoX79+Msbo6aef1qOPPqoNGzZo8ODBuuWWW/Tqq69q/vz5Sk1N1fTp0+X1evXOO++4VTIAAGhHXA0yLUlPT9ejjz6qa6+9Vl27dtWCBQt07bXXSpI++eQTDRw4UIWFhTr33HNdrhQAALgt1u0Cgurr6/XCCy+osrJSubm5KioqUm1trcaNG2e1GTBggLKzs48bZKqrq1VdXW19DgQC2rdvnzp37iyPxxPx3wMAAITPGKMDBw6oR48e8nqPPRLG9SDz4YcfKjc3V1VVVUpKStLLL7+sQYMGqbi4WD6fT2lpaSHtMzIyVFpaeszzFRQUaPbs2RGuGgAAOGHnzp3q2bPnMfe7HmT69++v4uJilZeX68UXX9SUKVO0atWqNp9v1qxZys/Ptz6Xl5crOztbO3fuVEpKih0lAwCACKuoqFBWVpaSk5OP2871IOPz+dS3b19J0ogRI/Tee+/pd7/7na677jrV1NRo//79Ib0yZWVlyszMPOb5/H6//H7/UduDT0YBAIDocaJhIe1uHplAIKDq6mqNGDFCcXFxWr58ubVvy5Yt2rFjh3Jzc12sEAAAtBeu9sjMmjVLEyZMUHZ2tg4cOKAFCxZo5cqVev3115WamqqpU6cqPz9f6enpSklJ0W233abc3FyeWAIAAJJcDjJ79uzRf/3Xf2n37t1KTU3V0KFD9frrr+uiiy6SJD322GPyer2aNGmSqqurNX78eP3xj390s2QAANCOtLt5ZOxWUVGh1NRUlZeXM0YGAIAo0dq/3+1ujAwAAEBrEWQAAEDUIsgAAICoRZABAABRiyADAACiFkEGAABELYIMAACIWgQZAAAQtVxfNDJaVVTVqvxQrZLjY5WW6HO7HAAATkn0yLTRrxZ/rO898qaeXbvD7VIAADhlEWTaKNEXI0mqqq13uRIAAE5dBJk2im8MModqCDIAALiFINNGiXENw4sO0yMDAIBrCDJtlOBruHSH6ZEBAMA1BJk2SvA19sgQZAAAcA1Bpo0S4hrHyHBrCQAA1xBk2sh6aokeGQAAXEOQaaOmHpk6lysBAODURZBpo4TGHhnGyAAA4B6CTBsFe2QIMgAAuIcg00bBMTLMIwMAgHsIMm0UH8fMvgAAuI0g00bBHpnquoACAeNyNQAAnJoIMm0UHOwrcXsJAAC3EGTaKD6WIAMAgNsIMm3k9XoUH8d6SwAAuIkgE4ZEHytgAwDgJoJMGBJ4cgkAAFcRZMLA7L4AALiLIBOGpknxWG8JAAA3EGTCEG8tUxBwuRIAAE5NBJkwBHtkDtXQIwMAgBsIMmGwFo7kqSUAAFxBkAkDg30BAHAXQSYMPH4NAIC7CDJhCI6RqeLWEgAAriDIhIEeGQAA3EWQCUMCSxQAAOAqgkwYElg0EgAAVxFkwsCikQAAuIsgE4Z4JsQDAMBVBJkwJFoT4rFEAQAAbiDIhKFpQjx6ZAAAcANBJgxWkGGMDAAAriDIhMFaa4mnlgAAcAVBJgyJrLUEAICrCDJhsGb2ra2XMcblagAAOPUQZMIQHCNjjFRdx5NLAAA4zdUgU1BQoFGjRik5OVndunXTVVddpS1btoS0GTt2rDweT8jr5ptvdqniUMEeGYnbSwAAuMHVILNq1Srl5eVpzZo1WrZsmWpra3XxxRersrIypN1NN92k3bt3W69HHnnEpYpDxcZ45YtpXKaAJ5cAAHBcrJtfvnTp0pDP8+fPV7du3VRUVKTzzz/f2p6YmKjMzEyny2uV+DivauoDrIANAIAL2tUYmfLycklSenp6yPZnn31WXbp00ZAhQzRr1iwdOnTomOeorq5WRUVFyCuSgustVdEjAwCA41ztkWkuEAhoxowZGj16tIYMGWJt/+EPf6hevXqpR48e2rhxo+6++25t2bJFCxcubPE8BQUFmj17tlNlWwN+6ZEBAMB57SbI5OXladOmTXr77bdDtk+bNs16f9ZZZ6l79+668MILtXXrVvXp0+eo88yaNUv5+fnW54qKCmVlZUWsbmtSPHpkAABwXLsIMtOnT9eSJUu0evVq9ezZ87htc3JyJEklJSUtBhm/3y+/3x+ROlvCeksAALjH1SBjjNFtt92ml19+WStXrlTv3r1PeExxcbEkqXv37hGurnUSWW8JAADXuBpk8vLytGDBAr3yyitKTk5WaWmpJCk1NVUJCQnaunWrFixYoEsvvVSdO3fWxo0bdeedd+r888/X0KFD3SzdEh/HGBkAANziapB54oknJDVMetfcvHnzdMMNN8jn8+mNN97QnDlzVFlZqaysLE2aNEn33nuvC9W2jPWWAABwj+u3lo4nKytLq1atcqiatmEFbAAA3NOu5pGJRgmMkQEAwDUEmTAlMEYGAADXEGTCFBwjw8y+AAA4jyATJp5aAgDAPQSZMAXXWmKMDAAAziPIhCnB13AJeWoJAADnEWTClBBHjwwAAG4hyIQpkdWvAQBwDUEmTAk8tQQAgGsIMmFqmkeG1a8BAHAaQSZMCdxaAgDANQSZMDEhHgAA7iHIhCl4a6m23qi2PuByNQAAnFoIMmEK3lqSeAQbAACnEWTC5IvxyutpeM+keAAAOIsgEyaPx9O0TAFBBgAARxFkbMDCkQAAuIMgY4Pgk0uMkQEAwFkEGRsEn1zi1hIAAM4iyNgggR4ZAABcQZCxAcsUAADgDoKMDZjdFwAAdxBkbBDPeksAALiCIGODxDjGyAAA4AaCjA2swb70yAAA4CiCjA0IMgAAuIMgYwPrqSVuLQEA4CiCjA2sp5bokQEAwFEEGRsksNYSAACuIMjYICG4+jW3lgAAcBRBxgastQQAgDsIMjZg9WsAANxBkLFBPGstAQDgCoKMDZrWWgq4XAkAAKcWgowNEnz0yAAA4AaCjA0SWGsJAABXEGRskNDs1lIgYFyuBgCAUwdBxgbBMTKSVFVHrwwAAE4hyNggPrYpyDC7LwAAziHI2MDr9Sg+ruFSMikeAADOIcjYhAG/AAA4jyBjk8Tgekv0yAAA4BiCjE2Ct5YYIwMAgHMIMjYJ9shUcWsJAADHEGRskmCtt0SQAQDAKQQZmySwAjYAAI4jyNgkOCneYdZbAgDAMQQZm3BrCQAA57kaZAoKCjRq1CglJyerW7duuuqqq7Rly5aQNlVVVcrLy1Pnzp2VlJSkSZMmqayszKWKj41bSwAAOM/VILNq1Srl5eVpzZo1WrZsmWpra3XxxRersrLSanPnnXdq8eLFeuGFF7Rq1Sp99dVXuuaaa1ysumXWhHj0yAAA4JhYN7986dKlIZ/nz5+vbt26qaioSOeff77Ky8v1f//3f1qwYIH+/d//XZI0b948DRw4UGvWrNG5557rRtktSqRHBgAAx7WrMTLl5eWSpPT0dElSUVGRamtrNW7cOKvNgAEDlJ2drcLCwhbPUV1drYqKipCXE+J9jJEBAMBp7SbIBAIBzZgxQ6NHj9aQIUMkSaWlpfL5fEpLSwtpm5GRodLS0hbPU1BQoNTUVOuVlZUV6dIlSYmstQQAgOPaTZDJy8vTpk2b9Pzzz4d1nlmzZqm8vNx67dy506YKj88a7Eu "text/plain": [ "<Figure size 640x480 with 1 Axes>" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import matplotlib.pyplot as plt\n", "plt.plot(nn.loss_history)\n", "plt.xlabel(\"# epochs / 1000\")\n", "plt.ylabel(\"loss\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "<matplotlib.colorbar.Colorbar at 0x12fe75c30>" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAssAAAJMCAYAAAAMveu7AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy89olMNAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdd3gUZdfH8e/MpvfeE5JQQu+9CdLFgr2jqPhaHxUfe8HesLcHOzYEFRQUpEhHmvROgPSQ3nvZmfePkJCQXbLpCZzPda2SmdmZOyEkv733zLkVXdd1hBBCCCGEELWorT0AIYQQQggh2ioJy0IIIYQQQpghYVkIIYQQQggzJCwLIYQQQghhhoRlIYQQQgghzJCwLIQQQgghhBkSloUQQgghhDBDwrIQQgghhBBmSFgWQgghhBDCDAnLQgghhBBCmCFhWQghhBBCtKqNGzdy2WWXERAQgKIo/P7773U+Z/369fTv3x9bW1s6derEvHnzaux/4YUXUBSlxqNr1671HpuEZSGEEEII0aoKCgro06cPn3zyiUXHR0dHM3XqVMaOHcvevXt5+OGHueuuu1i5cmWN43r06EFSUlLVY/PmzfUem1W9nyGEEEIIIUQTmjJlClOmTLH4+Llz5xIWFsY777wDQLdu3di8eTPvvfcekyZNqjrOysoKPz+/Ro3tvAvLmqZx6tQpnJ2dURSltYcjhBBCiAuYruvk5eUREBCAqrb+G/rFxcWUlpa2yLV0Xa+VxWxtbbG1tW30ubdu3cr48eNrbJs0aRIPP/xwjW3Hjx8nICAAOzs7hg0bxuuvv05ISEi9rnXeheVTp04RHBzc2sMQQgghhKgSHx9PUFBQq46huLiYwMBQMjNTWuR6Tk5O5Ofn19g2e/ZsXnjhhUafOzk5GV9f3xrbfH19yc3NpaioCHt7e4YMGcK8efOIiIggKSmJF198kVGjRnHw4EGcnZ0tvtZ5F5YrP/l933yJs4NDK49GCCGEEBeyvMJC+sy4q17hrLmUlpaSmZnCD4sO4eDYvOMpLMjjlqt7EB8fj4uLS9X2pphVtlT1so7evXszZMgQOnTowM8//8ydd95p8XnOu7BcOd3v7OAgYVkIIYQQbUJbKg11cHTG0dGl7gObgIuLS42w3FT8/PxISak5Q56SkoKLiwv29vYmn+Pm5kaXLl04ceJEva7V+sUzQgghhBBC1MOwYcNYs2ZNjW2rV69m2LBhZp+Tn5/PyZMn8ff3r9e1JCwLIYQQQohWlZ+fz969e9m7dy9Q0Rpu7969xMXFAfDUU08xffr0quPvueceoqKiePzxxzl69CiffvopP//8M4888kjVMf/973/ZsGEDMTExbNmyhSuvvBKDwcCNN95Yr7Gdd2UYQgghhBCifdm5cydjx46t+njWrFkA3HbbbcybN4+kpKSq4AwQFhbGsmXLeOSRR/jggw8ICgriyy+/rNE2LiEhgRtvvJGMjAy8vb0ZOXIk27Ztw9vbu15jk7AshBBCCCFa1ZgxY9B13ez+s1fnq3zOnj17zD5nwYIFTTE0KcMQQgghhBDCHAnLQgghhBBCmCFhWQghhBBCCDMkLAshhBBCCGGGhGUhhBBCCCHMkLAshBBCCCGEGRKWhRBCCCGEMEPCshBCCCGEEGZIWBZCCCGEEMIMCctCCCGEEEKYIWFZCCGEEEIIMyQsCyGEEEIIYYaEZSGEEEIIIcyQsCyEEEIIIYQZEpaFEEIIIYQwQ8KyEEIIIYQQZkhYFkIIIYQQwgwJy0IIIYQQQpghYVkIIYQQQggzmjUsb9y4kcsuu4yAgAAUReH3338/5/Hr169HUZRaj+Tk5OYcphBCCCGEECY1a1guKCigT58+fPLJJ/V63rFjx0hKSqp6+Pj4NNMIhRBCCCGEMM+qOU8+ZcoUpkyZUu/n+fj44Obm1vQDEkIIIYQQoh7aZM1y37598ff3Z8KECfzzzz/nPLakpITc3NwaDyGEEEIIIZpCmwrL/v7+zJ07l0WLFrFo0SKCg4MZM2YMu3fvNvuc119/HVdX16pHcHBwC45YCCGEEEKcz5q1DKO+IiIiiIiIqPp4+PDhnDx5kvfee4/vv//e5HOeeuopZs2aVfVxbm6uBGYhhBBCCNEk2lRYNmXw4MFs3rzZ7H5bW1tsbW1bcERCCCGEEOJC0abKMEzZu3cv/v7+rT0MIYQQQghxAWrWmeX8/HxOnDhR9XF0dDR79+7Fw8ODkJAQnnrqKRITE/nuu+8AeP/99wkLC6NHjx4UFxfz5ZdfsnbtWlatWtWcwxRCCCGEEMKkZg3LO3fuZOzYsVUfV9YW33bbbcybN4+kpCTi4uKq9peWlvLoo4+SmJiIg4MDvXv35u+//65xDiGEEEIIIVpKs4blMWPGoOu62f3z5s2r8fHjjz/O448/3pxDEkIIIYQQwmJtvmZZCCGEEEKI1iJhWQghhBBCCDMkLAshhBBCCGGGhGUhhBBCCCHMkLAshBBCCCGEGRKWhRBCCCGEMEPCshBCCCGEEGZIWBZCCCGEEMIMCctCCCGEEEKYIWFZCCGEEEIIMyQsCyGEEEIIYYaEZSGEEEIIIcyQsCyEEEIIIYQZEpaFEEIIIYQwQ8KyEEIIIYQQZkhYFkIIIYQQwgwJy0IIIYQQQpghYVkIIYQQQggzJCwLIYQQQghhhoRlIYQQQgghzJCwLIQQQgghhBkSloUQQgghhDBDwrIQQgghhBBmSFgWQgghhBDCDAnLQgghhBBCmCFhWQghhBBCCDMkLAshhBBCCGGGhGUhhBBCCCHMkLAshBBCCCGEGRKWhRBCCCGEMEPCshBCCCGEEGZIWBZCCCGEEMIMCctCCCGEEEKYIWFZCCGEEEIIMyQsCyGEEEIIYYaEZSGEEEIIIcyQsCyEEEIIIdqETz75hNDQUOzs7BgyZAg7duwwe2xZWRkvvfQSHTt2xM7Ojj59+rBixYpGndMUCctCCCGEEKLVLVy4kFmzZjF79mx2795Nnz59mDRpEqmpqSaPf/bZZ/nss8/46KOPOHz4MPfccw9XXnkle/bsafA5TZGwLIQQQgghWt27777LzJkzmTFjBt27d2fu3Lk4ODjw9ddfmzz++++/5+mnn+aSSy4hPDyce++9l0suuYR33nmnwec0RcKyEEIIIYRoFrm5uTUeJSUlJo8rLS1l165djB8/vmqbqqqMHz+erVu3mnxOSUkJdnZ2NbbZ29uzefPmBp/TFCuLjxRCCCGEEO2eh48zTk7OzXoN23wdgODg4BrbZ8+ezQsvvFDr+PT0dIxGI76+vjW2+/r6cvToUZPXmDRpEu+++y6jR4+mY8eOrFmzhsWLF2M0Ght8TlMkLAshRB00TSM2JYXY5BTKSktxcXamS3AQ7s7N+8tGCCHau/j4eFxcXKo+trW1bbJzf/DBB8ycOZOuXbuiKAodO3ZkxowZ9SqxsISEZSGEOIfC4hL+3raN3Kho/MrKcFRVThmNRLq60HPAAPp16dzaQxRCiDbLxcWlRlg2x8vLC4PBQEpKSo3tKSkp+Pn5mXyOt7c3v//+O8XFxWRkZBAQEMCTTz5JeHh4g89pitQsCyHEOazfuRPt6DGmubsxtUMIY4ODuK5DCIPLjRzauo0TCYmtPUQhhGj3bGxsGDBgAGvWrKnapmkaa9asYdiwYed8rp2dHYGBgZSXl7No0SKuuOKKRp+zOplZFkIIM1KyssiIimaKlyee9vZV21VFobe3F2nxCRyMPE6noMBWHKUQQpwfZs2axW233cbAgQMZPHgw77//PgUFBcyYMQOA6dOnExgYyOuvvw7A9u3bSUxMpG/fviQmJvLCCy+gaRqPP/64xee0hIRlIYQw41RaOk5FRQR4e5nc38nVhejkZPKLinCqFqaFEELU3/XXX09aWhrPP/88ycnJ9O3blxUrVlTdoBcXF4eqnimKKC4u5tlnnyUqKgonJycuueQSvv/+e9zc3Cw+pyUkLAshhBmarmMAFEUxud9aVUEvQ9f1lh2YEEKcpx544AEeeOABk/vWr19f4+OLLrqIw4cPN+qclpCaZSGEMMPT1YVca2syi4t "text/plain": [ "<Figure size 900x700 with 2 Axes>" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import matplotlib.pyplot as plt\n", "from matplotlib.colors import ListedColormap\n", "\n", "cm = plt.cm.RdBu\n", "cm_bright = ListedColormap([\"#FF0000\", \"#0000FF\"])\n", "\n", "xv = np.linspace(x_min, x_max, 10)\n", "yv = np.linspace(y_min, y_max, 10)\n", "Xv, Yv = np.meshgrid(xv, yv)\n", "XYpairs = np.vstack([ Xv.reshape(-1), Yv.reshape(-1)])\n", "zv = nn.predict(XYpairs.T)\n", "Zv = zv.reshape(Xv.shape)\n", "\n", "fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(9, 7))\n", "ax.set_aspect(1)\n", "cn = ax.contourf(Xv, Yv, Zv, cmap=\"coolwarm_r\", alpha=0.4)\n", "\n", "ax.scatter(X_train[:, 0], X_train[:, 1], c=y_train, cmap=cm_bright, edgecolors=\"k\")\n", "\n", "# Plot the testing points\n", "ax.scatter(X_test[:, 0], X_test[:, 1], c=y_test, cmap=cm_bright, alpha=0.4, edgecolors=\"k\")\n", "\n", "ax.set_xlim(x_min, x_max)\n", "ax.set_ylim(y_min, y_max)\n", "# ax.set_xticks(())\n", "# ax.set_yticks(())\n", "\n", "fig.colorbar(cn)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.9" }, "vscode": { "interpreter": { "hash": "b0fa6594d8f4cbf19f97940f81e996739fb7646882a419484c72d19e05852a7e" } } }, "nbformat": 4, "nbformat_minor": 4 }
|