2023-04-03 12:26:38 +02:00
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# A simple neural network with one hidden layer in pure Python\n",
"\n",
"## Introduction\n",
"We consider a simple feed-forward neural network with one hidden layer:"
]
},
{
"attachments": {
"48b1ed6e-8e2b-4883-82ac-a2bbed6e2885.png": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAASwAAAEsCAYAAAB5fY51AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy89olMNAAAACXBIWXMAAA9hAAAPYQGoP6dpAABEKklEQVR4nO2ddVxU6ffHD62EgYWBhd2K2F2767prrmLHioqBrWv3rh1rdxeIioGN3WAuusbaIqlISc39/P7gN8/XkYY7DHc879drXvta5855zjzc+5mnzjkGAEAMwzAKwFDXDjAMw6QVFiyGYRQDCxbDMIqBBYthGMXAgsUwjGJgwWIYRjGwYDEMoxhYsBiGUQwsWAzDKAYWLIZhFAMLFsMwioEFi2EYxcCCxTCMYmDBYhhGMbBgMQyjGFiwGIZRDCxYDMMoBhYshmEUAwsWwzCKgQWLYRjFwILFMIxiYMFiGEYxsGAxDKMYWLAYhlEMLFgMwygGFiyGYRQDCxbDMIqBBYthGMXAgsUwjGJgwWIYRjGwYDEMoxhYsBiGUQwsWAzDKAYWLIZhFAMLFsMwioEFi2EYxcCCxTCMYmDBYhhGMbBgMQyjGFiwGIZRDCxYDMMoBhYshmEUAwsWwzCKgQWLYRjFwILFMIxiMNa1A0zqxMTEkL+/P8XHx5OVlRUVKFCADAwMdO2WXsN9nj3hEVY25dGjRzRq1CiqWbMmWVpaUsmSJalMmTJUqFAhKly4MLVr14527NhB0dHRunZVb+A+z/4YAICunWD+x7Nnz2jYsGF05swZKliwILVr147s7e2pTJkyZGxsTB8/fqR79+7RlStX6OLFi2RtbU3Tp0+nESNGkKEh//5kBO5zBQEm27B69WrkzJkTpUuXxp49exATE5Pi9U+fPsWgQYNARGjUqBHevn2bRZ7qD9znyoIFK5swbdo0EBGGDh2KiIiIdH32woULsLW1RfHixfHixQsteah/cJ8rDxasbMD69etBRFiwYEGGbbx9+xZ2dnYoW7YswsPDZfROP+E+VyYsWDrmv//+g7m5OQYPHpxpW8+ePYO5uTmGDh0qg2f6C/e5cmHB0jHt2rVDyZIlZfuFXrlyJYgIt2/flsWePsJ9rlxYsHTIs2fPQETYsmWLbDbj4+NRokQJ9OnTRzab+gT3ubLhPVkdsmXLFsqbNy85OjrKZtPIyIicnZ1p//79FBYWJptdfYH7XNmwYOmQq1evUqtWrShnzpyy2m3Xrh3FxMSQj4+PrHb1Ae5zZcOCpSMkSaI7d+5Q7dq1ZbddoUIFMjc354fnG7jPlQ8Llo6IiIigiIgIKlGihOy2jYyMyNbWlj58+CC7bSXDfa58WLB0BP4/IkpbAbWGhoaiDSYB7nPlw4KlIywsLMjU1JQCAgJktw2AAgICyNraWnbbSob7XPmwYOkIY2Njqlq1qlbWPF69ekUfP36kmjVrym5byXCfKx8WLB1St25dOn/+PKlUKlntenl5kYGBgVYWl5UO97myYcHSIX369KE3b97QiRMnZLMJgP7++28qV66c7Fv3+oC2+nzNmjX0448/UqFChWSzyySG82HpEADk4OBAxsbGdPXqVTIyMsq0zePHj1O7du2IiChv3rw0ZswYcnFxoVy5cmXatj6gzT4/duwY/fzzzzJ4ySSLbg7YM2quXLkCAwMDLFq0KNO2Pn36hCJFiqB69eooV64ciAhEhDx58mDWrFkIDQ2VwWPlo40+//HHHyFJkgzeMSnBgpUNGDt2LExMTHD8+PEM24iKikLLli2RO3duvH37FvHx8dizZw8qVKigIVwzZ87Ep0+f5HNeoWijzxntw4KVDYiNjUX79u1hYmKCzZs3p/uX+t27d2jcuDFy5syJixcvarwXHx+PvXv3olKlSkK4cufOjenTp+Pjx49yfg1FIUefN2nSJMk+Z7QHC1Y2ITY2FgMGDAARoW3btvj3339T/Ux0dDTWr18PS0tL5MiRAxcuXEj2WpVKhf3796Ny5cpCuHLlyoVp06YhJCREzq+iGDLT57lz54aNjQ2uXLmSBZ4yaliwshlHjhxBkSJFQERo2bIlVq1ahevXryM4OBifP3/Gixcv4O7ujjFjxiBPnjwgIpiZmYGIMHfu3FTtq1QquLm5oUqVKkK4rKysMGXKFAQHB8v+fVQqFaKiohAbGyu7bblIT59bW1uDiFCqVKnvVuh1CQtWNuTLly/YuXMnGjduDBMTEyEsX7+MjY3FKGnz5s0gIpiYmOD+/ftpakOlUuHAgQOoVq2asGlpaYlJkyYhKCgow76rVCqcOXMGTk5OqFmzpob/NjY2aNu2LRYsWICAgIAMt6EN0tLnNjY26NGjB4gIBgYG8PX11bXb3x18rCGbExMTQw8fPqQ3b96QSqUiKysrCg0Npe7du5ORkRGpVCratm0bHTp0iDw8PKhmzZp08+ZNMjExSZN9SZLIw8ODZs2aRffv3yciIktLSxo+fDiNHTuW8ufPn2Zf3dzcaMqUKfTs2TMqV64cNWrUiGrUqEHW1tYUHx9Pz549Ix8fH7pw4QKpVCrq0aMHLVy4kAoWLJihvtEWSfV5tWrVqEiRIkRE1KlTJzp06BA5OjrS3r17deztd4auFZNJP9HR0TA3Nxe//A0aNMCHDx/EdGXWrFnptqlSqXDo0CHUrFlT2LWwsMCECRMQGBiY4mdDQkLQpUsXEBF++eUXXLp0KcVF7JCQECxZsgT58uVD/vz54e7unm5/dcn9+/fFKOuff/7RtTvfFSxYCuWXX34RDw0R4eHDh9i7d6+YLt69ezdDdiVJgoeHB2rVqiWEy9zcHOPGjUtyGhcQEIAqVarA2toarq6u6WorICAAHTt2BBFh9erVGfJXV3Tu3BlEhK5du+rale8KFiyFsmbNGhAR8uXLByLC8OHDIUmSeJCqVauWalHQlJAkCUePHoW9vb2GcI0dOxb+/v4AEkZ69vb2KFSoEB49epThdkaOHAkigpubW4b9zWoePHggfjAePnyoa3e+G1iwFMrLly9BRDA0NBRnqyIjIxEQEID8+fODiDBt2rRMtyNJEo4dOwYHBwchXDlz5sTo0aPh4uICExMT3LlzJ9NtdOnSBdbW1vjw4UOmfc4qfvvtNxARunTpomtXvhtYsBSM+jBowYIFNSrBuLq6gohgZGQEb29vWdqSJAmenp6oW7euEC4DAwPMmTNHFvtBQUEoWLAgunfvLou9rODhw4diSp7W3Vkmc7BgKZixY8eCiMR6U926dcV7Xbt2BRGhSpUqiI6Olq1NSZJw8uRJFCxYEAUKFMjUtPNbVq5cCSMjI7x79042m9pG3c+dOnXStSvfBSxYCubcuXMgIuTPnx9GRkYgIty7dw/A/0YsRITJkyfL2m5ERAQsLCwwY8YMWe1+/vwZFhYWmD17tqx2tYmvr68YZan7ntEenA9LwTRq1IgsLS0pODiYmjdvTkRE69evJyKi/Pnz09q1a4mIaP78+XT79m3Z2r179y5FRkZShw4dZLNJRJQrVy5q1aoVXbp0SVa72qRSpUrUrVs3IiKaNWuWjr3Rf1iwFIypqSm1atWKiIiKFStGRES7du2iiIgIIko44Ni9e3eSJIn69u1L0dHRsrTr7e1NZmZmVLlyZVnsfY29vT15e3srqpjD9OnTycDAgA4dOkT37t3TtTt6DQuWwmnbti0RET158oTKlClD4eHhtG/fPvH+ypUrqVChQvT48WOaOXOmLG2+ffuWSpQokebT9OmhTJkyFBoaSpGRkbLb1hYVK1ak7t27ExHJ1sdM0rBgKZyffvqJiIhu3LhBPXv2JKL/TQuJiPLlyyf+f9GiRXTjxo1Mt6lSqcjY2DjTdpJCnQFU7pzr2mbatGlkaGhIHh4edOfOHV27o7ewYCmcYsWKUdWqVQkA2djYkKmpKXl7e2s8NO3bt6fevXuTJEnUr18/+vLlS6bazJMnDwUHB2fW9SQJCQkhY2NjMjc314p9bVGhQgUeZWUBLFh6gHpaePXqVerUqRMRaY6yiIhWrFhBhQsXpidPntC0adMy1V716tUpMDCQ/Pz8MmUnKe7c
}
},
"cell_type": "markdown",
"metadata": {},
"source": [
"![nn.png](attachment:48b1ed6e-8e2b-4883-82ac-a2bbed6e2885.png)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In this example the input vector of the neural network has two features, i.e., the input is a two-dimensional vector:\n",
"\n",
"$$\n",
"\\mathbf x = (x_0, x_1).\n",
"$$\n",
"\n",
"We consider a set of $n$ vectors as training data. The training data can therefore be written as a $n \\times 2$ matrix where each row represents a feature vector:\n",
"\n",
"$$ \n",
"X = \n",
"\\begin{pmatrix}\n",
"x_{00} & x_{01} \\\\\n",
"x_{10} & x_{11} \\\\\n",
"\\vdots & \\vdots \\\\\n",
"x_{m-1\\,0} & x_{m-1\\,1} \n",
"\\end{pmatrix} $$\n",
"\n",
"The known labels (1 = 'signal', 0 = 'background') are stored in a $n$-dimensional column vector $\\mathbf y$.\n",
"\n",
"In the following, $n_1$ denotes the number of neurons in the hidden layer. The weights for the connections from the input layer (layer 0) to the hidden layer (layer 0) are given by the following matrix:\n",
"\n",
"$$\n",
"W^{(1)} = \n",
"\\begin{pmatrix}\n",
"w_{00}^{(1)} \\dots w_{0 \\, n_1-1}^{(1)} \\\\\n",
"w_{10}^{(1)} \\dots w_{1 \\, n_1-1}^{(1)} \n",
"\\end{pmatrix}\n",
"$$\n",
"\n",
"Each neuron in the hidden layer is assigned a bias $\\mathbf b^{(1)} = (b^{(1)}_0, \\ldots, b^{(1)}_{n_1-1})$. The neuron in the output layer has the bias $\\mathbf b^{(2)}$. With that, the output values of the network for the matrix $X$ of input feature vectors is given by\n",
"\n",
"$$\n",
"\\begin{align}\n",
"Z^{(1)} &= X W^{(1)} + \\mathbf b^{(1)} \\\\\n",
"A^{(1)} &= \\sigma(Z^{(1)}) \\\\\n",
"Z^{(2)} &= A^{(1)} W^{(2)} + \\mathbf b^{(2)} \\\\\n",
"A^{(2)} &= \\sigma(Z^{(2)})\n",
"\\end{align}\n",
"$$\n",
"\n",
"The loss function for a given set of weights is given by\n",
"\n",
"$$ L = \\sum_{i=0}^{n-1} (y_\\mathrm{pred} - y_\\mathrm{true})^2 $$\n",
"\n",
"We can know calculate the gradient of the loss function w.r.t. the wights. With the definition $\\hat L = (y_\\mathrm{pred} - y_\\mathrm{true})^2$, the gradients for the weights from the output layer to the hidden layer are given by: \n",
"\n",
"$$ \\frac{\\partial \\tilde L}{\\partial w_i^{(2)}} = \\frac{\\partial \\tilde L}{a_k^{(2)}} \\frac{a_k^{(2)}}{\\partial w_i^{(2)}} = \\frac{\\partial \\tilde L}{a_k^{(2)}} \\frac{a_k^{(2)}}{z_k^{(2)}} \\frac{z_k^{(2)}}{\\partial w_i^{(2)}} = 2 (a_k^{(2)} - y_k) a_k^{(2)} (1 - a_k^{(2)}) a_{k,i}^{(1)}$$\n",
"\n",
"Applying the chain rule further, we also obtain the gradient for the weights from the input layer to the hidden layer read: \n",
"\n",
"$$ \\frac{\\partial \\tilde L}{\\partial w_{ij}^{(1)}} = \\frac{\\partial \\tilde L}{\\partial a_k^{(2)}} \\frac{\\partial a_k^{(2)}}{\\partial z_k^{(2)}} \\frac{\\partial z_k^{(2)}}{\\partial a_{k,j}^{(1)}} \\frac{\\partial a_{k,j}^{(1)}}{\\partial z_{k,j}^{(1)}} \\frac{\\partial z_{k,j}^{(1)}}{\\partial w_{ij}^{(1)}} $$"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## A simple neural network class"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"# A simple feed-forward neutral network with on hidden layer\n",
"# see also https://towardsdatascience.com/how-to-build-your-own-neural-network-from-scratch-in-python-68998a08e4f6\n",
"\n",
"import numpy as np\n",
"\n",
"class NeuralNetwork:\n",
" def __init__(self, x, y):\n",
" n1 = 3 # number of neurons in the hidden layer\n",
" self.input = x\n",
" self.weights1 = np.random.rand(self.input.shape[1],n1)\n",
" self.bias1 = np.random.rand(n1)\n",
" self.weights2 = np.random.rand(n1,1)\n",
" self.bias2 = np.random.rand(1) \n",
" self.y = y\n",
" self.output = np.zeros(y.shape)\n",
" self.learning_rate = 0.01\n",
" self.n_train = 0\n",
" self.loss_history = []\n",
"\n",
" def sigmoid(self, x):\n",
" return 1/(1+np.exp(-x))\n",
"\n",
" def sigmoid_derivative(self, x):\n",
" return x * (1 - x)\n",
"\n",
" def feedforward(self):\n",
" self.layer1 = self.sigmoid(self.input @ self.weights1 + self.bias1)\n",
" self.output = self.sigmoid(self.layer1 @ self.weights2 + self.bias2)\n",
"\n",
" def backprop(self):\n",
"\n",
" # delta1: [m, 1], m = number of training data\n",
" delta1 = 2 * (self.y - self.output) * self.sigmoid_derivative(self.output)\n",
"\n",
" # Gradient w.r.t. weights from hidden to output layer: [n1, 1] matrix, n1 = # neurons in hidden layer\n",
" d_weights2 = self.layer1.T @ delta1\n",
" d_bias2 = np.sum(delta1) \n",
" \n",
" # shape of delta2: [m, n1], m = number of training data, n1 = # neurons in hidden layer\n",
" delta2 = (delta1 @ self.weights2.T) * self.sigmoid_derivative(self.layer1)\n",
" d_weights1 = self.input.T @ delta2\n",
" d_bias1 = np.ones(delta2.shape[0]) @ delta2 \n",
" \n",
" # update weights and biases\n",
" self.weights1 += self.learning_rate * d_weights1\n",
" self.weights2 += self.learning_rate * d_weights2\n",
"\n",
" self.bias1 += self.learning_rate * d_bias1\n",
" self.bias2 += self.learning_rate * d_bias2\n",
"\n",
" def train(self, X, y):\n",
" self.output = np.zeros(y.shape)\n",
" self.input = X\n",
" self.y = y\n",
" self.feedforward()\n",
" self.backprop()\n",
" self.n_train += 1\n",
" if (self.n_train %1000 == 0):\n",
" loss = np.sum((self.y - self.output)**2)\n",
" print(\"loss: \", loss)\n",
" self.loss_history.append(loss)\n",
" \n",
" def predict(self, X):\n",
" self.output = np.zeros(y.shape)\n",
" self.input = X\n",
" self.feedforward()\n",
" return self.output\n",
" \n",
" def loss_history(self):\n",
" return self.loss_history\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Create toy data\n",
"We create three toy data sets\n",
"1. two moon-like distributions\n",
"2. circles\n",
"3. linearly separable data sets"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# https://scikit-learn.org/stable/auto_examples/classification/plot_classifier_comparison.html#sphx-glr-auto-examples-classification-plot-classifier-comparison-py\n",
"import numpy as np\n",
"from sklearn.datasets import make_moons, make_circles, make_classification\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"X, y = make_classification(\n",
" n_features=2, n_redundant=0, n_informative=2, random_state=1, n_clusters_per_class=1\n",
")\n",
"rng = np.random.RandomState(2)\n",
"X += 2 * rng.uniform(size=X.shape)\n",
"linearly_separable = (X, y)\n",
"\n",
"datasets = [\n",
" make_moons(n_samples=200, noise=0.1, random_state=0),\n",
" make_circles(n_samples=200, noise=0.1, factor=0.5, random_state=1),\n",
" linearly_separable,\n",
"]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Create training and test data set"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# datasets: 0 = moons, 1 = circles, 2 = linearly separable\n",
2023-04-10 21:23:47 +02:00
"X, y = datasets[2]\n",
2023-04-03 12:26:38 +02:00
"X_train, X_test, y_train, y_test = train_test_split(\n",
" X, y, test_size=0.4, random_state=42\n",
")\n",
"\n",
"x_min, x_max = X[:, 0].min() - 0.5, X[:, 0].max() + 0.5\n",
"y_min, y_max = X[:, 1].min() - 0.5, X[:, 1].max() + 0.5\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Train the model"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
2023-04-10 21:23:47 +02:00
"loss: 1.5092056183794516\n",
"loss: 1.4052784494597108\n",
"loss: 1.3761345024941465\n",
"loss: 1.3630869405843156\n",
"loss: 1.3552464258406522\n",
"loss: 1.349123229154225\n",
"loss: 1.3431256808094927\n",
"loss: 1.3362112078682549\n",
"loss: 1.3274290190660858\n",
"loss: 1.3157172798860284\n",
"loss: 1.2997766928634182\n",
"loss: 1.2779742323824104\n",
"loss: 1.248342224204477\n",
"loss: 1.2087304881550036\n",
"loss: 1.1568998315568928\n",
"loss: 1.0918494643262784\n",
"loss: 1.0195590120436018\n",
"loss: 0.9457951095568214\n",
"loss: 0.8639430268192831\n",
"loss: 0.7647525848536025\n",
"loss: 0.6601823271635829\n",
"loss: 0.5661102798883235\n",
"loss: 0.4876160842592251\n",
"loss: 0.42308931814025225\n",
"loss: 0.3689552054429172\n",
"loss: 0.322923845307499\n",
"loss: 0.2836730132074807\n",
"loss: 0.25021217779446325\n",
"loss: 0.2216908592235331\n",
"loss: 0.19736061670901392\n",
"loss: 0.17656800134779702\n",
"loss: 0.15875047386321026\n",
"loss: 0.14343006843804879\n",
"loss: 0.13020485209452098\n",
"loss: 0.11873929885637868\n",
"loss: 0.1087546158367117\n",
"loss: 0.10001970854126296\n",
"loss: 0.09234314187572212\n",
"loss: 0.08556621671860219\n",
"loss: 0.07955713786272206\n",
"loss: 0.07420617412092073\n",
"loss: 0.06942168113049421\n",
"loss: 0.06512685317691934\n",
"loss: 0.061257079645640714\n",
"loss: 0.057757796816077525\n",
"loss: 0.0545827422872365\n",
"loss: 0.05169253512728908\n",
"loss: 0.049053518872192836\n",
"loss: 0.0466368164533751\n",
"loss: 0.044417556058805666\n",
"loss: 0.042374235033703585\n",
"loss: 0.0404881954696185\n",
"loss: 0.03874319037731412\n",
"loss: 0.037125023528634525\n",
"loss: 0.03562124939086246\n",
"loss: 0.03422092223474082\n",
"loss: 0.0329143856137121\n",
"loss: 0.03169309509897928\n",
"loss: 0.03054946850197824\n",
"loss: 0.029476758893627206\n",
"loss: 0.028468946594046432\n",
"loss: 0.027520647001531773\n",
"loss: 0.026627031690069865\n",
"loss: 0.02578376065800892\n",
"loss: 0.02498692397825075\n",
"loss: 0.024232991399588854\n",
"loss: 0.023518768693064036\n",
"loss: 0.022841359737221362\n",
"loss: 0.02219813350040272\n",
"loss: 0.021586695213538353\n",
"loss: 0.021004861138723636\n",
"loss: 0.020450636431551754\n",
"loss: 0.019922195672210468\n",
"loss: 0.019417865704588127\n",
"loss: 0.018936110476324985\n",
"loss: 0.018475517617776015\n",
"loss: 0.01803478653568396\n",
"loss: 0.017612717829281064\n",
"loss: 0.017208203863488286\n",
"loss: 0.016820220356725514\n",
"loss: 0.016447818860260015\n",
"loss: 0.0160901200225292\n",
"loss: 0.015746307545989182\n",
"loss: 0.01541562275609709\n",
"loss: 0.015097359712380383\n",
"loss: 0.014790860800440717\n",
"loss: 0.014495512751390105\n",
"loss: 0.014210743041828075\n",
"loss: 0.013936016633182046\n",
"loss: 0.013670833014182088\n",
"loss: 0.013414723514545273\n",
"loss: 0.013167248861679233\n",
"loss: 0.012927996955477793\n",
"loss: 0.012696580839129182\n",
"loss: 0.012472636846338482\n",
"loss: 0.012255822907560525\n",
"loss: 0.012045816999741691\n",
"loss: 0.011842315725759325\n",
"loss: 0.011645033011228996\n",
"loss: 0.011453698907655463\n"
2023-04-03 12:26:38 +02:00
]
}
],
"source": [
"y_train = y_train.reshape(-1, 1)\n",
"\n",
"nn = NeuralNetwork(X_train, y_train)\n",
"\n",
"for i in range(100000):\n",
" nn.train(X_train, y_train)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Plot the loss vs. the number of epochs"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Text(0, 0.5, 'loss')"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
2023-04-10 21:23:47 +02:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAjcAAAGwCAYAAABVdURTAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/P9b71AAAACXBIWXMAAA9hAAAPYQGoP6dpAABHfklEQVR4nO3deXhU1f0/8Pfsk3WyJ2Ql7IGwxICYAG5IKJvar1YqFtSKGisCxhV3qTXUVsvPKriitSJgBdFqikRQFlmUkLCGPZCFLCSEmayznt8fSQbGLIRkkpvMvF/Pc5/JnDl35jNX23l77znnyoQQAkREREQuQi51AURERETOxHBDRERELoXhhoiIiFwKww0RERG5FIYbIiIicikMN0RERORSGG6IiIjIpSilLqC72Ww2nD17Fj4+PpDJZFKXQ0RERO0ghEBVVRXCw8Mhl7d9bsbtws3Zs2cRFRUldRlERETUAQUFBYiMjGyzj9uFGx8fHwANB8fX11fiaoiIiKg9DAYDoqKi7L/jbXG7cNN0KcrX15fhhoiIqJdpz5ASDigmIiIil8JwQ0RERC6F4YaIiIhcCsMNERERuRSGGyIiInIpDDdERETkUhhuiIiIyKUw3BAREZFLYbghIiIil8JwQ0RERC6F4YaIiIhcCsMNERERuRSGGyfS15lx+KxB6jKIiIjcGsONkxwpMWDkyxsx64NdEEJIXQ4REZHbYrhxktggL6gUMlyoNaPoQp3U5RAREbkthhsn0SgVGBTqAwA4WMRLU0RERFJhuHGi+HAdAODQWb3ElRAREbkvhhsnio/wBQAcLGK4ISIikgrDjRMNi2g4c3OQM6aIiIgkw3DjRHFhvpDLgHNVRpQZ6qUuh4iIyC0x3DiRh1qBASHeAICDHHdDREQkCYYbJ2saVMwZU0RERNJguHEy+7gbDiomIiKSBMONk8WHN8yYOsRBxURERJJguHGyoY3hpuhCHc7XmCSuhoiIyP0w3DiZj1aF2CAvAFzMj4iISAoMN11gWHjTYn68NEVERNTdGG66QLx9MT+euSEiIupuDDddwH6PKc6YIiIi6nYMN12g6bLU6YpaGOrNEldDRETkXhhuuoC/lxoRfh4AgMOcEk5ERNStGG66CO8QTkREJA2Gmy5iH3fDMzdERETdiuGmi8TzNgxERESSYLjpIsMaL0udPFeNWpNF4mqIiIjch6ThZuvWrZgxYwbCw8Mhk8mwfv36du/7008/QalUYtSoUV1WX2eE+GgR4qOBTQC5xVVSl0NEROQ2JA03NTU1GDlyJN56660r2k+v12POnDmYOHFiF1XmHE2XpvYXXpC2ECIiIjeilPLDp0yZgilTplzxfg8++CBmzZoFhUJx2bM9RqMRRqPR/txg6L4BvqP7+mPzkTK8s+UkfpsQAT9Pdbd9NhERkbvqdWNuPvroI5w8eRIvvvhiu/qnp6dDp9PZt6ioqC6u8KJ7kvsiNsgLpQYjnv/qULd9LhERkTvrVeHm+PHjePrpp7Fy5Uoole076bRo0SLo9Xr7VlBQ0MVVXuSpVuIfM0dBIZfhv/vO4qucom77bCIiInfVa8KN1WrFrFmz8PLLL2PQoEHt3k+j0cDX19dh606jovzwyI0DAADPrz+IsxfquvXziYiI3E2vCTdVVVXYs2cP5s2bB6VSCaVSicWLF2Pfvn1QKpXYvHmz1CW26uEbBmBklB8M9RY88cU+2GxC6pKIiIhcVq8JN76+vjhw4ABycnLsW2pqKgYPHoycnByMHTtW6hJbpVLI8Y87RsJDpcBPJyrw8Y7TUpdERETksiSdLVVdXY0TJ07Yn+fl5SEnJwcBAQGIjo7GokWLUFRUhE8++QRyuRzx8fEO+4eEhECr1TZr74n6BXvj2WlxeG79QfwlIxelVfVYOHEQPNQKqUsjIiJyKZKeudmzZw8SEhKQkJAAAEhLS0NCQgJeeOEFAEBxcTHy8/OlLNGp7hobjZmjo2C1Cby75RR+8/+2YsfJcqnLIiIicikyIYRbDQAxGAzQ6XTQ6/XdPri4yfeHS/H8VwdRrK8HAMwcHYXHJg9CiI9WknqIiIh6uiv5/Wa4kUhVvRmvbTiKf+86AwBQyGW4cUgIZo6OwvWDg6FU9JrhUERERF2O4aYNPSXcNPnl9Hks+d8RZJ2ptLeF+Gjw24QIpAwLxagofyjkMgkrJCIikh7DTRt6Wrhpcry0Cp/vKcC6vUWoqDHZ2wO81LhhcAhuigvBhEHB8NZIOgaciIhIEgw3beip4aaJyWLD5iOlyDhQgh+PlsFQb7G/plLIcE2/QNwwOAQT40IQE+glYaVERETdh+GmDT093FzKbLVhz+lKfJ9bik25pThdUevwev9gL0wb3gczRoZjYKiPRFUSERF1PYabNvSmcPNrp85VY/ORMmzKLcMvp8/DcslKx4NDfTB9RB/cmhCBqABPCaskIiJyPoabNvTmcHMpQ70Zm3PL8M3+s9hy7BzM1oZ/jDIZcFNcKO4d1xdJ/QIhk3EwMhER9X4MN21wlXBzKX2tGd8dLsFXOUX46USFvX1ImA/+OC4Wv70qAipOLSciol6M4aYNrhhuLnWirAof7ziNtVlFqDNbATSMzXl++lBcPzhE4uqIiIg6huGmDa4ebproa81Ysycf7245ZZ9afsPgYDw3fSj6B3tLXB0REdGVYbhpg7uEmyaGejP+uek4Pt5xGmargFIuw9wJ/fBYyiBeqiIiol7jSn6/+evm4ny1Kjw7bSi+W3gtJg4JgcUm8M6Wk7jrg90oq6qXujwiIiKnY7hxE/2CvfHhPWOw/K6r4K1R4ue885jxz+3IOnNe6tKIiIiciuHGzUwZ3gdfzRuHgSHeKDUY8fv3duHfO0/Dza5OEhGRC2O4cUP9g72x/uFxmDa8D8xWgee/OoQl/zvCgENERC6B4cZNeWmUeGtWAhZNGQIAeHfrKSz78aTEVREREXUew40bk8lkePC6/nhuWhwA4G/fHcWnu85IXBUREVHnMNwQ5k7oh3k3DAAAPP/VQXyVUyRxRURERB3HcEMAgMdSBmH2NTEQAnjs83344UiZ1CURERF1CMMNAWi4RPXyzcNwy6hwWGwCD63Mwslz1VKXRUREdMUYbshOLpfh778biaR+gag325D2+T5YrDapyyIiIroiDDfkQKWQ4/U7RsJHq8S+ggt4ZwtnUBERUe/CcEPNhPt54OWbhwEAln5/HAeL9BJXRERE1H4MN9Si3yZEYPKwUFhsAo99vg9Gi1XqkoiIiNqF4YZaJJPJ8OpvhyPIW42jpVV4I/OY1CURERG1C8MNtSrQW4NXfzscAPDe1lPYc5o32SQiop6P4YbalDIsDLcnRkII4KX/HuL9p4iIqMdjuKHLemZqHLzUChwsMuC7Q6VSl0NERNQmhhu6rAAvNe4dFwsA+EfmMdhsPHtDREQ9F8MNtcv9E/rBR6vE0dIqfHugWOpyiIiIWsVwQ+2i81Rh7vh+AIB/fH+MKxcTEVGPxXBD7fbH8X3h56nCqXM1+CrnrNTlEBERtYjhhtrNR6vCg9f2BwD8v03HYebZGyIi6oEYbuiK3J0cgyBvNfLP12JtVqHU5RARETUjabjZunUrZsyYgfDwcMhkMqxfv77N/uvWrcOkSZMQHBwMX19fJCUl4bvvvuueYgkA4KlWIvW6hrM3/9x8grdlICKiHkfScFNTU4ORI0firbfealf/rVu3YtKkScjIyEBWVhZuuOEGzJgxA9nZ2V1cKV3qD9fEIMRHg6ILdVz3hoiIehyZ6CFLzspkMnz55Ze49dZbr2i/YcOGYebMmXjhhRdafN1oNMJoNNqfGwwGREVFQa/Xw9fXtzMlu7U3Mo/hzU3HMW5AIFbOvUbqcoiIyMUZDAbodLp2/X736jE3NpsNVVVVCAgIaLVPeno6dDqdfYuKiurGCl3X7xIjIZMBP52oQH5FrdTlEBER2fXqcPP666+jpqYGd9xxR6t9Fi1aBL1eb98KCgq6sULXFRXgifEDggAAn+/hMSUiop6j14abVatW4aWXXsKaNWsQEhLSaj+NRgNfX1+HjZzjzqujAQD
2023-04-03 12:26:38 +02:00
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import matplotlib.pyplot as plt\n",
"plt.plot(nn.loss_history)\n",
"plt.xlabel(\"# epochs / 1000\")\n",
"plt.ylabel(\"loss\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": []
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
2023-04-10 21:23:47 +02:00
"<matplotlib.colorbar.Colorbar at 0x7f4c4c86f460>"
2023-04-03 12:26:38 +02:00
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
2023-04-10 21:23:47 +02:00
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAsgAAAI9CAYAAAAn9I7cAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/P9b71AAAACXBIWXMAAA9hAAAPYQGoP6dpAAC5e0lEQVR4nOzdd3hUZfrG8e/MpPdKCkkg9N47CigKgqIiKpbFXrAra1ksq+j6Q137irgoiCwqFsSKCKJ0UFrovSVAQgqk9znn90cgJpCEhJRJwv25rlyaU59RmNx55znvazFN00RERERERACwOroAEREREZH6RAFZRERERKQEBWQRERERkRIUkEVERERESlBAFhEREREpQQFZRERERKQEBWQRERERkRIUkEVERERESlBAFhEREREpQQFZRERERKQEBWQRERERqXPLli1j1KhRhIeHY7FY+Pbbb896ztKlS+nZsydubm60aNGCDz74oNT+mTNnYrFYzvjKzc2tUm0KyCIiIiJS57KysujatSvvvfdepY4/cOAAI0eO5MILL2Tjxo08/fTTPPzww8ydO7fUcT4+PsTHx5f6cnNzq1JtTlU6WkRERESkBowYMYIRI0ZU+vgPPviAqKgo3n77bQDat2/PunXreP311xkzZkzxcRaLhdDQ0GrVphFkEREREan3Vq9ezbBhw0ptGz58OOvWraOgoKB4W2ZmJs2aNSMiIoIrrriCjRs3VvleGkEWERERacRyc3PJz8+vk3uZponFYim1zdXVFVdX12pfOyEhgZCQkFLbQkJCKCwsJDk5mbCwMNq1a8fMmTPp3Lkz6enpvPPOOwwcOJBNmzbRunXrSt9LAVlERESkkcrNzaVp0+YcP36sTu7n5eVFZmZmqW3PP/88L7zwQo1c//TwbZpmqe39+vWjX79+xfsHDhxIjx49+M9//sO7775b6fsoIIuIiIg0Uvn5+Rw/fozZc7fh4eldq/fKzsrgb2M6EhcXh4+PT/H2mhg9BggNDSUhIaHUtsTERJycnAgMDCzzHKvVSu/evdmzZ0+V7qWALCIiItLIeXh64+npc/YDa4CPj0+pgFxT+vfvzw8//FBq28KFC+nVqxfOzs5lnmOaJjExMXTu3LlK99JDeiIiIiJS5zIzM4mJiSEmJgYomsYtJiaG2NhYACZOnMgtt9xSfPz48eM5dOgQEyZMYMeOHcyYMYPp06fz+OOPFx8zadIkfvnlF/bv309MTAx33nknMTExjB8/vkq1aQRZREREROrcunXruOiii4q/nzBhAgC33norM2fOJD4+vjgsA0RHRzN//nwee+wxpkyZQnh4OO+++26pKd5SU1O55557SEhIwNfXl+7du7Ns2TL69OlTpdos5qnuZhERERFpVNLT0/H19eWbBbG13mKRlZXONZdFkZaWVistFnVJLRYiIiIiIiUoIIuIiIiIlKCALCIiIiJSggKyiIiIiEgJCsgiIiIiIiUoIIuIiIiIlKCALCIiIiJSggKyiIiIiEgJCsgiIiIiIiUoIIuIiIiIlKCALCIiIiJSggKyiIiIiEgJCsgiIiIiIiUoIIuIiIiIlKCALCIiIiJSggKyiIiIiEgJCsgiIiIiIiUoIIuIiIiIlKCALCIiIiJSggKyiIiIiEgJCsgiIiIiIiUoIIuIiIiIlKCALCIiIiJSggKyiIiIiEgJCsgiIiIiIiUoIIuIiIiIlKCALCIiIiJSggKyiIiIiEgJCsgiIiIiIiUoIIuIiIiIlKCALCIiIiJSggKyiIiIiEgJCsgiIiIiIiUoIIuIiIiIlKCALCIiIiJSggKyiIiIiEgJCsgiIiIiIiUoIIuIiIiIlKCALCIiIiJSggKyiIiIiEgJCsgiIiIiIiUoIIuIiIiIlKCALCIiIiJSggKyiIiIiEgJCsgiIiIiIiUoIIuIiIiIlKCALCIiIiJSggKyiIiIiEgJCsgiIiIiIiUoIIuIiIiIlKCALCIiIiJSggKyiIiIiEgJCsgiIiIiIiUoIIuIiIiIlKCALCIiIiJSggKyiIiIiEgJCsgiIiIiIiUoIIuIiIiIlODk6AIqwzAMjh49ire3NxaLxdHliIiIiJTLNE0yMjIIDw/HatVYZEPUIALy0aNHiYyMdHQZIiIiIpUWFxdHRESEo8uQc9AgArK3tzcAs+duw8PT28HViIiIVE9ckglA144+5R5zPMMAoHVE+SOQubk5AIT62Cu8nz0tCQDvCn7qFx45AIDV1avcYzK3bQbA9Gta4f3OdxnZ2XS9/a7i/CINT4MIyKfaKjw8vfH0LP/NREREpD47lFgUjAf2rfhnWUqGgacXtI0sPxzn5Obg5exMuG/54diemgiAt3/FQa3w8D7w8qwwHGdsjsHb3R3TXyOilaW20IarQQRkERGRhu5UOO7RufxwnHJy1LiiYAxF4RioXDiuaNT48D6g4lFjKArHgMKxnDcUkEVERGpZTYXjygRjqLlwrGAs5ysFZBERkVpSmWAMNReOKxOMQeFY5GwUkEVERGqBWipEGi5NziciIlLDarKlIic3h3Bfe52E44zNMWRsjsH0j1A4ljrx/vvvEx0djZubGz179mT58uUVHj9lyhTat2+Pu7s7bdu2ZdasWWccM3fuXDp06ICrqysdOnRg3rx5Va5LAVlERKSGHEo0OZRo0qOzT530G9tTE7GnJuLtpH5jaXi++OILHn30UZ555hk2btzIhRdeyIgRI4iNjS3z+KlTpzJx4kReeOEFtm3bxqRJk3jggQf44Ycfio9ZvXo1Y8eOZdy4cWzatIlx48Zx/fXX88cff1SpNotpmma1Xl0dSE9Px9fXl28WxGqaNxERqZfUUiGnZGRn02LsTaSlpeHj49jcUpcZKisrnWsui6r06+7bty89evRg6tSpxdvat2/P1VdfzeTJk884fsCAAQwcOJB///vfxdseffRR1q1bx4oVKwAYO3Ys6enp/Pzzz8XHXHbZZfj7+/P5559X+rVoBFlERKQaTo0ag1oqRCorPz+f9evXM2zYsFLbhw0bxqpVq8o8Jy8vDzc3t1Lb3N3d+fPPPykoKACKRpBPv+bw4cPLvWZ5FJBFRETOUclgrJYKkSLp6emlvvLy8s44Jjk5GbvdTkhISKntISEhJCQklHnd4cOH89FHH7F+/XpM02TdunXMmDGDgoICkpOTAUhISKjSNcujWSxERETOgVoqpCEJaOKNl1ftLn3tmln0dyIyMrLU9ueff54XXnihzHNOX23QNM1yVyB87rnnSEhIoF+/fpimSUhICLfddhuvvfYaNpvtnK5ZHgVkERGRKjgVjEELf4iUJS4urlQPsqur6xnHBAUFYbPZzhjZTUxMPGME+BR3d3dmzJjBf//7X44dO0ZYWBjTpk3D29uboKAgAEJDQ6t0zfKoxUJERKSS1FIhcnY+Pj6lvsoKyC4uLvTs2ZNFixaV2r5o0SIGDBhQ4fWdnZ2JiIjAZrMxZ84crrjiCqzWor9r/fv3P+OaCxcuPOs1T6cRZBERkUpQS4VIzZowYQLjxo2jV69e9O/fn2nTphEbG8v48eMBmDhxIkeOHCme63j37t38+eef9O3blxMnTvDmm2+ydetWPvnkk+JrPvLIIwwaNIhXX32Vq666iu+++45ff/21eJaLylJAFhERqUBdLxcNaqmQ88PYsWNJSUnhxRdfJD4+nk6dOjF//nyaNWsGQHx8fKk5ke12O2+88Qa7du3C2dmZiy66iFWrVtG8efPiYwYMGMCcOXN49tlnee6552jZsiVffPEFffv2rVJtmgdZRESkHFUZNYaaaamAioMxKBzXd/VxHuQlG07g5VW7tWRmpjOkh3+9eN3VpRFkERGRMqilQuT8pYAsIiJSgloqREQBWURE5CS1VIgIKCCLiIjU6NzGoJYKkYZOAVlERM5raqkQkdMpIIuIyHmrPj6IBwrHIo6mgCwiIued2hg1BrVUiDQWCsgiInJeUUuFiJyNArKIiJw31FIhIpWhgCwiIo2eWipEpCoUkEVEpFGr67mNQS0VIg2dArKIiDRaaqkQkXOhgCwiIo1OQ34QDxS
2023-04-03 12:26:38 +02:00
"text/plain": [
"<Figure size 900x700 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import matplotlib.pyplot as plt\n",
"from matplotlib.colors import ListedColormap\n",
"\n",
"cm = plt.cm.RdBu\n",
"cm_bright = ListedColormap([\"#FF0000\", \"#0000FF\"])\n",
"\n",
"xv = np.linspace(x_min, x_max, 10)\n",
"yv = np.linspace(y_min, y_max, 10)\n",
"Xv, Yv = np.meshgrid(xv, yv)\n",
"XYpairs = np.vstack([ Xv.reshape(-1), Yv.reshape(-1)])\n",
"zv = nn.predict(XYpairs.T)\n",
"Zv = zv.reshape(Xv.shape)\n",
"\n",
"fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(9, 7))\n",
"ax.set_aspect(1)\n",
"cn = ax.contourf(Xv, Yv, Zv, cmap=\"coolwarm_r\", alpha=0.4)\n",
"\n",
"ax.scatter(X_train[:, 0], X_train[:, 1], c=y_train, cmap=cm_bright, edgecolors=\"k\")\n",
"\n",
"# Plot the testing points\n",
"ax.scatter(X_test[:, 0], X_test[:, 1], c=y_test, cmap=cm_bright, alpha=0.4, edgecolors=\"k\")\n",
"\n",
"ax.set_xlim(x_min, x_max)\n",
"ax.set_ylim(y_min, y_max)\n",
"# ax.set_xticks(())\n",
"# ax.set_yticks(())\n",
"\n",
"fig.colorbar(cn)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
2023-04-10 21:23:47 +02:00
"version": "3.8.16"
2023-04-03 12:26:38 +02:00
},
"vscode": {
"interpreter": {
"hash": "b0fa6594d8f4cbf19f97940f81e996739fb7646882a419484c72d19e05852a7e"
}
}
},
"nbformat": 4,
"nbformat_minor": 4
}