ML-Kurs-SS2023/notebooks/simple_neural_network.ipynb

494 lines
156 KiB
Plaintext
Raw Normal View History

{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# A simple neural network with one hidden layer in pure Python\n",
"\n",
"## Introduction\n",
"We consider a simple feed-forward neural network with one hidden layer:"
]
},
{
"attachments": {
"48b1ed6e-8e2b-4883-82ac-a2bbed6e2885.png": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAASwAAAEsCAYAAAB5fY51AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy89olMNAAAACXBIWXMAAA9hAAAPYQGoP6dpAABEKklEQVR4nO2ddVxU6ffHD62EgYWBhd2K2F2767prrmLHioqBrWv3rh1rdxeIioGN3WAuusbaIqlISc39/P7gN8/XkYY7DHc879drXvta5855zjzc+5mnzjkGAEAMwzAKwFDXDjAMw6QVFiyGYRQDCxbDMIqBBYthGMXAgsUwjGJgwWIYRjGwYDEMoxhYsBiGUQwsWAzDKAYWLIZhFAMLFsMwioEFi2EYxcCCxTCMYmDBYhhGMbBgMQyjGFiwGIZRDCxYDMMoBhYshmEUAwsWwzCKgQWLYRjFwILFMIxiYMFiGEYxsGAxDKMYWLAYhlEMLFgMwygGFiyGYRQDCxbDMIqBBYthGMXAgsUwjGJgwWIYRjGwYDEMoxhYsBiGUQwsWAzDKAYWLIZhFAMLFsMwioEFi2EYxcCCxTCMYmDBYhhGMbBgMQyjGFiwGIZRDCxYDMMoBhYshmEUAwsWwzCKgQWLYRjFwILFMIxiMNa1A0zqxMTEkL+/P8XHx5OVlRUVKFCADAwMdO2WXsN9nj3hEVY25dGjRzRq1CiqWbMmWVpaUsmSJalMmTJUqFAhKly4MLVr14527NhB0dHRunZVb+A+z/4YAICunWD+x7Nnz2jYsGF05swZKliwILVr147s7e2pTJkyZGxsTB8/fqR79+7RlStX6OLFi2RtbU3Tp0+nESNGkKEh//5kBO5zBQEm27B69WrkzJkTpUuXxp49exATE5Pi9U+fPsWgQYNARGjUqBHevn2bRZ7qD9znyoIFK5swbdo0EBGGDh2KiIiIdH32woULsLW1RfHixfHixQsteah/cJ8rDxasbMD69etBRFiwYEGGbbx9+xZ2dnYoW7YswsPDZfROP+E+VyYsWDrmv//+g7m5OQYPHpxpW8+ePYO5uTmGDh0qg2f6C/e5cmHB0jHt2rVDyZIlZfuFXrlyJYgIt2/flsWePsJ9rlxYsHTIs2fPQETYsmWLbDbj4+NRokQJ9OnTRzab+gT3ubLhPVkdsmXLFsqbNy85OjrKZtPIyIicnZ1p//79FBYWJptdfYH7XNmwYOmQq1evUqtWrShnzpyy2m3Xrh3FxMSQj4+PrHb1Ae5zZcOCpSMkSaI7d+5Q7dq1ZbddoUIFMjc354fnG7jPlQ8Llo6IiIigiIgIKlGihOy2jYyMyNbWlj58+CC7bSXDfa58WLB0BP4/IkpbAbWGhoaiDSYB7nPlw4KlIywsLMjU1JQCAgJktw2AAgICyNraWnbbSob7XPmwYOkIY2Njqlq1qlbWPF69ekUfP36kmjVrym5byXCfKx8WLB1St25dOn/+PKlUKlntenl5kYGBgVYWl5UO97myYcHSIX369KE3b97QiRMnZLMJgP7++28qV66c7Fv3+oC2+nzNmjX0448/UqFChWSzyySG82HpEADk4OBAxsbGdPXqVTIyMsq0zePHj1O7du2IiChv3rw0ZswYcnFxoVy5cmXatj6gzT4/duwY/fzzzzJ4ySSLbg7YM2quXLkCAwMDLFq0KNO2Pn36hCJFiqB69eooV64ciAhEhDx58mDWrFkIDQ2VwWPlo40+//HHHyFJkgzeMSnBgpUNGDt2LExMTHD8+PEM24iKikLLli2RO3duvH37FvHx8dizZw8qVKigIVwzZ87Ep0+f5HNeoWijzxntw4KVDYiNjUX79u1hYmKCzZs3p/uX+t27d2jcuDFy5syJixcvarwXHx+PvXv3olKlSkK4cufOjenTp+Pjx49yfg1FIUefN2nSJMk+Z7QHC1Y2ITY2FgMGDAARoW3btvj3339T/Ux0dDTWr18PS0tL5MiRAxcuXEj2WpVKhf3796Ny5cpCuHLlyoVp06YhJCREzq+iGDLT57lz54aNjQ2uXLmSBZ4yaliwshlHjhxBkSJFQERo2bIlVq1ahevXryM4OBifP3/Gixcv4O7ujjFjxiBPnjwgIpiZmYGIMHfu3FTtq1QquLm5oUqVKkK4rKysMGXKFAQHB8v+fVQqFaKiohAbGyu7bblIT59bW1uDiFCqVKnvVuh1CQtWNuTLly/YuXMnGjduDBMTEyEsX7+MjY3FKGnz5s0gIpiYmOD+/ftpakOlUuHAgQOoVq2asGlpaYlJkyYhKCgow76rVCqcOXMGTk5OqFmzpob/NjY2aNu2LRYsWICAgIAMt6EN0tLnNjY26NGjB4gIBgYG8PX11bXb3x18rCGbExMTQw8fPqQ3b96QSqUiKysrCg0Npe7du5ORkRGpVCratm0bHTp0iDw8PKhmzZp08+ZNMjExSZN9SZLIw8ODZs2aRffv3yciIktLSxo+fDiNHTuW8ufPn2Zf3dzcaMqUKfTs2TMqV64cNWrUiGrUqEHW1tYUHx9Pz549Ix8fH7pw4QKpVCrq0aMHLVy4kAoWLJihvtEWSfV5tWrVqEiRIkRE1KlTJzp06BA5OjrS3r17deztd4auFZNJP9HR0TA3Nxe//A0aNMCHDx/EdGXWrFnptqlSqXDo0CHUrFlT2LWwsMCECRMQGBiY4mdDQkLQpUsXEBF++eUXXLp0KcVF7JCQECxZsgT58uVD/vz54e7unm5/dcn9+/fFKOuff/7RtTvfFSxYCuWXX34RDw0R4eHDh9i7d6+YLt69ezdDdiVJgoeHB2rVqiWEy9zcHOPGjUtyGhcQEIAqVarA2toarq6u6WorICAAHTt2BBFh9erVGfJXV3Tu3BlEhK5du+rale8KFiyFsmbNGhAR8uXLByLC8OHDIUmSeJCqVauWalHQlJAkCUePHoW9vb2GcI0dOxb+/v4AEkZ69vb2KFSoEB49epThdkaOHAkigpubW4b9zWoePHggfjAePnyoa3e+G1iwFMrLly9BRDA0NBRnqyIjIxEQEID8+fODiDBt2rRMtyNJEo4dOwYHBwchXDlz5sTo0aPh4uICExMT3LlzJ9NtdOnSBdbW1vjw4UOmfc4qfvvtNxARunTpomtXvhtYsBSM+jBowYIFNSrBuLq6gohgZGQEb29vWdqSJAmenp6oW7euEC4DAwPMmTNHFvtBQUEoWLAgunfvLou9rODhw4diSp7W3Vkmc7BgKZixY8eCiMR6U926dcV7Xbt2BRGhSpUqiI6Olq1NSZJw8uRJFCxYEAUKFMjUtPNbVq5cCSMjI7x79042m9pG3c+dOnXStSvfBSxYCubcuXMgIuTPnx9GRkYgIty7dw/A/0YsRITJkyfL2m5ERAQsLCwwY8YMWe1+/vwZFhYWmD17tqx2tYmvr68YZan7ntEenA9LwTRq1IgsLS0pODiYmjdvTkRE69evJyKi/Pnz09q1a4mIaP78+XT79m3Z2r179y5FRkZShw4dZLNJRJQrVy5q1aoVXbp0SVa72qRSpUrUrVs3IiKaNWuWjr3Rf1iwFIypqSm1atWKiIiKFStGRES7du2iiIgIIko44Ni9e3eSJIn69u1L0dHRsrTr7e1NZmZmVLlyZVnsfY29vT15e3srqpjD9OnTycDAgA4dOkT37t3TtTt6DQuWwmnbti0RET158oTKlClD4eHhtG/fPvH+ypUrqVChQvT48WOaOXOmLG2+ffuWSpQokebT9OmhTJkyFBoaSpGRkbLb1hYVK1ak7t27ExHJ1sdM0rBgKZyffvqJiIhu3LhBPXv2JKL/TQuJiPLlyyf+f9GiRXTjxo1Mt6lSqcjY2DjTdpJCnQFU7pzr2mbatGlkaGhIHh4edOfOHV27o7ewYCmcYsWKUdWqVQkA2djYkKmpKXl7e2s8NO3bt6fevXuTJEnUr18/+vLlS6bazJMnDwUHB2fW9SQJCQkhY2NjMjc314p9bVGhQgUeZWUBLFh6gHpaePXqVerUqRMRaY6yiIhWrFhBhQsXpidPntC0adMy1V716tUpMDCQ/Pz8MmUnKe7c
}
},
"cell_type": "markdown",
"metadata": {},
"source": [
"![nn.png](attachment:48b1ed6e-8e2b-4883-82ac-a2bbed6e2885.png)"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"In this example the input vector of the neural network has two features, i.e., the input is a two-dimensional vector:\n",
"\n",
"$$\n",
"\\mathbf x = (x_0, x_1).\n",
"$$\n",
"\n",
"We consider a set of $n$ vectors as training data. The training data can therefore be written as a $n \\times 2$ matrix where each row represents a feature vector:\n",
"\n",
"$$ \n",
"X = \n",
"\\begin{pmatrix}\n",
"x_{00} & x_{01} \\\\\n",
"x_{10} & x_{11} \\\\\n",
"\\vdots & \\vdots \\\\\n",
"x_{m-1\\,0} & x_{m-1\\,1} \n",
"\\end{pmatrix} $$\n",
"\n",
"The known labels (1 = 'signal', 0 = 'background') are stored in a $n$-dimensional column vector $\\mathbf y$.\n",
"\n",
"In the following, $n_1$ denotes the number of neurons in the hidden layer. The weights for the connections from the input layer (layer 0) to the hidden layer (layer 0) are given by the following matrix:\n",
"\n",
"$$\n",
"W^{(1)} = \n",
"\\begin{pmatrix}\n",
"w_{00}^{(1)} \\dots w_{0 \\, n_1-1}^{(1)} \\\\\n",
"w_{10}^{(1)} \\dots w_{1 \\, n_1-1}^{(1)} \n",
"\\end{pmatrix}\n",
"$$\n",
"\n",
"Each neuron in the hidden layer is assigned a bias $\\mathbf b^{(1)} = (b^{(1)}_0, \\ldots, b^{(1)}_{n_1-1})$. The neuron in the output layer has the bias $\\mathbf b^{(2)}$. With that, the output values of the network for the matrix $X$ of input feature vectors is given by\n",
"\n",
"$$\n",
"\\begin{align}\n",
"Z^{(1)} &= X W^{(1)} + \\mathbf b^{(1)} \\\\\n",
"A^{(1)} &= \\sigma(Z^{(1)}) \\\\\n",
"Z^{(2)} &= A^{(1)} W^{(2)} + \\mathbf b^{(2)} \\\\\n",
"A^{(2)} &= \\sigma(Z^{(2)})\n",
"\\end{align}\n",
"$$\n",
"\n",
"The loss function for a given set of weights is given by\n",
"\n",
"$$ L = \\sum_{i=0}^{n-1} (y_\\mathrm{pred} - y_\\mathrm{true})^2 $$\n",
"\n",
"We can know calculate the gradient of the loss function w.r.t. the wights. With the definition $\\hat L = (y_\\mathrm{pred} - y_\\mathrm{true})^2$, the gradients for the weights from the output layer to the hidden layer are given by: \n",
"\n",
"$$ \\frac{\\partial \\tilde L}{\\partial w_i^{(2)}} = \\frac{\\partial \\tilde L}{a_k^{(2)}} \\frac{a_k^{(2)}}{\\partial w_i^{(2)}} = \\frac{\\partial \\tilde L}{a_k^{(2)}} \\frac{a_k^{(2)}}{z_k^{(2)}} \\frac{z_k^{(2)}}{\\partial w_i^{(2)}} = 2 (a_k^{(2)} - y_k) a_k^{(2)} (1 - a_k^{(2)}) a_{k,i}^{(1)}$$\n",
"\n",
"Applying the chain rule further, we also obtain the gradient for the weights from the input layer to the hidden layer read: \n",
"\n",
"$$ \\frac{\\partial \\tilde L}{\\partial w_{ij}^{(1)}} = \\frac{\\partial \\tilde L}{\\partial a_k^{(2)}} \\frac{\\partial a_k^{(2)}}{\\partial z_k^{(2)}} \\frac{\\partial z_k^{(2)}}{\\partial a_{k,j}^{(1)}} \\frac{\\partial a_{k,j}^{(1)}}{\\partial z_{k,j}^{(1)}} \\frac{\\partial z_{k,j}^{(1)}}{\\partial w_{ij}^{(1)}} $$"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## A simple neural network class"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"# A simple feed-forward neutral network with on hidden layer\n",
"# see also https://towardsdatascience.com/how-to-build-your-own-neural-network-from-scratch-in-python-68998a08e4f6\n",
"\n",
"import numpy as np\n",
"\n",
"class NeuralNetwork:\n",
" def __init__(self, x, y):\n",
" n1 = 3 # number of neurons in the hidden layer\n",
" self.input = x\n",
" self.weights1 = np.random.rand(self.input.shape[1],n1)\n",
" self.bias1 = np.random.rand(n1)\n",
" self.weights2 = np.random.rand(n1,1)\n",
" self.bias2 = np.random.rand(1) \n",
" self.y = y\n",
" self.output = np.zeros(y.shape)\n",
" self.learning_rate = 0.01\n",
" self.n_train = 0\n",
" self.loss_history = []\n",
"\n",
" def sigmoid(self, x):\n",
" return 1/(1+np.exp(-x))\n",
"\n",
" def sigmoid_derivative(self, x):\n",
" return x * (1 - x)\n",
"\n",
" def feedforward(self):\n",
" self.layer1 = self.sigmoid(self.input @ self.weights1 + self.bias1)\n",
" self.output = self.sigmoid(self.layer1 @ self.weights2 + self.bias2)\n",
"\n",
" def backprop(self):\n",
"\n",
" # delta1: [m, 1], m = number of training data\n",
" delta1 = 2 * (self.y - self.output) * self.sigmoid_derivative(self.output)\n",
"\n",
" # Gradient w.r.t. weights from hidden to output layer: [n1, 1] matrix, n1 = # neurons in hidden layer\n",
" d_weights2 = self.layer1.T @ delta1\n",
" d_bias2 = np.sum(delta1) \n",
" \n",
" # shape of delta2: [m, n1], m = number of training data, n1 = # neurons in hidden layer\n",
" delta2 = (delta1 @ self.weights2.T) * self.sigmoid_derivative(self.layer1)\n",
" d_weights1 = self.input.T @ delta2\n",
" d_bias1 = np.ones(delta2.shape[0]) @ delta2 \n",
" \n",
" # update weights and biases\n",
" self.weights1 += self.learning_rate * d_weights1\n",
" self.weights2 += self.learning_rate * d_weights2\n",
"\n",
" self.bias1 += self.learning_rate * d_bias1\n",
" self.bias2 += self.learning_rate * d_bias2\n",
"\n",
" def train(self, X, y):\n",
" self.output = np.zeros(y.shape)\n",
" self.input = X\n",
" self.y = y\n",
" self.feedforward()\n",
" self.backprop()\n",
" self.n_train += 1\n",
" if (self.n_train %1000 == 0):\n",
" loss = np.sum((self.y - self.output)**2)\n",
" print(\"loss: \", loss)\n",
" self.loss_history.append(loss)\n",
" \n",
" def predict(self, X):\n",
" self.output = np.zeros(y.shape)\n",
" self.input = X\n",
" self.feedforward()\n",
" return self.output\n",
" \n",
" def loss_history(self):\n",
" return self.loss_history\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Create toy data\n",
"We create three toy data sets\n",
"1. two moon-like distributions\n",
"2. circles\n",
"3. linearly separable data sets"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# https://scikit-learn.org/stable/auto_examples/classification/plot_classifier_comparison.html#sphx-glr-auto-examples-classification-plot-classifier-comparison-py\n",
"import numpy as np\n",
"from sklearn.datasets import make_moons, make_circles, make_classification\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"X, y = make_classification(\n",
" n_features=2, n_redundant=0, n_informative=2, random_state=1, n_clusters_per_class=1\n",
")\n",
"rng = np.random.RandomState(2)\n",
"X += 2 * rng.uniform(size=X.shape)\n",
"linearly_separable = (X, y)\n",
"\n",
"datasets = [\n",
" make_moons(n_samples=200, noise=0.1, random_state=0),\n",
" make_circles(n_samples=200, noise=0.1, factor=0.5, random_state=1),\n",
" linearly_separable,\n",
"]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Create training and test data set"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# datasets: 0 = moons, 1 = circles, 2 = linearly separable\n",
"X, y = datasets[1]\n",
"X_train, X_test, y_train, y_test = train_test_split(\n",
" X, y, test_size=0.4, random_state=42\n",
")\n",
"\n",
"x_min, x_max = X[:, 0].min() - 0.5, X[:, 0].max() + 0.5\n",
"y_min, y_max = X[:, 1].min() - 0.5, X[:, 1].max() + 0.5\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Train the model"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"loss: 28.591431249971087\n",
"loss: 19.174944855091578\n",
"loss: 18.300519116661075\n",
"loss: 5.44035901972833\n",
"loss: 2.2654992441410906\n",
"loss: 1.6923656607186892\n",
"loss: 1.3715971480249087\n",
"loss: 1.1473150221090382\n",
"loss: 0.9774346378363713\n",
"loss: 0.8457117685917934\n",
"loss: 0.7429652120737472\n",
"loss: 0.6621808985042399\n",
"loss: 0.5977165926831687\n",
"loss: 0.545283043346378\n",
"loss: 0.5017902977940301\n",
"loss: 0.46506515287723293\n",
"loss: 0.4335772706016494\n",
"loss: 0.40623169342909965\n",
"loss: 0.3822273847227754\n",
"loss: 0.36096446182458697\n",
"loss: 0.3419836665195889\n",
"loss: 0.3249263905044797\n",
"loss: 0.3095077414631703\n",
"loss: 0.29549797484687557\n",
"loss: 0.282709394404349\n",
"loss: 0.27098690712728085\n",
"loss: 0.2602010759266338\n",
"loss: 0.2502429170283057\n",
"loss: 0.24101994107129043\n",
"loss: 0.23245309736167535\n",
"loss: 0.2244743850815736\n",
"loss: 0.2170249645242441\n",
"loss: 0.21005364833790718\n",
"loss: 0.2035156851277511\n",
"loss: 0.19737177048767093\n",
"loss: 0.19158723674048994\n",
"loss: 0.18613138439559326\n",
"loss: 0.1809769269368725\n",
"loss: 0.17609952694233805\n",
"loss: 0.17147740633361158\n",
"loss: 0.16709101719249247\n",
"loss: 0.16292276236867193\n",
"loss: 0.15895675725575403\n",
"loss: 0.15517862578972155\n",
"loss: 0.1515753250400271\n",
"loss: 0.1481349938036084\n",
"loss: 0.1448468214395504\n",
"loss: 0.1417009338445398\n",
"loss: 0.13868829400248622\n",
"loss: 0.13580061497353096\n",
"loss: 0.1330302835389617\n",
"loss: 0.1303702930059422\n",
"loss: 0.127814183912036\n",
"loss: 0.12535599156436183\n",
"loss: 0.12299019950967911\n",
"loss: 0.1207116981660866\n",
"loss: 0.11851574795923153\n",
"loss: 0.11639794640004714\n",
"loss: 0.11435419862018448\n",
"loss: 0.11238069094815246\n",
"loss: 0.11047386716576169\n",
"loss: 0.10863040713256283\n",
"loss: 0.10684720750694056\n",
"loss: 0.1051213643275356\n",
"loss: 0.10345015724866172\n",
"loss: 0.10183103524916717\n",
"loss: 0.10026160365640727\n",
"loss: 0.09873961234613571\n",
"loss: 0.09726294499576621\n",
"loss: 0.09582960928281026\n",
"loss: 0.09443772793285743\n",
"loss: 0.09308553053235864\n",
"loss: 0.0917713460310074\n",
"loss: 0.09049359586685884\n",
"loss: 0.08925078765463015\n",
"loss: 0.08804150938406197\n",
"loss: 0.08686442408087547\n",
"loss: 0.08571826488782217\n",
"loss: 0.08460183052776402\n",
"loss: 0.08351398111459095\n",
"loss: 0.08245363428124586\n",
"loss: 0.08141976159718958\n",
"loss: 0.08041138525037475\n",
"loss: 0.07942757497120295\n",
"loss: 0.07846744517812901\n",
"loss: 0.07753015232648876\n",
"loss: 0.0766148924438704\n",
"loss: 0.0757208988368859\n",
"loss: 0.07484743995559193\n",
"loss: 0.07399381740306453\n",
"loss: 0.07315936407873515\n",
"loss: 0.07234344244512234\n",
"loss: 0.07154544290849069\n",
"loss: 0.07076478230479659\n",
"loss: 0.07000090248301835\n",
"loss: 0.06925326897863726\n",
"loss: 0.06852136977065329\n",
"loss: 0.06780471411604974\n",
"loss: 0.06710283145614956\n",
"loss: 0.06641527038972668\n"
]
}
],
"source": [
"y_train = y_train.reshape(-1, 1)\n",
"\n",
"nn = NeuralNetwork(X_train, y_train)\n",
"\n",
"for i in range(100000):\n",
" nn.train(X_train, y_train)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Plot the loss vs. the number of epochs"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Text(0, 0.5, 'loss')"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAjIAAAG2CAYAAABlBWwKAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy89olMNAAAACXBIWXMAAA9hAAAPYQGoP6dpAAA36ElEQVR4nO3de3hU1b3/8c9Mkpkk5ka4JCAJIlCughQwRqgHjygi3vHUWk4P9qD81KBiesRi1Za2Gi+nlbZS+9QewT6KtipiQYtFEPASQAMRUUQDKKAkKEgCgVxn/f5IZicDAUJmz96Z8H49zzyZ2XvtPd9sq/l07bXX8hhjjAAAAKKQ1+0CAAAA2oogAwAAohZBBgAARC2CDAAAiFoEGQAAELUIMgAAIGoRZAAAQNQiyAAAgKhFkAEAAFGLIAMAAKKWq0HmiSee0NChQ5WSkqKUlBTl5ubqn//8p7W/qqpKeXl56ty5s5KSkjRp0iSVlZW5WDEAAGhPPG6utbR48WLFxMSoX79+Msbo6aef1qOPPqoNGzZo8ODBuuWWW/Tqq69q/vz5Sk1N1fTp0+X1evXOO++4VTIAAGhHXA0yLUlPT9ejjz6qa6+9Vl27dtWCBQt07bXXSpI++eQTDRw4UIWFhTr33HNdrhQAALgt1u0Cgurr6/XCCy+osrJSubm5KioqUm1trcaNG2e1GTBggLKzs48bZKqrq1VdXW19DgQC2rdvnzp37iyPxxPx3wMAAITPGKMDBw6oR48e8nqPPRLG9SDz4YcfKjc3V1VVVUpKStLLL7+sQYMGqbi4WD6fT2lpaSHtMzIyVFpaeszzFRQUaPbs2RGuGgAAOGHnzp3q2bPnMfe7HmT69++v4uJilZeX68UXX9SUKVO0atWqNp9v1qxZys/Ptz6Xl5crOztbO3fuVEpKih0lAwCACKuoqFBWVpaSk5OP2871IOPz+dS3b19J0ogRI/Tee+/pd7/7na677jrV1NRo//79Ib0yZWVlyszMPOb5/H6//H7/UduDT0YBAIDocaJhIe1uHplAIKDq6mqNGDFCcXFxWr58ubVvy5Yt2rFjh3Jzc12sEAAAtBeu9sjMmjVLEyZMUHZ2tg4cOKAFCxZo5cqVev3115WamqqpU6cqPz9f6enpSklJ0W233abc3FyeWAIAAJJcDjJ79uzRf/3Xf2n37t1KTU3V0KFD9frrr+uiiy6SJD322GPyer2aNGmSqqurNX78eP3xj390s2QAANCOtLt5ZOxWUVGh1NRUlZeXM0YGAIAo0dq/3+1ujAwAAEBrEWQAAEDUIsgAAICoRZABAABRiyADAACiFkEGAABELYIMAACIWgQZAAAQtVxfNDJaVVTVqvxQrZLjY5WW6HO7HAAATkn0yLTRrxZ/rO898qaeXbvD7VIAADhlEWTaKNEXI0mqqq13uRIAAE5dBJk2im8MModqCDIAALiFINNGiXENw4sO0yMDAIBrCDJtlOBruHSH6ZEBAMA1BJk2SvA19sgQZAAAcA1Bpo0S4hrHyHBrCQAA1xBk2sh6aokeGQAAXEOQaaOmHpk6lysBAODURZBpo4TGHhnGyAAA4B6CTBsFe2QIMgAAuIcg00bBMTLMIwMAgHsIMm0UH8fMvgAAuI0g00bBHpnquoACAeNyNQAAnJoIMm0UHOwrcXsJAAC3EGTaKD6WIAMAgNsIMm3k9XoUH8d6SwAAuIkgE4ZEHytgAwDgJoJMGBJ4cgkAAFcRZMLA7L4AALiLIBOGpknxWG8JAAA3EGTCEG8tUxBwuRIAAE5NBJkwBHtkDtXQIwMAgBsIMmGwFo7kqSUAAFxBkAkDg30BAHAXQSYMPH4NAIC7CDJhCI6RqeLWEgAAriDIhIEeGQAA3EWQCUMCSxQAAOAqgkwYElg0EgAAVxFkwsCikQAAuIsgE4Z4JsQDAMBVBJkwJFoT4rFEAQAAbiDIhKFpQjx6ZAAAcANBJgxWkGGMDAAAriDIhMFaa4mnlgAAcAVBJgyJrLUEAICrCDJhsGb2ra2XMcblagAAOPUQZMIQHCNjjFRdx5NLAAA4zdUgU1BQoFGjRik5OVndunXTVVddpS1btoS0GTt2rDweT8jr5ptvdqniUMEeGYnbSwAAuMHVILNq1Srl5eVpzZo1WrZsmWpra3XxxRersrIypN1NN92k3bt3W69HHnnEpYpDxcZ45YtpXKaAJ5cAAHBcrJtfvnTp0pDP8+fPV7du3VRUVKTzzz/f2p6YmKjMzEyny2uV+DivauoDrIANAIAL2tUYmfLycklSenp6yPZnn31WXbp00ZAhQzRr1iwdOnTomOeorq5WRUVFyCuSgustVdEjAwCA41ztkWkuEAhoxowZGj16tIYMGWJt/+EPf6hevXqpR48e2rhxo+6++25t2bJFCxcubPE8BQUFmj17tlNlWwN+6ZEBAMB57SbI5OXladOmTXr77bdDtk+bNs16f9ZZZ6l79+668MILtXXrVvXp0+eo88yaNUv5+fnW54qKCmVlZUWsbmtSPHpkAABwXLsIMtOnT9eSJUu0evVq9ezZ87htc3JyJEklJSUtBhm/3y+/3x+ROlvCeksAALjH1SBjjNFtt92ml19+WStXrlTv3r1PeExxcbEkqXv37hGurnUSWW8JAADXuBpk8vLytGDBAr3yyitKTk5WaWmpJCk1NVUJCQnaunWrFixYoEsvvVSdO3fWxo0bdeedd+r888/X0KFD3SzdEh/HGBkAANziapB54oknJDVMetfcvHnzdMMNN8jn8+mNN97QnDlzVFlZqaysLE2aNEn33nuvC9W2jPWWAABwj+u3lo4nKytLq1atcqiatmEFbAAA3NOu5pGJRgmMkQEAwDUEmTAlMEYGAADXEGTCFBwjw8y+AAA4jyATJp5aAgDAPQSZMAXXWmKMDAAAziPIhCnB13AJeWoJAADnEWTClBBHjwwAAG4hyIQpkdWvAQBwDUEmTAk8tQQAgGsIMmFqmkeG1a8BAHAaQSZMCdxaAgDANQSZMDEhHgAA7iHIhCl4a6m23qi2PuByNQAAnFoIMmEK3lqSeAQbAACnEWTC5IvxyutpeM+keAAAOIsgEyaPx9O0TAFBBgAARxFkbMDCkQAAuIMgY4Pgk0uMkQEAwFkEGRsEn1zi1hIAAM4iyNgggR4ZAABcQZCxAcsUAADgDoKMDZjdFwAAdxBkbBDPeksAALiCIGODxDjGyAAA4AaCjA2swb70yAAA4CiCjA0IMgAAuIMgYwPrqSVuLQEA4CiCjA2sp5bokQEAwFEEGRsksNYSAACuIMjYICG4+jW3lgAAcBRBxgastQQAgDsIMjZg9WsAANxBkLFBPGstAQDgCoKMDZrWWgq4XAkAAKcWgowNEnz0yAAA4AaCjA0SWGsJAABXEGRskNDs1lIgYFyuBgCAUwdBxgbBMTKSVFVHrwwAAE4hyNggPrYpyDC7LwAAziHI2MDr9Sg+ruFSMikeAADOIcjYhAG/AAA4jyBjk8Tgekv0yAAA4BiCjE2Ct5YYIwMAgHMIMjYJ9shUcWsJAADHEGRskmCtt0SQAQDAKQQZmySwAjYAAI4jyNgkOCneYdZbAgDAMQQZm3BrCQAA57kaZAoKCjRq1CglJyerW7duuuqqq7Rly5aQNlVVVcrLy1Pnzp2VlJSkSZMmqayszKWKj41bSwAAOM/VILNq1Srl5eVpzZo1WrZsmWpra3XxxRersrLSanPnnXdq8eLFeuGFF7Rq1Sp99dVXuuaaa1ysumXWhHj0yAAA4JhYN7986dKlIZ/nz5+vbt26qaioSOeff77Ky8v1f//3f1qwYIH+/d//XZI0b948DRw4UGvWrNG5557rRtktSqRHBgAAx7WrMTLl5eWSpPT0dElSUVGRamtrNW7cOKvNgAEDlJ2drcLCwhbPUV1drYqKipCXE+J9jJEBAMBp7SbIBAIBzZgxQ6NHj9aQIUMkSaWlpfL5fEpLSwtpm5GRodLS0hbPU1BQoNTUVOuVlZUV6dIlSYmstQQAgOPaTZDJy8vTpk2b9Pzzz4d1nlmzZqm8vNx67dy506YKj88a7Eu
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import matplotlib.pyplot as plt\n",
"plt.plot(nn.loss_history)\n",
"plt.xlabel(\"# epochs / 1000\")\n",
"plt.ylabel(\"loss\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": []
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.colorbar.Colorbar at 0x12fe75c30>"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAssAAAJMCAYAAAAMveu7AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy89olMNAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdd3gUZdfH8e/MpvfeE5JQQu+9CdLFgr2jqPhaHxUfe8HesLcHOzYEFRQUpEhHmvROgPSQ3nvZmfePkJCQXbLpCZzPda2SmdmZOyEkv733zLkVXdd1hBBCCCGEELWorT0AIYQQQggh2ioJy0IIIYQQQpghYVkIIYQQQggzJCwLIYQQQghhhoRlIYQQQgghzJCwLIQQQgghhBkSloUQQgghhDBDwrIQQgghhBBmSFgWQgghhBDCDAnLQgghhBBCmCFhWQghhBBCtKqNGzdy2WWXERAQgKIo/P7773U+Z/369fTv3x9bW1s6derEvHnzaux/4YUXUBSlxqNr1671HpuEZSGEEEII0aoKCgro06cPn3zyiUXHR0dHM3XqVMaOHcvevXt5+OGHueuuu1i5cmWN43r06EFSUlLVY/PmzfUem1W9nyGEEEIIIUQTmjJlClOmTLH4+Llz5xIWFsY777wDQLdu3di8eTPvvfcekyZNqjrOysoKPz+/Ro3tvAvLmqZx6tQpnJ2dURSltYcjhBBCiAuYruvk5eUREBCAqrb+G/rFxcWUlpa2yLV0Xa+VxWxtbbG1tW30ubdu3cr48eNrbJs0aRIPP/xwjW3Hjx8nICAAOzs7hg0bxuuvv05ISEi9rnXeheVTp04RHBzc2sMQQgghhKgSHx9PUFBQq46huLiYwMBQMjNTWuR6Tk5O5Ofn19g2e/ZsXnjhhUafOzk5GV9f3xrbfH19yc3NpaioCHt7e4YMGcK8efOIiIggKSmJF198kVGjRnHw4EGcnZ0tvtZ5F5YrP/l933yJs4NDK49GCCGEEBeyvMJC+sy4q17hrLmUlpaSmZnCD4sO4eDYvOMpLMjjlqt7EB8fj4uLS9X2pphVtlT1so7evXszZMgQOnTowM8//8ydd95p8XnOu7BcOd3v7OAgYVkIIYQQbUJbKg11cHTG0dGl7gObgIuLS42w3FT8/PxISak5Q56SkoKLiwv29vYmn+Pm5kaXLl04ceJEva7V+sUzQgghhBBC1MOwYcNYs2ZNjW2rV69m2LBhZp+Tn5/PyZMn8ff3r9e1JCwLIYQQQohWlZ+fz969e9m7dy9Q0Rpu7969xMXFAfDUU08xffr0quPvueceoqKiePzxxzl69CiffvopP//8M4888kjVMf/973/ZsGEDMTExbNmyhSuvvBKDwcCNN95Yr7Gdd2UYQgghhBCifdm5cydjx46t+njWrFkA3HbbbcybN4+kpKSq4AwQFhbGsmXLeOSRR/jggw8ICgriyy+/rNE2LiEhgRtvvJGMjAy8vb0ZOXIk27Ztw9vbu15jk7AshBBCCCFa1ZgxY9B13ez+s1fnq3zOnj17zD5nwYIFTTE0KcMQQgghhBDCHAnLQgghhBBCmCFhWQghhBBCCDMkLAshhBBCCGGGhGUhhBBCCCHMkLAshBBCCCGEGRKWhRBCCCGEMEPCshBCCCGEEGZIWBZCCCGEEMIMCctCCCGEEEKYIWFZCCGEEEIIMyQsCyGEEEIIYYaEZSGEEEIIIcyQsCyEEEIIIYQZEpaFEEIIIYQwQ8KyEEIIIYQQZkhYFkIIIYQQwgwJy0IIIYQQQpghYVkIIYQQQggzmjUsb9y4kcsuu4yAgAAUReH3338/5/Hr169HUZRaj+Tk5OYcphBCCCGEECY1a1guKCigT58+fPLJJ/V63rFjx0hKSqp6+Pj4NNMIhRBCCCGEMM+qOU8+ZcoUpkyZUu/n+fj44Obm1vQDEkIIIYQQoh7aZM1y37598ff3Z8KECfzzzz/nPLakpITc3NwaDyGEEEIIIZpCmwrL/v7+zJ07l0WLFrFo0SKCg4MZM2YMu3fvNvuc119/HVdX16pHcHBwC45YCCGEEEKcz5q1DKO+IiIiiIiIqPp4+PDhnDx5kvfee4/vv//e5HOeeuopZs2aVfVxbm6uBGYhhBBCCNEk2lRYNmXw4MFs3rzZ7H5bW1tsbW1bcERCCCGEEOJC0abKMEzZu3cv/v7+rT0MIYQQQghxAWrWmeX8/HxOnDhR9XF0dDR79+7Fw8ODkJAQnnrqKRITE/nuu+8AeP/99wkLC6NHjx4UFxfz5ZdfsnbtWlatWtWcwxRCCCGEEMKkZg3LO3fuZOzYsVUfV9YW33bbbcybN4+kpCTi4uKq9peWlvLoo4+SmJiIg4MDvXv35u+//65xDiGEEEIIIVpKs4blMWPGoOu62f3z5s2r8fHjjz/O448/3pxDEkIIIYQQwmJtvmZZCCGEEEKI1iJhWQghhBBCCDMkLAshhBBCCGGGhGUhhBBCCCHMkLAshBBCCCGEGRKWhRBCCCGEMEPCshBCCCGEEGZIWBZCCCGEEMIMCctCCCGEEEKYIWFZCCGEEEIIMyQsCyGEEEIIYYaEZSGEEEIIIcyQsCyEEEIIIYQZEpaFEEIIIYQwQ8KyEEIIIYQQZkhYFkIIIYQQwgwJy0IIIYQQQpghYVkIIYQQQggzJCwLIYQQQghhhoRlIYQQQgghzJCwLIQQQgghhBkSloUQQgghhDBDwrIQQgghhBBmSFgWQgghhBDCDAnLQgghhBBCmCFhWQghhBBCCDMkLAshhBBCCGGGhGUhhBBCCCHMkLAshBBCCCGEGRKWhRBCCCGEMEPCshBCCCGEEGZIWBZCCCGEEMIMCctCCCGEEEKYIWFZCCGEEEIIMyQsCyGEEEIIYYaEZSGEEEIIIcyQsCyEEEIIIdqETz75hNDQUOzs7BgyZAg7duwwe2xZWRkvvfQSHTt2xM7Ojj59+rBixYpGndMUCctCCCGEEKLVLVy4kFmzZjF79mx2795Nnz59mDRpEqmpqSaPf/bZZ/nss8/46KOPOHz4MPfccw9XXnkle/bsafA5TZGwLIQQQgghWt27777LzJkzmTFjBt27d2fu3Lk4ODjw9ddfmzz++++/5+mnn+aSSy4hPDyce++9l0suuYR33nmnwec0RcKyEEIIIYRoFrm5uTUeJSUlJo8rLS1l165djB8/vmqbqqqMHz+erVu3mnxOSUkJdnZ2NbbZ29uzefPmBp/TFCuLjxRCCCGEEO2eh48zTk7OzXoN23wdgODg4BrbZ8+ezQsvvFDr+PT0dIxGI76+vjW2+/r6cvToUZPXmDRpEu+++y6jR4+mY8eOrFmzhsWLF2M0Ght8TlMkLAshRB00TSM2JYXY5BTKSktxcXamS3AQ7s7N+8tGCCHau/j4eFxcXKo+trW1bbJzf/DBB8ycOZOuXbuiKAodO3ZkxowZ9SqxsISEZSGEOIfC4hL+3raN3Kho/MrKcFRVThmNRLq60HPAAPp16dzaQxRCiDbLxcWlRlg2x8vLC4PBQEpKSo3tKSkp+Pn5mXyOt7c3v//+O8XFxWRkZBAQEMCTTz5JeHh4g89pitQsCyHEOazfuRPt6DGmubsxtUMIY4ODuK5DCIPLjRzauo0TCYmtPUQhhGj3bGxsGDBgAGvWrKnapmkaa9asYdiwYed8rp2dHYGBgZSXl7No0SKuuOKKRp+zOplZFkIIM1KyssiIimaKlyee9vZV21VFobe3F2nxCRyMPE6noMBWHKUQQpwfZs2axW233cbAgQMZPHgw77//PgUFBcyYMQOA6dOnExgYyOuvvw7A9u3bSUxMpG/fviQmJvLCCy+gaRqPP/64xee0hIRlIYQw41RaOk5FRQR4e5nc38nVhejkZPKLinCqFqaFEELU3/XXX09aWhrPP/88ycnJ9O3blxUrVlTdoBcXF4eqnimKKC4u5tlnnyUqKgonJycuueQSvv/+e9zc3Cw+pyUkLAshhBmarmMAFEUxud9aVUEvQ9f1lh2YEEKcpx544AEeeOABk/vWr19f4+OLLrqIw4cPN+qclpCaZSGEMMPT1YVca2syi4t
"text/plain": [
"<Figure size 900x700 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import matplotlib.pyplot as plt\n",
"from matplotlib.colors import ListedColormap\n",
"\n",
"cm = plt.cm.RdBu\n",
"cm_bright = ListedColormap([\"#FF0000\", \"#0000FF\"])\n",
"\n",
"xv = np.linspace(x_min, x_max, 10)\n",
"yv = np.linspace(y_min, y_max, 10)\n",
"Xv, Yv = np.meshgrid(xv, yv)\n",
"XYpairs = np.vstack([ Xv.reshape(-1), Yv.reshape(-1)])\n",
"zv = nn.predict(XYpairs.T)\n",
"Zv = zv.reshape(Xv.shape)\n",
"\n",
"fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(9, 7))\n",
"ax.set_aspect(1)\n",
"cn = ax.contourf(Xv, Yv, Zv, cmap=\"coolwarm_r\", alpha=0.4)\n",
"\n",
"ax.scatter(X_train[:, 0], X_train[:, 1], c=y_train, cmap=cm_bright, edgecolors=\"k\")\n",
"\n",
"# Plot the testing points\n",
"ax.scatter(X_test[:, 0], X_test[:, 1], c=y_test, cmap=cm_bright, alpha=0.4, edgecolors=\"k\")\n",
"\n",
"ax.set_xlim(x_min, x_max)\n",
"ax.set_ylim(y_min, y_max)\n",
"# ax.set_xticks(())\n",
"# ax.set_yticks(())\n",
"\n",
"fig.colorbar(cn)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.9"
},
"vscode": {
"interpreter": {
"hash": "b0fa6594d8f4cbf19f97940f81e996739fb7646882a419484c72d19e05852a7e"
}
}
},
"nbformat": 4,
"nbformat_minor": 4
}