|
|
{ "cells": [ { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# A simple neural network with one hidden layer in pure Python" ] }, { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "## A simple neural network class with ReLU activation function" ] }, { "cell_type": "code", "execution_count": 77, "metadata": {}, "outputs": [], "source": [ "# A simple feed-forward neutral network with on hidden layer\n", "# see also https://towardsdatascience.com/how-to-build-your-own-neural-network-from-scratch-in-python-68998a08e4f6\n", "\n", "import numpy as np\n", "\n", "class NeuralNetwork:\n", " def __init__(self, x, y):\n", " n1 = 4 # number of neurons in the hidden layer\n", " self.input = x\n", " self.weights1 = np.random.rand(self.input.shape[1],n1)\n", " self.bias1 = np.random.rand(n1)\n", " self.weights2 = np.random.rand(n1,1)\n", " self.bias2 = np.random.rand(1) \n", " self.y = y\n", " self.output = np.zeros(y.shape)\n", " self.learning_rate = 0.00001\n", " self.n_train = 0\n", " self.loss_history = []\n", "\n", " def relu(self, x):\n", " return np.where(x>0, x, 0)\n", " \n", " def relu_derivative(self, x):\n", " return np.where(x>0, 1, 0)\n", "\n", " def feedforward(self):\n", " self.layer1 = self.relu(self.input @ self.weights1 + self.bias1)\n", " self.output = self.relu(self.layer1 @ self.weights2 + self.bias2)\n", "\n", " def backprop(self):\n", "\n", " # delta1: [m, 1], m = number of training data\n", " delta1 = 2 * (self.y - self.output) * self.relu_derivative(self.output)\n", "\n", " # Gradient w.r.t. weights from hidden to output layer: [n1, 1] matrix, n1 = # neurons in hidden layer\n", " d_weights2 = self.layer1.T @ delta1\n", " d_bias2 = np.sum(delta1) \n", " \n", " # shape of delta2: [m, n1], m = number of training data, n1 = # neurons in hidden layer\n", " delta2 = (delta1 @ self.weights2.T) * self.relu_derivative(self.layer1)\n", " d_weights1 = self.input.T @ delta2\n", " d_bias1 = np.ones(delta2.shape[0]) @ delta2 \n", " \n", " # update weights and biases\n", " self.weights1 += self.learning_rate * d_weights1\n", " self.weights2 += self.learning_rate * d_weights2\n", "\n", " self.bias1 += self.learning_rate * d_bias1\n", " self.bias2 += self.learning_rate * d_bias2\n", "\n", " def train(self, X, y):\n", " self.output = np.zeros(y.shape)\n", " self.input = X\n", " self.y = y\n", " self.feedforward()\n", " self.backprop()\n", " self.n_train += 1\n", " if (self.n_train %1000 == 0):\n", " loss = np.sum((self.y - self.output)**2)\n", " print(\"loss: \", loss)\n", " self.loss_history.append(loss)\n", " \n", " def predict(self, X):\n", " self.output = np.zeros(y.shape)\n", " self.input = X\n", " self.feedforward()\n", " return self.output\n", " \n", " def loss_history(self):\n", " return self.loss_history\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Create toy data\n", "We create three toy data sets\n", "1. two moon-like distributions\n", "2. circles\n", "3. linearly separable data sets" ] }, { "cell_type": "code", "execution_count": 78, "metadata": {}, "outputs": [], "source": [ "# https://scikit-learn.org/stable/auto_examples/classification/plot_classifier_comparison.html#sphx-glr-auto-examples-classification-plot-classifier-comparison-py\n", "import numpy as np\n", "from sklearn.datasets import make_moons, make_circles, make_classification\n", "from sklearn.model_selection import train_test_split\n", "\n", "X, y = make_classification(\n", " n_features=2, n_redundant=0, n_informative=2, random_state=1, n_clusters_per_class=1\n", ")\n", "rng = np.random.RandomState(2)\n", "X += 2 * rng.uniform(size=X.shape)\n", "linearly_separable = (X, y)\n", "\n", "datasets = [\n", " make_moons(n_samples=200, noise=0.1, random_state=0),\n", " make_circles(n_samples=200, noise=0.1, factor=0.5, random_state=1),\n", " linearly_separable,\n", "]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Create training and test data set" ] }, { "cell_type": "code", "execution_count": 79, "metadata": {}, "outputs": [], "source": [ "# datasets: 0 = moons, 1 = circles, 2 = linearly separable\n", "X, y = datasets[1]\n", "X_train, X_test, y_train, y_test = train_test_split(\n", " X, y, test_size=0.4, random_state=42\n", ")\n", "\n", "x_min, x_max = X[:, 0].min() - 0.5, X[:, 0].max() + 0.5\n", "y_min, y_max = X[:, 1].min() - 0.5, X[:, 1].max() + 0.5\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Train the model" ] }, { "cell_type": "code", "execution_count": 80, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "loss: 34.671913102152374\n", "loss: 31.424782860564203\n", "loss: 29.772915496524135\n", "loss: 28.762023680772913\n", "loss: 28.00200726838712\n", "loss: 27.32590942137339\n", "loss: 26.752368734071535\n", "loss: 26.230440689447903\n", "loss: 25.673463509689576\n", "loss: 25.012834504148312\n", "loss: 24.289682045629544\n", "loss: 23.555645514965384\n", "loss: 22.76462670346343\n", "loss: 21.904104226889068\n", "loss: 20.943637847221698\n", "loss: 19.89434572572985\n", "loss: 18.727285500049177\n", "loss: 17.485616842253226\n", "loss: 16.142413632344777\n", "loss: 14.852364407067075\n", "loss: 13.635545514668182\n", "loss: 12.456629856179049\n", "loss: 11.347265353073684\n", "loss: 10.419340643305858\n", "loss: 9.610799938794724\n", "loss: 8.897580679158944\n", "loss: 8.258004600111189\n", "loss: 7.684500186535497\n", "loss: 7.1748549390018574\n", "loss: 6.718209468903557\n", "loss: 6.309864315153381\n", "loss: 5.944399105330259\n", "loss: 5.621553827962666\n", "loss: 5.333909699361839\n", "loss: 5.077286239602076\n", "loss: 4.84889061532151\n", "loss: 4.646022685947024\n", "loss: 4.465758350759858\n", "loss: 4.305913173647123\n", "loss: 4.1640718175126095\n", "loss: 4.039102963682375\n", "loss: 3.9297524332426623\n", "loss: 3.832402269592843\n", "loss: 3.7453438160159322\n", "loss: 3.6674507148484525\n", "loss: 3.5977884037626\n", "loss: 3.535827925326507\n", "loss: 3.479967988008011\n", "loss: 3.4291169910556207\n", "loss: 3.3829982181684475\n", "loss: 3.3416403251567335\n", "loss: 3.304229099924413\n", "loss: 3.2684585556590573\n", "loss: 3.2354701067381977\n", "loss: 3.2051369151544105\n", "loss: 3.1772080296379404\n", "loss: 3.1476337866558435\n", "loss: 3.1193880834750365\n", "loss: 3.0930575696528475\n", "loss: 3.0684279860561725\n", "loss: 3.0453159280825632\n", "loss: 3.023566344470656\n", "loss: 3.0031494958762517\n", "loss: 2.9838385760074786\n", "loss: 2.965539553283577\n", "loss: 2.9481708868131866\n", "loss: 2.9316489223756497\n", "loss: 2.9158898824688726\n", "loss: 2.900839148101508\n", "loss: 2.8864469332954905\n", "loss: 2.8726683402256388\n", "loss: 2.8594627350494104\n", "loss: 2.846788596539098\n", "loss: 2.834699126928559\n", "loss: 2.8238134101918453\n", "loss: 2.8133709056375507\n", "loss: 2.8032597976075766\n", "loss: 2.793428392265398\n", "loss: 2.7839305952541142\n", "loss: 2.7747494784865347\n", "loss: 2.765836907330045\n", "loss: 2.757121848680945\n", "loss: 2.7486569559419625\n", "loss: 2.7404305726008564\n", "loss: 2.732433904640994\n", "loss: 2.724658913060166\n", "loss: 2.717098191391666\n", "loss: 2.709744869114474\n", "loss: 2.7025925339845283\n", "loss: 2.695634579597031\n", "loss: 2.688854332606485\n", "loss: 2.682427203225914\n", "loss: 2.6764899407900025\n", "loss: 2.6707857438883167\n", "loss: 2.6652986297014336\n", "loss: 2.6600185764704065\n", "loss: 2.654938846386348\n", "loss: 2.650048013830516\n", "loss: 2.6434375100435394\n", "loss: 2.6372726876694945\n", "loss: 2.6315142701295082\n", "loss: 2.6260949485933374\n", "loss: 2.620971039483914\n", "loss: 2.61174470987536\n", "loss: 2.5984776428986\n", "loss: 2.587556467070711\n", "loss: 2.5779657263762803\n", "loss: 2.5692396803357767\n", "loss: 2.5611227223903583\n", "loss: 2.5534830205555155\n", "loss: 2.5462350514450636\n", "loss: 2.539331795274396\n", "loss: 2.5327378554215794\n", "loss: 2.5264333984648015\n", "loss: 2.5203974791572867\n", "loss: 2.5146177883938456\n", "loss: 2.5090780321819945\n", "loss: 2.5037668178628807\n", "loss: 2.4986865327324224\n", "loss: 2.4938357429187015\n", "loss: 2.4892125195154855\n", "loss: 2.484814217154528\n", "loss: 2.480624991605507\n", "loss: 2.47663936164734\n", "loss: 2.4728453322029287\n", "loss: 2.469227702733714\n", "loss: 2.465784971369546\n", "loss: 2.4625031729417586\n", "loss: 2.459373586417242\n", "loss: 2.456395742549959\n", "loss: 2.453556149147296\n", "loss: 2.450848516477792\n", "loss: 2.448257826886593\n", "loss: 2.4457793902422473\n", "loss: 2.4434057564290246\n", "loss: 2.4411335977457957\n", "loss: 2.43895711642629\n", "loss: 2.4368736364557093\n", "loss: 2.4348749182486014\n", "loss: 2.432961782234024\n", "loss: 2.4311267156550267\n", "loss: 2.4293686022331507\n", "loss: 2.4276818857227433\n", "loss: 2.4260653790156557\n", "loss: 2.4245172581509893\n", "loss: 2.423035547508019\n", "loss: 2.4216137562678623\n", "loss: 2.420252690046535\n", "loss: 2.4189495881209595\n", "loss: 2.4176996906636115\n", "loss: 2.4165033152293676\n", "loss: 2.415355506487388\n", "loss: 2.414254264583601\n", "loss: 2.4131997637552898\n", "loss: 2.4121900926116586\n", "loss: 2.411220865950546\n", "loss: 2.410291344402852\n", "loss: 2.4094025182811554\n", "loss: 2.4085499517354947\n", "loss: 2.407731441094278\n", "loss: 2.406946144975749\n", "loss: 2.4061932934032764\n", "loss: 2.4054693005333485\n", "loss: 2.4047725159312234\n", "loss: 2.404104507441233\n", "loss: 2.4034628933277684\n", "loss: 2.4028442537087553\n", "loss: 2.4022504549932577\n", "loss: 2.4016803840643326\n", "loss: 2.4011319551366914\n", "loss: 2.4006060782593357\n", "loss: 2.400098897713013\n", "loss: 2.39961231412508\n", "loss: 2.3991458381267075\n", "loss: 2.3986954028524448\n", "loss: 2.3982632819114857\n", "loss: 2.3978461833600178\n", "loss: 2.3974452187294\n", "loss: 2.3970596690136103\n", "loss: 2.396686140489889\n", "loss: 2.396328903510832\n", "loss: 2.3959835555706825\n", "loss: 2.39565059757921\n", "loss: 2.3953300040263765\n", "loss: 2.3950220475433692\n", "loss: 2.3947263525017353\n", "loss: 2.3944396560020063\n", "loss: 2.3941632766670184\n", "loss: 2.3938993096058816\n", "loss: 2.39364294492824\n", "loss: 2.3933965389172958\n", "loss: 2.3931594462236814\n", "loss: 2.392932651756755\n", "loss: 2.392711772531679\n", "loss: 2.3925004166785255\n", "loss: 2.392296083041902\n", "loss: 2.392100842332238\n", "loss: 2.3919111950972347\n", "loss: 2.3917290333380343\n", "loss: 2.391553470675823\n", "loss: 2.3913848103373843\n", "loss: 2.391220583744764\n", "loss: 2.3910649553444348\n", "loss: 2.390914207314057\n", "loss: 2.3907697541809365\n", "loss: 2.390630728668481\n", "loss: 2.390498495256061\n", "loss: 2.390369937885896\n", "loss: 2.390247045664548\n", "loss: 2.3901273102670126\n", "loss: 2.3900148760237014\n", "loss: 2.389905238354612\n", "loss: 2.3898001952359604\n", "loss: 2.3896980625916764\n", "loss: 2.3896010024421126\n", "loss: 2.3895086082689625\n", "loss: 2.389417860716362\n", "loss: 2.3893319903283574\n", "loss: 2.3892491959225457\n", "loss: 2.3891687694660892\n", "loss: 2.3890934197829052\n", "loss: 2.389019136431073\n", "loss: 2.388947788821395\n", "loss: 2.3888805609783272\n", "loss: 2.3888150688462746\n", "loss: 2.3887518707039823\n", "loss: 2.388691911298264\n", "loss: 2.388634014187227\n", "loss: 2.38857805590494\n", "loss: 2.388523918864406\n", "loss: 2.388473510627236\n", "loss: 2.388423436117929\n", "loss: 2.3883753262717216\n", "loss: 2.3883303602155683\n", "loss: 2.3882866724515304\n", "loss: 2.388244442215278\n", "loss: 2.3882033453528138\n", "loss: 2.3881652931929307\n", "loss: 2.38812677875453\n", "loss: 2.388091060218199\n", "loss: 2.388056092752744\n", "loss: 2.3880237467127237\n", "loss: 2.3879919924582405\n", "loss: 2.387960763231075\n", "loss: 2.387931658872838\n", "loss: 2.3879034749475188\n", "loss: 2.387875347245705\n", "loss: 2.387849402809727\n", "loss: 2.387823662709372\n", "loss: 2.387799967311831\n", "loss: 2.387776360266166\n", "loss: 2.387754683818766\n", "loss: 2.3877329827919898\n", "loss: 2.3877121910475196\n", "loss: 2.3876929899702644\n", "loss: 2.3876729319637\n", "loss: 2.387654431642712\n", "loss: 2.387637541283918\n", "loss: 2.387620354216107\n", "loss: 2.387604686551482\n", "loss: 2.3875886398363733\n", "loss: 2.387573984851815\n", "loss: 2.3875588098577873\n", "loss: 2.3875448805324915\n", "loss: 2.387532242680253\n", "loss: 2.3875189700442068\n", "loss: 2.387506861417358\n", "loss: 2.387495291770735\n", "loss: 2.387484231905034\n", "loss: 2.3874731451723217\n", "loss: 2.387462274358618\n", "loss: 2.3874519373951664\n", "loss: 2.3874425920575746\n", "loss: 2.3874333869787168\n", "loss: 2.3874249670575987\n", "loss: 2.3874162377994788\n", "loss: 2.3874073874920074\n", "loss: 2.3873990135367182\n", "loss: 2.3873914786024937\n", "loss: 2.387384759212198\n", "loss: 2.387377049938677\n", "loss: 2.387371052646234\n", "loss: 2.387363939101346\n", "loss: 2.3873584890193866\n", "loss: 2.3873519773705674\n", "loss: 2.3873461623087593\n", "loss: 2.387341023755127\n", "loss: 2.387336250114317\n", "loss: 2.3873309349740612\n", "loss: 2.387325950896895\n", "loss: 2.3873215752647576\n", "loss: 2.3873160311256796\n", "loss: 2.3873128208612444\n", "loss: 2.3873084158082647\n", "loss: 2.387304558605306\n", "loss: 2.3872995395407717\n", "loss: 2.387296679993435\n", "loss: 2.387292629417728\n", "loss: 2.3872890739934816\n", "loss: 2.387285997937587\n", "loss: 2.387283389935851\n", "loss: 2.3872794993942383\n", "loss: 2.387277790051124\n", "loss: 2.387274775181382\n", "loss: 2.387272183856358\n", "loss: 2.387269056421407\n", "loss: 2.3872664925693012\n", "loss: 2.387264376208358\n", "loss: 2.3872623679945018\n", "loss: 2.3872600043678207\n", "loss: 2.3872580041281823\n", "loss: 2.387256273384648\n", "loss: 2.3872533408675327\n", "loss: 2.3872523812347737\n", "loss: 2.3872500233166005\n", "loss: 2.387247988754982\n", "loss: 2.387246267292523\n", "loss: 2.3872448527441414\n", "loss: 2.387243735408723\n", "loss: 2.387242025422281\n", "loss: 2.3872406504464334\n", "loss: 2.3872386748070245\n", "loss: 2.3872386165309805\n", "loss: 2.387237233810461\n", "loss: 2.3872360479643593\n", "loss: 2.38723476536898\n", "loss: 2.387232724485056\n", "loss: 2.3872322755810025\n", "loss: 2.3872308935896562\n", "loss: 2.387230374354215\n", "loss: 2.3872289131536952\n", "loss: 2.387227666907803\n", "loss: 2.387227223159598\n", "loss: 2.3872263094643955\n", "loss: 2.3872251968418357\n", "loss: 2.387224775014026\n", "loss: 2.387224547860667\n", "loss: 2.3872230838809783\n", "loss: 2.3872229638510785\n", "loss: 2.387221601841123\n", "loss: 2.3872221053237612\n", "loss: 2.3872210930668896\n", "loss: 2.3872202490512318\n", "loss: 2.3872195705964097\n", "loss: 2.387219054999678\n", "loss: 2.387218694718085\n", "loss: 2.3872184895236703\n", "loss: 2.3872184361150453\n", "loss: 2.3872168420512097\n", "loss: 2.38721707828012\n", "loss: 2.3872157746881033\n", "loss: 2.3872162871830027\n", "loss: 2.3872152565137865\n", "loss: 2.3872156550203423\n", "loss: 2.387215262058789\n", "loss: 2.3872146171203603\n", "loss: 2.3872140924465817\n", "loss: 2.387213692303668\n", "loss: 2.3872134068601816\n", "loss: 2.3872132382331945\n", "loss: 2.3872131800547116\n", "loss: 2.3872132336735374\n", "loss: 2.3872122046766457\n", "loss: 2.387211985620091\n", "loss: 2.3872123543033004\n", "loss: 2.3872111512061682\n", "loss: 2.3872117185255872\n", "loss: 2.387210711934733\n", "loss: 2.3872114695614397\n", "loss: 2.3872106483724744\n", "loss: 2.3872106246269746\n", "loss: 2.387210949823621\n", "loss: 2.3872103921805325\n", "loss: 2.387209923070608\n", "loss: 2.387209539035157\n", "loss: 2.387209233578181\n", "loss: 2.387209278909502\n", "loss: 2.3872091575872285\n", "loss: 2.3872087902003143\n", "loss: 2.387208796135135\n", "loss: 2.387208863491243\n", "loss: 2.3872090209788177\n", "loss: 2.3872092395967757\n", "loss: 2.387209509667974\n", "loss: 2.3872081941411776\n", "loss: 2.387208614493238\n", "loss: 2.387209092011758\n", "loss: 2.3872079707116916\n", "loss: 2.387208566693145\n", "loss: 2.3872075814397924\n", "loss: 2.3872082912786885\n", "loss: 2.3872074055895007\n", "loss: 2.3872082394740444\n", "loss: 2.3872074711869145\n", "loss: 2.38720841308389\n", "loss: 2.3872077482796317\n", "loss: 2.3872071369801997\n", "loss: 2.387208234823193\n", "loss: 2.3872077218609924\n", "loss: 2.3872072579088783\n", "loss: 2.387206844007732\n", "loss: 2.3872079496309033\n", "loss: 2.3872078102348167\n", "loss: 2.387207529575949\n", "loss: 2.387207296171491\n", "loss: 2.3872071039581373\n", "loss: 2.38720695276087\n", "loss: 2.3872068462090654\n", "loss: 2.387206777400199\n", "loss: 2.3872067479488583\n", "loss: 2.3872067565550017\n", "loss: 2.3872068036021385\n", "loss: 2.3872068874369434\n", "loss: 2.3872070087213757\n", "loss: 2.3872071655473297\n", "loss: 2.3872073550964092\n", "loss: 2.387207093581078\n", "loss: 2.3872061834452403\n", "loss: 2.387206473681233\n", "loss: 2.3872067944256044\n", "loss: 2.387207150189065\n", "loss: 2.3872065584074114\n", "loss: 2.3872062931664226\n", "loss: 2.3872067377036283\n", "loss: 2.387207211254447\n", "loss: 2.3872060575595855\n", "loss: 2.3872065866005077\n", "loss: 2.387207143362015\n", "loss: 2.387206071870889\n", "loss: 2.3872066797178375\n", "loss: 2.3872065479626627\n", "loss: 2.387206319019361\n", "loss: 2.3872070019550087\n", "loss: 2.3872060569372797\n", "loss: 2.38720678668193\n", "loss: 2.387205887930111\n", "loss: 2.387206664076764\n", "loss: 2.3872058107150407\n", "loss: 2.3872066293892216\n", "loss: 2.387205818746259\n", "loss: 2.3872066801575853\n", "loss: 2.3872059111197688\n", "loss: 2.3872068126406196\n", "loss: 2.387206081035995\n", "loss: 2.3872068669685325\n", "loss: 2.3872063280795652\n", "loss: 2.3872056557817367\n", "loss: 2.3872066491216897\n", "loss: 2.387206011860928\n", "loss: 2.387206535957502\n", "loss: 2.3872064376325017\n", "loss: 2.3872058508160308\n", "loss: 2.3872069163305536\n", "loss: 2.3872063735393425\n", "loss: 2.387205832345839\n", "loss: 2.387206677283909\n", "loss: 2.3872064498889176\n", "loss: 2.387205955600857\n", "loss: 2.387205820939399\n", "loss: 2.387206660636184\n", "loss: 2.387206206604038\n", "loss: 2.38720576900394\n", "loss: 2.3872063117983804\n", "loss: 2.3872065813147874\n", "loss: 2.387206181466505\n", "loss: 2.3872057946984633\n", "loss: 2.3872058708004715\n", "loss: 2.387206706121622\n", "loss: 2.3872063554658762\n", "loss: 2.3872060160821453\n", "loss: 2.3872056878815267\n", "loss: 2.3872059982983496\n", "loss: 2.387206714970821\n", "loss: 2.3872064188808673\n", "loss: 2.3872061334063455\n", "loss: 2.3872058573779524\n", "loss: 2.3872055929594205\n", "loss: 2.3872060525929197\n", "loss: 2.3872067406100412\n", "loss: 2.387206502734586\n", "loss: 2.3872062762370385\n", "loss: 2.387206057098045\n", "loss: 2.3872058489669072\n", "loss: 2.3872056488912214\n", "loss: 2.3872054563548764\n", "loss: 2.3872062535038063\n", "loss: 2.387206745595907\n", "loss: 2.3872065783271466\n", "loss: 2.38720641696552\n", "loss: 2.387206265759137\n", "loss: 2.387206119480711\n", "loss: 2.3872059828172514\n", "loss: 2.387205852070047\n", "loss: 2.387205730140802\n", "loss: 2.387205613674655\n", "loss: 2.3872055054664263\n", "loss: 2.3872055703172705\n", "loss: 2.387206001531239\n", "loss: 2.3872064059032465\n", "loss: 2.3872067748722134\n", "loss: 2.387206704414516\n", "loss: 2.387206632138044\n", "loss: 2.3872065672351592\n", "loss: 2.3872065077955\n", "loss: 2.387206453507227\n", "loss: 2.387206405811236\n", "loss: 2.3872063624448145\n", "loss: 2.387206324770439\n", "loss: 2.387206293044411\n", "loss: 2.387206266172666\n", "loss: 2.387206246548697\n", "loss: 2.3872062269091705\n", "loss: 2.3872062166383956\n", "loss: 2.3872062117393194\n", "loss: 2.3872062110609655\n", "loss: 2.3872062139535695\n", "loss: 2.387206219909649\n", "loss: 2.3872062285052236\n", "loss: 2.387206240576973\n", "loss: 2.3872062606586537\n", "loss: 2.3872062832874716\n", "loss: 2.3872063082635315\n", "loss: 2.3872063391881677\n", "loss: 2.387206373803312\n", "loss: 2.3872064127175974\n", "loss: 2.3872064540621514\n", "loss: 2.3872064987956327\n", "loss: 2.3872065498963284\n", "loss: 2.3872066013767226\n", "loss: 2.387206657950146\n", "loss: 2.3872067182049292\n", "loss: 2.387206672867762\n", "loss: 2.3872063672970487\n", "loss: 2.387206055011894\n", "loss: 2.3872057237905118\n", "loss: 2.3872054165649037\n", "loss: 2.3872054954444497\n", "loss: 2.387205575891558\n", "loss: 2.3872056608606287\n", "loss: 2.3872057469958534\n", "loss: 2.387205835730483\n", "loss: 2.3872059291294483\n", "loss: 2.387206023172342\n", "loss: 2.387206121095825\n", "loss: 2.3872062219751333\n", "loss: 2.3872063233856755\n", "loss: 2.387206428473149\n", "loss: 2.3872065378135736\n", "loss: 2.387206647283522\n", "loss: 2.387206749735465\n", "loss: 2.3872062240356984\n", "loss: 2.38720569664805\n", "loss: 2.3872054625336876\n", "loss: 2.387205583157073\n", "loss: 2.3872057046336925\n", "loss: 2.387205829508387\n", "loss: 2.387205957406746\n", "loss: 2.387206085740676\n", "loss: 2.3872062170332273\n", "loss: 2.387206349757113\n", "loss: 2.387206484077489\n", "loss: 2.387206620076463\n", "loss: 2.387206742599622\n", "loss: 2.38720610614635\n", "loss: 2.3872054598822556\n", "loss: 2.3872055360296702\n", "loss: 2.3872056808451143\n", "loss: 2.3872058264733313\n", "loss: 2.3872059747774506\n", "loss: 2.387206123890669\n", "loss: 2.3872062756016943\n", "loss: 2.387206426949868\n", "loss: 2.3872065819415873\n", "loss: 2.387206737478283\n", "loss: 2.387206126865435\n", "loss: 2.3872054054434018\n", "loss: 2.3872055638677026\n", "loss: 2.387205724088219\n", "loss: 2.38720588669721\n", "loss: 2.387206049768166\n", "loss: 2.3872062148316706\n", "loss: 2.3872063803883172\n", "loss: 2.3872065476541806\n", "loss: 2.3872067152675234\n", "loss: 2.3872061646400056\n", "loss: 2.3872054078993044\n", "loss: 2.3872055802952357\n", "loss: 2.3872057516280947\n", "loss: 2.3872059245622954\n", "loss: 2.3872060978440746\n", "loss: 2.387206273640891\n", "loss: 2.387206448981906\n", "loss: 2.3872066283894027\n", "loss: 2.387206516906132\n", "loss: 2.387205699504398\n", "loss: 2.3872055196151685\n", "loss: 2.3872056996032054\n", "loss: 2.3872058817924033\n", "loss: 2.387206063764966\n", "loss: 2.387206248517262\n", "loss: 2.387206432649466\n", "loss: 2.3872066179179283\n", "loss: 2.387206522287034\n", "loss: 2.3872056702282123\n", "loss: 2.3872055324647627\n", "loss: 2.38720572034085\n", "loss: 2.3872059096035927\n", "loss: 2.3872060985335164\n", "loss: 2.387206289705417\n", "loss: 2.387206481657449\n", "loss: 2.3872066741473743\n", "loss: 2.387206243340186\n", "loss: 2.387205412579381\n", "loss: 2.387205607172106\n", "loss: 2.3872058001813334\n", "loss: 2.3872059950155946\n", "loss: 2.3872061909696214\n", "loss: 2.387206387319673\n", "loss: 2.3872065845641055\n", "loss: 2.387206628280637\n", "loss: 2.3872057223715246\n", "loss: 2.3872055317401317\n", "loss: 2.387205729959452\n", "loss: 2.3872059297733506\n", "loss: 2.3872061287534345\n", "loss: 2.3872063300590187\n", "loss: 2.3872065319453286\n", "loss: 2.3872067329639384\n", "loss: 2.3872059293532697\n", "loss: 2.3872054909066573\n", "loss: 2.3872056929994176\n", "loss: 2.3872058959012414\n", "loss: 2.3872060991225794\n", "loss: 2.387206303755043\n", "loss: 2.3872065090776227\n", "loss: 2.387206713881384\n", "loss: 2.3872060034612277\n", "loss: 2.387205479096127\n", "loss: 2.3872056842701372\n", "loss: 2.3872058905037483\n", "loss: 2.3872060971157936\n", "loss: 2.3872063046852645\n", "loss: 2.3872065130746725\n", "loss: 2.387206720682877\n", "loss: 2.3872059543810957\n", "loss: 2.3872054920341013\n", "loss: 2.387205699860994\n", "loss: 2.387205909361374\n", "loss: 2.387206118592575\n", "loss: 2.387206328727607\n", "loss: 2.38720654022318\n", "loss: 2.3872067496108995\n", "loss: 2.38720580949524\n", "loss: 2.387205525406542\n", "loss: 2.3872057364936534\n", "loss: 2.3872059484123977\n", "loss: 2.387206158732515\n", "loss: 2.3872063716998198\n", "loss: 2.387206584128892\n", "loss: 2.3872065550645125\n", "loss: 2.387205585137984\n", "loss: 2.387205576467598\n", "loss: 2.3872057895930716\n", "loss: 2.3872060024345876\n", "loss: 2.387206217104482\n", "loss: 2.38720643053896\n", "loss: 2.3872066462470087\n", "loss: 2.3872062650681847\n", "loss: 2.3872054287829734\n", "loss: 2.387205644040444\n", "loss: 2.387205857402579\n", "loss: 2.3872060732060456\n", "loss: 2.3872062890605856\n", "loss: 2.3872065058667893\n", "loss: 2.387206721085871\n", "loss: 2.3872059173144433\n", "loss: 2.3872055071105676\n", "loss: 2.387205722800243\n", "loss: 2.387205939850607\n", "loss: 2.387206155334122\n", "loss: 2.3872063740586773\n", "loss: 2.38720659063653\n", "loss: 2.3872065025409945\n", "loss: 2.3872055102865852\n", "loss: 2.387205597163091\n", "loss: 2.3872058144515096\n", "loss: 2.3872060323445647\n", "loss: 2.38720625026218\n", "loss: 2.387206468189427\n", "loss: 2.3872066869296926\n", "loss: 2.387206064964853\n", "loss: 2.38720547798433\n", "loss: 2.387205695647846\n", "loss: 2.3872059148036042\n", "loss: 2.387206132165185\n", "loss: 2.3872063520034\n", "loss: 2.387206571675985\n", "loss: 2.387206580623468\n", "loss: 2.387205582588098\n", "loss: 2.3872055839413213\n", "loss: 2.3872058020374514\n", "loss: 2.3872060214815694\n", "loss: 2.38720624240693\n", "loss: 2.387206462370405\n", "loss: 2.387206683255392\n", "loss: 2.3872060762465175\n", "loss: 2.3872054775249913\n", "loss: 2.387205696602586\n", "loss: 2.3872059172419493\n", "loss: 2.3872061373373787\n", "loss: 2.387206357728515\n", "loss: 2.387206579560316\n", "loss: 2.387206536499058\n", "loss: 2.387205534906743\n", "loss: 2.387205596098073\n", "loss: 2.387205817218552\n", "loss: 2.387206037559115\n", "loss: 2.3872062589616636\n", "loss: 2.3872064810746823\n", "loss: 2.387206702376931\n", "loss: 2.3872059789399236\n", "loss: 2.3872054989934695\n", "loss: 2.3872057193198266\n", "loss: 2.3872059415304925\n", "loss: 2.387206163057817\n", "loss: 2.387206385045773\n", "loss: 2.3872066064853166\n", "loss: 2.38720640924132\n", "loss: 2.3872054050281237\n", "loss: 2.3872056266849664\n", "loss: 2.3872058475459403\n", "loss: 2.387206070207055\n", "loss: 2.387206292224948\n", "loss: 2.3872065152244177\n", "loss: 2.3872067373632193\n", "loss: 2.387205812795538\n", "loss: 2.387205535628719\n", "loss: 2.3872057581142827\n", "loss: 2.387205979871331\n", "loss: 2.387206202847354\n", "loss: 2.38720642492102\n", "loss: 2.3872066486900834\n", "loss: 2.3872062143940815\n", "loss: 2.3872054479581615\n", "loss: 2.3872056700624618\n", "loss: 2.387205892745017\n", "loss: 2.3872061156322637\n", "loss: 2.3872063392335336\n", "loss: 2.3872065621484544\n", "loss: 2.3872066127839506\n", "loss: 2.387205593010971\n", "loss: 2.3872055852578473\n", "loss: 2.3872058079900116\n", "loss: 2.3872060312398813\n", "loss: 2.387206254831415\n", "loss: 2.387206477626406\n", "loss: 2.3872067013641285\n", "loss: 2.387205974948894\n", "loss: 2.3872055015373546\n", "loss: 2.387205724165402\n", "loss: 2.387205948609596\n", "loss: 2.3872061711547934\n", "loss: 2.3872063942990995\n", "loss: 2.387206617976577\n", "loss: 2.3872063484996606\n", "loss: 2.387205419784946\n", "loss: 2.387205643696668\n", "loss: 2.38720586667547\n", "loss: 2.387206089943084\n", "loss: 2.387206312830524\n", "loss: 2.3872065377208713\n", "loss: 2.3872067192036486\n", "loss: 2.3872056983399608\n", "loss: 2.387205562202803\n", "loss: 2.3872057863378924\n", "loss: 2.387206009488842\n", "loss: 2.3872062344077247\n", "loss: 2.3872064585590174\n", "loss: 2.3872066828857177\n", "loss: 2.387206060927091\n", "loss: 2.387205483917369\n", "loss: 2.387205706729346\n", "loss: 2.3872059309762115\n", "loss: 2.387206154167569\n", "loss: 2.387206379212923\n", "loss: 2.387206603498562\n", "loss: 2.387206412786183\n", "loss: 2.38720540594447\n", "loss: 2.387205630194073\n", "loss: 2.3872058532972202\n", "loss: 2.387206078179721\n", "loss: 2.3872063011757954\n", "loss: 2.3872065261212008\n", "loss: 2.387206750494888\n", "loss: 2.38720574408085\n", "loss: 2.3872055526202605\n", "loss: 2.387205776875942\n", "loss: 2.3872060000469304\n", "loss: 2.3872062251849693\n", "loss: 2.3872064490769467\n", "loss: 2.38720667505617\n", "loss: 2.387206092621105\n", "loss: 2.3872054778245344\n", "loss: 2.387205700855323\n", "loss: 2.387205925499759\n", "loss: 2.387206149256795\n", "loss: 2.3872063747392476\n", "loss: 2.3872065982742674\n", "loss: 2.387206435850997\n", "loss: 2.3872054149088293\n", "loss: 2.3872056262571535\n", "loss: 2.387205849166078\n", "loss: 2.3872060742134376\n", "loss: 2.3872062979103035\n", "loss: 2.387206523614174\n", "loss: 2.3872067479852594\n", "loss: 2.3872057530264144\n", "loss: 2.387205550630556\n", "loss: 2.3872057753165112\n", "loss: 2.387205999046393\n", "loss: 2.3872062242513152\n", "loss: 2.387206448562063\n", "loss: 2.3872066745566825\n", "loss: 2.3872060933404238\n", "loss: 2.387205477652032\n", "loss: 2.38720570096479\n", "loss: 2.3872059260253913\n", "loss: 2.387206149856662\n", "loss: 2.3872063756569637\n", "loss: 2.3872066002677177\n", "loss: 2.3872064248665876\n", "loss: 2.3872054055763066\n", "loss: 2.3872056285152192\n", "loss: 2.387205852129036\n", "loss: 2.387206076714375\n", "loss: 2.3872063008900715\n", "loss: 2.3872065262982316\n", "loss: 2.3872067510414143\n", "loss: 2.387205738066258\n", "loss: 2.387205553932228\n", "loss: 2.3872057787560985\n", "loss: 2.387206002620486\n", "loss: 2.3872062278227126\n", "loss: 2.387206452219362\n", "loss: 2.3872066785656476\n", "loss: 2.3872060756651354\n", "loss: 2.3872054817901702\n", "loss: 2.387205705294623\n", "loss: 2.387205930184685\n", "loss: 2.3872061542181156\n", "loss: 2.387206380162742\n", "loss: 2.3872066048043825\n", "loss: 2.3872064037912457\n", "loss: 2.387205408672593\n", "loss: 2.3872056335074885\n", "loss: 2.38720585704452\n", "loss: 2.3872060825534622\n", "loss: 2.3872063070962546\n", "loss: 2.3872065331763173\n", "loss: 2.387206736126897\n", "loss: 2.3872057102550297\n", "loss: 2.387205560708403\n", "loss: 2.387205785698165\n", "loss: 2.3872060098092085\n", "loss: 2.3872062357191077\n", "loss: 2.3872064606568575\n", "loss: 2.3872066856260075\n", "loss: 2.3872060412743403\n", "loss: 2.3872054896579566\n", "loss: 2.387205713503783\n", "loss: 2.387205938855604\n", "loss: 2.3872061633253967\n", "loss: 2.3872063880939898\n", "loss: 2.387206613272621\n", "loss: 2.3872063644496766\n", "loss: 2.387205417523893\n", "loss: 2.387205642655334\n", "loss: 2.3872058668390554\n", "loss: 2.3872060910716635\n", "loss: 2.3872063157741197\n", "loss: 2.387206542322383\n", "loss: 2.387206692745429\n", "loss: 2.3872056660868344\n", "loss: 2.3872055706534514\n", "loss: 2.3872057946635303\n", "loss: 2.387206019070375\n", "loss: 2.387206245277133\n", "loss: 2.387206470316894\n", "loss: 2.3872066957452835\n", "loss: 2.3872059933352743\n", "loss: 2.3872054990132523\n", "loss: 2.387205722878048\n", "loss: 2.387205948575221\n", "loss: 2.387206173353647\n", "loss: 2.3872063984591825\n", "loss: 2.3872066226559827\n", "loss: 2.3872063207497023\n", "loss: 2.3872054272954717\n", "loss: 2.3872056526531233\n", "loss: 2.3872058770925495\n", "loss: 2.3872061007027394\n", "loss: 2.387206325498709\n", "loss: 2.387206552189155\n", "loss: 2.3872066470668103\n", "loss: 2.387205618564661\n", "loss: 2.387205580450609\n", "loss: 2.3872058044940334\n", "loss: 2.3872060289891905\n", "loss: 2.387206255311463\n", "loss: 2.387206480708284\n", "loss: 2.387206705586487\n", "loss: 2.3872059484852777\n", "loss: 2.387205508896236\n", "loss: 2.387205733053314\n", "loss: 2.3872059590012453\n", "loss: 2.38720618326989\n", "loss: 2.387206408350126\n", "loss: 2.3872066339450995\n", "loss: 2.3872062743922324\n", "loss: 2.3872054376631935\n", "loss: 2.3872056632364425\n", "loss: 2.3872058871467248\n", "loss: 2.387206111855699\n", "loss: 2.387206337069455\n", "loss: 2.3872065625907197\n", "loss: 2.387206598875525\n", "loss: 2.3872055726406853\n", "loss: 2.3872055914219663\n", "loss: 2.3872058158103724\n", "loss: 2.3872060406633677\n", "loss: 2.3872062658126643\n", "loss: 2.3872064913732993\n", "loss: 2.3872067166127193\n", "loss: 2.38720589721323\n", "loss: 2.3872055203865594\n", "loss: 2.3872057448636044\n", "loss: 2.387205969632241\n", "loss: 2.3872061948051737\n", "loss: 2.387206419656803\n", "loss: 2.3872066454234093\n", "loss: 2.38720622109226\n", "loss: 2.3872054496058155\n", "loss: 2.387205673994574\n", "loss: 2.387205898781459\n", "loss: 2.3872061232475654\n", "loss: 2.3872063486162522\n", "loss: 2.3872065743640327\n", "loss: 2.38720654403819\n", "loss: 2.3872055202172575\n", "loss: 2.3872056032796642\n", "loss: 2.3872058273645207\n", "loss: 2.387206052339985\n", "loss: 2.3872062776900504\n", "loss: 2.3872065035141774\n", "loss: 2.3872067282141374\n", "loss: 2.3872058427484215\n", "loss: 2.387205531994678\n", "loss: 2.387205756582158\n", "loss: 2.387205981539894\n", "loss: 2.3872062069646356\n", "loss: 2.387206431282743\n", "loss: 2.387206657412996\n", "loss: 2.3872061679858723\n", "loss: 2.3872054613328837\n", "loss: 2.3872056859037873\n", "loss: 2.3872059109348127\n", "loss: 2.387206134876376\n", "loss: 2.387206360603658\n", "loss: 2.3872065860311924\n", "loss: 2.3872064904558536\n", "loss: 2.3872054658487727\n", "loss: 2.3872056154159624\n", "loss: 2.387205838986268\n", "loss: 2.3872060643164303\n", "loss: 2.38720629000679\n", "loss: 2.387206515465516\n", "loss: 2.3872067402918593\n", "loss: 2.3872057868542456\n", "loss: 2.3872055436974926\n", "loss: 2.387205768622451\n", "loss: 2.3872059939050634\n", "loss: 2.3872062189566643\n", "loss: 2.387206443389059\n", "loss: 2.387206669729105\n", "loss: 2.387206114421231\n", "loss: 2.3872054733876444\n", "loss: 2.3872056982757286\n", "loss: 2.3872059236497964\n", "loss: 2.3872061471739836\n", "loss: 2.38720637304622\n", "loss: 2.387206597910911\n", "loss: 2.3872064349529696\n", "loss: 2.3872054090077848\n", "loss: 2.3872056281842866\n", "loss: 2.387205851329026\n", "loss: 2.387206076790427\n", "loss: 2.387206301261524\n", "loss: 2.387206527494985\n", "loss: 2.3872067526753113\n", "loss: 2.387205728959938\n" ] } ], "source": [ "y_train = y_train.reshape(-1, 1)\n", "\n", "nn = NeuralNetwork(X_train, y_train)\n", "\n", "for i in range(1000000):\n", " nn.train(X_train, y_train)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Plot the loss vs. the number of epochs" ] }, { "cell_type": "code", "execution_count": 83, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "Text(0, 0.5, 'loss')" ] }, "execution_count": 83, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAjIAAAGwCAYAAACzXI8XAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy89olMNAAAACXBIWXMAAA9hAAAPYQGoP6dpAAA3LklEQVR4nO3de3RU5b3/8c/kMpOEZCYkkISYBBEs91CLiBHlaEEQEG+0tspp8RwvCxsvgD+xtGq1rQY9p1Y9pbTHC+iqyCkVsN6gCBKqBpRL5CKNBqKgkKBAMiGQ6zy/P8IMjFydzJ6dCe/XWnuR7L0z852HBfms5/nuvR3GGCMAAIAoFGN3AQAAAKEiyAAAgKhFkAEAAFGLIAMAAKIWQQYAAEQtggwAAIhaBBkAABC14uwuwGo+n0+7du1SSkqKHA6H3eUAAIDTYIxRbW2tsrOzFRNz4nmXDh9kdu3apdzcXLvLAAAAIdi5c6dycnJOeLzDB5mUlBRJrQPhdrttrgYAAJwOr9er3NzcwO/xE+nwQca/nOR2uwkyAABEmVO1hdDsCwAAohZBBgAARC2CDAAAiFoEGQAAELUIMgAAIGoRZAAAQNQiyAAAgKhla5CZPXu28vPzA/d4KSgo0FtvvRU4fumll8rhcARtkydPtrFiAADQnth6Q7ycnBzNnDlT5557rowxeuGFF3T11Vdrw4YN6t+/vyTp1ltv1a9//evAzyQlJdlVLgAAaGdsDTLjx48P+v6RRx7R7NmztXr16kCQSUpKUlZWlh3lAQCAdq7d9Mi0tLRo/vz5qqurU0FBQWD/Sy+9pC5dumjAgAGaMWOGDh48eNLXaWhokNfrDdoAAEDHZPuzljZt2qSCggLV19crOTlZixYtUr9+/SRJN954o7p3767s7Gxt3LhR9913n8rKyrRw4cITvl5RUZEefvjhSJUPAABs5DDGGDsLaGxs1I4dO1RTU6O//e1vevbZZ1VcXBwIM0dbsWKFRowYofLycvXs2fO4r9fQ0KCGhobA9/6nZ9bU1IT1oZHVBxtVW98sd2K8PInxYXtdAADQ+vvb4/Gc8ve37UtLTqdTvXr10uDBg1VUVKRBgwbpqaeeOu65Q4cOlSSVl5ef8PVcLlfgKigrn3j92JJ/6ZLH39GL739myesDAIBTsz3IfJPP5wuaUTlaaWmpJKlbt24RrOj4YmNaHyveYu+EFgAAZzRbe2RmzJihMWPGKC8vT7W1tZo3b55WrlyppUuXatu2bZo3b57Gjh2r9PR0bdy4UVOnTtXw4cOVn59vZ9mSpFjH4SDjI8gAAGAXW4PMnj179NOf/lS7d++Wx+NRfn6+li5dqssvv1w7d+7U22+/rSeffFJ1dXXKzc3VhAkTdP/999tZckBsTOtkFkEGAAD72BpknnvuuRMey83NVXFxcQSr+XZiDy/KEWQAALBPu+uRiRYxMSwtAQBgN4JMiOJo9gUAwHYEmRDR7AsAgP0IMiGi2RcAAPsRZEJEsy8AAPYjyISIZl8AAOxHkAlRHEEGAADbEWRCFOPgqiUAAOxGkAkRMzIAANiPIBOiWIIMAAC2I8iEiMuvAQCwH0EmRFx+DQCA/QgyIaLZFwAA+xFkQhQXS48MAAB2I8iEKIZnLQEAYDuCTIjiaPYFAMB2BJkQ0ewLAID9CDIhotkXAAD7EWRCRLMvAAD2I8iEiGZfAADsR5AJEc2+AADYjyATohiafQEAsB1BJkSxNPsCAGA7gkyIaPYFAMB+BJkQ0ewLAID9CDIh8jf7+ggyAADYhiATIn+zbzNBBgAA2xBkQhSYkaHZFwAA2xBkQhTLjAwAALYjyIQo1n9DvBaCDAAAdiHIhCgupvWqpSafz+ZKAAA4cxFkQuSKax26xmaCDAAAdiHIhMgZ52/2lZpbCDMAANiBIBOi+NgjQ9dEnwwAALYgyITIPyMjsbwEAIBdCDIhiotx6PBTCtTQ0mJvMQAAnKEIMiFyOByB5SWWlgAAsAdBpg1csVy5BACAnQgybeDkEmwAAGxFkGmDI0tLBBkAAOxAkGkD/4xMAzMyAADYwtYgM3v2bOXn58vtdsvtdqugoEBvvfVW4Hh9fb0KCwuVnp6u5ORkTZgwQVVVVTZWHIylJQAA7GVrkMnJydHMmTO1bt06rV27Vt///vd19dVXa8uWLZKkqVOn6rXXXtOCBQtUXFysXbt26brrrrOz5CD+paVGlpYAALBFnJ1vPn78+KDvH3nkEc2ePVurV69WTk6OnnvuOc2bN0/f//73JUlz5sxR3759tXr1al144YV2lBzEPyPTxIwMAAC2aDc9Mi0tLZo/f77q6upUUFCgdevWqampSSNHjgyc06dPH+Xl5amkpOSEr9PQ0CCv1xu0WcXFjAwAALayPchs2rRJycnJcrlcmjx5shYtWqR+/fqpsrJSTqdTqampQednZmaqsrLyhK9XVFQkj8cT2HJzcy2rPT6u9da+9MgAAGAP24NM7969VVpaqjVr1uj222/XpEmT9PHHH4f8ejNmzFBNTU1g27lzZxirDebkhngAANjK1h4ZSXI6nerVq5ckafDgwfrwww/11FNP6Uc/+pEaGxtVXV0dNCtTVVWlrKysE76ey+WSy+WyumxJUkJ8rCSpvplnLQEAYAfbZ2S+yefzqaGhQYMHD1Z8fLyWL18eOFZWVqYdO3aooKDAxgqPSHS2BpmDjQQZAADsYOuMzIwZMzRmzBjl5eWptrZW8+bN08qVK7V06VJ5PB7dfPPNmjZtmtLS0uR2u3XnnXeqoKCgXVyxJElJBBkAAGxla5DZs2ePfvrTn2r37t3yeDzKz8/X0qVLdfnll0uSfv/73ysmJkYTJkxQQ0ODRo8erT/+8Y92lhwkydk6fIcam22uBACAM5OtQea555476fGEhATNmjVLs2bNilBF305iPDMyAADYqd31yEQT/9LSIYIMAAC2IMi0AT0yAADYiyDTBomHe2QONhFkAACwA0GmDY4sLdHsCwCAHQgybcB9ZAAAsBdBpg06+ZeWCDIAANiCINMGKQmtQaa2vsnmSgAAODMRZNrAH2S8h+iRAQDADgSZNnAnxkuSGlt8qufKJQAAIo4g0wbJzjg5HK1fe1leAgAg4ggybRAT41Cyy98nw/ISAACRRpBpI3dC6/KS9xAzMgAARBpBpo2OXLnEjAwAAJFGkGkj/4wMQQYAgMgjyLRR4BJsmn0BAIg4gkwb+S/B5qZ4AABEHkGmjbgpHgAA9iHItNGRHhlmZAAAiDSCTBsd6ZFhRgYAgEgjyLQRPTIAANiHINNG/qWlGm6IBwBAxBFk2ig1qTXI7D9IkAEAINIIMm3UOckpSao+2GhzJQAAnHkIMm3UuVPrjEz1wSYZY2yuBgCAMwtBpo38MzLNPqPaBq5cAgAgkggybZQQH6uE+NZh3F/H8hIAAJFEkAmDtMOzMjT8AgAQWQSZMEgNBBlmZAAAiCSCTBgcafglyAAAEEkEmTAIzMjUsbQEAEAkEWTCII2lJQAAbEGQCYPOgbv7EmQAAIgkgkwYpHLVEgAAtiDIhAHNvgAA2IMgEwY0+wIAYA+CTBjQ7AsAgD0IMmHQmSADAIAtCDJhkHq4R6a+yaf6phabqwEA4MxBkAmDFFec4mIckpiVAQAgkggyYeBwOGj4BQDABgSZMPHfFI9LsAEAiBxbg0xRUZGGDBmilJQUZWRk6JprrlFZWVnQOZdeeqkcDkfQNnnyZJsqPjF/w+8+ggwAABFja5ApLi5WYWGhVq9erWXLlqmpqUmjRo1SXV1d0Hm33nqrdu/eHdgef/xxmyo+sdTAYwpYWgIAIFLi7HzzJUuWBH0/d+5cZWRkaN26dRo+fHhgf1JSkrKysiJd3reS1ql1Rqa6jhkZAAAipV31yNTU1EiS0tLSgva/9NJL6tKliwYMGKAZM2bo4MGDJ3yNhoYGeb3eoC0SeN4SAACRZ+uMzNF8Pp+mTJmiYcOGacCAAYH9N954o7p3767s7Gxt3LhR9913n8rKyrRw4cLjvk5RUZEefvjhSJUdkJLQOpS19QQZAAAipd0EmcLCQm3evFnvvvt "text/plain": [ "<Figure size 640x480 with 1 Axes>" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import matplotlib.pyplot as plt\n", "plt.plot(nn.loss_history)\n", "plt.xlabel(\"# epochs / 1000\")\n", "plt.ylabel(\"loss\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [] }, { "cell_type": "code", "execution_count": 84, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "<matplotlib.colorbar.Colorbar at 0x147ecf700>" ] }, "execution_count": 84, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAsIAAAJMCAYAAADwqMBxAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy89olMNAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdd3gU1f7H8ffMpvce0kkIhN47gvQiIip2EezXdq+K/lRs2L3X3nsBlGJBmiBdpEjvPUBCCiQhvbfdmd8fCzEhG0jZNPJ9PU8eyJzZM2dDSD579jvnKLqu6wghhBBCCNHCqI09ACGEEEIIIRqDBGEhhBBCCNEiSRAWQgghhBAtkgRhIYQQQgjRIkkQFkIIIYQQLZIEYSGEEEII0SJJEBZCCCGEEC2SBGEhhBBCCNEiSRAWQgghhBAtkgRhIYQQQgjRIkkQFkIIIYQQjWrDhg1MmDCBwMBAFEVh0aJFl3xMcXExzz33HGFhYdjb29O6dWu+++67Gl3XppbjFUIIIYQQwiry8/Pp1q0bd999N9dff321HnPTTTeRkpLCt99+S2RkJElJSWiaVqPrShAWQgghhBCNaty4cYwbN67a569YsYK//vqLmJgYvLy8AGjdunWNr3vZBWFN0zhz5gyurq4oitLYwxFCCCFEC6brOrm5uQQGBqKqjV+RWlRURElJSYNcS9f1SlnM3t4ee3v7Ove9ZMkSevfuzVtvvcUPP/yAs7Mz11xzDa+++iqOjo7V7ueyC8JnzpwhJCSksYchhBBCCFEmISGB4ODgRh1DUVERgUGtycxIaZDrubi4kJeXV+HYjBkzeOmll+rcd0xMDJs2bcLBwYGFCxeSlpbGQw89RHp6Ot9//321+7nsgrCrqysAe7/9Glcnp0YejRBCCCFastyCArrfc19ZPmlMJSUlZGakMPuXAzg51+94CvJzmXJjFxISEnBzcys7bo3ZYDBXACiKwpw5c3B3dwfgvffe44YbbuCzzz6r9qzwZReEz0/Buzo5SRAWQgghRJPQlMo1nZxdcXZ2u/SJVuDm5lYhCFtLQEAAQUFBZSEYoEOHDui6TmJiIm3btq1WP41frCKEEEIIIUQNDBo0iDNnzlQovYiOjkZV1RqVoEgQFkIIIYQQjSovL4+9e/eyd+9eAGJjY9m7dy/x8fEATJ8+nSlTppSdf9ttt+Ht7c1dd93F4cOH2bBhA//3f//H3XffXaOb5SQICyGEEEKIRrVz50569OhBjx49AJg2bRo9evTgxRdfBCApKaksFIP5RrzVq1eTlZVF7969uf3225kwYQIfffRRja572dUICyGEEEKI5mXo0KHoul5l+8yZMysda9++PatXr67TdWVGWAghhBBCtEgShIUQQgghRIskQVgIIYQQQrRIEoSFEEIIIUSLJEFYCCGEEEK0SBKEhRBCCCFEiyRBWAghhBBCtEgShIUQQgghRIskQVgIIYQQQrRIEoSFEEIIIUSLJEFYCCGEEEK0SBKEhRBCCCFEiyRBWAghhBBCtEgShIUQQgghRIskQVgIIYQQQrRIEoSFEEIIIUSLJEFYCCGEEEK0SBKEhRBCCCFEiyRBWAghhBBCtEj1GoQ3bNjAhAkTCAwMRFEUFi1adNHz169fj6IolT6Sk5Prc5hCCCGEEKIFqtcgnJ+fT7du3fj0009r9Lhjx46RlJRU9uHn51dPIxRCCCGEEC2VTX12Pm7cOMaNG1fjx/n5+eHh4WH9AQkhhBBCCHFOk6wR7t69OwEBAYwaNYrNmzdf9Nzi4mJycnIqfAghhBBCCHEpTSoIBwQE8MUXX7BgwQIWLFhASEgIQ4cOZffu3VU+5s0338Td3b3sIyQkpAFHLIQQQgghmqt6LY2oqaioKKKioso+HzhwICdPnuT999/nhx9+sPiY6dOnM23atLLPc3JyJAwLIYQQQohLalJB2JK+ffuyadOmKtvt7e2xt7dvwBEJIYQQQojLQZMqjbBk7969BAQENPYwhBBCCCHEZaZeZ4Tz8vI4ceJE2eexsbHs3bsXLy8vQkNDmT59OqdPn2b27NkAfPDBB4SHh9OpUyeKior45ptvWLduHatWrarPYQohhBBCiBaoXoPwzp07GTZsWNnn52t5p06dysyZM0lKSiI+Pr6svaSkhCeeeILTp0/j5ORE165dWbNmTYU+hBBCCCGEsIZ6DcJDhw5F1/Uq22fOnFnh86eeeoqnnnqqPockhBBCCCEE0AxqhIUQQgghhKgPEoSFEEIIIUSLJEFYCCGEEEK0SBKEhRBCCCFEiyRBWAghhBBCtEgShIUQQgghRIskQVgIIYQQQrRIEoSFEEIIIUSLJEFYCCGEEEK0SBKEhRBCCCFEiyRBWAghhBBCtEgShIUQQgghRIskQVgIIYQQQrRIEoSFEEIIIUSLJEFYCCGEEEK0SBKEhRBCCCFEiyRBWAghhBBCtEgShIUQQgghRIskQVgIIYQQQrRIEoSFEEIIIUSLJEFYCCGEEEK0SBKEhRBCCCFEiyRBWAghhBBCtEgShIUQQgghRIskQVgIIYQQQrRIEoSFEEIIIUSLJEFYCCGEEEI0qg0bNjBhwgQCAwNRFIVFixZV+7GbN2/GxsaG7t271/i6EoSFEEIIIUSjys/Pp1u3bnz66ac1elxWVhZTpkxhxIgRtbquTa0eJYQQQgghhJWMGzeOcePG1fhxDzzwALfddhsGg6FGs8jnyYywEEIIIYRodr7//ntiYmKYMWNGrfuQGWEhhBBCCFEvcnJyKnxub2+Pvb19nfs9fvw4zzzzDBs3bsTGpvZxVoKwEEIIIUQL4unjirOLa71ew85RByAkJKTC8RkzZvDSSy/VqW+TycRtt93Gyy+/TLt27erUlwRhIYQQQghRLxISEnBzcyv73Bqzwbm5uezcuZM9e/bwyCOPAKBpGrquY2Njw6pVqxg+fHi1+pIgLIQQQggh6oWbm1uFIGytPg8cOFDh2Geffca6dev49ddfCQ8Pr3ZfEoSFEEIIIUSjysvL48SJE2Wfx8bGsnfvXry8vAgNDWX69OmcPn2a2bNno6oqnTt3rvB4Pz8/HBwcKh2/FAnCQgghhBCiUe3cuZNhw4aVfT5t2jQApk6dysyZM0lKSiI+Pt7q11V0Xdet3msjysnJwd3dnZPz5uDq5NTYwxFCCCFEC5ZbUECbW28nOzvb6iUCNXU+I63dno6zS/2OJT8vhxF9vZvE874YWUdYCCGEEEK0SBKEhRBCCCFEiyRBWAghhBBCtEgShIUQQgghRIskq0YIIcQlaJpGXEoKcSkpGEtKcXV1oV1wMJ6u9bszkxBCiPolQVgIIS6ioLiYtVu3kh0bS6uSUpxVA0kmE8fd3ejUqyc92rZt7CEKIYSoJQnCQghxEet37kQ7Gs11/n54OzoCoOk6B9PT2b51K66OTkQGBzXyKIUQQtSG1AgLIUQVUjIzyYg9xRU+3mUhGEBVFLr6+BBeXMKh48cbcYRCCCHqQoKwEEJU4UxaGs6FhQQ6O1tsj3R3Jzc5mbzCwgYemRBCCGuQICyEEFXQNB0bQFEUi+22qgq6zmW2QacQQrQYEoSFEKIK3u5uZNvakFFUZLE9PjcXe09PnB0cGnhkQgghrEGCsBBCVCHUzw/HwCC2pJyl1GSq0Jacn89Ro4m2kZGoqvwoFUKI5khWjRBCiCqoqsrgvn1YV1zML4mJtHVwwNnGhqSCAuJUA75dOtMlIryxhymEEKKWJAgLIVoETdMoLi3Fwc6uyppfS/w8PBg/YjiH4+I4GnsKY0kJLkFB9AoPp21wkMwGCyFEMyZBWAhxWYtOSOCT3xayeOMmCktL8Xdz4/axY3jo2om4VbEaxIVcnZzo16ED/Tp0qOfRCiGEaEgShIUQl60dR49y8wsz8DIama5ptAa25OTw5a8L+OPvv1n03zdlm2QhhGjB5D09IcRlyWQy8eBb79C9tJRDmsYLwB3AZ8AOTSPlTBKvzfqhkUcphBCiMUkQFkJcltbt2UN8ejrv6zouF7R1AB7XNH5d/yc5+fmNMTwhhBBNgARhIcRl6VDsKXwMBnpX0X4VUFh "text/plain": [ "<Figure size 900x700 with 2 Axes>" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import matplotlib.pyplot as plt\n", "from matplotlib.colors import ListedColormap\n", "\n", "cm = plt.cm.RdBu\n", "cm_bright = ListedColormap([\"#FF0000\", \"#0000FF\"])\n", "\n", "xv = np.linspace(x_min, x_max, 10)\n", "yv = np.linspace(y_min, y_max, 10)\n", "Xv, Yv = np.meshgrid(xv, yv)\n", "XYpairs = np.vstack([ Xv.reshape(-1), Yv.reshape(-1)])\n", "zv = nn.predict(XYpairs.T)\n", "Zv = zv.reshape(Xv.shape)\n", "\n", "fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(9, 7))\n", "ax.set_aspect(1)\n", "cn = ax.contourf(Xv, Yv, Zv, cmap=\"coolwarm_r\", alpha=0.4)\n", "\n", "ax.scatter(X_train[:, 0], X_train[:, 1], c=y_train, cmap=cm_bright, edgecolors=\"k\")\n", "\n", "# Plot the testing points\n", "ax.scatter(X_test[:, 0], X_test[:, 1], c=y_test, cmap=cm_bright, alpha=0.4, edgecolors=\"k\")\n", "\n", "ax.set_xlim(x_min, x_max)\n", "ax.set_ylim(y_min, y_max)\n", "# ax.set_xticks(())\n", "# ax.set_yticks(())\n", "\n", "fig.colorbar(cn)\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.9" }, "vscode": { "interpreter": { "hash": "b0fa6594d8f4cbf19f97940f81e996739fb7646882a419484c72d19e05852a7e" } } }, "nbformat": 4, "nbformat_minor": 4 }
|