Machine Learning Kurs im Rahmen der Studierendentage im SS 2023
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

493 lines
156 KiB

2 years ago
  1. {
  2. "cells": [
  3. {
  4. "attachments": {},
  5. "cell_type": "markdown",
  6. "metadata": {},
  7. "source": [
  8. "# A simple neural network with one hidden layer in pure Python\n",
  9. "\n",
  10. "## Introduction\n",
  11. "We consider a simple feed-forward neural network with one hidden layer:"
  12. ]
  13. },
  14. {
  15. "attachments": {
  16. "48b1ed6e-8e2b-4883-82ac-a2bbed6e2885.png": {
  17. "image/png": "iVBORw0KGgoAAAANSUhEUgAAASwAAAEsCAYAAAB5fY51AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy89olMNAAAACXBIWXMAAA9hAAAPYQGoP6dpAABEKklEQVR4nO2ddVxU6ffHD62EgYWBhd2K2F2767prrmLHioqBrWv3rh1rdxeIioGN3WAuusbaIqlISc39/P7gN8/XkYY7DHc879drXvta5855zjzc+5mnzjkGAEAMwzAKwFDXDjAMw6QVFiyGYRQDCxbDMIqBBYthGMXAgsUwjGJgwWIYRjGwYDEMoxhYsBiGUQwsWAzDKAYWLIZhFAMLFsMwioEFi2EYxcCCxTCMYmDBYhhGMbBgMQyjGFiwGIZRDCxYDMMoBhYshmEUAwsWwzCKgQWLYRjFwILFMIxiYMFiGEYxsGAxDKMYWLAYhlEMLFgMwygGFiyGYRQDCxbDMIqBBYthGMXAgsUwjGJgwWIYRjGwYDEMoxhYsBiGUQwsWAzDKAYWLIZhFAMLFsMwioEFi2EYxcCCxTCMYmDBYhhGMbBgMQyjGFiwGIZRDCxYDMMoBhYshmEUAwsWwzCKgQWLYRjFwILFMIxiMNa1A0zqxMTEkL+/P8XHx5OVlRUVKFCADAwMdO2WXsN9nj3hEVY25dGjRzRq1CiqWbMmWVpaUsmSJalMmTJUqFAhKly4MLVr14527NhB0dHRunZVb+A+z/4YAICunWD+x7Nnz2jYsGF05swZKliwILVr147s7e2pTJkyZGxsTB8/fqR79+7RlStX6OLFi2RtbU3Tp0+nESNGkKEh//5kBO5zBQEm27B69WrkzJkTpUuXxp49exATE5Pi9U+fPsWgQYNARGjUqBHevn2bRZ7qD9znyoIFK5swbdo0EBGGDh2KiIiIdH32woULsLW1RfHixfHixQsteah/cJ8rDxasbMD69etBRFiwYEGGbbx9+xZ2dnYoW7YswsPDZfROP+E+VyYsWDrmv//+g7m5OQYPHpxpW8+ePYO5uTmGDh0qg2f6C/e5cmHB0jHt2rVDyZIlZfuFXrlyJYgIt2/flsWePsJ9rlxYsHTIs2fPQETYsmWLbDbj4+NRokQJ9OnTRzab+gT3ubLhPVkdsmXLFsqbNy85OjrKZtPIyIicnZ1p//79FBYWJptdfYH7XNmwYOmQq1evUqtWrShnzpyy2m3Xrh3FxMSQj4+PrHb1Ae5zZcOCpSMkSaI7d+5Q7dq1ZbddoUIFMjc354fnG7jPlQ8Llo6IiIigiIgIKlGihOy2jYyMyNbWlj58+CC7bSXDfa58WLB0BP4/IkpbAbWGhoaiDSYB7nPlw4KlIywsLMjU1JQCAgJktw2AAgICyNraWnbbSob7XPmwYOkIY2Njqlq1qlbWPF69ekUfP36kmjVrym5byXCfKx8WLB1St25dOn/+PKlUKlntenl5kYGBgVYWl5UO97myYcHSIX369KE3b97QiRMnZLMJgP7++28qV66c7Fv3+oC2+nzNmjX0448/UqFChWSzyySG82HpEADk4OBAxsbGdPXqVTIyMsq0zePHj1O7du2IiChv3rw0ZswYcnFxoVy5cmXatj6gzT4/duwY/fzzzzJ4ySSLbg7YM2quXLkCAwMDLFq0KNO2Pn36hCJFiqB69eooV64ciAhEhDx58mDWrFkIDQ2VwWPlo40+//HHHyFJkgzeMSnBgpUNGDt2LExMTHD8+PEM24iKikLLli2RO3duvH37FvHx8dizZw8qVKigIVwzZ87Ep0+f5HNeoWijzxntw4KVDYiNjUX79u1hYmKCzZs3p/uX+t27d2jcuDFy5syJixcvarwXHx+PvXv3olKlSkK4cufOjenTp+Pjx49yfg1FIUefN2nSJMk+Z7QHC1Y2ITY2FgMGDAARoW3btvj3339T/Ux0dDTWr18PS0tL5MiRAxcuXEj2WpVKhf3796Ny5cpCuHLlyoVp06YhJCREzq+iGDLT57lz54aNjQ2uXLmSBZ4yaliwshlHjhxBkSJFQERo2bIlVq1ahevXryM4OBifP3/Gixcv4O7ujjFjxiBPnjwgIpiZmYGIMHfu3FTtq1QquLm5oUqVKkK4rKysMGXKFAQHB8v+fVQqFaKiohAbGyu7bblIT59bW1uDiFCqVKnvVuh1CQtWNuTLly/YuXMnGjduDBMTEyEsX7+MjY3FKGnz5s0gIpiYmOD+/ftpakOlUuHAgQOoVq2asGlpaYlJkyYhKCgow76rVCqcOXMGTk5OqFmzpob/NjY2aNu2LRYsWICAgIAMt6EN0tLnNjY26NGjB4gIBgYG8PX11bXb3x18rCGbExMTQw8fPqQ3b96QSqUiKysrCg0Npe7du5ORkRGpVCratm0bHTp0iDw8PKhmzZp08+ZNMjExSZN9SZLIw8ODZs2aRffv3yciIktLSxo+fDiNHTuW8ufPn2Zf3dzcaMqUKfTs2TMqV64cNWrUiGrUqEHW1tYUHx9Pz549Ix8fH7pw4QKpVCrq0aMHLVy4kAoWLJihvtEWSfV5tWrVqEiRIkRE1KlTJzp06BA5OjrS3r17deztd4auFZNJP9HR0TA3Nxe//A0aNMCHDx/EdGXWrFnptqlSqXDo0CHUrFlT2LWwsMCECRMQGBiY4mdDQkLQpUsXEBF++eUXXLp0KcVF7JCQECxZsgT58uVD/vz54e7unm5/dcn9+/fFKOuff/7RtTvfFSxYCuWXX34RDw0R4eHDh9i7d6+YLt69ezdDdiVJgoeHB2rVqiWEy9zcHOPGjUtyGhcQEIAqVarA2toarq6u6WorICAAHTt2BBFh9erVGfJXV3Tu3BlEhK5du+rale8KFiyFsmbNGhAR8uXLByLC8OHDIUmSeJCqVauWalHQlJAkCUePHoW9vb2GcI0dOxb+/v4AEkZ69vb2KFSoEB49epThdkaOHAkigpubW4b9zWoePHggfjAePnyoa3e+G1iwFMrLly9BRDA0NBRnqyIjIxEQEID8+fODiDBt2rRMtyNJEo4dOwYHBwchXDlz5sTo0aPh4uICExMT3LlzJ9NtdOnSBdbW1vjw4UOmfc4qfvvtNxARunTpomtXvhtYsBSM+jBowYIFNSrBuLq6gohgZGQEb29vWdqSJAmenp6oW7euEC4DAwPMmTNHFvtBQUEoWLAgunfvLou9rODhw4diSp7W3Vkmc7BgKZixY8eCiMR6U926dcV7Xbt2BRGhSpUqiI6Olq1NSZJw8uRJFCxYEAUKFMjUtPNbVq5cCSMjI7x79042m9pG3c+dOnXStSvfBSxYCubcuXMgIuTPnx9GRkYgIty7dw/A/0YsRITJkyfL2m5ERAQsLCwwY8YMWe1+/vwZFhYWmD17tqx2tYmvr68YZan7ntEenA9LwTRq1IgsLS0pODiYmjdvTkRE69evJyKi/Pnz09q1a4mIaP78+XT79m3Z2r179y5FRkZShw4dZLNJRJQrVy5q1aoVXbp0SVa72qRSpUrUrVs3IiKaNWuWjr3Rf1iwFIypqSm1atWKiIiKFStGRES7du2iiIgIIko44Ni9e3eSJIn69u1L0dHRsrTr7e1NZmZmVLlyZVnsfY29vT15e3srqpjD9OnTycDAgA4dOkT37t3TtTt6DQuWwmnbti0RET158oTKlClD4eHhtG/fPvH+ypUrqVChQvT48WOaOXOmLG2+ffuWSpQokebT9OmhTJkyFBoaSpGRkbLb1hYVK1ak7t27ExHJ1sdM0rBgKZyffvqJiIhu3LhBPXv2JKL/TQuJiPLlyyf+f9GiRXTjxo1Mt6lSqcjY2DjTdpJCnQFU7pzr2mbatGlkaGhIHh4edOfOHV27o7ewYCmcYsWKUdWqVQkA2djYkKmpKXl7e2s8NO3bt6fevXuTJEnUr18/+vLlS6bazJMnDwUHB2fW9SQJCQkhY2NjMjc314p9bVGhQgUeZWUBLFh6gHpaePXqVerUqRMRaY6yiIhWrFhBhQsXpidPntC0adMy1V716tUpMDCQ/Pz8MmUnKe7c
  18. }
  19. },
  20. "cell_type": "markdown",
  21. "metadata": {},
  22. "source": [
  23. "![nn.png](attachment:48b1ed6e-8e2b-4883-82ac-a2bbed6e2885.png)"
  24. ]
  25. },
  26. {
  27. "attachments": {},
  28. "cell_type": "markdown",
  29. "metadata": {},
  30. "source": [
  31. "In this example the input vector of the neural network has two features, i.e., the input is a two-dimensional vector:\n",
  32. "\n",
  33. "$$\n",
  34. "\\mathbf x = (x_0, x_1).\n",
  35. "$$\n",
  36. "\n",
  37. "We consider a set of $n$ vectors as training data. The training data can therefore be written as a $n \\times 2$ matrix where each row represents a feature vector:\n",
  38. "\n",
  39. "$$ \n",
  40. "X = \n",
  41. "\\begin{pmatrix}\n",
  42. "x_{00} & x_{01} \\\\\n",
  43. "x_{10} & x_{11} \\\\\n",
  44. "\\vdots & \\vdots \\\\\n",
  45. "x_{m-1\\,0} & x_{m-1\\,1} \n",
  46. "\\end{pmatrix} $$\n",
  47. "\n",
  48. "The known labels (1 = 'signal', 0 = 'background') are stored in a $n$-dimensional column vector $\\mathbf y$.\n",
  49. "\n",
  50. "In the following, $n_1$ denotes the number of neurons in the hidden layer. The weights for the connections from the input layer (layer 0) to the hidden layer (layer 0) are given by the following matrix:\n",
  51. "\n",
  52. "$$\n",
  53. "W^{(1)} = \n",
  54. "\\begin{pmatrix}\n",
  55. "w_{00}^{(1)} \\dots w_{0 \\, n_1-1}^{(1)} \\\\\n",
  56. "w_{10}^{(1)} \\dots w_{1 \\, n_1-1}^{(1)} \n",
  57. "\\end{pmatrix}\n",
  58. "$$\n",
  59. "\n",
  60. "Each neuron in the hidden layer is assigned a bias $\\mathbf b^{(1)} = (b^{(1)}_0, \\ldots, b^{(1)}_{n_1-1})$. The neuron in the output layer has the bias $\\mathbf b^{(2)}$. With that, the output values of the network for the matrix $X$ of input feature vectors is given by\n",
  61. "\n",
  62. "$$\n",
  63. "\\begin{align}\n",
  64. "Z^{(1)} &= X W^{(1)} + \\mathbf b^{(1)} \\\\\n",
  65. "A^{(1)} &= \\sigma(Z^{(1)}) \\\\\n",
  66. "Z^{(2)} &= A^{(1)} W^{(2)} + \\mathbf b^{(2)} \\\\\n",
  67. "A^{(2)} &= \\sigma(Z^{(2)})\n",
  68. "\\end{align}\n",
  69. "$$\n",
  70. "\n",
  71. "The loss function for a given set of weights is given by\n",
  72. "\n",
  73. "$$ L = \\sum_{i=0}^{n-1} (y_\\mathrm{pred} - y_\\mathrm{true})^2 $$\n",
  74. "\n",
  75. "We can know calculate the gradient of the loss function w.r.t. the wights. With the definition $\\hat L = (y_\\mathrm{pred} - y_\\mathrm{true})^2$, the gradients for the weights from the output layer to the hidden layer are given by: \n",
  76. "\n",
  77. "$$ \\frac{\\partial \\tilde L}{\\partial w_i^{(2)}} = \\frac{\\partial \\tilde L}{a_k^{(2)}} \\frac{a_k^{(2)}}{\\partial w_i^{(2)}} = \\frac{\\partial \\tilde L}{a_k^{(2)}} \\frac{a_k^{(2)}}{z_k^{(2)}} \\frac{z_k^{(2)}}{\\partial w_i^{(2)}} = 2 (a_k^{(2)} - y_k) a_k^{(2)} (1 - a_k^{(2)}) a_{k,i}^{(1)}$$\n",
  78. "\n",
  79. "Applying the chain rule further, we also obtain the gradient for the weights from the input layer to the hidden layer read: \n",
  80. "\n",
  81. "$$ \\frac{\\partial \\tilde L}{\\partial w_{ij}^{(1)}} = \\frac{\\partial \\tilde L}{\\partial a_k^{(2)}} \\frac{\\partial a_k^{(2)}}{\\partial z_k^{(2)}} \\frac{\\partial z_k^{(2)}}{\\partial a_{k,j}^{(1)}} \\frac{\\partial a_{k,j}^{(1)}}{\\partial z_{k,j}^{(1)}} \\frac{\\partial z_{k,j}^{(1)}}{\\partial w_{ij}^{(1)}} $$"
  82. ]
  83. },
  84. {
  85. "attachments": {},
  86. "cell_type": "markdown",
  87. "metadata": {},
  88. "source": [
  89. "## A simple neural network class"
  90. ]
  91. },
  92. {
  93. "cell_type": "code",
  94. "execution_count": 1,
  95. "metadata": {},
  96. "outputs": [],
  97. "source": [
  98. "# A simple feed-forward neutral network with on hidden layer\n",
  99. "# see also https://towardsdatascience.com/how-to-build-your-own-neural-network-from-scratch-in-python-68998a08e4f6\n",
  100. "\n",
  101. "import numpy as np\n",
  102. "\n",
  103. "class NeuralNetwork:\n",
  104. " def __init__(self, x, y):\n",
  105. " n1 = 3 # number of neurons in the hidden layer\n",
  106. " self.input = x\n",
  107. " self.weights1 = np.random.rand(self.input.shape[1],n1)\n",
  108. " self.bias1 = np.random.rand(n1)\n",
  109. " self.weights2 = np.random.rand(n1,1)\n",
  110. " self.bias2 = np.random.rand(1) \n",
  111. " self.y = y\n",
  112. " self.output = np.zeros(y.shape)\n",
  113. " self.learning_rate = 0.01\n",
  114. " self.n_train = 0\n",
  115. " self.loss_history = []\n",
  116. "\n",
  117. " def sigmoid(self, x):\n",
  118. " return 1/(1+np.exp(-x))\n",
  119. "\n",
  120. " def sigmoid_derivative(self, x):\n",
  121. " return x * (1 - x)\n",
  122. "\n",
  123. " def feedforward(self):\n",
  124. " self.layer1 = self.sigmoid(self.input @ self.weights1 + self.bias1)\n",
  125. " self.output = self.sigmoid(self.layer1 @ self.weights2 + self.bias2)\n",
  126. "\n",
  127. " def backprop(self):\n",
  128. "\n",
  129. " # delta1: [m, 1], m = number of training data\n",
  130. " delta1 = 2 * (self.y - self.output) * self.sigmoid_derivative(self.output)\n",
  131. "\n",
  132. " # Gradient w.r.t. weights from hidden to output layer: [n1, 1] matrix, n1 = # neurons in hidden layer\n",
  133. " d_weights2 = self.layer1.T @ delta1\n",
  134. " d_bias2 = np.sum(delta1) \n",
  135. " \n",
  136. " # shape of delta2: [m, n1], m = number of training data, n1 = # neurons in hidden layer\n",
  137. " delta2 = (delta1 @ self.weights2.T) * self.sigmoid_derivative(self.layer1)\n",
  138. " d_weights1 = self.input.T @ delta2\n",
  139. " d_bias1 = np.ones(delta2.shape[0]) @ delta2 \n",
  140. " \n",
  141. " # update weights and biases\n",
  142. " self.weights1 += self.learning_rate * d_weights1\n",
  143. " self.weights2 += self.learning_rate * d_weights2\n",
  144. "\n",
  145. " self.bias1 += self.learning_rate * d_bias1\n",
  146. " self.bias2 += self.learning_rate * d_bias2\n",
  147. "\n",
  148. " def train(self, X, y):\n",
  149. " self.output = np.zeros(y.shape)\n",
  150. " self.input = X\n",
  151. " self.y = y\n",
  152. " self.feedforward()\n",
  153. " self.backprop()\n",
  154. " self.n_train += 1\n",
  155. " if (self.n_train %1000 == 0):\n",
  156. " loss = np.sum((self.y - self.output)**2)\n",
  157. " print(\"loss: \", loss)\n",
  158. " self.loss_history.append(loss)\n",
  159. " \n",
  160. " def predict(self, X):\n",
  161. " self.output = np.zeros(y.shape)\n",
  162. " self.input = X\n",
  163. " self.feedforward()\n",
  164. " return self.output\n",
  165. " \n",
  166. " def loss_history(self):\n",
  167. " return self.loss_history\n"
  168. ]
  169. },
  170. {
  171. "cell_type": "markdown",
  172. "metadata": {},
  173. "source": [
  174. "## Create toy data\n",
  175. "We create three toy data sets\n",
  176. "1. two moon-like distributions\n",
  177. "2. circles\n",
  178. "3. linearly separable data sets"
  179. ]
  180. },
  181. {
  182. "cell_type": "code",
  183. "execution_count": 2,
  184. "metadata": {},
  185. "outputs": [],
  186. "source": [
  187. "# https://scikit-learn.org/stable/auto_examples/classification/plot_classifier_comparison.html#sphx-glr-auto-examples-classification-plot-classifier-comparison-py\n",
  188. "import numpy as np\n",
  189. "from sklearn.datasets import make_moons, make_circles, make_classification\n",
  190. "from sklearn.model_selection import train_test_split\n",
  191. "\n",
  192. "X, y = make_classification(\n",
  193. " n_features=2, n_redundant=0, n_informative=2, random_state=1, n_clusters_per_class=1\n",
  194. ")\n",
  195. "rng = np.random.RandomState(2)\n",
  196. "X += 2 * rng.uniform(size=X.shape)\n",
  197. "linearly_separable = (X, y)\n",
  198. "\n",
  199. "datasets = [\n",
  200. " make_moons(n_samples=200, noise=0.1, random_state=0),\n",
  201. " make_circles(n_samples=200, noise=0.1, factor=0.5, random_state=1),\n",
  202. " linearly_separable,\n",
  203. "]"
  204. ]
  205. },
  206. {
  207. "cell_type": "markdown",
  208. "metadata": {},
  209. "source": [
  210. "## Create training and test data set"
  211. ]
  212. },
  213. {
  214. "cell_type": "code",
  215. "execution_count": 3,
  216. "metadata": {},
  217. "outputs": [],
  218. "source": [
  219. "# datasets: 0 = moons, 1 = circles, 2 = linearly separable\n",
  220. "X, y = datasets[1]\n",
  221. "X_train, X_test, y_train, y_test = train_test_split(\n",
  222. " X, y, test_size=0.4, random_state=42\n",
  223. ")\n",
  224. "\n",
  225. "x_min, x_max = X[:, 0].min() - 0.5, X[:, 0].max() + 0.5\n",
  226. "y_min, y_max = X[:, 1].min() - 0.5, X[:, 1].max() + 0.5\n"
  227. ]
  228. },
  229. {
  230. "cell_type": "markdown",
  231. "metadata": {},
  232. "source": [
  233. "## Train the model"
  234. ]
  235. },
  236. {
  237. "cell_type": "code",
  238. "execution_count": 4,
  239. "metadata": {},
  240. "outputs": [
  241. {
  242. "name": "stdout",
  243. "output_type": "stream",
  244. "text": [
  245. "loss: 28.591431249971087\n",
  246. "loss: 19.174944855091578\n",
  247. "loss: 18.300519116661075\n",
  248. "loss: 5.44035901972833\n",
  249. "loss: 2.2654992441410906\n",
  250. "loss: 1.6923656607186892\n",
  251. "loss: 1.3715971480249087\n",
  252. "loss: 1.1473150221090382\n",
  253. "loss: 0.9774346378363713\n",
  254. "loss: 0.8457117685917934\n",
  255. "loss: 0.7429652120737472\n",
  256. "loss: 0.6621808985042399\n",
  257. "loss: 0.5977165926831687\n",
  258. "loss: 0.545283043346378\n",
  259. "loss: 0.5017902977940301\n",
  260. "loss: 0.46506515287723293\n",
  261. "loss: 0.4335772706016494\n",
  262. "loss: 0.40623169342909965\n",
  263. "loss: 0.3822273847227754\n",
  264. "loss: 0.36096446182458697\n",
  265. "loss: 0.3419836665195889\n",
  266. "loss: 0.3249263905044797\n",
  267. "loss: 0.3095077414631703\n",
  268. "loss: 0.29549797484687557\n",
  269. "loss: 0.282709394404349\n",
  270. "loss: 0.27098690712728085\n",
  271. "loss: 0.2602010759266338\n",
  272. "loss: 0.2502429170283057\n",
  273. "loss: 0.24101994107129043\n",
  274. "loss: 0.23245309736167535\n",
  275. "loss: 0.2244743850815736\n",
  276. "loss: 0.2170249645242441\n",
  277. "loss: 0.21005364833790718\n",
  278. "loss: 0.2035156851277511\n",
  279. "loss: 0.19737177048767093\n",
  280. "loss: 0.19158723674048994\n",
  281. "loss: 0.18613138439559326\n",
  282. "loss: 0.1809769269368725\n",
  283. "loss: 0.17609952694233805\n",
  284. "loss: 0.17147740633361158\n",
  285. "loss: 0.16709101719249247\n",
  286. "loss: 0.16292276236867193\n",
  287. "loss: 0.15895675725575403\n",
  288. "loss: 0.15517862578972155\n",
  289. "loss: 0.1515753250400271\n",
  290. "loss: 0.1481349938036084\n",
  291. "loss: 0.1448468214395504\n",
  292. "loss: 0.1417009338445398\n",
  293. "loss: 0.13868829400248622\n",
  294. "loss: 0.13580061497353096\n",
  295. "loss: 0.1330302835389617\n",
  296. "loss: 0.1303702930059422\n",
  297. "loss: 0.127814183912036\n",
  298. "loss: 0.12535599156436183\n",
  299. "loss: 0.12299019950967911\n",
  300. "loss: 0.1207116981660866\n",
  301. "loss: 0.11851574795923153\n",
  302. "loss: 0.11639794640004714\n",
  303. "loss: 0.11435419862018448\n",
  304. "loss: 0.11238069094815246\n",
  305. "loss: 0.11047386716576169\n",
  306. "loss: 0.10863040713256283\n",
  307. "loss: 0.10684720750694056\n",
  308. "loss: 0.1051213643275356\n",
  309. "loss: 0.10345015724866172\n",
  310. "loss: 0.10183103524916717\n",
  311. "loss: 0.10026160365640727\n",
  312. "loss: 0.09873961234613571\n",
  313. "loss: 0.09726294499576621\n",
  314. "loss: 0.09582960928281026\n",
  315. "loss: 0.09443772793285743\n",
  316. "loss: 0.09308553053235864\n",
  317. "loss: 0.0917713460310074\n",
  318. "loss: 0.09049359586685884\n",
  319. "loss: 0.08925078765463015\n",
  320. "loss: 0.08804150938406197\n",
  321. "loss: 0.08686442408087547\n",
  322. "loss: 0.08571826488782217\n",
  323. "loss: 0.08460183052776402\n",
  324. "loss: 0.08351398111459095\n",
  325. "loss: 0.08245363428124586\n",
  326. "loss: 0.08141976159718958\n",
  327. "loss: 0.08041138525037475\n",
  328. "loss: 0.07942757497120295\n",
  329. "loss: 0.07846744517812901\n",
  330. "loss: 0.07753015232648876\n",
  331. "loss: 0.0766148924438704\n",
  332. "loss: 0.0757208988368859\n",
  333. "loss: 0.07484743995559193\n",
  334. "loss: 0.07399381740306453\n",
  335. "loss: 0.07315936407873515\n",
  336. "loss: 0.07234344244512234\n",
  337. "loss: 0.07154544290849069\n",
  338. "loss: 0.07076478230479659\n",
  339. "loss: 0.07000090248301835\n",
  340. "loss: 0.06925326897863726\n",
  341. "loss: 0.06852136977065329\n",
  342. "loss: 0.06780471411604974\n",
  343. "loss: 0.06710283145614956\n",
  344. "loss: 0.06641527038972668\n"
  345. ]
  346. }
  347. ],
  348. "source": [
  349. "y_train = y_train.reshape(-1, 1)\n",
  350. "\n",
  351. "nn = NeuralNetwork(X_train, y_train)\n",
  352. "\n",
  353. "for i in range(100000):\n",
  354. " nn.train(X_train, y_train)\n"
  355. ]
  356. },
  357. {
  358. "cell_type": "markdown",
  359. "metadata": {},
  360. "source": [
  361. "## Plot the loss vs. the number of epochs"
  362. ]
  363. },
  364. {
  365. "cell_type": "code",
  366. "execution_count": 5,
  367. "metadata": {},
  368. "outputs": [
  369. {
  370. "data": {
  371. "text/plain": [
  372. "Text(0, 0.5, 'loss')"
  373. ]
  374. },
  375. "execution_count": 5,
  376. "metadata": {},
  377. "output_type": "execute_result"
  378. },
  379. {
  380. "data": {
  381. "image/png": "iVBORw0KGgoAAAANSUhEUgAAAjIAAAG2CAYAAABlBWwKAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy89olMNAAAACXBIWXMAAA9hAAAPYQGoP6dpAAA36ElEQVR4nO3de3hU1b3/8c9Mkpkk5ka4JCAJIlCughQwRqgHjygi3vHUWk4P9qD81KBiesRi1Za2Gi+nlbZS+9QewT6KtipiQYtFEPASQAMRUUQDKKAkKEgCgVxn/f5IZicDAUJmz96Z8H49zzyZ2XvtPd9sq/l07bXX8hhjjAAAAKKQ1+0CAAAA2oogAwAAohZBBgAARC2CDAAAiFoEGQAAELUIMgAAIGoRZAAAQNQiyAAAgKhFkAEAAFGLIAMAAKKWq0HmiSee0NChQ5WSkqKUlBTl5ubqn//8p7W/qqpKeXl56ty5s5KSkjRp0iSVlZW5WDEAAGhPPG6utbR48WLFxMSoX79+Msbo6aef1qOPPqoNGzZo8ODBuuWWW/Tqq69q/vz5Sk1N1fTp0+X1evXOO++4VTIAAGhHXA0yLUlPT9ejjz6qa6+9Vl27dtWCBQt07bXXSpI++eQTDRw4UIWFhTr33HNdrhQAALgt1u0Cgurr6/XCCy+osrJSubm5KioqUm1trcaNG2e1GTBggLKzs48bZKqrq1VdXW19DgQC2rdvnzp37iyPxxPx3wMAAITPGKMDBw6oR48e8nqPPRLG9SDz4YcfKjc3V1VVVUpKStLLL7+sQYMGqbi4WD6fT2lpaSHtMzIyVFpaeszzFRQUaPbs2RGuGgAAOGHnzp3q2bPnMfe7HmT69++v4uJilZeX68UXX9SUKVO0atWqNp9v1qxZys/Ptz6Xl5crOztbO3fuVEpKih0lAwCACKuoqFBWVpaSk5OP2871IOPz+dS3b19J0ogRI/Tee+/pd7/7na677jrV1NRo//79Ib0yZWVlyszMPOb5/H6//H7/UduDT0YBAIDocaJhIe1uHplAIKDq6mqNGDFCcXFxWr58ubVvy5Yt2rFjh3Jzc12sEAAAtBeu9sjMmjVLEyZMUHZ2tg4cOKAFCxZo5cqVev3115WamqqpU6cqPz9f6enpSklJ0W233abc3FyeWAIAAJJcDjJ79uzRf/3Xf2n37t1KTU3V0KFD9frrr+uiiy6SJD322GPyer2aNGmSqqurNX78eP3xj390s2QAANCOtLt5ZOxWUVGh1NRUlZeXM0YGAIAo0dq/3+1ujAwAAEBrEWQAAEDUIsgAAICoRZABAABRiyADAACiFkEGAABELYIMAACIWgQZAAAQtVxfNDJaVVTVqvxQrZLjY5WW6HO7HAAATkn0yLTRrxZ/rO898qaeXbvD7VIAADhlEWTaKNEXI0mqqq13uRIAAE5dBJk2im8MModqCDIAALiFINNGiXENw4sO0yMDAIBrCDJtlOBruHSH6ZEBAMA1BJk2SvA19sgQZAAAcA1Bpo0S4hrHyHBrCQAA1xBk2sh6aokeGQAAXEOQaaOmHpk6lysBAODURZBpo4TGHhnGyAAA4B6CTBsFe2QIMgAAuIcg00bBMTLMIwMAgHsIMm0UH8fMvgAAuI0g00bBHpnquoACAeNyNQAAnJoIMm0UHOwrcXsJAAC3EGTaKD6WIAMAgNsIMm3k9XoUH8d6SwAAuIkgE4ZEHytgAwDgJoJMGBJ4cgkAAFcRZMLA7L4AALiLIBOGpknxWG8JAAA3EGTCEG8tUxBwuRIAAE5NBJkwBHtkDtXQIwMAgBsIMmGwFo7kqSUAAFxBkAkDg30BAHAXQSYMPH4NAIC7CDJhCI6RqeLWEgAAriDIhIEeGQAA3EWQCUMCSxQAAOAqgkwYElg0EgAAVxFkwsCikQAAuIsgE4Z4JsQDAMBVBJkwJFoT4rFEAQAAbiDIhKFpQjx6ZAAAcANBJgxWkGGMDAAAriDIhMFaa4mnlgAAcAVBJgyJrLUEAICrCDJhsGb2ra2XMcblagAAOPUQZMIQHCNjjFRdx5NLAAA4zdUgU1BQoFGjRik5OVndunXTVVddpS1btoS0GTt2rDweT8jr5ptvdqniUMEeGYnbSwAAuMHVILNq1Srl5eVpzZo1WrZsmWpra3XxxRersrIypN1NN92k3bt3W69HHnnEpYpDxcZ45YtpXKaAJ5cAAHBcrJtfvnTp0pDP8+fPV7du3VRUVKTzzz/f2p6YmKjMzEyny2uV+DivauoDrIANAIAL2tUYmfLycklSenp6yPZnn31WXbp00ZAhQzRr1iwdOnTomOeorq5WRUVFyCuSgustVdEjAwCA41ztkWkuEAhoxowZGj16tIYMGWJt/+EPf6hevXqpR48e2rhxo+6++25t2bJFCxcubPE8BQUFmj17tlNlWwN+6ZEBAMB57SbI5OXladOmTXr77bdDtk+bNs16f9ZZZ6l79+668MILtXXrVvXp0+eo88yaNUv5+fnW54qKCmVlZUWsbmtSPHpkAABwXLsIMtOnT9eSJUu0evVq9ezZ87htc3JyJEklJSUtBhm/3y+/3x+ROlvCeksAALjH1SBjjNFtt92ml19+WStXrlTv3r1PeExxcbEkqXv37hGurnUSWW8JAADXuBpk8vLytGDBAr3yyitKTk5WaWmpJCk1NVUJCQnaunWrFixYoEsvvVSdO3fWxo0bdeedd+r888/X0KFD3SzdEh/HGBkAANziapB54oknJDVMetfcvHnzdMMNN8jn8+mNN97QnDlzVFlZqaysLE2aNEn33nuvC9W2jPWWAABwj+u3lo4nKytLq1atcqiatmEFbAAA3NOu5pGJRgmMkQEAwDUEmTAlMEYGAADXEGTCFBwjw8y+AAA4jyATJp5aAgDAPQSZMAXXWmKMDAAAziPIhCnB13AJeWoJAADnEWTClBBHjwwAAG4hyIQpkdWvAQBwDUEmTAk8tQQAgGsIMmFqmkeG1a8BAHAaQSZMCdxaAgDANQSZMDEhHgAA7iHIhCl4a6m23qi2PuByNQAAnFoIMmEK3lqSeAQbAACnEWTC5IvxyutpeM+keAAAOIsgEyaPx9O0TAFBBgAARxFkbMDCkQAAuIMgY4Pgk0uMkQEAwFkEGRsEn1zi1hIAAM4iyNgggR4ZAABcQZCxAcsUAADgDoKMDZjdFwAAdxBkbBDPeksAALiCIGODxDjGyAAA4AaCjA2swb70yAAA4CiCjA0IMgAAuIMgYwPrqSVuLQEA4CiCjA2sp5bokQEAwFEEGRsksNYSAACuIMjYICG4+jW3lgAAcBRBxgastQQAgDsIMjZg9WsAANxBkLFBPGstAQDgCoKMDZrWWgq4XAkAAKcWgowNEnz0yAAA4AaCjA0SWGsJAABXEGRskNDs1lIgYFyuBgCAUwdBxgbBMTKSVFVHrwwAAE4hyNggPrYpyDC7LwAAziHI2MDr9Sg+ruFSMikeAADOIcjYhAG/AAA4jyBjk8Tgekv0yAAA4BiCjE2Ct5YYIwMAgHMIMjYJ9shUcWsJAADHEGRskmCtt0SQAQDAKQQZmySwAjYAAI4jyNgkOCneYdZbAgDAMQQZm3BrCQAA57kaZAoKCjRq1CglJyerW7duuuqqq7Rly5aQNlVVVcrLy1Pnzp2VlJSkSZMmqayszKWKj41bSwAAOM/VILNq1Srl5eVpzZo1WrZsmWpra3XxxRersrLSanPnnXdq8eLFeuGFF7Rq1Sp99dVXuuaaa1ysumXWhHj0yAAA4JhYN7986dKlIZ/nz5+vbt26qaioSOeff77Ky8v1f//3f1qwYIH+/d//XZI0b948DRw4UGvWrNG5557rRtktSqRHBgAAx7WrMTLl5eWSpPT0dElSUVGRamtrNW7cOKvNgAEDlJ2drcLCwhbPUV1drYqKipCXE+J9jJEBAMBp7SbIBAIBzZgxQ6NHj9aQIUMkSaWlpfL5fEpLSwtpm5GRodLS0hbPU1BQoNTUVOuVlZUV6dIlSYmstQQAgOPaTZDJy8vTpk2b9Pzzz4d1nlmzZqm8vNx67dy506YKj88a7Eu
  382. "text/plain": [
  383. "<Figure size 640x480 with 1 Axes>"
  384. ]
  385. },
  386. "metadata": {},
  387. "output_type": "display_data"
  388. }
  389. ],
  390. "source": [
  391. "import matplotlib.pyplot as plt\n",
  392. "plt.plot(nn.loss_history)\n",
  393. "plt.xlabel(\"# epochs / 1000\")\n",
  394. "plt.ylabel(\"loss\")"
  395. ]
  396. },
  397. {
  398. "cell_type": "markdown",
  399. "metadata": {},
  400. "source": []
  401. },
  402. {
  403. "cell_type": "code",
  404. "execution_count": 6,
  405. "metadata": {},
  406. "outputs": [
  407. {
  408. "data": {
  409. "text/plain": [
  410. "<matplotlib.colorbar.Colorbar at 0x12fe75c30>"
  411. ]
  412. },
  413. "execution_count": 6,
  414. "metadata": {},
  415. "output_type": "execute_result"
  416. },
  417. {
  418. "data": {
  419. "image/png": "iVBORw0KGgoAAAANSUhEUgAAAssAAAJMCAYAAAAMveu7AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy89olMNAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdd3gUZdfH8e/MpvfeE5JQQu+9CdLFgr2jqPhaHxUfe8HesLcHOzYEFRQUpEhHmvROgPSQ3nvZmfePkJCQXbLpCZzPda2SmdmZOyEkv733zLkVXdd1hBBCCCGEELWorT0AIYQQQggh2ioJy0IIIYQQQpghYVkIIYQQQggzJCwLIYQQQghhhoRlIYQQQgghzJCwLIQQQgghhBkSloUQQgghhDBDwrIQQgghhBBmSFgWQgghhBDCDAnLQgghhBBCmCFhWQghhBBCtKqNGzdy2WWXERAQgKIo/P7773U+Z/369fTv3x9bW1s6derEvHnzaux/4YUXUBSlxqNr1671HpuEZSGEEEII0aoKCgro06cPn3zyiUXHR0dHM3XqVMaOHcvevXt5+OGHueuuu1i5cmWN43r06EFSUlLVY/PmzfUem1W9nyGEEEIIIUQTmjJlClOmTLH4+Llz5xIWFsY777wDQLdu3di8eTPvvfcekyZNqjrOysoKPz+/Ro3tvAvLmqZx6tQpnJ2dURSltYcjhBBCiAuYruvk5eUREBCAqrb+G/rFxcWUlpa2yLV0Xa+VxWxtbbG1tW30ubdu3cr48eNrbJs0aRIPP/xwjW3Hjx8nICAAOzs7hg0bxuuvv05ISEi9rnXeheVTp04RHBzc2sMQQgghhKgSHx9PUFBQq46huLiYwMBQMjNTWuR6Tk5O5Ofn19g2e/ZsXnjhhUafOzk5GV9f3xrbfH19yc3NpaioCHt7e4YMGcK8efOIiIggKSmJF198kVGjRnHw4EGcnZ0tvtZ5F5YrP/l933yJs4NDK49GCCGEEBeyvMJC+sy4q17hrLmUlpaSmZnCD4sO4eDYvOMpLMjjlqt7EB8fj4uLS9X2pphVtlT1so7evXszZMgQOnTowM8//8ydd95p8XnOu7BcOd3v7OAgYVkIIYQQbUJbKg11cHTG0dGl7gObgIuLS42w3FT8/PxISak5Q56SkoKLiwv29vYmn+Pm5kaXLl04ceJEva7V+sUzQgghhBBC1MOwYcNYs2ZNjW2rV69m2LBhZp+Tn5/PyZMn8ff3r9e1JCwLIYQQQohWlZ+fz969e9m7dy9Q0Rpu7969xMXFAfDUU08xffr0quPvueceoqKiePzxxzl69CiffvopP//8M4888kjVMf/973/ZsGEDMTExbNmyhSuvvBKDwcCNN95Yr7Gdd2UYQgghhBCifdm5cydjx46t+njWrFkA3HbbbcybN4+kpKSq4AwQFhbGsmXLeOSRR/jggw8ICgriyy+/rNE2LiEhgRtvvJGMjAy8vb0ZOXIk27Ztw9vbu15jk7AshBBCCCFa1ZgxY9B13ez+s1fnq3zOnj17zD5nwYIFTTE0KcMQQgghhBDCHAnLQgghhBBCmCFhWQghhBBCCDMkLAshhBBCCGGGhGUhhBBCCCHMkLAshBBCCCGEGRKWhRBCCCGEMEPCshBCCCGEEGZIWBZCCCGEEMIMCctCCCGEEEKYIWFZCCGEEEIIMyQsCyGEEEIIYYaEZSGEEEIIIcyQsCyEEEIIIYQZEpaFEEIIIYQwQ8KyEEIIIYQQZkhYFkIIIYQQwgwJy0IIIYQQQpghYVkIIYQQQggzmjUsb9y4kcsuu4yAgAAUReH3338/5/Hr169HUZRaj+Tk5OYcphBCCCGEECY1a1guKCigT58+fPLJJ/V63rFjx0hKSqp6+Pj4NNMIhRBCCCGEMM+qOU8+ZcoUpkyZUu/n+fj44Obm1vQDEkIIIYQQoh7aZM1y37598ff3Z8KECfzzzz/nPLakpITc3NwaDyGEEEIIIZpCmwrL/v7+zJ07l0WLFrFo0SKCg4MZM2YMu3fvNvuc119/HVdX16pHcHBwC45YCCGEEEKcz5q1DKO+IiIiiIiIqPp4+PDhnDx5kvfee4/vv//e5HOeeuopZs2aVfVxbm6uBGYhhBBCCNEk2lRYNmXw4MFs3rzZ7H5bW1tsbW1bcERCCCGEEOJC0abKMEzZu3cv/v7+rT0MIYQQQghxAWrWmeX8/HxOnDhR9XF0dDR79+7Fw8ODkJAQnnrqKRITE/nuu+8AeP/99wkLC6NHjx4UFxfz5ZdfsnbtWlatWtWcwxRCCCGEEMKkZg3LO3fuZOzYsVUfV9YW33bbbcybN4+kpCTi4uKq9peWlvLoo4+SmJiIg4MDvXv35u+//65xDiGEEEIIIVpKs4blMWPGoOu62f3z5s2r8fHjjz/O448/3pxDEkIIIYQQwmJtvmZZCCGEEEKI1iJhWQghhBBCCDMkLAshhBBCCGGGhGUhhBBCCCHMkLAshBBCCCGEGRKWhRBCCCGEMEPCshBCCCGEEGZIWBZCCCGEEMIMCctCCCGEEEKYIWFZCCGEEEIIMyQsCyGEEEIIYYaEZSGEEEIIIcyQsCyEEEIIIYQZEpaFEEIIIYQwQ8KyEEIIIYQQZkhYFkIIIYQQwgwJy0IIIYQQQpghYVkIIYQQQggzJCwLIYQQQghhhoRlIYQQQgghzJCwLIQQQgghhBkSloUQQgghhDBDwrIQQgghhBBmSFgWQgghhBDCDAnLQgghhBBCmCFhWQghhBBCCDMkLAshhBBCCGGGhGUhhBBCCCHMkLAshBBCCCGEGRKWhRBCCCGEMEPCshBCCCGEEGZIWBZCCCGEEMIMCctCCCGEEEKYIWFZCCGEEEIIMyQsCyGEEEIIYYaEZSGEEEIIIcyQsCyEEEIIIdqETz75hNDQUOzs7BgyZAg7duwwe2xZWRkvvfQSHTt2xM7Ojj59+rBixYpGndMUCctCCCGEEKLVLVy4kFmzZjF79mx2795Nnz59mDRpEqmpqSaPf/bZZ/nss8/46KOPOHz4MPfccw9XXnkle/bsafA5TZGwLIQQQgghWt27777LzJkzmTFjBt27d2fu3Lk4ODjw9ddfmzz++++/5+mnn+aSSy4hPDyce++9l0suuYR33nmnwec0RcKyEEIIIYRoFrm5uTUeJSUlJo8rLS1l165djB8/vmqbqqqMHz+erVu3mnxOSUkJdnZ2NbbZ29uzefPmBp/TFCuLjxRCCCGEEO2eh48zTk7OzXoN23wdgODg4BrbZ8+ezQsvvFDr+PT0dIxGI76+vjW2+/r6cvToUZPXmDRpEu+++y6jR4+mY8eOrFmzhsWLF2M0Ght8TlMkLAshRB00TSM2JYXY5BTKSktxcXamS3AQ7s7N+8tGCCHau/j4eFxcXKo+trW1bbJzf/DBB8ycOZOuXbuiKAodO3ZkxowZ9SqxsISEZSGEOIfC4hL+3raN3Kho/MrKcFRVThmNRLq60HPAAPp16dzaQxRCiDbLxcWlRlg2x8vLC4PBQEpKSo3tKSkp+Pn5mXyOt7c3v//+O8XFxWRkZBAQEMCTTz5JeHh4g89pitQsCyHEOazfuRPt6DGmubsxtUMIY4ODuK5DCIPLjRzauo0TCYmtPUQhhGj3bGxsGDBgAGvWrKnapmkaa9asYdiwYed8rp2dHYGBgZSXl7No0SKuuOKKRp+zOplZFkIIM1KyssiIimaKlyee9vZV21VFobe3F2nxCRyMPE6noMBWHKUQQpwfZs2axW233cbAgQMZPHgw77//PgUFBcyYMQOA6dOnExgYyOuvvw7A9u3bSUxMpG/fviQmJvLCCy+gaRqPP/64xee0hIRlIYQw41RaOk5FRQR4e5nc38nVhejkZPKLinCqFqaFEELU3/XXX09aWhrPP/88ycnJ9O3blxUrVlTdoBcXF4eqnimKKC4u5tlnnyUqKgonJycuueQSvv/+e9zc3Cw+pyUkLAshhBmarmMAFEUxud9aVUEvQ9f1lh2YEEKcpx544AEeeOABk/vWr19f4+OLLrqIw4cPN+qclpCaZSGEMMPT1YVca2syi4t
  420. "text/plain": [
  421. "<Figure size 900x700 with 2 Axes>"
  422. ]
  423. },
  424. "metadata": {},
  425. "output_type": "display_data"
  426. }
  427. ],
  428. "source": [
  429. "import matplotlib.pyplot as plt\n",
  430. "from matplotlib.colors import ListedColormap\n",
  431. "\n",
  432. "cm = plt.cm.RdBu\n",
  433. "cm_bright = ListedColormap([\"#FF0000\", \"#0000FF\"])\n",
  434. "\n",
  435. "xv = np.linspace(x_min, x_max, 10)\n",
  436. "yv = np.linspace(y_min, y_max, 10)\n",
  437. "Xv, Yv = np.meshgrid(xv, yv)\n",
  438. "XYpairs = np.vstack([ Xv.reshape(-1), Yv.reshape(-1)])\n",
  439. "zv = nn.predict(XYpairs.T)\n",
  440. "Zv = zv.reshape(Xv.shape)\n",
  441. "\n",
  442. "fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(9, 7))\n",
  443. "ax.set_aspect(1)\n",
  444. "cn = ax.contourf(Xv, Yv, Zv, cmap=\"coolwarm_r\", alpha=0.4)\n",
  445. "\n",
  446. "ax.scatter(X_train[:, 0], X_train[:, 1], c=y_train, cmap=cm_bright, edgecolors=\"k\")\n",
  447. "\n",
  448. "# Plot the testing points\n",
  449. "ax.scatter(X_test[:, 0], X_test[:, 1], c=y_test, cmap=cm_bright, alpha=0.4, edgecolors=\"k\")\n",
  450. "\n",
  451. "ax.set_xlim(x_min, x_max)\n",
  452. "ax.set_ylim(y_min, y_max)\n",
  453. "# ax.set_xticks(())\n",
  454. "# ax.set_yticks(())\n",
  455. "\n",
  456. "fig.colorbar(cn)\n"
  457. ]
  458. },
  459. {
  460. "cell_type": "code",
  461. "execution_count": null,
  462. "metadata": {},
  463. "outputs": [],
  464. "source": []
  465. }
  466. ],
  467. "metadata": {
  468. "kernelspec": {
  469. "display_name": "Python 3 (ipykernel)",
  470. "language": "python",
  471. "name": "python3"
  472. },
  473. "language_info": {
  474. "codemirror_mode": {
  475. "name": "ipython",
  476. "version": 3
  477. },
  478. "file_extension": ".py",
  479. "mimetype": "text/x-python",
  480. "name": "python",
  481. "nbconvert_exporter": "python",
  482. "pygments_lexer": "ipython3",
  483. "version": "3.10.9"
  484. },
  485. "vscode": {
  486. "interpreter": {
  487. "hash": "b0fa6594d8f4cbf19f97940f81e996739fb7646882a419484c72d19e05852a7e"
  488. }
  489. }
  490. },
  491. "nbformat": 4,
  492. "nbformat_minor": 4
  493. }