Machine Learning Kurs im Rahmen der Studierendentage im SS 2023
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

1320 lines
162 KiB

  1. {
  2. "cells": [
  3. {
  4. "attachments": {},
  5. "cell_type": "markdown",
  6. "metadata": {},
  7. "source": [
  8. "# A simple neural network with one hidden layer in pure Python"
  9. ]
  10. },
  11. {
  12. "attachments": {},
  13. "cell_type": "markdown",
  14. "metadata": {},
  15. "source": [
  16. "## A simple neural network class with ReLU activation function"
  17. ]
  18. },
  19. {
  20. "cell_type": "code",
  21. "execution_count": 77,
  22. "metadata": {},
  23. "outputs": [],
  24. "source": [
  25. "# A simple feed-forward neutral network with on hidden layer\n",
  26. "# see also https://towardsdatascience.com/how-to-build-your-own-neural-network-from-scratch-in-python-68998a08e4f6\n",
  27. "\n",
  28. "import numpy as np\n",
  29. "\n",
  30. "class NeuralNetwork:\n",
  31. " def __init__(self, x, y):\n",
  32. " n1 = 4 # number of neurons in the hidden layer\n",
  33. " self.input = x\n",
  34. " self.weights1 = np.random.rand(self.input.shape[1],n1)\n",
  35. " self.bias1 = np.random.rand(n1)\n",
  36. " self.weights2 = np.random.rand(n1,1)\n",
  37. " self.bias2 = np.random.rand(1) \n",
  38. " self.y = y\n",
  39. " self.output = np.zeros(y.shape)\n",
  40. " self.learning_rate = 0.00001\n",
  41. " self.n_train = 0\n",
  42. " self.loss_history = []\n",
  43. "\n",
  44. " def relu(self, x):\n",
  45. " return np.where(x>0, x, 0)\n",
  46. " \n",
  47. " def relu_derivative(self, x):\n",
  48. " return np.where(x>0, 1, 0)\n",
  49. "\n",
  50. " def feedforward(self):\n",
  51. " self.layer1 = self.relu(self.input @ self.weights1 + self.bias1)\n",
  52. " self.output = self.relu(self.layer1 @ self.weights2 + self.bias2)\n",
  53. "\n",
  54. " def backprop(self):\n",
  55. "\n",
  56. " # delta1: [m, 1], m = number of training data\n",
  57. " delta1 = 2 * (self.y - self.output) * self.relu_derivative(self.output)\n",
  58. "\n",
  59. " # Gradient w.r.t. weights from hidden to output layer: [n1, 1] matrix, n1 = # neurons in hidden layer\n",
  60. " d_weights2 = self.layer1.T @ delta1\n",
  61. " d_bias2 = np.sum(delta1) \n",
  62. " \n",
  63. " # shape of delta2: [m, n1], m = number of training data, n1 = # neurons in hidden layer\n",
  64. " delta2 = (delta1 @ self.weights2.T) * self.relu_derivative(self.layer1)\n",
  65. " d_weights1 = self.input.T @ delta2\n",
  66. " d_bias1 = np.ones(delta2.shape[0]) @ delta2 \n",
  67. " \n",
  68. " # update weights and biases\n",
  69. " self.weights1 += self.learning_rate * d_weights1\n",
  70. " self.weights2 += self.learning_rate * d_weights2\n",
  71. "\n",
  72. " self.bias1 += self.learning_rate * d_bias1\n",
  73. " self.bias2 += self.learning_rate * d_bias2\n",
  74. "\n",
  75. " def train(self, X, y):\n",
  76. " self.output = np.zeros(y.shape)\n",
  77. " self.input = X\n",
  78. " self.y = y\n",
  79. " self.feedforward()\n",
  80. " self.backprop()\n",
  81. " self.n_train += 1\n",
  82. " if (self.n_train %1000 == 0):\n",
  83. " loss = np.sum((self.y - self.output)**2)\n",
  84. " print(\"loss: \", loss)\n",
  85. " self.loss_history.append(loss)\n",
  86. " \n",
  87. " def predict(self, X):\n",
  88. " self.output = np.zeros(y.shape)\n",
  89. " self.input = X\n",
  90. " self.feedforward()\n",
  91. " return self.output\n",
  92. " \n",
  93. " def loss_history(self):\n",
  94. " return self.loss_history\n"
  95. ]
  96. },
  97. {
  98. "cell_type": "markdown",
  99. "metadata": {},
  100. "source": [
  101. "## Create toy data\n",
  102. "We create three toy data sets\n",
  103. "1. two moon-like distributions\n",
  104. "2. circles\n",
  105. "3. linearly separable data sets"
  106. ]
  107. },
  108. {
  109. "cell_type": "code",
  110. "execution_count": 78,
  111. "metadata": {},
  112. "outputs": [],
  113. "source": [
  114. "# https://scikit-learn.org/stable/auto_examples/classification/plot_classifier_comparison.html#sphx-glr-auto-examples-classification-plot-classifier-comparison-py\n",
  115. "import numpy as np\n",
  116. "from sklearn.datasets import make_moons, make_circles, make_classification\n",
  117. "from sklearn.model_selection import train_test_split\n",
  118. "\n",
  119. "X, y = make_classification(\n",
  120. " n_features=2, n_redundant=0, n_informative=2, random_state=1, n_clusters_per_class=1\n",
  121. ")\n",
  122. "rng = np.random.RandomState(2)\n",
  123. "X += 2 * rng.uniform(size=X.shape)\n",
  124. "linearly_separable = (X, y)\n",
  125. "\n",
  126. "datasets = [\n",
  127. " make_moons(n_samples=200, noise=0.1, random_state=0),\n",
  128. " make_circles(n_samples=200, noise=0.1, factor=0.5, random_state=1),\n",
  129. " linearly_separable,\n",
  130. "]"
  131. ]
  132. },
  133. {
  134. "cell_type": "markdown",
  135. "metadata": {},
  136. "source": [
  137. "## Create training and test data set"
  138. ]
  139. },
  140. {
  141. "cell_type": "code",
  142. "execution_count": 79,
  143. "metadata": {},
  144. "outputs": [],
  145. "source": [
  146. "# datasets: 0 = moons, 1 = circles, 2 = linearly separable\n",
  147. "X, y = datasets[1]\n",
  148. "X_train, X_test, y_train, y_test = train_test_split(\n",
  149. " X, y, test_size=0.4, random_state=42\n",
  150. ")\n",
  151. "\n",
  152. "x_min, x_max = X[:, 0].min() - 0.5, X[:, 0].max() + 0.5\n",
  153. "y_min, y_max = X[:, 1].min() - 0.5, X[:, 1].max() + 0.5\n"
  154. ]
  155. },
  156. {
  157. "cell_type": "markdown",
  158. "metadata": {},
  159. "source": [
  160. "## Train the model"
  161. ]
  162. },
  163. {
  164. "cell_type": "code",
  165. "execution_count": 80,
  166. "metadata": {},
  167. "outputs": [
  168. {
  169. "name": "stdout",
  170. "output_type": "stream",
  171. "text": [
  172. "loss: 34.671913102152374\n",
  173. "loss: 31.424782860564203\n",
  174. "loss: 29.772915496524135\n",
  175. "loss: 28.762023680772913\n",
  176. "loss: 28.00200726838712\n",
  177. "loss: 27.32590942137339\n",
  178. "loss: 26.752368734071535\n",
  179. "loss: 26.230440689447903\n",
  180. "loss: 25.673463509689576\n",
  181. "loss: 25.012834504148312\n",
  182. "loss: 24.289682045629544\n",
  183. "loss: 23.555645514965384\n",
  184. "loss: 22.76462670346343\n",
  185. "loss: 21.904104226889068\n",
  186. "loss: 20.943637847221698\n",
  187. "loss: 19.89434572572985\n",
  188. "loss: 18.727285500049177\n",
  189. "loss: 17.485616842253226\n",
  190. "loss: 16.142413632344777\n",
  191. "loss: 14.852364407067075\n",
  192. "loss: 13.635545514668182\n",
  193. "loss: 12.456629856179049\n",
  194. "loss: 11.347265353073684\n",
  195. "loss: 10.419340643305858\n",
  196. "loss: 9.610799938794724\n",
  197. "loss: 8.897580679158944\n",
  198. "loss: 8.258004600111189\n",
  199. "loss: 7.684500186535497\n",
  200. "loss: 7.1748549390018574\n",
  201. "loss: 6.718209468903557\n",
  202. "loss: 6.309864315153381\n",
  203. "loss: 5.944399105330259\n",
  204. "loss: 5.621553827962666\n",
  205. "loss: 5.333909699361839\n",
  206. "loss: 5.077286239602076\n",
  207. "loss: 4.84889061532151\n",
  208. "loss: 4.646022685947024\n",
  209. "loss: 4.465758350759858\n",
  210. "loss: 4.305913173647123\n",
  211. "loss: 4.1640718175126095\n",
  212. "loss: 4.039102963682375\n",
  213. "loss: 3.9297524332426623\n",
  214. "loss: 3.832402269592843\n",
  215. "loss: 3.7453438160159322\n",
  216. "loss: 3.6674507148484525\n",
  217. "loss: 3.5977884037626\n",
  218. "loss: 3.535827925326507\n",
  219. "loss: 3.479967988008011\n",
  220. "loss: 3.4291169910556207\n",
  221. "loss: 3.3829982181684475\n",
  222. "loss: 3.3416403251567335\n",
  223. "loss: 3.304229099924413\n",
  224. "loss: 3.2684585556590573\n",
  225. "loss: 3.2354701067381977\n",
  226. "loss: 3.2051369151544105\n",
  227. "loss: 3.1772080296379404\n",
  228. "loss: 3.1476337866558435\n",
  229. "loss: 3.1193880834750365\n",
  230. "loss: 3.0930575696528475\n",
  231. "loss: 3.0684279860561725\n",
  232. "loss: 3.0453159280825632\n",
  233. "loss: 3.023566344470656\n",
  234. "loss: 3.0031494958762517\n",
  235. "loss: 2.9838385760074786\n",
  236. "loss: 2.965539553283577\n",
  237. "loss: 2.9481708868131866\n",
  238. "loss: 2.9316489223756497\n",
  239. "loss: 2.9158898824688726\n",
  240. "loss: 2.900839148101508\n",
  241. "loss: 2.8864469332954905\n",
  242. "loss: 2.8726683402256388\n",
  243. "loss: 2.8594627350494104\n",
  244. "loss: 2.846788596539098\n",
  245. "loss: 2.834699126928559\n",
  246. "loss: 2.8238134101918453\n",
  247. "loss: 2.8133709056375507\n",
  248. "loss: 2.8032597976075766\n",
  249. "loss: 2.793428392265398\n",
  250. "loss: 2.7839305952541142\n",
  251. "loss: 2.7747494784865347\n",
  252. "loss: 2.765836907330045\n",
  253. "loss: 2.757121848680945\n",
  254. "loss: 2.7486569559419625\n",
  255. "loss: 2.7404305726008564\n",
  256. "loss: 2.732433904640994\n",
  257. "loss: 2.724658913060166\n",
  258. "loss: 2.717098191391666\n",
  259. "loss: 2.709744869114474\n",
  260. "loss: 2.7025925339845283\n",
  261. "loss: 2.695634579597031\n",
  262. "loss: 2.688854332606485\n",
  263. "loss: 2.682427203225914\n",
  264. "loss: 2.6764899407900025\n",
  265. "loss: 2.6707857438883167\n",
  266. "loss: 2.6652986297014336\n",
  267. "loss: 2.6600185764704065\n",
  268. "loss: 2.654938846386348\n",
  269. "loss: 2.650048013830516\n",
  270. "loss: 2.6434375100435394\n",
  271. "loss: 2.6372726876694945\n",
  272. "loss: 2.6315142701295082\n",
  273. "loss: 2.6260949485933374\n",
  274. "loss: 2.620971039483914\n",
  275. "loss: 2.61174470987536\n",
  276. "loss: 2.5984776428986\n",
  277. "loss: 2.587556467070711\n",
  278. "loss: 2.5779657263762803\n",
  279. "loss: 2.5692396803357767\n",
  280. "loss: 2.5611227223903583\n",
  281. "loss: 2.5534830205555155\n",
  282. "loss: 2.5462350514450636\n",
  283. "loss: 2.539331795274396\n",
  284. "loss: 2.5327378554215794\n",
  285. "loss: 2.5264333984648015\n",
  286. "loss: 2.5203974791572867\n",
  287. "loss: 2.5146177883938456\n",
  288. "loss: 2.5090780321819945\n",
  289. "loss: 2.5037668178628807\n",
  290. "loss: 2.4986865327324224\n",
  291. "loss: 2.4938357429187015\n",
  292. "loss: 2.4892125195154855\n",
  293. "loss: 2.484814217154528\n",
  294. "loss: 2.480624991605507\n",
  295. "loss: 2.47663936164734\n",
  296. "loss: 2.4728453322029287\n",
  297. "loss: 2.469227702733714\n",
  298. "loss: 2.465784971369546\n",
  299. "loss: 2.4625031729417586\n",
  300. "loss: 2.459373586417242\n",
  301. "loss: 2.456395742549959\n",
  302. "loss: 2.453556149147296\n",
  303. "loss: 2.450848516477792\n",
  304. "loss: 2.448257826886593\n",
  305. "loss: 2.4457793902422473\n",
  306. "loss: 2.4434057564290246\n",
  307. "loss: 2.4411335977457957\n",
  308. "loss: 2.43895711642629\n",
  309. "loss: 2.4368736364557093\n",
  310. "loss: 2.4348749182486014\n",
  311. "loss: 2.432961782234024\n",
  312. "loss: 2.4311267156550267\n",
  313. "loss: 2.4293686022331507\n",
  314. "loss: 2.4276818857227433\n",
  315. "loss: 2.4260653790156557\n",
  316. "loss: 2.4245172581509893\n",
  317. "loss: 2.423035547508019\n",
  318. "loss: 2.4216137562678623\n",
  319. "loss: 2.420252690046535\n",
  320. "loss: 2.4189495881209595\n",
  321. "loss: 2.4176996906636115\n",
  322. "loss: 2.4165033152293676\n",
  323. "loss: 2.415355506487388\n",
  324. "loss: 2.414254264583601\n",
  325. "loss: 2.4131997637552898\n",
  326. "loss: 2.4121900926116586\n",
  327. "loss: 2.411220865950546\n",
  328. "loss: 2.410291344402852\n",
  329. "loss: 2.4094025182811554\n",
  330. "loss: 2.4085499517354947\n",
  331. "loss: 2.407731441094278\n",
  332. "loss: 2.406946144975749\n",
  333. "loss: 2.4061932934032764\n",
  334. "loss: 2.4054693005333485\n",
  335. "loss: 2.4047725159312234\n",
  336. "loss: 2.404104507441233\n",
  337. "loss: 2.4034628933277684\n",
  338. "loss: 2.4028442537087553\n",
  339. "loss: 2.4022504549932577\n",
  340. "loss: 2.4016803840643326\n",
  341. "loss: 2.4011319551366914\n",
  342. "loss: 2.4006060782593357\n",
  343. "loss: 2.400098897713013\n",
  344. "loss: 2.39961231412508\n",
  345. "loss: 2.3991458381267075\n",
  346. "loss: 2.3986954028524448\n",
  347. "loss: 2.3982632819114857\n",
  348. "loss: 2.3978461833600178\n",
  349. "loss: 2.3974452187294\n",
  350. "loss: 2.3970596690136103\n",
  351. "loss: 2.396686140489889\n",
  352. "loss: 2.396328903510832\n",
  353. "loss: 2.3959835555706825\n",
  354. "loss: 2.39565059757921\n",
  355. "loss: 2.3953300040263765\n",
  356. "loss: 2.3950220475433692\n",
  357. "loss: 2.3947263525017353\n",
  358. "loss: 2.3944396560020063\n",
  359. "loss: 2.3941632766670184\n",
  360. "loss: 2.3938993096058816\n",
  361. "loss: 2.39364294492824\n",
  362. "loss: 2.3933965389172958\n",
  363. "loss: 2.3931594462236814\n",
  364. "loss: 2.392932651756755\n",
  365. "loss: 2.392711772531679\n",
  366. "loss: 2.3925004166785255\n",
  367. "loss: 2.392296083041902\n",
  368. "loss: 2.392100842332238\n",
  369. "loss: 2.3919111950972347\n",
  370. "loss: 2.3917290333380343\n",
  371. "loss: 2.391553470675823\n",
  372. "loss: 2.3913848103373843\n",
  373. "loss: 2.391220583744764\n",
  374. "loss: 2.3910649553444348\n",
  375. "loss: 2.390914207314057\n",
  376. "loss: 2.3907697541809365\n",
  377. "loss: 2.390630728668481\n",
  378. "loss: 2.390498495256061\n",
  379. "loss: 2.390369937885896\n",
  380. "loss: 2.390247045664548\n",
  381. "loss: 2.3901273102670126\n",
  382. "loss: 2.3900148760237014\n",
  383. "loss: 2.389905238354612\n",
  384. "loss: 2.3898001952359604\n",
  385. "loss: 2.3896980625916764\n",
  386. "loss: 2.3896010024421126\n",
  387. "loss: 2.3895086082689625\n",
  388. "loss: 2.389417860716362\n",
  389. "loss: 2.3893319903283574\n",
  390. "loss: 2.3892491959225457\n",
  391. "loss: 2.3891687694660892\n",
  392. "loss: 2.3890934197829052\n",
  393. "loss: 2.389019136431073\n",
  394. "loss: 2.388947788821395\n",
  395. "loss: 2.3888805609783272\n",
  396. "loss: 2.3888150688462746\n",
  397. "loss: 2.3887518707039823\n",
  398. "loss: 2.388691911298264\n",
  399. "loss: 2.388634014187227\n",
  400. "loss: 2.38857805590494\n",
  401. "loss: 2.388523918864406\n",
  402. "loss: 2.388473510627236\n",
  403. "loss: 2.388423436117929\n",
  404. "loss: 2.3883753262717216\n",
  405. "loss: 2.3883303602155683\n",
  406. "loss: 2.3882866724515304\n",
  407. "loss: 2.388244442215278\n",
  408. "loss: 2.3882033453528138\n",
  409. "loss: 2.3881652931929307\n",
  410. "loss: 2.38812677875453\n",
  411. "loss: 2.388091060218199\n",
  412. "loss: 2.388056092752744\n",
  413. "loss: 2.3880237467127237\n",
  414. "loss: 2.3879919924582405\n",
  415. "loss: 2.387960763231075\n",
  416. "loss: 2.387931658872838\n",
  417. "loss: 2.3879034749475188\n",
  418. "loss: 2.387875347245705\n",
  419. "loss: 2.387849402809727\n",
  420. "loss: 2.387823662709372\n",
  421. "loss: 2.387799967311831\n",
  422. "loss: 2.387776360266166\n",
  423. "loss: 2.387754683818766\n",
  424. "loss: 2.3877329827919898\n",
  425. "loss: 2.3877121910475196\n",
  426. "loss: 2.3876929899702644\n",
  427. "loss: 2.3876729319637\n",
  428. "loss: 2.387654431642712\n",
  429. "loss: 2.387637541283918\n",
  430. "loss: 2.387620354216107\n",
  431. "loss: 2.387604686551482\n",
  432. "loss: 2.3875886398363733\n",
  433. "loss: 2.387573984851815\n",
  434. "loss: 2.3875588098577873\n",
  435. "loss: 2.3875448805324915\n",
  436. "loss: 2.387532242680253\n",
  437. "loss: 2.3875189700442068\n",
  438. "loss: 2.387506861417358\n",
  439. "loss: 2.387495291770735\n",
  440. "loss: 2.387484231905034\n",
  441. "loss: 2.3874731451723217\n",
  442. "loss: 2.387462274358618\n",
  443. "loss: 2.3874519373951664\n",
  444. "loss: 2.3874425920575746\n",
  445. "loss: 2.3874333869787168\n",
  446. "loss: 2.3874249670575987\n",
  447. "loss: 2.3874162377994788\n",
  448. "loss: 2.3874073874920074\n",
  449. "loss: 2.3873990135367182\n",
  450. "loss: 2.3873914786024937\n",
  451. "loss: 2.387384759212198\n",
  452. "loss: 2.387377049938677\n",
  453. "loss: 2.387371052646234\n",
  454. "loss: 2.387363939101346\n",
  455. "loss: 2.3873584890193866\n",
  456. "loss: 2.3873519773705674\n",
  457. "loss: 2.3873461623087593\n",
  458. "loss: 2.387341023755127\n",
  459. "loss: 2.387336250114317\n",
  460. "loss: 2.3873309349740612\n",
  461. "loss: 2.387325950896895\n",
  462. "loss: 2.3873215752647576\n",
  463. "loss: 2.3873160311256796\n",
  464. "loss: 2.3873128208612444\n",
  465. "loss: 2.3873084158082647\n",
  466. "loss: 2.387304558605306\n",
  467. "loss: 2.3872995395407717\n",
  468. "loss: 2.387296679993435\n",
  469. "loss: 2.387292629417728\n",
  470. "loss: 2.3872890739934816\n",
  471. "loss: 2.387285997937587\n",
  472. "loss: 2.387283389935851\n",
  473. "loss: 2.3872794993942383\n",
  474. "loss: 2.387277790051124\n",
  475. "loss: 2.387274775181382\n",
  476. "loss: 2.387272183856358\n",
  477. "loss: 2.387269056421407\n",
  478. "loss: 2.3872664925693012\n",
  479. "loss: 2.387264376208358\n",
  480. "loss: 2.3872623679945018\n",
  481. "loss: 2.3872600043678207\n",
  482. "loss: 2.3872580041281823\n",
  483. "loss: 2.387256273384648\n",
  484. "loss: 2.3872533408675327\n",
  485. "loss: 2.3872523812347737\n",
  486. "loss: 2.3872500233166005\n",
  487. "loss: 2.387247988754982\n",
  488. "loss: 2.387246267292523\n",
  489. "loss: 2.3872448527441414\n",
  490. "loss: 2.387243735408723\n",
  491. "loss: 2.387242025422281\n",
  492. "loss: 2.3872406504464334\n",
  493. "loss: 2.3872386748070245\n",
  494. "loss: 2.3872386165309805\n",
  495. "loss: 2.387237233810461\n",
  496. "loss: 2.3872360479643593\n",
  497. "loss: 2.38723476536898\n",
  498. "loss: 2.387232724485056\n",
  499. "loss: 2.3872322755810025\n",
  500. "loss: 2.3872308935896562\n",
  501. "loss: 2.387230374354215\n",
  502. "loss: 2.3872289131536952\n",
  503. "loss: 2.387227666907803\n",
  504. "loss: 2.387227223159598\n",
  505. "loss: 2.3872263094643955\n",
  506. "loss: 2.3872251968418357\n",
  507. "loss: 2.387224775014026\n",
  508. "loss: 2.387224547860667\n",
  509. "loss: 2.3872230838809783\n",
  510. "loss: 2.3872229638510785\n",
  511. "loss: 2.387221601841123\n",
  512. "loss: 2.3872221053237612\n",
  513. "loss: 2.3872210930668896\n",
  514. "loss: 2.3872202490512318\n",
  515. "loss: 2.3872195705964097\n",
  516. "loss: 2.387219054999678\n",
  517. "loss: 2.387218694718085\n",
  518. "loss: 2.3872184895236703\n",
  519. "loss: 2.3872184361150453\n",
  520. "loss: 2.3872168420512097\n",
  521. "loss: 2.38721707828012\n",
  522. "loss: 2.3872157746881033\n",
  523. "loss: 2.3872162871830027\n",
  524. "loss: 2.3872152565137865\n",
  525. "loss: 2.3872156550203423\n",
  526. "loss: 2.387215262058789\n",
  527. "loss: 2.3872146171203603\n",
  528. "loss: 2.3872140924465817\n",
  529. "loss: 2.387213692303668\n",
  530. "loss: 2.3872134068601816\n",
  531. "loss: 2.3872132382331945\n",
  532. "loss: 2.3872131800547116\n",
  533. "loss: 2.3872132336735374\n",
  534. "loss: 2.3872122046766457\n",
  535. "loss: 2.387211985620091\n",
  536. "loss: 2.3872123543033004\n",
  537. "loss: 2.3872111512061682\n",
  538. "loss: 2.3872117185255872\n",
  539. "loss: 2.387210711934733\n",
  540. "loss: 2.3872114695614397\n",
  541. "loss: 2.3872106483724744\n",
  542. "loss: 2.3872106246269746\n",
  543. "loss: 2.387210949823621\n",
  544. "loss: 2.3872103921805325\n",
  545. "loss: 2.387209923070608\n",
  546. "loss: 2.387209539035157\n",
  547. "loss: 2.387209233578181\n",
  548. "loss: 2.387209278909502\n",
  549. "loss: 2.3872091575872285\n",
  550. "loss: 2.3872087902003143\n",
  551. "loss: 2.387208796135135\n",
  552. "loss: 2.387208863491243\n",
  553. "loss: 2.3872090209788177\n",
  554. "loss: 2.3872092395967757\n",
  555. "loss: 2.387209509667974\n",
  556. "loss: 2.3872081941411776\n",
  557. "loss: 2.387208614493238\n",
  558. "loss: 2.387209092011758\n",
  559. "loss: 2.3872079707116916\n",
  560. "loss: 2.387208566693145\n",
  561. "loss: 2.3872075814397924\n",
  562. "loss: 2.3872082912786885\n",
  563. "loss: 2.3872074055895007\n",
  564. "loss: 2.3872082394740444\n",
  565. "loss: 2.3872074711869145\n",
  566. "loss: 2.38720841308389\n",
  567. "loss: 2.3872077482796317\n",
  568. "loss: 2.3872071369801997\n",
  569. "loss: 2.387208234823193\n",
  570. "loss: 2.3872077218609924\n",
  571. "loss: 2.3872072579088783\n",
  572. "loss: 2.387206844007732\n",
  573. "loss: 2.3872079496309033\n",
  574. "loss: 2.3872078102348167\n",
  575. "loss: 2.387207529575949\n",
  576. "loss: 2.387207296171491\n",
  577. "loss: 2.3872071039581373\n",
  578. "loss: 2.38720695276087\n",
  579. "loss: 2.3872068462090654\n",
  580. "loss: 2.387206777400199\n",
  581. "loss: 2.3872067479488583\n",
  582. "loss: 2.3872067565550017\n",
  583. "loss: 2.3872068036021385\n",
  584. "loss: 2.3872068874369434\n",
  585. "loss: 2.3872070087213757\n",
  586. "loss: 2.3872071655473297\n",
  587. "loss: 2.3872073550964092\n",
  588. "loss: 2.387207093581078\n",
  589. "loss: 2.3872061834452403\n",
  590. "loss: 2.387206473681233\n",
  591. "loss: 2.3872067944256044\n",
  592. "loss: 2.387207150189065\n",
  593. "loss: 2.3872065584074114\n",
  594. "loss: 2.3872062931664226\n",
  595. "loss: 2.3872067377036283\n",
  596. "loss: 2.387207211254447\n",
  597. "loss: 2.3872060575595855\n",
  598. "loss: 2.3872065866005077\n",
  599. "loss: 2.387207143362015\n",
  600. "loss: 2.387206071870889\n",
  601. "loss: 2.3872066797178375\n",
  602. "loss: 2.3872065479626627\n",
  603. "loss: 2.387206319019361\n",
  604. "loss: 2.3872070019550087\n",
  605. "loss: 2.3872060569372797\n",
  606. "loss: 2.38720678668193\n",
  607. "loss: 2.387205887930111\n",
  608. "loss: 2.387206664076764\n",
  609. "loss: 2.3872058107150407\n",
  610. "loss: 2.3872066293892216\n",
  611. "loss: 2.387205818746259\n",
  612. "loss: 2.3872066801575853\n",
  613. "loss: 2.3872059111197688\n",
  614. "loss: 2.3872068126406196\n",
  615. "loss: 2.387206081035995\n",
  616. "loss: 2.3872068669685325\n",
  617. "loss: 2.3872063280795652\n",
  618. "loss: 2.3872056557817367\n",
  619. "loss: 2.3872066491216897\n",
  620. "loss: 2.387206011860928\n",
  621. "loss: 2.387206535957502\n",
  622. "loss: 2.3872064376325017\n",
  623. "loss: 2.3872058508160308\n",
  624. "loss: 2.3872069163305536\n",
  625. "loss: 2.3872063735393425\n",
  626. "loss: 2.387205832345839\n",
  627. "loss: 2.387206677283909\n",
  628. "loss: 2.3872064498889176\n",
  629. "loss: 2.387205955600857\n",
  630. "loss: 2.387205820939399\n",
  631. "loss: 2.387206660636184\n",
  632. "loss: 2.387206206604038\n",
  633. "loss: 2.38720576900394\n",
  634. "loss: 2.3872063117983804\n",
  635. "loss: 2.3872065813147874\n",
  636. "loss: 2.387206181466505\n",
  637. "loss: 2.3872057946984633\n",
  638. "loss: 2.3872058708004715\n",
  639. "loss: 2.387206706121622\n",
  640. "loss: 2.3872063554658762\n",
  641. "loss: 2.3872060160821453\n",
  642. "loss: 2.3872056878815267\n",
  643. "loss: 2.3872059982983496\n",
  644. "loss: 2.387206714970821\n",
  645. "loss: 2.3872064188808673\n",
  646. "loss: 2.3872061334063455\n",
  647. "loss: 2.3872058573779524\n",
  648. "loss: 2.3872055929594205\n",
  649. "loss: 2.3872060525929197\n",
  650. "loss: 2.3872067406100412\n",
  651. "loss: 2.387206502734586\n",
  652. "loss: 2.3872062762370385\n",
  653. "loss: 2.387206057098045\n",
  654. "loss: 2.3872058489669072\n",
  655. "loss: 2.3872056488912214\n",
  656. "loss: 2.3872054563548764\n",
  657. "loss: 2.3872062535038063\n",
  658. "loss: 2.387206745595907\n",
  659. "loss: 2.3872065783271466\n",
  660. "loss: 2.38720641696552\n",
  661. "loss: 2.387206265759137\n",
  662. "loss: 2.387206119480711\n",
  663. "loss: 2.3872059828172514\n",
  664. "loss: 2.387205852070047\n",
  665. "loss: 2.387205730140802\n",
  666. "loss: 2.387205613674655\n",
  667. "loss: 2.3872055054664263\n",
  668. "loss: 2.3872055703172705\n",
  669. "loss: 2.387206001531239\n",
  670. "loss: 2.3872064059032465\n",
  671. "loss: 2.3872067748722134\n",
  672. "loss: 2.387206704414516\n",
  673. "loss: 2.387206632138044\n",
  674. "loss: 2.3872065672351592\n",
  675. "loss: 2.3872065077955\n",
  676. "loss: 2.387206453507227\n",
  677. "loss: 2.387206405811236\n",
  678. "loss: 2.3872063624448145\n",
  679. "loss: 2.387206324770439\n",
  680. "loss: 2.387206293044411\n",
  681. "loss: 2.387206266172666\n",
  682. "loss: 2.387206246548697\n",
  683. "loss: 2.3872062269091705\n",
  684. "loss: 2.3872062166383956\n",
  685. "loss: 2.3872062117393194\n",
  686. "loss: 2.3872062110609655\n",
  687. "loss: 2.3872062139535695\n",
  688. "loss: 2.387206219909649\n",
  689. "loss: 2.3872062285052236\n",
  690. "loss: 2.387206240576973\n",
  691. "loss: 2.3872062606586537\n",
  692. "loss: 2.3872062832874716\n",
  693. "loss: 2.3872063082635315\n",
  694. "loss: 2.3872063391881677\n",
  695. "loss: 2.387206373803312\n",
  696. "loss: 2.3872064127175974\n",
  697. "loss: 2.3872064540621514\n",
  698. "loss: 2.3872064987956327\n",
  699. "loss: 2.3872065498963284\n",
  700. "loss: 2.3872066013767226\n",
  701. "loss: 2.387206657950146\n",
  702. "loss: 2.3872067182049292\n",
  703. "loss: 2.387206672867762\n",
  704. "loss: 2.3872063672970487\n",
  705. "loss: 2.387206055011894\n",
  706. "loss: 2.3872057237905118\n",
  707. "loss: 2.3872054165649037\n",
  708. "loss: 2.3872054954444497\n",
  709. "loss: 2.387205575891558\n",
  710. "loss: 2.3872056608606287\n",
  711. "loss: 2.3872057469958534\n",
  712. "loss: 2.387205835730483\n",
  713. "loss: 2.3872059291294483\n",
  714. "loss: 2.387206023172342\n",
  715. "loss: 2.387206121095825\n",
  716. "loss: 2.3872062219751333\n",
  717. "loss: 2.3872063233856755\n",
  718. "loss: 2.387206428473149\n",
  719. "loss: 2.3872065378135736\n",
  720. "loss: 2.387206647283522\n",
  721. "loss: 2.387206749735465\n",
  722. "loss: 2.3872062240356984\n",
  723. "loss: 2.38720569664805\n",
  724. "loss: 2.3872054625336876\n",
  725. "loss: 2.387205583157073\n",
  726. "loss: 2.3872057046336925\n",
  727. "loss: 2.387205829508387\n",
  728. "loss: 2.387205957406746\n",
  729. "loss: 2.387206085740676\n",
  730. "loss: 2.3872062170332273\n",
  731. "loss: 2.387206349757113\n",
  732. "loss: 2.387206484077489\n",
  733. "loss: 2.387206620076463\n",
  734. "loss: 2.387206742599622\n",
  735. "loss: 2.38720610614635\n",
  736. "loss: 2.3872054598822556\n",
  737. "loss: 2.3872055360296702\n",
  738. "loss: 2.3872056808451143\n",
  739. "loss: 2.3872058264733313\n",
  740. "loss: 2.3872059747774506\n",
  741. "loss: 2.387206123890669\n",
  742. "loss: 2.3872062756016943\n",
  743. "loss: 2.387206426949868\n",
  744. "loss: 2.3872065819415873\n",
  745. "loss: 2.387206737478283\n",
  746. "loss: 2.387206126865435\n",
  747. "loss: 2.3872054054434018\n",
  748. "loss: 2.3872055638677026\n",
  749. "loss: 2.387205724088219\n",
  750. "loss: 2.38720588669721\n",
  751. "loss: 2.387206049768166\n",
  752. "loss: 2.3872062148316706\n",
  753. "loss: 2.3872063803883172\n",
  754. "loss: 2.3872065476541806\n",
  755. "loss: 2.3872067152675234\n",
  756. "loss: 2.3872061646400056\n",
  757. "loss: 2.3872054078993044\n",
  758. "loss: 2.3872055802952357\n",
  759. "loss: 2.3872057516280947\n",
  760. "loss: 2.3872059245622954\n",
  761. "loss: 2.3872060978440746\n",
  762. "loss: 2.387206273640891\n",
  763. "loss: 2.387206448981906\n",
  764. "loss: 2.3872066283894027\n",
  765. "loss: 2.387206516906132\n",
  766. "loss: 2.387205699504398\n",
  767. "loss: 2.3872055196151685\n",
  768. "loss: 2.3872056996032054\n",
  769. "loss: 2.3872058817924033\n",
  770. "loss: 2.387206063764966\n",
  771. "loss: 2.387206248517262\n",
  772. "loss: 2.387206432649466\n",
  773. "loss: 2.3872066179179283\n",
  774. "loss: 2.387206522287034\n",
  775. "loss: 2.3872056702282123\n",
  776. "loss: 2.3872055324647627\n",
  777. "loss: 2.38720572034085\n",
  778. "loss: 2.3872059096035927\n",
  779. "loss: 2.3872060985335164\n",
  780. "loss: 2.387206289705417\n",
  781. "loss: 2.387206481657449\n",
  782. "loss: 2.3872066741473743\n",
  783. "loss: 2.387206243340186\n",
  784. "loss: 2.387205412579381\n",
  785. "loss: 2.387205607172106\n",
  786. "loss: 2.3872058001813334\n",
  787. "loss: 2.3872059950155946\n",
  788. "loss: 2.3872061909696214\n",
  789. "loss: 2.387206387319673\n",
  790. "loss: 2.3872065845641055\n",
  791. "loss: 2.387206628280637\n",
  792. "loss: 2.3872057223715246\n",
  793. "loss: 2.3872055317401317\n",
  794. "loss: 2.387205729959452\n",
  795. "loss: 2.3872059297733506\n",
  796. "loss: 2.3872061287534345\n",
  797. "loss: 2.3872063300590187\n",
  798. "loss: 2.3872065319453286\n",
  799. "loss: 2.3872067329639384\n",
  800. "loss: 2.3872059293532697\n",
  801. "loss: 2.3872054909066573\n",
  802. "loss: 2.3872056929994176\n",
  803. "loss: 2.3872058959012414\n",
  804. "loss: 2.3872060991225794\n",
  805. "loss: 2.387206303755043\n",
  806. "loss: 2.3872065090776227\n",
  807. "loss: 2.387206713881384\n",
  808. "loss: 2.3872060034612277\n",
  809. "loss: 2.387205479096127\n",
  810. "loss: 2.3872056842701372\n",
  811. "loss: 2.3872058905037483\n",
  812. "loss: 2.3872060971157936\n",
  813. "loss: 2.3872063046852645\n",
  814. "loss: 2.3872065130746725\n",
  815. "loss: 2.387206720682877\n",
  816. "loss: 2.3872059543810957\n",
  817. "loss: 2.3872054920341013\n",
  818. "loss: 2.387205699860994\n",
  819. "loss: 2.387205909361374\n",
  820. "loss: 2.387206118592575\n",
  821. "loss: 2.387206328727607\n",
  822. "loss: 2.38720654022318\n",
  823. "loss: 2.3872067496108995\n",
  824. "loss: 2.38720580949524\n",
  825. "loss: 2.387205525406542\n",
  826. "loss: 2.3872057364936534\n",
  827. "loss: 2.3872059484123977\n",
  828. "loss: 2.387206158732515\n",
  829. "loss: 2.3872063716998198\n",
  830. "loss: 2.387206584128892\n",
  831. "loss: 2.3872065550645125\n",
  832. "loss: 2.387205585137984\n",
  833. "loss: 2.387205576467598\n",
  834. "loss: 2.3872057895930716\n",
  835. "loss: 2.3872060024345876\n",
  836. "loss: 2.387206217104482\n",
  837. "loss: 2.38720643053896\n",
  838. "loss: 2.3872066462470087\n",
  839. "loss: 2.3872062650681847\n",
  840. "loss: 2.3872054287829734\n",
  841. "loss: 2.387205644040444\n",
  842. "loss: 2.387205857402579\n",
  843. "loss: 2.3872060732060456\n",
  844. "loss: 2.3872062890605856\n",
  845. "loss: 2.3872065058667893\n",
  846. "loss: 2.387206721085871\n",
  847. "loss: 2.3872059173144433\n",
  848. "loss: 2.3872055071105676\n",
  849. "loss: 2.387205722800243\n",
  850. "loss: 2.387205939850607\n",
  851. "loss: 2.387206155334122\n",
  852. "loss: 2.3872063740586773\n",
  853. "loss: 2.38720659063653\n",
  854. "loss: 2.3872065025409945\n",
  855. "loss: 2.3872055102865852\n",
  856. "loss: 2.387205597163091\n",
  857. "loss: 2.3872058144515096\n",
  858. "loss: 2.3872060323445647\n",
  859. "loss: 2.38720625026218\n",
  860. "loss: 2.387206468189427\n",
  861. "loss: 2.3872066869296926\n",
  862. "loss: 2.387206064964853\n",
  863. "loss: 2.38720547798433\n",
  864. "loss: 2.387205695647846\n",
  865. "loss: 2.3872059148036042\n",
  866. "loss: 2.387206132165185\n",
  867. "loss: 2.3872063520034\n",
  868. "loss: 2.387206571675985\n",
  869. "loss: 2.387206580623468\n",
  870. "loss: 2.387205582588098\n",
  871. "loss: 2.3872055839413213\n",
  872. "loss: 2.3872058020374514\n",
  873. "loss: 2.3872060214815694\n",
  874. "loss: 2.38720624240693\n",
  875. "loss: 2.387206462370405\n",
  876. "loss: 2.387206683255392\n",
  877. "loss: 2.3872060762465175\n",
  878. "loss: 2.3872054775249913\n",
  879. "loss: 2.387205696602586\n",
  880. "loss: 2.3872059172419493\n",
  881. "loss: 2.3872061373373787\n",
  882. "loss: 2.387206357728515\n",
  883. "loss: 2.387206579560316\n",
  884. "loss: 2.387206536499058\n",
  885. "loss: 2.387205534906743\n",
  886. "loss: 2.387205596098073\n",
  887. "loss: 2.387205817218552\n",
  888. "loss: 2.387206037559115\n",
  889. "loss: 2.3872062589616636\n",
  890. "loss: 2.3872064810746823\n",
  891. "loss: 2.387206702376931\n",
  892. "loss: 2.3872059789399236\n",
  893. "loss: 2.3872054989934695\n",
  894. "loss: 2.3872057193198266\n",
  895. "loss: 2.3872059415304925\n",
  896. "loss: 2.387206163057817\n",
  897. "loss: 2.387206385045773\n",
  898. "loss: 2.3872066064853166\n",
  899. "loss: 2.38720640924132\n",
  900. "loss: 2.3872054050281237\n",
  901. "loss: 2.3872056266849664\n",
  902. "loss: 2.3872058475459403\n",
  903. "loss: 2.387206070207055\n",
  904. "loss: 2.387206292224948\n",
  905. "loss: 2.3872065152244177\n",
  906. "loss: 2.3872067373632193\n",
  907. "loss: 2.387205812795538\n",
  908. "loss: 2.387205535628719\n",
  909. "loss: 2.3872057581142827\n",
  910. "loss: 2.387205979871331\n",
  911. "loss: 2.387206202847354\n",
  912. "loss: 2.38720642492102\n",
  913. "loss: 2.3872066486900834\n",
  914. "loss: 2.3872062143940815\n",
  915. "loss: 2.3872054479581615\n",
  916. "loss: 2.3872056700624618\n",
  917. "loss: 2.387205892745017\n",
  918. "loss: 2.3872061156322637\n",
  919. "loss: 2.3872063392335336\n",
  920. "loss: 2.3872065621484544\n",
  921. "loss: 2.3872066127839506\n",
  922. "loss: 2.387205593010971\n",
  923. "loss: 2.3872055852578473\n",
  924. "loss: 2.3872058079900116\n",
  925. "loss: 2.3872060312398813\n",
  926. "loss: 2.387206254831415\n",
  927. "loss: 2.387206477626406\n",
  928. "loss: 2.3872067013641285\n",
  929. "loss: 2.387205974948894\n",
  930. "loss: 2.3872055015373546\n",
  931. "loss: 2.387205724165402\n",
  932. "loss: 2.387205948609596\n",
  933. "loss: 2.3872061711547934\n",
  934. "loss: 2.3872063942990995\n",
  935. "loss: 2.387206617976577\n",
  936. "loss: 2.3872063484996606\n",
  937. "loss: 2.387205419784946\n",
  938. "loss: 2.387205643696668\n",
  939. "loss: 2.38720586667547\n",
  940. "loss: 2.387206089943084\n",
  941. "loss: 2.387206312830524\n",
  942. "loss: 2.3872065377208713\n",
  943. "loss: 2.3872067192036486\n",
  944. "loss: 2.3872056983399608\n",
  945. "loss: 2.387205562202803\n",
  946. "loss: 2.3872057863378924\n",
  947. "loss: 2.387206009488842\n",
  948. "loss: 2.3872062344077247\n",
  949. "loss: 2.3872064585590174\n",
  950. "loss: 2.3872066828857177\n",
  951. "loss: 2.387206060927091\n",
  952. "loss: 2.387205483917369\n",
  953. "loss: 2.387205706729346\n",
  954. "loss: 2.3872059309762115\n",
  955. "loss: 2.387206154167569\n",
  956. "loss: 2.387206379212923\n",
  957. "loss: 2.387206603498562\n",
  958. "loss: 2.387206412786183\n",
  959. "loss: 2.38720540594447\n",
  960. "loss: 2.387205630194073\n",
  961. "loss: 2.3872058532972202\n",
  962. "loss: 2.387206078179721\n",
  963. "loss: 2.3872063011757954\n",
  964. "loss: 2.3872065261212008\n",
  965. "loss: 2.387206750494888\n",
  966. "loss: 2.38720574408085\n",
  967. "loss: 2.3872055526202605\n",
  968. "loss: 2.387205776875942\n",
  969. "loss: 2.3872060000469304\n",
  970. "loss: 2.3872062251849693\n",
  971. "loss: 2.3872064490769467\n",
  972. "loss: 2.38720667505617\n",
  973. "loss: 2.387206092621105\n",
  974. "loss: 2.3872054778245344\n",
  975. "loss: 2.387205700855323\n",
  976. "loss: 2.387205925499759\n",
  977. "loss: 2.387206149256795\n",
  978. "loss: 2.3872063747392476\n",
  979. "loss: 2.3872065982742674\n",
  980. "loss: 2.387206435850997\n",
  981. "loss: 2.3872054149088293\n",
  982. "loss: 2.3872056262571535\n",
  983. "loss: 2.387205849166078\n",
  984. "loss: 2.3872060742134376\n",
  985. "loss: 2.3872062979103035\n",
  986. "loss: 2.387206523614174\n",
  987. "loss: 2.3872067479852594\n",
  988. "loss: 2.3872057530264144\n",
  989. "loss: 2.387205550630556\n",
  990. "loss: 2.3872057753165112\n",
  991. "loss: 2.387205999046393\n",
  992. "loss: 2.3872062242513152\n",
  993. "loss: 2.387206448562063\n",
  994. "loss: 2.3872066745566825\n",
  995. "loss: 2.3872060933404238\n",
  996. "loss: 2.387205477652032\n",
  997. "loss: 2.38720570096479\n",
  998. "loss: 2.3872059260253913\n",
  999. "loss: 2.387206149856662\n",
  1000. "loss: 2.3872063756569637\n",
  1001. "loss: 2.3872066002677177\n",
  1002. "loss: 2.3872064248665876\n",
  1003. "loss: 2.3872054055763066\n",
  1004. "loss: 2.3872056285152192\n",
  1005. "loss: 2.387205852129036\n",
  1006. "loss: 2.387206076714375\n",
  1007. "loss: 2.3872063008900715\n",
  1008. "loss: 2.3872065262982316\n",
  1009. "loss: 2.3872067510414143\n",
  1010. "loss: 2.387205738066258\n",
  1011. "loss: 2.387205553932228\n",
  1012. "loss: 2.3872057787560985\n",
  1013. "loss: 2.387206002620486\n",
  1014. "loss: 2.3872062278227126\n",
  1015. "loss: 2.387206452219362\n",
  1016. "loss: 2.3872066785656476\n",
  1017. "loss: 2.3872060756651354\n",
  1018. "loss: 2.3872054817901702\n",
  1019. "loss: 2.387205705294623\n",
  1020. "loss: 2.387205930184685\n",
  1021. "loss: 2.3872061542181156\n",
  1022. "loss: 2.387206380162742\n",
  1023. "loss: 2.3872066048043825\n",
  1024. "loss: 2.3872064037912457\n",
  1025. "loss: 2.387205408672593\n",
  1026. "loss: 2.3872056335074885\n",
  1027. "loss: 2.38720585704452\n",
  1028. "loss: 2.3872060825534622\n",
  1029. "loss: 2.3872063070962546\n",
  1030. "loss: 2.3872065331763173\n",
  1031. "loss: 2.387206736126897\n",
  1032. "loss: 2.3872057102550297\n",
  1033. "loss: 2.387205560708403\n",
  1034. "loss: 2.387205785698165\n",
  1035. "loss: 2.3872060098092085\n",
  1036. "loss: 2.3872062357191077\n",
  1037. "loss: 2.3872064606568575\n",
  1038. "loss: 2.3872066856260075\n",
  1039. "loss: 2.3872060412743403\n",
  1040. "loss: 2.3872054896579566\n",
  1041. "loss: 2.387205713503783\n",
  1042. "loss: 2.387205938855604\n",
  1043. "loss: 2.3872061633253967\n",
  1044. "loss: 2.3872063880939898\n",
  1045. "loss: 2.387206613272621\n",
  1046. "loss: 2.3872063644496766\n",
  1047. "loss: 2.387205417523893\n",
  1048. "loss: 2.387205642655334\n",
  1049. "loss: 2.3872058668390554\n",
  1050. "loss: 2.3872060910716635\n",
  1051. "loss: 2.3872063157741197\n",
  1052. "loss: 2.387206542322383\n",
  1053. "loss: 2.387206692745429\n",
  1054. "loss: 2.3872056660868344\n",
  1055. "loss: 2.3872055706534514\n",
  1056. "loss: 2.3872057946635303\n",
  1057. "loss: 2.387206019070375\n",
  1058. "loss: 2.387206245277133\n",
  1059. "loss: 2.387206470316894\n",
  1060. "loss: 2.3872066957452835\n",
  1061. "loss: 2.3872059933352743\n",
  1062. "loss: 2.3872054990132523\n",
  1063. "loss: 2.387205722878048\n",
  1064. "loss: 2.387205948575221\n",
  1065. "loss: 2.387206173353647\n",
  1066. "loss: 2.3872063984591825\n",
  1067. "loss: 2.3872066226559827\n",
  1068. "loss: 2.3872063207497023\n",
  1069. "loss: 2.3872054272954717\n",
  1070. "loss: 2.3872056526531233\n",
  1071. "loss: 2.3872058770925495\n",
  1072. "loss: 2.3872061007027394\n",
  1073. "loss: 2.387206325498709\n",
  1074. "loss: 2.387206552189155\n",
  1075. "loss: 2.3872066470668103\n",
  1076. "loss: 2.387205618564661\n",
  1077. "loss: 2.387205580450609\n",
  1078. "loss: 2.3872058044940334\n",
  1079. "loss: 2.3872060289891905\n",
  1080. "loss: 2.387206255311463\n",
  1081. "loss: 2.387206480708284\n",
  1082. "loss: 2.387206705586487\n",
  1083. "loss: 2.3872059484852777\n",
  1084. "loss: 2.387205508896236\n",
  1085. "loss: 2.387205733053314\n",
  1086. "loss: 2.3872059590012453\n",
  1087. "loss: 2.38720618326989\n",
  1088. "loss: 2.387206408350126\n",
  1089. "loss: 2.3872066339450995\n",
  1090. "loss: 2.3872062743922324\n",
  1091. "loss: 2.3872054376631935\n",
  1092. "loss: 2.3872056632364425\n",
  1093. "loss: 2.3872058871467248\n",
  1094. "loss: 2.387206111855699\n",
  1095. "loss: 2.387206337069455\n",
  1096. "loss: 2.3872065625907197\n",
  1097. "loss: 2.387206598875525\n",
  1098. "loss: 2.3872055726406853\n",
  1099. "loss: 2.3872055914219663\n",
  1100. "loss: 2.3872058158103724\n",
  1101. "loss: 2.3872060406633677\n",
  1102. "loss: 2.3872062658126643\n",
  1103. "loss: 2.3872064913732993\n",
  1104. "loss: 2.3872067166127193\n",
  1105. "loss: 2.38720589721323\n",
  1106. "loss: 2.3872055203865594\n",
  1107. "loss: 2.3872057448636044\n",
  1108. "loss: 2.387205969632241\n",
  1109. "loss: 2.3872061948051737\n",
  1110. "loss: 2.387206419656803\n",
  1111. "loss: 2.3872066454234093\n",
  1112. "loss: 2.38720622109226\n",
  1113. "loss: 2.3872054496058155\n",
  1114. "loss: 2.387205673994574\n",
  1115. "loss: 2.387205898781459\n",
  1116. "loss: 2.3872061232475654\n",
  1117. "loss: 2.3872063486162522\n",
  1118. "loss: 2.3872065743640327\n",
  1119. "loss: 2.38720654403819\n",
  1120. "loss: 2.3872055202172575\n",
  1121. "loss: 2.3872056032796642\n",
  1122. "loss: 2.3872058273645207\n",
  1123. "loss: 2.387206052339985\n",
  1124. "loss: 2.3872062776900504\n",
  1125. "loss: 2.3872065035141774\n",
  1126. "loss: 2.3872067282141374\n",
  1127. "loss: 2.3872058427484215\n",
  1128. "loss: 2.387205531994678\n",
  1129. "loss: 2.387205756582158\n",
  1130. "loss: 2.387205981539894\n",
  1131. "loss: 2.3872062069646356\n",
  1132. "loss: 2.387206431282743\n",
  1133. "loss: 2.387206657412996\n",
  1134. "loss: 2.3872061679858723\n",
  1135. "loss: 2.3872054613328837\n",
  1136. "loss: 2.3872056859037873\n",
  1137. "loss: 2.3872059109348127\n",
  1138. "loss: 2.387206134876376\n",
  1139. "loss: 2.387206360603658\n",
  1140. "loss: 2.3872065860311924\n",
  1141. "loss: 2.3872064904558536\n",
  1142. "loss: 2.3872054658487727\n",
  1143. "loss: 2.3872056154159624\n",
  1144. "loss: 2.387205838986268\n",
  1145. "loss: 2.3872060643164303\n",
  1146. "loss: 2.38720629000679\n",
  1147. "loss: 2.387206515465516\n",
  1148. "loss: 2.3872067402918593\n",
  1149. "loss: 2.3872057868542456\n",
  1150. "loss: 2.3872055436974926\n",
  1151. "loss: 2.387205768622451\n",
  1152. "loss: 2.3872059939050634\n",
  1153. "loss: 2.3872062189566643\n",
  1154. "loss: 2.387206443389059\n",
  1155. "loss: 2.387206669729105\n",
  1156. "loss: 2.387206114421231\n",
  1157. "loss: 2.3872054733876444\n",
  1158. "loss: 2.3872056982757286\n",
  1159. "loss: 2.3872059236497964\n",
  1160. "loss: 2.3872061471739836\n",
  1161. "loss: 2.38720637304622\n",
  1162. "loss: 2.387206597910911\n",
  1163. "loss: 2.3872064349529696\n",
  1164. "loss: 2.3872054090077848\n",
  1165. "loss: 2.3872056281842866\n",
  1166. "loss: 2.387205851329026\n",
  1167. "loss: 2.387206076790427\n",
  1168. "loss: 2.387206301261524\n",
  1169. "loss: 2.387206527494985\n",
  1170. "loss: 2.3872067526753113\n",
  1171. "loss: 2.387205728959938\n"
  1172. ]
  1173. }
  1174. ],
  1175. "source": [
  1176. "y_train = y_train.reshape(-1, 1)\n",
  1177. "\n",
  1178. "nn = NeuralNetwork(X_train, y_train)\n",
  1179. "\n",
  1180. "for i in range(1000000):\n",
  1181. " nn.train(X_train, y_train)\n"
  1182. ]
  1183. },
  1184. {
  1185. "cell_type": "markdown",
  1186. "metadata": {},
  1187. "source": [
  1188. "## Plot the loss vs. the number of epochs"
  1189. ]
  1190. },
  1191. {
  1192. "cell_type": "code",
  1193. "execution_count": 83,
  1194. "metadata": {},
  1195. "outputs": [
  1196. {
  1197. "data": {
  1198. "text/plain": [
  1199. "Text(0, 0.5, 'loss')"
  1200. ]
  1201. },
  1202. "execution_count": 83,
  1203. "metadata": {},
  1204. "output_type": "execute_result"
  1205. },
  1206. {
  1207. "data": {
  1208. "image/png": "iVBORw0KGgoAAAANSUhEUgAAAjIAAAGwCAYAAACzXI8XAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy89olMNAAAACXBIWXMAAA9hAAAPYQGoP6dpAAA3LklEQVR4nO3de3RU5b3/8c/kMpOEZCYkkISYBBEs91CLiBHlaEEQEG+0tspp8RwvCxsvgD+xtGq1rQY9p1Y9pbTHC+iqyCkVsN6gCBKqBpRL5CKNBqKgkKBAMiGQ6zy/P8IMjFydzJ6dCe/XWnuR7L0z852HBfms5/nuvR3GGCMAAIAoFGN3AQAAAKEiyAAAgKhFkAEAAFGLIAMAAKIWQQYAAEQtggwAAIhaBBkAABC14uwuwGo+n0+7du1SSkqKHA6H3eUAAIDTYIxRbW2tsrOzFRNz4nmXDh9kdu3apdzcXLvLAAAAIdi5c6dycnJOeLzDB5mUlBRJrQPhdrttrgYAAJwOr9er3NzcwO/xE+nwQca/nOR2uwkyAABEmVO1hdDsCwAAohZBBgAARC2CDAAAiFoEGQAAELUIMgAAIGoRZAAAQNQiyAAAgKhla5CZPXu28vPzA/d4KSgo0FtvvRU4fumll8rhcARtkydPtrFiAADQnth6Q7ycnBzNnDlT5557rowxeuGFF3T11Vdrw4YN6t+/vyTp1ltv1a9//evAzyQlJdlVLgAAaGdsDTLjx48P+v6RRx7R7NmztXr16kCQSUpKUlZWlh3lAQCAdq7d9Mi0tLRo/vz5qqurU0FBQWD/Sy+9pC5dumjAgAGaMWOGDh48eNLXaWhokNfrDdoAAEDHZPuzljZt2qSCggLV19crOTlZixYtUr9+/SRJN954o7p3767s7Gxt3LhR9913n8rKyrRw4cITvl5RUZEefvjhSJUPAABs5DDGGDsLaGxs1I4dO1RTU6O//e1vevbZZ1VcXBwIM0dbsWKFRowYofLycvXs2fO4r9fQ0KCGhobA9/6nZ9bU1IT1oZHVBxtVW98sd2K8PInxYXtdAADQ+vvb4/Gc8ve37UtLTqdTvXr10uDBg1VUVKRBgwbpqaeeOu65Q4cOlSSVl5ef8PVcLlfgKigrn3j92JJ/6ZLH39GL739myesDAIBTsz3IfJPP5wuaUTlaaWmpJKlbt24RrOj4YmNaHyveYu+EFgAAZzRbe2RmzJihMWPGKC8vT7W1tZo3b55WrlyppUuXatu2bZo3b57Gjh2r9PR0bdy4UVOnTtXw4cOVn59vZ9mSpFjH4SDjI8gAAGAXW4PMnj179NOf/lS7d++Wx+NRfn6+li5dqssvv1w7d+7U22+/rSeffFJ1dXXKzc3VhAkTdP/999tZckBsTOtkFkEGAAD72BpknnvuuRMey83NVXFxcQSr+XZiDy/KEWQAALBPu+uRiRYxMSwtAQBgN4JMiOJo9gUAwHYEmRDR7AsAgP0IMiGi2RcAAPsRZEJEsy8AAPYjyISIZl8AAOxHkAlRHEEGAADbEWRCFOPgqiUAAOxGkAkRMzIAANiPIBOiWIIMAAC2I8iEiMuvAQCwH0EmRFx+DQCA/QgyIaLZFwAA+xFkQhQXS48MAAB2I8iEKIZnLQEAYDuCTIjiaPYFAMB2BJkQ0ewLAID9CDIhotkXAAD7EWRCRLMvAAD2I8iEiGZfAADsR5AJEc2+AADYjyATohiafQEAsB1BJkSxNPsCAGA7gkyIaPYFAMB+BJkQ0ewLAID9CDIh8jf7+ggyAADYhiATIn+zbzNBBgAA2xBkQhSYkaHZFwAA2xBkQhTLjAwAALYjyIQo1n9DvBaCDAAAdiHIhCgupvWqpSafz+ZKAAA4cxFkQuSKax26xmaCDAAAdiHIhMgZ52/2lZpbCDMAANiBIBOi+NgjQ9dEnwwAALYgyITIPyMjsbwEAIBdCDIhiotx6PBTCtTQ0mJvMQAAnKEIMiFyOByB5SWWlgAAsAdBpg1csVy5BACAnQgybeDkEmwAAGxFkGmDI0tLBBkAAOxAkGkD/4xMAzMyAADYwtYgM3v2bOXn58vtdsvtdqugoEBvvfVW4Hh9fb0KCwuVnp6u5ORkTZgwQVVVVTZWHIylJQAA7GVrkMnJydHMmTO1bt06rV27Vt///vd19dVXa8uWLZKkqVOn6rXXXtOCBQtUXFysXbt26brrrrOz5CD+paVGlpYAALBFnJ1vPn78+KDvH3nkEc2ePVurV69WTk6OnnvuOc2bN0/f//73JUlz5sxR3759tXr1al144YV2lBzEPyPTxIwMAAC2aDc9Mi0tLZo/f77q6upUUFCgdevWqampSSNHjgyc06dPH+Xl5amkpOSEr9PQ0CCv1xu0WcXFjAwAALayPchs2rRJycnJcrlcmjx5shYtWqR+/fqpsrJSTqdTqampQednZmaqsrLyhK9XVFQkj8cT2HJzcy2rPT6u9da+9MgAAGAP24NM7969VVpaqjVr1uj222/XpEmT9PHHH4f8ejNmzFBNTU1g27lzZxirDebkhngAANjK1h4ZSXI6nerVq5ckafDgwfrwww/11FNP6Uc/+pEaGxtVXV0dNCtTVVWlrKysE76ey+WSy+WyumxJUkJ8rCSpvplnLQEAYAfbZ2S+yefzqaGhQYMHD1Z8fLyWL18eOFZWVqYdO3aooKDAxgqPSHS2BpmDjQQZAADsYOuMzIwZMzRmzBjl5eWptrZW8+bN08qVK7V06VJ5PB7dfPPNmjZtmtLS0uR2u3XnnXeqoKCgXVyxJElJBBkAAGxla5DZs2ePfvrTn2r37t3yeDzKz8/X0qVLdfnll0uSfv/73ysmJkYTJkxQQ0ODRo8erT/+8Y92lhwkydk6fIcam22uBACAM5OtQea555476fGEhATNmjVLs2bNilBF305iPDMyAADYqd31yEQT/9LSIYIMAAC2IMi0AT0yAADYiyDTBomHe2QONhFkAACwA0GmDY4sLdHsCwCAHQgybcB9ZAAAsBdBpg06+ZeWCDIAANiCINMGKQmtQaa2vsnmSgAAODMRZNrAH2S8h+iRAQDADgSZNnAnxkuSGlt8qufKJQAAIo4g0wbJzjg5HK1fe1leAgAg4ggybRAT41Cyy98nw/ISAACRRpBpI3dC6/KS9xAzMgAARBpBpo2OXLnEjAwAAJFGkGkj/4wMQQYAgMgjyLRR4BJsmn0BAIg4gkwb+S/B5qZ4AABEHkGmjbgpHgAA9iHItNGRHhlmZAAAiDSCTBsd6ZFhRgYAgEgjyLQRPTIAANiHINNG/qWlGm6IBwBAxBFk2ig1qTXI7D9IkAEAINIIMm3UOckpSao+2GhzJQAAnHkIMm3UuVPrjEz1wSYZY2yuBgCAMwtBpo38MzLNPqPaBq5cAgAgkggybZQQH6uE+NZh3F/H8hIAAJFEkAmDtMOzMjT8AgAQWQSZMEgNBBlmZAAAiCSCTBgcafglyAAAEEkEmTAIzMjUsbQEAEAkEWTCII2lJQAAbEGQCYPOgbv7EmQAAIgkgkwYpHLVEgAAtiDIhAHNvgAA2IMgEwY0+wIAYA+CTBjQ7AsAgD0IMmHQmSADAIAtCDJhkHq4R6a+yaf6phabqwEA4MxBkAmDFFec4mIckpiVAQAgkggyYeBwOGj4BQDABgSZMPHfFI9LsAEAiBxbg0xRUZGGDBmilJQUZWRk6JprrlFZWVnQOZdeeqkcDkfQNnnyZJsqPjF/w+8+ggwAABFja5ApLi5WYWGhVq9erWXLlqmpqUmjRo1SXV1d0Hm33nqrdu/eHdgef/xxmyo+sdTAYwpYWgIAIFLi7HzzJUuWBH0/d+5cZWRkaN26dRo+fHhgf1JSkrKysiJd3reS1ql1Rqa6jhkZAAAipV31yNTU1EiS0tLSgva/9NJL6tKliwYMGKAZM2bo4MGDJ3yNhoYGeb3eoC0SeN4SAACRZ+uMzNF8Pp+mTJmiYcOGacCAAYH9N954o7p3767s7Gxt3LhR9913n8rKyrRw4cLjvk5RUZEefvjhSJUdkJLQOpS19QQZAAAipd0EmcLCQm3evFnvvvt
  1209. "text/plain": [
  1210. "<Figure size 640x480 with 1 Axes>"
  1211. ]
  1212. },
  1213. "metadata": {},
  1214. "output_type": "display_data"
  1215. }
  1216. ],
  1217. "source": [
  1218. "import matplotlib.pyplot as plt\n",
  1219. "plt.plot(nn.loss_history)\n",
  1220. "plt.xlabel(\"# epochs / 1000\")\n",
  1221. "plt.ylabel(\"loss\")"
  1222. ]
  1223. },
  1224. {
  1225. "cell_type": "markdown",
  1226. "metadata": {},
  1227. "source": []
  1228. },
  1229. {
  1230. "cell_type": "code",
  1231. "execution_count": 84,
  1232. "metadata": {},
  1233. "outputs": [
  1234. {
  1235. "data": {
  1236. "text/plain": [
  1237. "<matplotlib.colorbar.Colorbar at 0x147ecf700>"
  1238. ]
  1239. },
  1240. "execution_count": 84,
  1241. "metadata": {},
  1242. "output_type": "execute_result"
  1243. },
  1244. {
  1245. "data": {
  1246. "image/png": "iVBORw0KGgoAAAANSUhEUgAAAsIAAAJMCAYAAADwqMBxAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjYuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy89olMNAAAACXBIWXMAAA9hAAAPYQGoP6dpAAEAAElEQVR4nOzdd3gU1f7H8ffMpvce0kkIhN47gvQiIip2EezXdq+K/lRs2L3X3nsBlGJBmiBdpEjvPUBCCiQhvbfdmd8fCzEhG0jZNPJ9PU8eyJzZM2dDSD579jvnKLqu6wghhBBCCNHCqI09ACGEEEIIIRqDBGEhhBBCCNEiSRAWQgghhBAtkgRhIYQQQgjRIkkQFkIIIYQQLZIEYSGEEEII0SJJEBZCCCGEEC2SBGEhhBBCCNEiSRAWQgghhBAtkgRhIYQQQgjRIkkQFkIIIYQQjWrDhg1MmDCBwMBAFEVh0aJFl3xMcXExzz33HGFhYdjb29O6dWu+++67Gl3XppbjFUIIIYQQwiry8/Pp1q0bd999N9dff321HnPTTTeRkpLCt99+S2RkJElJSWiaVqPrShAWQgghhBCNaty4cYwbN67a569YsYK//vqLmJgYvLy8AGjdunWNr3vZBWFN0zhz5gyurq4oitLYwxFCCCFEC6brOrm5uQQGBqKqjV+RWlRURElJSYNcS9f1SlnM3t4ee3v7Ove9ZMkSevfuzVtvvcUPP/yAs7Mz11xzDa+++iqOjo7V7ueyC8JnzpwhJCSksYchhBBCCFEmISGB4ODgRh1DUVERgUGtycxIaZDrubi4kJeXV+HYjBkzeOmll+rcd0xMDJs2bcLBwYGFCxeSlpbGQw89RHp6Ot9//321+7nsgrCrqysAe7/9Glcnp0YejRBCCCFastyCArrfc19ZPmlMJSUlZGakMPuXAzg51+94CvJzmXJjFxISEnBzcys7bo3ZYDBXACiKwpw5c3B3dwfgvffe44YbbuCzzz6r9qzwZReEz0/Buzo5SRAWQgghRJPQlMo1nZxdcXZ2u/SJVuDm5lYhCFtLQEAAQUFBZSEYoEOHDui6TmJiIm3btq1WP41frCKEEEIIIUQNDBo0iDNnzlQovYiOjkZV1RqVoEgQFkIIIYQQjSovL4+9e/eyd+9eAGJjY9m7dy/x8fEATJ8+nSlTppSdf9ttt+Ht7c1dd93F4cOH2bBhA//3f//H3XffXaOb5SQICyGEEEKIRrVz50569OhBjx49AJg2bRo9evTgxRdfBCApKaksFIP5RrzVq1eTlZVF7969uf3225kwYQIfffRRja572dUICyGEEEKI5mXo0KHoul5l+8yZMysda9++PatXr67TdWVGWAghhBBCtEgShIUQQgghRIskQVgIIYQQQrRIEoSFEEIIIUSLJEFYCCGEEEK0SBKEhRBCCCFEiyRBWAghhBBCtEgShIUQQgghRIskQVgIIYQQQrRIEoSFEEIIIUSLJEFYCCGEEEK0SBKEhRBCCCFEiyRBWAghhBBCtEgShIUQQgghRIskQVgIIYQQQrRIEoSFEEIIIUSLJEFYCCGEEEK0SBKEhRBCCCFEiyRBWAghhBBCtEj1GoQ3bNjAhAkTCAwMRFEUFi1adNHz169fj6IolT6Sk5Prc5hCCCGEEKIFqtcgnJ+fT7du3fj0009r9Lhjx46RlJRU9uHn51dPIxRCCCGEEC2VTX12Pm7cOMaNG1fjx/n5+eHh4WH9AQkhhBBCCHFOk6wR7t69OwEBAYwaNYrNmzdf9Nzi4mJycnIqfAghhBBCCHEpTSoIBwQE8MUXX7BgwQIWLFhASEgIQ4cOZffu3VU+5s0338Td3b3sIyQkpAFHLIQQQgghmqt6LY2oqaioKKKioso+HzhwICdPnuT999/nhx9+sPiY6dOnM23atLLPc3JyJAwLIYQQQohLalJB2JK+ffuyadOmKtvt7e2xt7dvwBEJIYQQQojLQZMqjbBk7969BAQENPYwhBBCCCHEZaZeZ4Tz8vI4ceJE2eexsbHs3bsXLy8vQkNDmT59OqdPn2b27NkAfPDBB4SHh9OpUyeKior45ptvWLduHatWrarPYQohhBBCiBaoXoPwzp07GTZsWNnn52t5p06dysyZM0lKSiI+Pr6svaSkhCeeeILTp0/j5ORE165dWbNmTYU+hBBCCCGEsIZ6DcJDhw5F1/Uq22fOnFnh86eeeoqnnnqqPockhBBCCCEE0AxqhIUQQgghhKgPEoSFEEIIIUSLJEFYCCGEEEK0SBKEhRBCCCFEiyRBWAghhBBCtEgShIUQQgghRIskQVgIIYQQQrRIEoSFEEIIIUSLJEFYCCGEEEK0SBKEhRBCCCFEiyRBWAghhBBCtEgShIUQQgghRIskQVgIIYQQQrRIEoSFEEIIIUSLJEFYCCGEEEK0SBKEhRBCCCFEiyRBWAghhBBCtEgShIUQQgghRIskQVgIIYQQQrRIEoSFEEIIIUSLJEFYCCGEEEK0SBKEhRBCCCFEiyRBWAghhBBCtEgShIUQQgghRIskQVgIIYQQQrRIEoSFEEIIIUSLJEFYCCGEEEI0qg0bNjBhwgQCAwNRFIVFixZV+7GbN2/GxsaG7t271/i6EoSFEEIIIUSjys/Pp1u3bnz66ac1elxWVhZTpkxhxIgRtbquTa0eJYQQQgghhJWMGzeOcePG1fhxDzzwALfddhsGg6FGs8jnyYywEEIIIYRodr7//ntiYmKYMWNGrfuQGWEhhBBCCFEvcnJyKnxub2+Pvb19nfs9fvw4zzzzDBs3bsTGpvZxVoKwEEIIIUQL4unjirOLa71ew85RByAkJKTC8RkzZvDSSy/VqW+TycRtt93Gyy+/TLt27erUlwRhIYQQQghRLxISEnBzcyv73Bqzwbm5uezcuZM9e/bwyCOPAKBpGrquY2Njw6pVqxg+fHi1+pIgLIQQQggh6oWbm1uFIGytPg8cOFDh2Geffca6dev49ddfCQ8Pr3ZfEoSFEEIIIUSjysvL48SJE2Wfx8bGsnfvXry8vAgNDWX69OmcPn2a2bNno6oqnTt3rvB4Pz8/HBwcKh2/FAnCQgghhBCiUe3cuZNhw4aVfT5t2jQApk6dysyZM0lKSiI+Pt7q11V0Xdet3msjysnJwd3dnZPz5uDq5NTYwxFCCCFEC5ZbUECbW28nOzvb6iUCNXU+I63dno6zS/2OJT8vhxF9vZvE874YWUdYCCGEEEK0SBKEhRBCCCFEiyRBWAghhBBCtEgShIUQQgghRIskq0YIIcQlaJpGXEoKcSkpGEtKcXV1oV1wMJ6u9bszkxBCiPolQVgIIS6ioLiYtVu3kh0bS6uSUpxVA0kmE8fd3ejUqyc92rZt7CEKIYSoJQnCQghxEet37kQ7Gs11/n54OzoCoOk6B9PT2b51K66OTkQGBzXyKIUQQtSG1AgLIUQVUjIzyYg9xRU+3mUhGEBVFLr6+BBeXMKh48cbcYRCCCHqQoKwEEJU4UxaGs6FhQQ6O1tsj3R3Jzc5mbzCwgYemRBCCGuQICyEEFXQNB0bQFEUi+22qgq6zmW2QacQQrQYEoSFEKIK3u5uZNvakFFUZLE9PjcXe09PnB0cGnhkQgghrEGCsBBCVCHUzw/HwCC2pJyl1GSq0Jacn89Ro4m2kZGoqvwoFUKI5khWjRBCiCqoqsrgvn1YV1zML4mJtHVwwNnGhqSCAuJUA75dOtMlIryxhymEEKKWJAgLIVoETdMoLi3Fwc6uyppfS/w8PBg/YjiH4+I4GnsKY0kJLkFB9AoPp21wkMwGCyFEMyZBWAhxWYtOSOCT3xayeOMmCktL8Xdz4/axY3jo2om4VbEaxIVcnZzo16ED/Tp0qOfRCiGEaEgShIUQl60dR49y8wsz8DIama5ptAa25OTw5a8L+OPvv1n03zdlm2QhhGjB5D09IcRlyWQy8eBb79C9tJRDmsYLwB3AZ8AOTSPlTBKvzfqhkUcphBCiMUkQFkJcltbt2UN8ejrv6zouF7R1AB7XNH5d/yc5+fmNMTwhhBBNgARhIcRl6VDsKXwMBnpX0X4VUFh
  1247. "text/plain": [
  1248. "<Figure size 900x700 with 2 Axes>"
  1249. ]
  1250. },
  1251. "metadata": {},
  1252. "output_type": "display_data"
  1253. }
  1254. ],
  1255. "source": [
  1256. "import matplotlib.pyplot as plt\n",
  1257. "from matplotlib.colors import ListedColormap\n",
  1258. "\n",
  1259. "cm = plt.cm.RdBu\n",
  1260. "cm_bright = ListedColormap([\"#FF0000\", \"#0000FF\"])\n",
  1261. "\n",
  1262. "xv = np.linspace(x_min, x_max, 10)\n",
  1263. "yv = np.linspace(y_min, y_max, 10)\n",
  1264. "Xv, Yv = np.meshgrid(xv, yv)\n",
  1265. "XYpairs = np.vstack([ Xv.reshape(-1), Yv.reshape(-1)])\n",
  1266. "zv = nn.predict(XYpairs.T)\n",
  1267. "Zv = zv.reshape(Xv.shape)\n",
  1268. "\n",
  1269. "fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(9, 7))\n",
  1270. "ax.set_aspect(1)\n",
  1271. "cn = ax.contourf(Xv, Yv, Zv, cmap=\"coolwarm_r\", alpha=0.4)\n",
  1272. "\n",
  1273. "ax.scatter(X_train[:, 0], X_train[:, 1], c=y_train, cmap=cm_bright, edgecolors=\"k\")\n",
  1274. "\n",
  1275. "# Plot the testing points\n",
  1276. "ax.scatter(X_test[:, 0], X_test[:, 1], c=y_test, cmap=cm_bright, alpha=0.4, edgecolors=\"k\")\n",
  1277. "\n",
  1278. "ax.set_xlim(x_min, x_max)\n",
  1279. "ax.set_ylim(y_min, y_max)\n",
  1280. "# ax.set_xticks(())\n",
  1281. "# ax.set_yticks(())\n",
  1282. "\n",
  1283. "fig.colorbar(cn)\n"
  1284. ]
  1285. },
  1286. {
  1287. "cell_type": "code",
  1288. "execution_count": null,
  1289. "metadata": {},
  1290. "outputs": [],
  1291. "source": []
  1292. }
  1293. ],
  1294. "metadata": {
  1295. "kernelspec": {
  1296. "display_name": "Python 3 (ipykernel)",
  1297. "language": "python",
  1298. "name": "python3"
  1299. },
  1300. "language_info": {
  1301. "codemirror_mode": {
  1302. "name": "ipython",
  1303. "version": 3
  1304. },
  1305. "file_extension": ".py",
  1306. "mimetype": "text/x-python",
  1307. "name": "python",
  1308. "nbconvert_exporter": "python",
  1309. "pygments_lexer": "ipython3",
  1310. "version": "3.10.9"
  1311. },
  1312. "vscode": {
  1313. "interpreter": {
  1314. "hash": "b0fa6594d8f4cbf19f97940f81e996739fb7646882a419484c72d19e05852a7e"
  1315. }
  1316. }
  1317. },
  1318. "nbformat": 4,
  1319. "nbformat_minor": 4
  1320. }