ML-Kurs-SS2023/notebooks/simple_neural_network_exercise_solution.ipynb

1321 lines
162 KiB
Plaintext
Raw Normal View History

{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"# A simple neural network with one hidden layer in pure Python"
]
},
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## A simple neural network class with ReLU activation function"
]
},
{
"cell_type": "code",
"execution_count": 77,
"metadata": {},
"outputs": [],
"source": [
"# A simple feed-forward neutral network with on hidden layer\n",
"# see also https://towardsdatascience.com/how-to-build-your-own-neural-network-from-scratch-in-python-68998a08e4f6\n",
"\n",
"import numpy as np\n",
"\n",
"class NeuralNetwork:\n",
" def __init__(self, x, y):\n",
" n1 = 4 # number of neurons in the hidden layer\n",
" self.input = x\n",
" self.weights1 = np.random.rand(self.input.shape[1],n1)\n",
" self.bias1 = np.random.rand(n1)\n",
" self.weights2 = np.random.rand(n1,1)\n",
" self.bias2 = np.random.rand(1) \n",
" self.y = y\n",
" self.output = np.zeros(y.shape)\n",
" self.learning_rate = 0.00001\n",
" self.n_train = 0\n",
" self.loss_history = []\n",
"\n",
" def relu(self, x):\n",
" return np.where(x>0, x, 0)\n",
" \n",
" def relu_derivative(self, x):\n",
" return np.where(x>0, 1, 0)\n",
"\n",
" def feedforward(self):\n",
" self.layer1 = self.relu(self.input @ self.weights1 + self.bias1)\n",
" self.output = self.relu(self.layer1 @ self.weights2 + self.bias2)\n",
"\n",
" def backprop(self):\n",
"\n",
" # delta1: [m, 1], m = number of training data\n",
" delta1 = 2 * (self.y - self.output) * self.relu_derivative(self.output)\n",
"\n",
" # Gradient w.r.t. weights from hidden to output layer: [n1, 1] matrix, n1 = # neurons in hidden layer\n",
" d_weights2 = self.layer1.T @ delta1\n",
" d_bias2 = np.sum(delta1) \n",
" \n",
" # shape of delta2: [m, n1], m = number of training data, n1 = # neurons in hidden layer\n",
" delta2 = (delta1 @ self.weights2.T) * self.relu_derivative(self.layer1)\n",
" d_weights1 = self.input.T @ delta2\n",
" d_bias1 = np.ones(delta2.shape[0]) @ delta2 \n",
" \n",
" # update weights and biases\n",
" self.weights1 += self.learning_rate * d_weights1\n",
" self.weights2 += self.learning_rate * d_weights2\n",
"\n",
" self.bias1 += self.learning_rate * d_bias1\n",
" self.bias2 += self.learning_rate * d_bias2\n",
"\n",
" def train(self, X, y):\n",
" self.output = np.zeros(y.shape)\n",
" self.input = X\n",
" self.y = y\n",
" self.feedforward()\n",
" self.backprop()\n",
" self.n_train += 1\n",
" if (self.n_train %1000 == 0):\n",
" loss = np.sum((self.y - self.output)**2)\n",
" print(\"loss: \", loss)\n",
" self.loss_history.append(loss)\n",
" \n",
" def predict(self, X):\n",
" self.output = np.zeros(y.shape)\n",
" self.input = X\n",
" self.feedforward()\n",
" return self.output\n",
" \n",
" def loss_history(self):\n",
" return self.loss_history\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Create toy data\n",
"We create three toy data sets\n",
"1. two moon-like distributions\n",
"2. circles\n",
"3. linearly separable data sets"
]
},
{
"cell_type": "code",
"execution_count": 78,
"metadata": {},
"outputs": [],
"source": [
"# https://scikit-learn.org/stable/auto_examples/classification/plot_classifier_comparison.html#sphx-glr-auto-examples-classification-plot-classifier-comparison-py\n",
"import numpy as np\n",
"from sklearn.datasets import make_moons, make_circles, make_classification\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"X, y = make_classification(\n",
" n_features=2, n_redundant=0, n_informative=2, random_state=1, n_clusters_per_class=1\n",
")\n",
"rng = np.random.RandomState(2)\n",
"X += 2 * rng.uniform(size=X.shape)\n",
"linearly_separable = (X, y)\n",
"\n",
"datasets = [\n",
" make_moons(n_samples=200, noise=0.1, random_state=0),\n",
" make_circles(n_samples=200, noise=0.1, factor=0.5, random_state=1),\n",
" linearly_separable,\n",
"]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Create training and test data set"
]
},
{
"cell_type": "code",
"execution_count": 79,
"metadata": {},
"outputs": [],
"source": [
"# datasets: 0 = moons, 1 = circles, 2 = linearly separable\n",
"X, y = datasets[1]\n",
"X_train, X_test, y_train, y_test = train_test_split(\n",
" X, y, test_size=0.4, random_state=42\n",
")\n",
"\n",
"x_min, x_max = X[:, 0].min() - 0.5, X[:, 0].max() + 0.5\n",
"y_min, y_max = X[:, 1].min() - 0.5, X[:, 1].max() + 0.5\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Train the model"
]
},
{
"cell_type": "code",
"execution_count": 80,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"loss: 34.671913102152374\n",
"loss: 31.424782860564203\n",
"loss: 29.772915496524135\n",
"loss: 28.762023680772913\n",
"loss: 28.00200726838712\n",
"loss: 27.32590942137339\n",
"loss: 26.752368734071535\n",
"loss: 26.230440689447903\n",
"loss: 25.673463509689576\n",
"loss: 25.012834504148312\n",
"loss: 24.289682045629544\n",
"loss: 23.555645514965384\n",
"loss: 22.76462670346343\n",
"loss: 21.904104226889068\n",
"loss: 20.943637847221698\n",
"loss: 19.89434572572985\n",
"loss: 18.727285500049177\n",
"loss: 17.485616842253226\n",
"loss: 16.142413632344777\n",
"loss: 14.852364407067075\n",
"loss: 13.635545514668182\n",
"loss: 12.456629856179049\n",
"loss: 11.347265353073684\n",
"loss: 10.419340643305858\n",
"loss: 9.610799938794724\n",
"loss: 8.897580679158944\n",
"loss: 8.258004600111189\n",
"loss: 7.684500186535497\n",
"loss: 7.1748549390018574\n",
"loss: 6.718209468903557\n",
"loss: 6.309864315153381\n",
"loss: 5.944399105330259\n",
"loss: 5.621553827962666\n",
"loss: 5.333909699361839\n",
"loss: 5.077286239602076\n",
"loss: 4.84889061532151\n",
"loss: 4.646022685947024\n",
"loss: 4.465758350759858\n",
"loss: 4.305913173647123\n",
"loss: 4.1640718175126095\n",
"loss: 4.039102963682375\n",
"loss: 3.9297524332426623\n",
"loss: 3.832402269592843\n",
"loss: 3.7453438160159322\n",
"loss: 3.6674507148484525\n",
"loss: 3.5977884037626\n",
"loss: 3.535827925326507\n",
"loss: 3.479967988008011\n",
"loss: 3.4291169910556207\n",
"loss: 3.3829982181684475\n",
"loss: 3.3416403251567335\n",
"loss: 3.304229099924413\n",
"loss: 3.2684585556590573\n",
"loss: 3.2354701067381977\n",
"loss: 3.2051369151544105\n",
"loss: 3.1772080296379404\n",
"loss: 3.1476337866558435\n",
"loss: 3.1193880834750365\n",
"loss: 3.0930575696528475\n",
"loss: 3.0684279860561725\n",
"loss: 3.0453159280825632\n",
"loss: 3.023566344470656\n",
"loss: 3.0031494958762517\n",
"loss: 2.9838385760074786\n",
"loss: 2.965539553283577\n",
"loss: 2.9481708868131866\n",
"loss: 2.9316489223756497\n",
"loss: 2.9158898824688726\n",
"loss: 2.900839148101508\n",
"loss: 2.8864469332954905\n",
"loss: 2.8726683402256388\n",
"loss: 2.8594627350494104\n",
"loss: 2.846788596539098\n",
"loss: 2.834699126928559\n",
"loss: 2.8238134101918453\n",
"loss: 2.8133709056375507\n",
"loss: 2.8032597976075766\n",
"loss: 2.793428392265398\n",
"loss: 2.7839305952541142\n",
"loss: 2.7747494784865347\n",
"loss: 2.765836907330045\n",
"loss: 2.757121848680945\n",
"loss: 2.7486569559419625\n",
"loss: 2.7404305726008564\n",
"loss: 2.732433904640994\n",
"loss: 2.724658913060166\n",
"loss: 2.717098191391666\n",
"loss: 2.709744869114474\n",
"loss: 2.7025925339845283\n",
"loss: 2.695634579597031\n",
"loss: 2.688854332606485\n",
"loss: 2.682427203225914\n",
"loss: 2.6764899407900025\n",
"loss: 2.6707857438883167\n",
"loss: 2.6652986297014336\n",
"loss: 2.6600185764704065\n",
"loss: 2.654938846386348\n",
"loss: 2.650048013830516\n",
"loss: 2.6434375100435394\n",
"loss: 2.6372726876694945\n",
"loss: 2.6315142701295082\n",
"loss: 2.6260949485933374\n",
"loss: 2.620971039483914\n",
"loss: 2.61174470987536\n",
"loss: 2.5984776428986\n",
"loss: 2.587556467070711\n",
"loss: 2.5779657263762803\n",
"loss: 2.5692396803357767\n",
"loss: 2.5611227223903583\n",
"loss: 2.5534830205555155\n",
"loss: 2.5462350514450636\n",
"loss: 2.539331795274396\n",
"loss: 2.5327378554215794\n",
"loss: 2.5264333984648015\n",
"loss: 2.5203974791572867\n",
"loss: 2.5146177883938456\n",
"loss: 2.5090780321819945\n",
"loss: 2.5037668178628807\n",
"loss: 2.4986865327324224\n",
"loss: 2.4938357429187015\n",
"loss: 2.4892125195154855\n",
"loss: 2.484814217154528\n",
"loss: 2.480624991605507\n",
"loss: 2.47663936164734\n",
"loss: 2.4728453322029287\n",
"loss: 2.469227702733714\n",
"loss: 2.465784971369546\n",
"loss: 2.4625031729417586\n",
"loss: 2.459373586417242\n",
"loss: 2.456395742549959\n",
"loss: 2.453556149147296\n",
"loss: 2.450848516477792\n",
"loss: 2.448257826886593\n",
"loss: 2.4457793902422473\n",
"loss: 2.4434057564290246\n",
"loss: 2.4411335977457957\n",
"loss: 2.43895711642629\n",
"loss: 2.4368736364557093\n",
"loss: 2.4348749182486014\n",
"loss: 2.432961782234024\n",
"loss: 2.4311267156550267\n",
"loss: 2.4293686022331507\n",
"loss: 2.4276818857227433\n",
"loss: 2.4260653790156557\n",
"loss: 2.4245172581509893\n",
"loss: 2.423035547508019\n",
"loss: 2.4216137562678623\n",
"loss: 2.420252690046535\n",
"loss: 2.4189495881209595\n",
"loss: 2.4176996906636115\n",
"loss: 2.4165033152293676\n",
"loss: 2.415355506487388\n",
"loss: 2.414254264583601\n",
"loss: 2.4131997637552898\n",
"loss: 2.4121900926116586\n",
"loss: 2.411220865950546\n",
"loss: 2.410291344402852\n",
"loss: 2.4094025182811554\n",
"loss: 2.4085499517354947\n",
"loss: 2.407731441094278\n",
"loss: 2.406946144975749\n",
"loss: 2.4061932934032764\n",
"loss: 2.4054693005333485\n",
"loss: 2.4047725159312234\n",
"loss: 2.404104507441233\n",
"loss: 2.4034628933277684\n",
"loss: 2.4028442537087553\n",
"loss: 2.4022504549932577\n",
"loss: 2.4016803840643326\n",
"loss: 2.4011319551366914\n",
"loss: 2.4006060782593357\n",
"loss: 2.400098897713013\n",
"loss: 2.39961231412508\n",
"loss: 2.3991458381267075\n",
"loss: 2.3986954028524448\n",
"loss: 2.3982632819114857\n",
"loss: 2.3978461833600178\n",
"loss: 2.3974452187294\n",
"loss: 2.3970596690136103\n",
"loss: 2.396686140489889\n",
"loss: 2.396328903510832\n",
"loss: 2.3959835555706825\n",
"loss: 2.39565059757921\n",
"loss: 2.3953300040263765\n",
"loss: 2.3950220475433692\n",
"loss: 2.3947263525017353\n",
"loss: 2.3944396560020063\n",
"loss: 2.3941632766670184\n",
"loss: 2.3938993096058816\n",
"loss: 2.39364294492824\n",
"loss: 2.3933965389172958\n",
"loss: 2.3931594462236814\n",
"loss: 2.392932651756755\n",
"loss: 2.392711772531679\n",
"loss: 2.3925004166785255\n",
"loss: 2.392296083041902\n",
"loss: 2.392100842332238\n",
"loss: 2.3919111950972347\n",
"loss: 2.3917290333380343\n",
"loss: 2.391553470675823\n",
"loss: 2.3913848103373843\n",
"loss: 2.391220583744764\n",
"loss: 2.3910649553444348\n",
"loss: 2.390914207314057\n",
"loss: 2.3907697541809365\n",
"loss: 2.390630728668481\n",
"loss: 2.390498495256061\n",
"loss: 2.390369937885896\n",
"loss: 2.390247045664548\n",
"loss: 2.3901273102670126\n",
"loss: 2.3900148760237014\n",
"loss: 2.389905238354612\n",
"loss: 2.3898001952359604\n",
"loss: 2.3896980625916764\n",
"loss: 2.3896010024421126\n",
"loss: 2.3895086082689625\n",
"loss: 2.389417860716362\n",
"loss: 2.3893319903283574\n",
"loss: 2.3892491959225457\n",
"loss: 2.3891687694660892\n",
"loss: 2.3890934197829052\n",
"loss: 2.389019136431073\n",
"loss: 2.388947788821395\n",
"loss: 2.3888805609783272\n",
"loss: 2.3888150688462746\n",
"loss: 2.3887518707039823\n",
"loss: 2.388691911298264\n",
"loss: 2.388634014187227\n",
"loss: 2.38857805590494\n",
"loss: 2.388523918864406\n",
"loss: 2.388473510627236\n",
"loss: 2.388423436117929\n",
"loss: 2.3883753262717216\n",
"loss: 2.3883303602155683\n",
"loss: 2.3882866724515304\n",
"loss: 2.388244442215278\n",
"loss: 2.3882033453528138\n",
"loss: 2.3881652931929307\n",
"loss: 2.38812677875453\n",
"loss: 2.388091060218199\n",
"loss: 2.388056092752744\n",
"loss: 2.3880237467127237\n",
"loss: 2.3879919924582405\n",
"loss: 2.387960763231075\n",
"loss: 2.387931658872838\n",
"loss: 2.3879034749475188\n",
"loss: 2.387875347245705\n",
"loss: 2.387849402809727\n",
"loss: 2.387823662709372\n",
"loss: 2.387799967311831\n",
"loss: 2.387776360266166\n",
"loss: 2.387754683818766\n",
"loss: 2.3877329827919898\n",
"loss: 2.3877121910475196\n",
"loss: 2.3876929899702644\n",
"loss: 2.3876729319637\n",
"loss: 2.387654431642712\n",
"loss: 2.387637541283918\n",
"loss: 2.387620354216107\n",
"loss: 2.387604686551482\n",
"loss: 2.3875886398363733\n",
"loss: 2.387573984851815\n",
"loss: 2.3875588098577873\n",
"loss: 2.3875448805324915\n",
"loss: 2.387532242680253\n",
"loss: 2.3875189700442068\n",
"loss: 2.387506861417358\n",
"loss: 2.387495291770735\n",
"loss: 2.387484231905034\n",
"loss: 2.3874731451723217\n",
"loss: 2.387462274358618\n",
"loss: 2.3874519373951664\n",
"loss: 2.3874425920575746\n",
"loss: 2.3874333869787168\n",
"loss: 2.3874249670575987\n",
"loss: 2.3874162377994788\n",
"loss: 2.3874073874920074\n",
"loss: 2.3873990135367182\n",
"loss: 2.3873914786024937\n",
"loss: 2.387384759212198\n",
"loss: 2.387377049938677\n",
"loss: 2.387371052646234\n",
"loss: 2.387363939101346\n",
"loss: 2.3873584890193866\n",
"loss: 2.3873519773705674\n",
"loss: 2.3873461623087593\n",
"loss: 2.387341023755127\n",
"loss: 2.387336250114317\n",
"loss: 2.3873309349740612\n",
"loss: 2.387325950896895\n",
"loss: 2.3873215752647576\n",
"loss: 2.3873160311256796\n",
"loss: 2.3873128208612444\n",
"loss: 2.3873084158082647\n",
"loss: 2.387304558605306\n",
"loss: 2.3872995395407717\n",
"loss: 2.387296679993435\n",
"loss: 2.387292629417728\n",
"loss: 2.3872890739934816\n",
"loss: 2.387285997937587\n",
"loss: 2.387283389935851\n",
"loss: 2.3872794993942383\n",
"loss: 2.387277790051124\n",
"loss: 2.387274775181382\n",
"loss: 2.387272183856358\n",
"loss: 2.387269056421407\n",
"loss: 2.3872664925693012\n",
"loss: 2.387264376208358\n",
"loss: 2.3872623679945018\n",
"loss: 2.3872600043678207\n",
"loss: 2.3872580041281823\n",
"loss: 2.387256273384648\n",
"loss: 2.3872533408675327\n",
"loss: 2.3872523812347737\n",
"loss: 2.3872500233166005\n",
"loss: 2.387247988754982\n",
"loss: 2.387246267292523\n",
"loss: 2.3872448527441414\n",
"loss: 2.387243735408723\n",
"loss: 2.387242025422281\n",
"loss: 2.3872406504464334\n",
"loss: 2.3872386748070245\n",
"loss: 2.3872386165309805\n",
"loss: 2.387237233810461\n",
"loss: 2.3872360479643593\n",
"loss: 2.38723476536898\n",
"loss: 2.387232724485056\n",
"loss: 2.3872322755810025\n",
"loss: 2.3872308935896562\n",
"loss: 2.387230374354215\n",
"loss: 2.3872289131536952\n",
"loss: 2.387227666907803\n",
"loss: 2.387227223159598\n",
"loss: 2.3872263094643955\n",
"loss: 2.3872251968418357\n",
"loss: 2.387224775014026\n",
"loss: 2.387224547860667\n",
"loss: 2.3872230838809783\n",
"loss: 2.3872229638510785\n",
"loss: 2.387221601841123\n",
"loss: 2.3872221053237612\n",
"loss: 2.3872210930668896\n",
"loss: 2.3872202490512318\n",
"loss: 2.3872195705964097\n",
"loss: 2.387219054999678\n",
"loss: 2.387218694718085\n",
"loss: 2.3872184895236703\n",
"loss: 2.3872184361150453\n",
"loss: 2.3872168420512097\n",
"loss: 2.38721707828012\n",
"loss: 2.3872157746881033\n",
"loss: 2.3872162871830027\n",
"loss: 2.3872152565137865\n",
"loss: 2.3872156550203423\n",
"loss: 2.387215262058789\n",
"loss: 2.3872146171203603\n",
"loss: 2.3872140924465817\n",
"loss: 2.387213692303668\n",
"loss: 2.3872134068601816\n",
"loss: 2.3872132382331945\n",
"loss: 2.3872131800547116\n",
"loss: 2.3872132336735374\n",
"loss: 2.3872122046766457\n",
"loss: 2.387211985620091\n",
"loss: 2.3872123543033004\n",
"loss: 2.3872111512061682\n",
"loss: 2.3872117185255872\n",
"loss: 2.387210711934733\n",
"loss: 2.3872114695614397\n",
"loss: 2.3872106483724744\n",
"loss: 2.3872106246269746\n",
"loss: 2.387210949823621\n",
"loss: 2.3872103921805325\n",
"loss: 2.387209923070608\n",
"loss: 2.387209539035157\n",
"loss: 2.387209233578181\n",
"loss: 2.387209278909502\n",
"loss: 2.3872091575872285\n",
"loss: 2.3872087902003143\n",
"loss: 2.387208796135135\n",
"loss: 2.387208863491243\n",
"loss: 2.3872090209788177\n",
"loss: 2.3872092395967757\n",
"loss: 2.387209509667974\n",
"loss: 2.3872081941411776\n",
"loss: 2.387208614493238\n",
"loss: 2.387209092011758\n",
"loss: 2.3872079707116916\n",
"loss: 2.387208566693145\n",
"loss: 2.3872075814397924\n",
"loss: 2.3872082912786885\n",
"loss: 2.3872074055895007\n",
"loss: 2.3872082394740444\n",
"loss: 2.3872074711869145\n",
"loss: 2.38720841308389\n",
"loss: 2.3872077482796317\n",
"loss: 2.3872071369801997\n",
"loss: 2.387208234823193\n",
"loss: 2.3872077218609924\n",
"loss: 2.3872072579088783\n",
"loss: 2.387206844007732\n",
"loss: 2.3872079496309033\n",
"loss: 2.3872078102348167\n",
"loss: 2.387207529575949\n",
"loss: 2.387207296171491\n",
"loss: 2.3872071039581373\n",
"loss: 2.38720695276087\n",
"loss: 2.3872068462090654\n",
"loss: 2.387206777400199\n",
"loss: 2.3872067479488583\n",
"loss: 2.3872067565550017\n",
"loss: 2.3872068036021385\n",
"loss: 2.3872068874369434\n",
"loss: 2.3872070087213757\n",
"loss: 2.3872071655473297\n",
"loss: 2.3872073550964092\n",
"loss: 2.387207093581078\n",
"loss: 2.3872061834452403\n",
"loss: 2.387206473681233\n",
"loss: 2.3872067944256044\n",
"loss: 2.387207150189065\n",
"loss: 2.3872065584074114\n",
"loss: 2.3872062931664226\n",
"loss: 2.3872067377036283\n",
"loss: 2.387207211254447\n",
"loss: 2.3872060575595855\n",
"loss: 2.3872065866005077\n",
"loss: 2.387207143362015\n",
"loss: 2.387206071870889\n",
"loss: 2.3872066797178375\n",
"loss: 2.3872065479626627\n",
"loss: 2.387206319019361\n",
"loss: 2.3872070019550087\n",
"loss: 2.3872060569372797\n",
"loss: 2.38720678668193\n",
"loss: 2.387205887930111\n",
"loss: 2.387206664076764\n",
"loss: 2.3872058107150407\n",
"loss: 2.3872066293892216\n",
"loss: 2.387205818746259\n",
"loss: 2.3872066801575853\n",
"loss: 2.3872059111197688\n",
"loss: 2.3872068126406196\n",
"loss: 2.387206081035995\n",
"loss: 2.3872068669685325\n",
"loss: 2.3872063280795652\n",
"loss: 2.3872056557817367\n",
"loss: 2.3872066491216897\n",
"loss: 2.387206011860928\n",
"loss: 2.387206535957502\n",
"loss: 2.3872064376325017\n",
"loss: 2.3872058508160308\n",
"loss: 2.3872069163305536\n",
"loss: 2.3872063735393425\n",
"loss: 2.387205832345839\n",
"loss: 2.387206677283909\n",
"loss: 2.3872064498889176\n",
"loss: 2.387205955600857\n",
"loss: 2.387205820939399\n",
"loss: 2.387206660636184\n",
"loss: 2.387206206604038\n",
"loss: 2.38720576900394\n",
"loss: 2.3872063117983804\n",
"loss: 2.3872065813147874\n",
"loss: 2.387206181466505\n",
"loss: 2.3872057946984633\n",
"loss: 2.3872058708004715\n",
"loss: 2.387206706121622\n",
"loss: 2.3872063554658762\n",
"loss: 2.3872060160821453\n",
"loss: 2.3872056878815267\n",
"loss: 2.3872059982983496\n",
"loss: 2.387206714970821\n",
"loss: 2.3872064188808673\n",
"loss: 2.3872061334063455\n",
"loss: 2.3872058573779524\n",
"loss: 2.3872055929594205\n",
"loss: 2.3872060525929197\n",
"loss: 2.3872067406100412\n",
"loss: 2.387206502734586\n",
"loss: 2.3872062762370385\n",
"loss: 2.387206057098045\n",
"loss: 2.3872058489669072\n",
"loss: 2.3872056488912214\n",
"loss: 2.3872054563548764\n",
"loss: 2.3872062535038063\n",
"loss: 2.387206745595907\n",
"loss: 2.3872065783271466\n",
"loss: 2.38720641696552\n",
"loss: 2.387206265759137\n",
"loss: 2.387206119480711\n",
"loss: 2.3872059828172514\n",
"loss: 2.387205852070047\n",
"loss: 2.387205730140802\n",
"loss: 2.387205613674655\n",
"loss: 2.3872055054664263\n",
"loss: 2.3872055703172705\n",
"loss: 2.387206001531239\n",
"loss: 2.3872064059032465\n",
"loss: 2.3872067748722134\n",
"loss: 2.387206704414516\n",
"loss: 2.387206632138044\n",
"loss: 2.3872065672351592\n",
"loss: 2.3872065077955\n",
"loss: 2.387206453507227\n",
"loss: 2.387206405811236\n",
"loss: 2.3872063624448145\n",
"loss: 2.387206324770439\n",
"loss: 2.387206293044411\n",
"loss: 2.387206266172666\n",
"loss: 2.387206246548697\n",
"loss: 2.3872062269091705\n",
"loss: 2.3872062166383956\n",
"loss: 2.3872062117393194\n",
"loss: 2.3872062110609655\n",
"loss: 2.3872062139535695\n",
"loss: 2.387206219909649\n",
"loss: 2.3872062285052236\n",
"loss: 2.387206240576973\n",
"loss: 2.3872062606586537\n",
"loss: 2.3872062832874716\n",
"loss: 2.3872063082635315\n",
"loss: 2.3872063391881677\n",
"loss: 2.387206373803312\n",
"loss: 2.3872064127175974\n",
"loss: 2.3872064540621514\n",
"loss: 2.3872064987956327\n",
"loss: 2.3872065498963284\n",
"loss: 2.3872066013767226\n",
"loss: 2.387206657950146\n",
"loss: 2.3872067182049292\n",
"loss: 2.387206672867762\n",
"loss: 2.3872063672970487\n",
"loss: 2.387206055011894\n",
"loss: 2.3872057237905118\n",
"loss: 2.3872054165649037\n",
"loss: 2.3872054954444497\n",
"loss: 2.387205575891558\n",
"loss: 2.3872056608606287\n",
"loss: 2.3872057469958534\n",
"loss: 2.387205835730483\n",
"loss: 2.3872059291294483\n",
"loss: 2.387206023172342\n",
"loss: 2.387206121095825\n",
"loss: 2.3872062219751333\n",
"loss: 2.3872063233856755\n",
"loss: 2.387206428473149\n",
"loss: 2.3872065378135736\n",
"loss: 2.387206647283522\n",
"loss: 2.387206749735465\n",
"loss: 2.3872062240356984\n",
"loss: 2.38720569664805\n",
"loss: 2.3872054625336876\n",
"loss: 2.387205583157073\n",
"loss: 2.3872057046336925\n",
"loss: 2.387205829508387\n",
"loss: 2.387205957406746\n",
"loss: 2.387206085740676\n",
"loss: 2.3872062170332273\n",
"loss: 2.387206349757113\n",
"loss: 2.387206484077489\n",
"loss: 2.387206620076463\n",
"loss: 2.387206742599622\n",
"loss: 2.38720610614635\n",
"loss: 2.3872054598822556\n",
"loss: 2.3872055360296702\n",
"loss: 2.3872056808451143\n",
"loss: 2.3872058264733313\n",
"loss: 2.3872059747774506\n",
"loss: 2.387206123890669\n",
"loss: 2.3872062756016943\n",
"loss: 2.387206426949868\n",
"loss: 2.3872065819415873\n",
"loss: 2.387206737478283\n",
"loss: 2.387206126865435\n",
"loss: 2.3872054054434018\n",
"loss: 2.3872055638677026\n",
"loss: 2.387205724088219\n",
"loss: 2.38720588669721\n",
"loss: 2.387206049768166\n",
"loss: 2.3872062148316706\n",
"loss: 2.3872063803883172\n",
"loss: 2.3872065476541806\n",
"loss: 2.3872067152675234\n",
"loss: 2.3872061646400056\n",
"loss: 2.3872054078993044\n",
"loss: 2.3872055802952357\n",
"loss: 2.3872057516280947\n",
"loss: 2.3872059245622954\n",
"loss: 2.3872060978440746\n",
"loss: 2.387206273640891\n",
"loss: 2.387206448981906\n",
"loss: 2.3872066283894027\n",
"loss: 2.387206516906132\n",
"loss: 2.387205699504398\n",
"loss: 2.3872055196151685\n",
"loss: 2.3872056996032054\n",
"loss: 2.3872058817924033\n",
"loss: 2.387206063764966\n",
"loss: 2.387206248517262\n",
"loss: 2.387206432649466\n",
"loss: 2.3872066179179283\n",
"loss: 2.387206522287034\n",
"loss: 2.3872056702282123\n",
"loss: 2.3872055324647627\n",
"loss: 2.38720572034085\n",
"loss: 2.3872059096035927\n",
"loss: 2.3872060985335164\n",
"loss: 2.387206289705417\n",
"loss: 2.387206481657449\n",
"loss: 2.3872066741473743\n",
"loss: 2.387206243340186\n",
"loss: 2.387205412579381\n",
"loss: 2.387205607172106\n",
"loss: 2.3872058001813334\n",
"loss: 2.3872059950155946\n",
"loss: 2.3872061909696214\n",
"loss: 2.387206387319673\n",
"loss: 2.3872065845641055\n",
"loss: 2.387206628280637\n",
"loss: 2.3872057223715246\n",
"loss: 2.3872055317401317\n",
"loss: 2.387205729959452\n",
"loss: 2.3872059297733506\n",
"loss: 2.3872061287534345\n",
"loss: 2.3872063300590187\n",
"loss: 2.3872065319453286\n",
"loss: 2.3872067329639384\n",
"loss: 2.3872059293532697\n",
"loss: 2.3872054909066573\n",
"loss: 2.3872056929994176\n",
"loss: 2.3872058959012414\n",
"loss: 2.3872060991225794\n",
"loss: 2.387206303755043\n",
"loss: 2.3872065090776227\n",
"loss: 2.387206713881384\n",
"loss: 2.3872060034612277\n",
"loss: 2.387205479096127\n",
"loss: 2.3872056842701372\n",
"loss: 2.3872058905037483\n",
"loss: 2.3872060971157936\n",
"loss: 2.3872063046852645\n",
"loss: 2.3872065130746725\n",
"loss: 2.387206720682877\n",
"loss: 2.3872059543810957\n",
"loss: 2.3872054920341013\n",
"loss: 2.387205699860994\n",
"loss: 2.387205909361374\n",
"loss: 2.387206118592575\n",
"loss: 2.387206328727607\n",
"loss: 2.38720654022318\n",
"loss: 2.3872067496108995\n",
"loss: 2.38720580949524\n",
"loss: 2.387205525406542\n",
"loss: 2.3872057364936534\n",
"loss: 2.3872059484123977\n",
"loss: 2.387206158732515\n",
"loss: 2.3872063716998198\n",
"loss: 2.387206584128892\n",
"loss: 2.3872065550645125\n",
"loss: 2.387205585137984\n",
"loss: 2.387205576467598\n",
"loss: 2.3872057895930716\n",
"loss: 2.3872060024345876\n",
"loss: 2.387206217104482\n",
"loss: 2.38720643053896\n",
"loss: 2.3872066462470087\n",
"loss: 2.3872062650681847\n",
"loss: 2.3872054287829734\n",
"loss: 2.387205644040444\n",
"loss: 2.387205857402579\n",
"loss: 2.3872060732060456\n",
"loss: 2.3872062890605856\n",
"loss: 2.3872065058667893\n",
"loss: 2.387206721085871\n",
"loss: 2.3872059173144433\n",
"loss: 2.3872055071105676\n",
"loss: 2.387205722800243\n",
"loss: 2.387205939850607\n",
"loss: 2.387206155334122\n",
"loss: 2.3872063740586773\n",
"loss: 2.38720659063653\n",
"loss: 2.3872065025409945\n",
"loss: 2.3872055102865852\n",
"loss: 2.387205597163091\n",
"loss: 2.3872058144515096\n",
"loss: 2.3872060323445647\n",
"loss: 2.38720625026218\n",
"loss: 2.387206468189427\n",
"loss: 2.3872066869296926\n",
"loss: 2.387206064964853\n",
"loss: 2.38720547798433\n",
"loss: 2.387205695647846\n",
"loss: 2.3872059148036042\n",
"loss: 2.387206132165185\n",
"loss: 2.3872063520034\n",
"loss: 2.387206571675985\n",
"loss: 2.387206580623468\n",
"loss: 2.387205582588098\n",
"loss: 2.3872055839413213\n",
"loss: 2.3872058020374514\n",
"loss: 2.3872060214815694\n",
"loss: 2.38720624240693\n",
"loss: 2.387206462370405\n",
"loss: 2.387206683255392\n",
"loss: 2.3872060762465175\n",
"loss: 2.3872054775249913\n",
"loss: 2.387205696602586\n",
"loss: 2.3872059172419493\n",
"loss: 2.3872061373373787\n",
"loss: 2.387206357728515\n",
"loss: 2.387206579560316\n",
"loss: 2.387206536499058\n",
"loss: 2.387205534906743\n",
"loss: 2.387205596098073\n",
"loss: 2.387205817218552\n",
"loss: 2.387206037559115\n",
"loss: 2.3872062589616636\n",
"loss: 2.3872064810746823\n",
"loss: 2.387206702376931\n",
"loss: 2.3872059789399236\n",
"loss: 2.3872054989934695\n",
"loss: 2.3872057193198266\n",
"loss: 2.3872059415304925\n",
"loss: 2.387206163057817\n",
"loss: 2.387206385045773\n",
"loss: 2.3872066064853166\n",
"loss: 2.38720640924132\n",
"loss: 2.3872054050281237\n",
"loss: 2.3872056266849664\n",
"loss: 2.3872058475459403\n",
"loss: 2.387206070207055\n",
"loss: 2.387206292224948\n",
"loss: 2.3872065152244177\n",
"loss: 2.3872067373632193\n",
"loss: 2.387205812795538\n",
"loss: 2.387205535628719\n",
"loss: 2.3872057581142827\n",
"loss: 2.387205979871331\n",
"loss: 2.387206202847354\n",
"loss: 2.38720642492102\n",
"loss: 2.3872066486900834\n",
"loss: 2.3872062143940815\n",
"loss: 2.3872054479581615\n",
"loss: 2.3872056700624618\n",
"loss: 2.387205892745017\n",
"loss: 2.3872061156322637\n",
"loss: 2.3872063392335336\n",
"loss: 2.3872065621484544\n",
"loss: 2.3872066127839506\n",
"loss: 2.387205593010971\n",
"loss: 2.3872055852578473\n",
"loss: 2.3872058079900116\n",
"loss: 2.3872060312398813\n",
"loss: 2.387206254831415\n",
"loss: 2.387206477626406\n",
"loss: 2.3872067013641285\n",
"loss: 2.387205974948894\n",
"loss: 2.3872055015373546\n",
"loss: 2.387205724165402\n",
"loss: 2.387205948609596\n",
"loss: 2.3872061711547934\n",
"loss: 2.3872063942990995\n",
"loss: 2.387206617976577\n",
"loss: 2.3872063484996606\n",
"loss: 2.387205419784946\n",
"loss: 2.387205643696668\n",
"loss: 2.38720586667547\n",
"loss: 2.387206089943084\n",
"loss: 2.387206312830524\n",
"loss: 2.3872065377208713\n",
"loss: 2.3872067192036486\n",
"loss: 2.3872056983399608\n",
"loss: 2.387205562202803\n",
"loss: 2.3872057863378924\n",
"loss: 2.387206009488842\n",
"loss: 2.3872062344077247\n",
"loss: 2.3872064585590174\n",
"loss: 2.3872066828857177\n",
"loss: 2.387206060927091\n",
"loss: 2.387205483917369\n",
"loss: 2.387205706729346\n",
"loss: 2.3872059309762115\n",
"loss: 2.387206154167569\n",
"loss: 2.387206379212923\n",
"loss: 2.387206603498562\n",
"loss: 2.387206412786183\n",
"loss: 2.38720540594447\n",
"loss: 2.387205630194073\n",
"loss: 2.3872058532972202\n",
"loss: 2.387206078179721\n",
"loss: 2.3872063011757954\n",
"loss: 2.3872065261212008\n",
"loss: 2.387206750494888\n",
"loss: 2.38720574408085\n",
"loss: 2.3872055526202605\n",
"loss: 2.387205776875942\n",
"loss: 2.3872060000469304\n",
"loss: 2.3872062251849693\n",
"loss: 2.3872064490769467\n",
"loss: 2.38720667505617\n",
"loss: 2.387206092621105\n",
"loss: 2.3872054778245344\n",
"loss: 2.387205700855323\n",
"loss: 2.387205925499759\n",
"loss: 2.387206149256795\n",
"loss: 2.3872063747392476\n",
"loss: 2.3872065982742674\n",
"loss: 2.387206435850997\n",
"loss: 2.3872054149088293\n",
"loss: 2.3872056262571535\n",
"loss: 2.387205849166078\n",
"loss: 2.3872060742134376\n",
"loss: 2.3872062979103035\n",
"loss: 2.387206523614174\n",
"loss: 2.3872067479852594\n",
"loss: 2.3872057530264144\n",
"loss: 2.387205550630556\n",
"loss: 2.3872057753165112\n",
"loss: 2.387205999046393\n",
"loss: 2.3872062242513152\n",
"loss: 2.387206448562063\n",
"loss: 2.3872066745566825\n",
"loss: 2.3872060933404238\n",
"loss: 2.387205477652032\n",
"loss: 2.38720570096479\n",
"loss: 2.3872059260253913\n",
"loss: 2.387206149856662\n",
"loss: 2.3872063756569637\n",
"loss: 2.3872066002677177\n",
"loss: 2.3872064248665876\n",
"loss: 2.3872054055763066\n",
"loss: 2.3872056285152192\n",
"loss: 2.387205852129036\n",
"loss: 2.387206076714375\n",
"loss: 2.3872063008900715\n",
"loss: 2.3872065262982316\n",
"loss: 2.3872067510414143\n",
"loss: 2.387205738066258\n",
"loss: 2.387205553932228\n",
"loss: 2.3872057787560985\n",
"loss: 2.387206002620486\n",
"loss: 2.3872062278227126\n",
"loss: 2.387206452219362\n",
"loss: 2.3872066785656476\n",
"loss: 2.3872060756651354\n",
"loss: 2.3872054817901702\n",
"loss: 2.387205705294623\n",
"loss: 2.387205930184685\n",
"loss: 2.3872061542181156\n",
"loss: 2.387206380162742\n",
"loss: 2.3872066048043825\n",
"loss: 2.3872064037912457\n",
"loss: 2.387205408672593\n",
"loss: 2.3872056335074885\n",
"loss: 2.38720585704452\n",
"loss: 2.3872060825534622\n",
"loss: 2.3872063070962546\n",
"loss: 2.3872065331763173\n",
"loss: 2.387206736126897\n",
"loss: 2.3872057102550297\n",
"loss: 2.387205560708403\n",
"loss: 2.387205785698165\n",
"loss: 2.3872060098092085\n",
"loss: 2.3872062357191077\n",
"loss: 2.3872064606568575\n",
"loss: 2.3872066856260075\n",
"loss: 2.3872060412743403\n",
"loss: 2.3872054896579566\n",
"loss: 2.387205713503783\n",
"loss: 2.387205938855604\n",
"loss: 2.3872061633253967\n",
"loss: 2.3872063880939898\n",
"loss: 2.387206613272621\n",
"loss: 2.3872063644496766\n",
"loss: 2.387205417523893\n",
"loss: 2.387205642655334\n",
"loss: 2.3872058668390554\n",
"loss: 2.3872060910716635\n",
"loss: 2.3872063157741197\n",
"loss: 2.387206542322383\n",
"loss: 2.387206692745429\n",
"loss: 2.3872056660868344\n",
"loss: 2.3872055706534514\n",
"loss: 2.3872057946635303\n",
"loss: 2.387206019070375\n",
"loss: 2.387206245277133\n",
"loss: 2.387206470316894\n",
"loss: 2.3872066957452835\n",
"loss: 2.3872059933352743\n",
"loss: 2.3872054990132523\n",
"loss: 2.387205722878048\n",
"loss: 2.387205948575221\n",
"loss: 2.387206173353647\n",
"loss: 2.3872063984591825\n",
"loss: 2.3872066226559827\n",
"loss: 2.3872063207497023\n",
"loss: 2.3872054272954717\n",
"loss: 2.3872056526531233\n",
"loss: 2.3872058770925495\n",
"loss: 2.3872061007027394\n",
"loss: 2.387206325498709\n",
"loss: 2.387206552189155\n",
"loss: 2.3872066470668103\n",
"loss: 2.387205618564661\n",
"loss: 2.387205580450609\n",
"loss: 2.3872058044940334\n",
"loss: 2.3872060289891905\n",
"loss: 2.387206255311463\n",
"loss: 2.387206480708284\n",
"loss: 2.387206705586487\n",
"loss: 2.3872059484852777\n",
"loss: 2.387205508896236\n",
"loss: 2.387205733053314\n",
"loss: 2.3872059590012453\n",
"loss: 2.38720618326989\n",
"loss: 2.387206408350126\n",
"loss: 2.3872066339450995\n",
"loss: 2.3872062743922324\n",
"loss: 2.3872054376631935\n",
"loss: 2.3872056632364425\n",
"loss: 2.3872058871467248\n",
"loss: 2.387206111855699\n",
"loss: 2.387206337069455\n",
"loss: 2.3872065625907197\n",
"loss: 2.387206598875525\n",
"loss: 2.3872055726406853\n",
"loss: 2.3872055914219663\n",
"loss: 2.3872058158103724\n",
"loss: 2.3872060406633677\n",
"loss: 2.3872062658126643\n",
"loss: 2.3872064913732993\n",
"loss: 2.3872067166127193\n",
"loss: 2.38720589721323\n",
"loss: 2.3872055203865594\n",
"loss: 2.3872057448636044\n",
"loss: 2.387205969632241\n",
"loss: 2.3872061948051737\n",
"loss: 2.387206419656803\n",
"loss: 2.3872066454234093\n",
"loss: 2.38720622109226\n",
"loss: 2.3872054496058155\n",
"loss: 2.387205673994574\n",
"loss: 2.387205898781459\n",
"loss: 2.3872061232475654\n",
"loss: 2.3872063486162522\n",
"loss: 2.3872065743640327\n",
"loss: 2.38720654403819\n",
"loss: 2.3872055202172575\n",
"loss: 2.3872056032796642\n",
"loss: 2.3872058273645207\n",
"loss: 2.387206052339985\n",
"loss: 2.3872062776900504\n",
"loss: 2.3872065035141774\n",
"loss: 2.3872067282141374\n",
"loss: 2.3872058427484215\n",
"loss: 2.387205531994678\n",
"loss: 2.387205756582158\n",
"loss: 2.387205981539894\n",
"loss: 2.3872062069646356\n",
"loss: 2.387206431282743\n",
"loss: 2.387206657412996\n",
"loss: 2.3872061679858723\n",
"loss: 2.3872054613328837\n",
"loss: 2.3872056859037873\n",
"loss: 2.3872059109348127\n",
"loss: 2.387206134876376\n",
"loss: 2.387206360603658\n",
"loss: 2.3872065860311924\n",
"loss: 2.3872064904558536\n",
"loss: 2.3872054658487727\n",
"loss: 2.3872056154159624\n",
"loss: 2.387205838986268\n",
"loss: 2.3872060643164303\n",
"loss: 2.38720629000679\n",
"loss: 2.387206515465516\n",
"loss: 2.3872067402918593\n",
"loss: 2.3872057868542456\n",
"loss: 2.3872055436974926\n",
"loss: 2.387205768622451\n",
"loss: 2.3872059939050634\n",
"loss: 2.3872062189566643\n",
"loss: 2.387206443389059\n",
"loss: 2.387206669729105\n",
"loss: 2.387206114421231\n",
"loss: 2.3872054733876444\n",
"loss: 2.3872056982757286\n",
"loss: 2.3872059236497964\n",
"loss: 2.3872061471739836\n",
"loss: 2.38720637304622\n",
"loss: 2.387206597910911\n",
"loss: 2.3872064349529696\n",
"loss: 2.3872054090077848\n",
"loss: 2.3872056281842866\n",
"loss: 2.387205851329026\n",
"loss: 2.387206076790427\n",
"loss: 2.387206301261524\n",
"loss: 2.387206527494985\n",
"loss: 2.3872067526753113\n",
"loss: 2.387205728959938\n"
]
}
],
"source": [
"y_train = y_train.reshape(-1, 1)\n",
"\n",
"nn = NeuralNetwork(X_train, y_train)\n",
"\n",
"for i in range(1000000):\n",
" nn.train(X_train, y_train)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Plot the loss vs. the number of epochs"
]
},
{
"cell_type": "code",
"execution_count": 83,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Text(0, 0.5, 'loss')"
]
},
"execution_count": 83,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAjIAAAGwCAYAAACzXI8XAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy89olMNAAAACXBIWXMAAA9hAAAPYQGoP6dpAAA3LklEQVR4nO3de3RU5b3/8c/kMpOEZCYkkISYBBEs91CLiBHlaEEQEG+0tspp8RwvCxsvgD+xtGq1rQY9p1Y9pbTHC+iqyCkVsN6gCBKqBpRL5CKNBqKgkKBAMiGQ6zy/P8IMjFydzJ6dCe/XWnuR7L0z852HBfms5/nuvR3GGCMAAIAoFGN3AQAAAKEiyAAAgKhFkAEAAFGLIAMAAKIWQQYAAEQtggwAAIhaBBkAABC14uwuwGo+n0+7du1SSkqKHA6H3eUAAIDTYIxRbW2tsrOzFRNz4nmXDh9kdu3apdzcXLvLAAAAIdi5c6dycnJOeLzDB5mUlBRJrQPhdrttrgYAAJwOr9er3NzcwO/xE+nwQca/nOR2uwkyAABEmVO1hdDsCwAAohZBBgAARC2CDAAAiFoEGQAAELUIMgAAIGoRZAAAQNQiyAAAgKhla5CZPXu28vPzA/d4KSgo0FtvvRU4fumll8rhcARtkydPtrFiAADQnth6Q7ycnBzNnDlT5557rowxeuGFF3T11Vdrw4YN6t+/vyTp1ltv1a9//evAzyQlJdlVLgAAaGdsDTLjx48P+v6RRx7R7NmztXr16kCQSUpKUlZWlh3lAQCAdq7d9Mi0tLRo/vz5qqurU0FBQWD/Sy+9pC5dumjAgAGaMWOGDh48eNLXaWhokNfrDdoAAEDHZPuzljZt2qSCggLV19crOTlZixYtUr9+/SRJN954o7p3767s7Gxt3LhR9913n8rKyrRw4cITvl5RUZEefvjhSJUPAABs5DDGGDsLaGxs1I4dO1RTU6O//e1vevbZZ1VcXBwIM0dbsWKFRowYofLycvXs2fO4r9fQ0KCGhobA9/6nZ9bU1IT1oZHVBxtVW98sd2K8PInxYXtdAADQ+vvb4/Gc8ve37UtLTqdTvXr10uDBg1VUVKRBgwbpqaeeOu65Q4cOlSSVl5ef8PVcLlfgKigrn3j92JJ/6ZLH39GL739myesDAIBTsz3IfJPP5wuaUTlaaWmpJKlbt24RrOj4YmNaHyveYu+EFgAAZzRbe2RmzJihMWPGKC8vT7W1tZo3b55WrlyppUuXatu2bZo3b57Gjh2r9PR0bdy4UVOnTtXw4cOVn59vZ9mSpFjH4SDjI8gAAGAXW4PMnj179NOf/lS7d++Wx+NRfn6+li5dqssvv1w7d+7U22+/rSeffFJ1dXXKzc3VhAkTdP/999tZckBsTOtkFkEGAAD72BpknnvuuRMey83NVXFxcQSr+XZiDy/KEWQAALBPu+uRiRYxMSwtAQBgN4JMiOJo9gUAwHYEmRDR7AsAgP0IMiGi2RcAAPsRZEJEsy8AAPYjyISIZl8AAOxHkAlRHEEGAADbEWRCFOPgqiUAAOxGkAkRMzIAANiPIBOiWIIMAAC2I8iEiMuvAQCwH0EmRFx+DQCA/QgyIaLZFwAA+xFkQhQXS48MAAB2I8iEKIZnLQEAYDuCTIjiaPYFAMB2BJkQ0ewLAID9CDIhotkXAAD7EWRCRLMvAAD2I8iEiGZfAADsR5AJEc2+AADYjyATohiafQEAsB1BJkSxNPsCAGA7gkyIaPYFAMB+BJkQ0ewLAID9CDIh8jf7+ggyAADYhiATIn+zbzNBBgAA2xBkQhSYkaHZFwAA2xBkQhTLjAwAALYjyIQo1n9DvBaCDAAAdiHIhCgupvWqpSafz+ZKAAA4cxFkQuSKax26xmaCDAAAdiHIhMgZ52/2lZpbCDMAANiBIBOi+NgjQ9dEnwwAALYgyITIPyMjsbwEAIBdCDIhiotx6PBTCtTQ0mJvMQAAnKEIMiFyOByB5SWWlgAAsAdBpg1csVy5BACAnQgybeDkEmwAAGxFkGmDI0tLBBkAAOxAkGkD/4xMAzMyAADYwtYgM3v2bOXn58vtdsvtdqugoEBvvfVW4Hh9fb0KCwuVnp6u5ORkTZgwQVVVVTZWHIylJQAA7GVrkMnJydHMmTO1bt06rV27Vt///vd19dVXa8uWLZKkqVOn6rXXXtOCBQtUXFysXbt26brrrrOz5CD+paVGlpYAALBFnJ1vPn78+KDvH3nkEc2ePVurV69WTk6OnnvuOc2bN0/f//73JUlz5sxR3759tXr1al144YV2lBzEPyPTxIwMAAC2aDc9Mi0tLZo/f77q6upUUFCgdevWqampSSNHjgyc06dPH+Xl5amkpOSEr9PQ0CCv1xu0WcXFjAwAALayPchs2rRJycnJcrlcmjx5shYtWqR+/fqpsrJSTqdTqampQednZmaqsrLyhK9XVFQkj8cT2HJzcy2rPT6u9da+9MgAAGAP24NM7969VVpaqjVr1uj222/XpEmT9PHHH4f8ejNmzFBNTU1g27lzZxirDebkhngAANjK1h4ZSXI6nerVq5ckafDgwfrwww/11FNP6Uc/+pEaGxtVXV0dNCtTVVWlrKysE76ey+WSy+WyumxJUkJ8rCSpvplnLQEAYAfbZ2S+yefzqaGhQYMHD1Z8fLyWL18eOFZWVqYdO3aooKDAxgqPSHS2BpmDjQQZAADsYOuMzIwZMzRmzBjl5eWptrZW8+bN08qVK7V06VJ5PB7dfPPNmjZtmtLS0uR2u3XnnXeqoKCgXVyxJElJBBkAAGxla5DZs2ePfvrTn2r37t3yeDzKz8/X0qVLdfnll0uSfv/73ysmJkYTJkxQQ0ODRo8erT/+8Y92lhwkydk6fIcam22uBACAM5OtQea555476fGEhATNmjVLs2bNilBF305iPDMyAADYqd31yEQT/9LSIYIMAAC2IMi0AT0yAADYiyDTBomHe2QONhFkAACwA0GmDY4sLdHsCwCAHQgybcB9ZAAAsBdBpg06+ZeWCDIAANiCINMGKQmtQaa2vsnmSgAAODMRZNrAH2S8h+iRAQDADgSZNnAnxkuSGlt8qufKJQAAIo4g0wbJzjg5HK1fe1leAgAg4ggybRAT41Cyy98nw/ISAACRRpBpI3dC6/KS9xAzMgAARBpBpo2OXLnEjAwAAJFGkGkj/4wMQQYAgMgjyLRR4BJsmn0BAIg4gkwb+S/B5qZ4AABEHkGmjbgpHgAA9iHItNGRHhlmZAAAiDSCTBsd6ZFhRgYAgEgjyLQRPTIAANiHINNG/qWlGm6IBwBAxBFk2ig1qTXI7D9IkAEAINIIMm3UOckpSao+2GhzJQAAnHkIMm3UuVPrjEz1wSYZY2yuBgCAMwtBpo38MzLNPqPaBq5cAgAgkggybZQQH6uE+NZh3F/H8hIAAJFEkAmDtMOzMjT8AgAQWQSZMEgNBBlmZAAAiCSCTBgcafglyAAAEEkEmTAIzMjUsbQEAEAkEWTCII2lJQAAbEGQCYPOgbv7EmQAAIgkgkwYpHLVEgAAtiDIhAHNvgAA2IMgEwY0+wIAYA+CTBjQ7AsAgD0IMmHQmSADAIAtCDJhkHq4R6a+yaf6phabqwEA4MxBkAmDFFec4mIckpiVAQAgkggyYeBwOGj4BQDABgSZMPHfFI9LsAEAiBxbg0xRUZGGDBmilJQUZWRk6JprrlFZWVnQOZdeeqkcDkfQNnnyZJsqPjF/w+8+ggwAABFja5ApLi5WYWGhVq9erWXLlqmpqUmjRo1SXV1d0Hm33nqrdu/eHdgef/xxmyo+sdTAYwpYWgIAIFLi7HzzJUuWBH0/d+5cZWRkaN26dRo+fHhgf1JSkrKysiJd3reS1ql1Rqa6jhkZAAAipV31yNTU1EiS0tLSgva/9NJL6tKliwYMGKAZM2bo4MGDJ3yNhoYGeb3eoC0SeN4SAACRZ+uMzNF8Pp+mTJmiYcOGacCAAYH9N954o7p3767s7Gxt3LhR9913n8rKyrRw4cLjvk5RUZEefvjhSJUdkJLQOpS19QQZAAAipd0EmcLCQm3evFnvvvt
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import matplotlib.pyplot as plt\n",
"plt.plot(nn.loss_history)\n",
"plt.xlabel(\"# epochs / 1000\")\n",
"plt.ylabel(\"loss\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": []
},
{
"cell_type": "code",
"execution_count": 84,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"<matplotlib.colorbar.Colorbar at 0x147ecf700>"
]
},
"execution_count": 84,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAsIAAAJMCAYAAADwqMBxAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy89olMNAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdd3gU1f7H8ffMpvce0kkIhN47gvQiIip2EezXdq+K/lRs2L3X3nsBlGJBmiBdpEjvPUBCCiQhvbfdmd8fCzEhG0jZNPJ9PU8eyJzZM2dDSD579jvnKLqu6wghhBBCCNHCqI09ACGEEEIIIRqDBGEhhBBCCNEiSRAWQgghhBAtkgRhIYQQQgjRIkkQFkIIIYQQLZIEYSGEEEII0SJJEBZCCCGEEC2SBGEhhBBCCNEiSRAWQgghhBAtkgRhIYQQQgjRIkkQFkIIIYQQjWrDhg1MmDCBwMBAFEVh0aJFl3xMcXExzz33HGFhYdjb29O6dWu+++67Gl3XppbjFUIIIYQQwiry8/Pp1q0bd999N9dff321HnPTTTeRkpLCt99+S2RkJElJSWiaVqPrShAWQgghhBCNaty4cYwbN67a569YsYK//vqLmJgYvLy8AGjdunWNr3vZBWFN0zhz5gyurq4oitLYwxFCCCFEC6brOrm5uQQGBqKqjV+RWlRURElJSYNcS9f1SlnM3t4ee3v7Ove9ZMkSevfuzVtvvcUPP/yAs7Mz11xzDa+++iqOjo7V7ueyC8JnzpwhJCSksYchhBBCCFEmISGB4ODgRh1DUVERgUGtycxIaZDrubi4kJeXV+HYjBkzeOmll+rcd0xMDJs2bcLBwYGFCxeSlpbGQw89RHp6Ot9//321+7nsgrCrqysAe7/9Glcnp0YejRBCCCFastyCArrfc19ZPmlMJSUlZGakMPuXAzg51+94CvJzmXJjFxISEnBzcys7bo3ZYDBXACiKwpw5c3B3dwfgvffe44YbbuCzzz6r9qzwZReEz0/Buzo5SRAWQgghRJPQlMo1nZxdcXZ2u/SJVuDm5lYhCFtLQEAAQUFBZSEYoEOHDui6TmJiIm3btq1WP41frCKEEEIIIUQNDBo0iDNnzlQovYiOjkZV1RqVoEgQFkIIIYQQjSovL4+9e/eyd+9eAGJjY9m7dy/x8fEATJ8+nSlTppSdf9ttt+Ht7c1dd93F4cOH2bBhA//3f//H3XffXaOb5SQICyGEEEKIRrVz50569OhBjx49AJg2bRo9evTgxRdfBCApKaksFIP5RrzVq1eTlZVF7969uf3225kwYQIfffRRja572dUICyGEEEKI5mXo0KHoul5l+8yZMysda9++PatXr67TdWVGWAghhBBCtEgShIUQQgghRIskQVgIIYQQQrRIEoSFEEIIIUSLJEFYCCGEEEK0SBKEhRBCCCFEiyRBWAghhBBCtEgShIUQQgghRIskQVgIIYQQQrRIEoSFEEIIIUSLJEFYCCGEEEK0SBKEhRBCCCFEiyRBWAghhBBCtEgShIUQQgghRIskQVgIIYQQQrRIEoSFEEIIIUSLJEFYCCGEEEK0SBKEhRBCCCFEiyRBWAghhBBCtEj1GoQ3bNjAhAkTCAwMRFEUFi1adNHz169fj6IolT6Sk5Prc5hCCCGEEKIFqtcgnJ+fT7du3fj0009r9Lhjx46RlJRU9uHn51dPIxRCCCGEEC2VTX12Pm7cOMaNG1fjx/n5+eHh4WH9AQkhhBBCCHFOk6wR7t69OwEBAYwaNYrNmzdf9Nzi4mJycnIqfAghhBBCCHEpTSoIBwQE8MUXX7BgwQIWLFhASEgIQ4cOZffu3VU+5s0338Td3b3sIyQkpAFHLIQQQgghmqt6LY2oqaioKKKioso+HzhwICdPnuT999/nhx9+sPiY6dOnM23atLLPc3JyJAwLIYQQQohLalJB2JK+ffuyadOmKtvt7e2xt7dvwBEJIYQQQojLQZMqjbBk7969BAQENPYwhBBCCCHEZaZeZ4Tz8vI4ceJE2eexsbHs3bsXLy8vQkNDmT59OqdPn2b27NkAfPDBB4SHh9OpUyeKior45ptvWLduHatWrarPYQohhBBCiBaoXoPwzp07GTZsWNnn52t5p06dysyZM0lKSiI+Pr6svaSkhCeeeILTp0/j5ORE165dWbNmTYU+hBBCCCGEsIZ6DcJDhw5F1/Uq22fOnFnh86eeeoqnnnqqPockhBBCCCEE0AxqhIUQQgghhKgPEoSFEEIIIUSLJEFYCCGEEEK0SBKEhRBCCCFEiyRBWAghhBBCtEgShIUQQgghRIskQVgIIYQQQrRIEoSFEEIIIUSLJEFYCCGEEEK0SBKEhRBCCCFEiyRBWAghhBBCtEgShIUQQgghRIskQVgIIYQQQrRIEoSFEEIIIUSLJEFYCCGEEEK0SBKEhRBCCCFEiyRBWAghhBBCtEgShIUQQgghRIskQVgIIYQQQrRIEoSFEEIIIUSLJEFYCCGEEEK0SBKEhRBCCCFEiyRBWAghhBBCtEgShIUQQgghRIskQVgIIYQQQrRIEoSFEEIIIUSLJEFYCCGEEEI0qg0bNjBhwgQCAwNRFIVFixZV+7GbN2/GxsaG7t271/i6EoSFEEIIIUSjys/Pp1u3bnz66ac1elxWVhZTpkxhxIgRtbquTa0eJYQQQgghhJWMGzeOcePG1fhxDzzwALfddhsGg6FGs8jnyYywEEIIIYRodr7//ntiYmKYMWNGrfuQGWEhhBBCCFEvcnJyKnxub2+Pvb19nfs9fvw4zzzzDBs3bsTGpvZxVoKwEEIIIUQL4unjirOLa71ew85RByAkJKTC8RkzZvDSSy/VqW+TycRtt93Gyy+/TLt27erUlwRhIYQQQghRLxISEnBzcyv73Bqzwbm5uezcuZM9e/bwyCOPAKBpGrquY2Njw6pVqxg+fHi1+pIgLIQQQggh6oWbm1uFIGytPg8cOFDh2Geffca6dev49ddfCQ8Pr3ZfEoSFEEIIIUSjysvL48SJE2Wfx8bGsnfvXry8vAgNDWX69OmcPn2a2bNno6oqnTt3rvB4Pz8/HBwcKh2/FAnCQgghhBCiUe3cuZNhw4aVfT5t2jQApk6dysyZM0lKSiI+Pt7q11V0Xdet3msjysnJwd3dnZPz5uDq5NTYwxFCCCFEC5ZbUECbW28nOzvb6iUCNXU+I63dno6zS/2OJT8vhxF9vZvE874YWUdYCCGEEEK0SBKEhRBCCCFEiyRBWAghhBBCtEgShIUQQgghRIskq0YIIcQlaJpGXEoKcSkpGEtKcXV1oV1wMJ6u9bszkxBCiPolQVgIIS6ioLiYtVu3kh0bS6uSUpxVA0kmE8fd3ejUqyc92rZt7CEKIYSoJQnCQghxEet37kQ7Gs11/n54OzoCoOk6B9PT2b51K66OTkQGBzXyKIUQQtSG1AgLIUQVUjIzyYg9xRU+3mUhGEBVFLr6+BBeXMKh48cbcYRCCCHqQoKwEEJU4UxaGs6FhQQ6O1tsj3R3Jzc5mbzCwgYemRBCCGuQICyEEFXQNB0bQFEUi+22qgq6zmW2QacQQrQYEoSFEKIK3u5uZNvakFFUZLE9PjcXe09PnB0cGnhkQgghrEGCsBBCVCHUzw/HwCC2pJyl1GSq0Jacn89Ro4m2kZGoqvwoFUKI5khWjRBCiCqoqsrgvn1YV1zML4mJtHVwwNnGhqSCAuJUA75dOtMlIryxhymEEKKWJAgLIVoETdMoLi3Fwc6uyppfS/w8PBg/YjiH4+I4GnsKY0kJLkFB9AoPp21wkMwGCyFEMyZBWAhxWYtOSOCT3xayeOMmCktL8Xdz4/axY3jo2om4VbEaxIVcnZzo16ED/Tp0qOfRCiGEaEgShIUQl60dR49y8wsz8DIama5ptAa25OTw5a8L+OPvv1n03zdlm2QhhGjB5D09IcRlyWQy8eBb79C9tJRDmsYLwB3AZ8AOTSPlTBKvzfqhkUcphBCiMUkQFkJcltbt2UN8ejrv6zouF7R1AB7XNH5d/yc5+fmNMTwhhBBNgARhIcRl6VDsKXwMBnpX0X4VUFh
"text/plain": [
"<Figure size 900x700 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"import matplotlib.pyplot as plt\n",
"from matplotlib.colors import ListedColormap\n",
"\n",
"cm = plt.cm.RdBu\n",
"cm_bright = ListedColormap([\"#FF0000\", \"#0000FF\"])\n",
"\n",
"xv = np.linspace(x_min, x_max, 10)\n",
"yv = np.linspace(y_min, y_max, 10)\n",
"Xv, Yv = np.meshgrid(xv, yv)\n",
"XYpairs = np.vstack([ Xv.reshape(-1), Yv.reshape(-1)])\n",
"zv = nn.predict(XYpairs.T)\n",
"Zv = zv.reshape(Xv.shape)\n",
"\n",
"fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(9, 7))\n",
"ax.set_aspect(1)\n",
"cn = ax.contourf(Xv, Yv, Zv, cmap=\"coolwarm_r\", alpha=0.4)\n",
"\n",
"ax.scatter(X_train[:, 0], X_train[:, 1], c=y_train, cmap=cm_bright, edgecolors=\"k\")\n",
"\n",
"# Plot the testing points\n",
"ax.scatter(X_test[:, 0], X_test[:, 1], c=y_test, cmap=cm_bright, alpha=0.4, edgecolors=\"k\")\n",
"\n",
"ax.set_xlim(x_min, x_max)\n",
"ax.set_ylim(y_min, y_max)\n",
"# ax.set_xticks(())\n",
"# ax.set_yticks(())\n",
"\n",
"fig.colorbar(cn)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.9"
},
"vscode": {
"interpreter": {
"hash": "b0fa6594d8f4cbf19f97940f81e996739fb7646882a419484c72d19e05852a7e"
}
}
},
"nbformat": 4,
"nbformat_minor": 4
}