167 lines
3.7 KiB
Plaintext
167 lines
3.7 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "042acd49",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Test a minimizer"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "cb51a492",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import tensorflow as tf\n",
|
|
"import numpy as np\n",
|
|
"import matplotlib.pyplot as plt\n",
|
|
"from scipy.optimize import minimize\n",
|
|
"plt.style.use(\"ggplot\")\n",
|
|
"from matplotlib import colors, cm"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "2ac3651a",
|
|
"metadata": {},
|
|
"source": [
|
|
"plt.rcParams controls the appearance of your plots globally,\n",
|
|
"affecting all subsequent plots created in your session."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "97ef9933",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"plt.rcParams[\"axes.grid\"] = False\n",
|
|
"plt.rcParams.update({'font.size': 20})\n",
|
|
"plt.rcParams.update({'figure.figsize': (12,9)})\n",
|
|
"plt.rcParams['lines.markersize'] = 8"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "f15200f9",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# Generate data points with gaussian smearing\n",
|
|
"data = np.random.uniform(size=100)\n",
|
|
"labels = 5.*data*data*data + 1 + np.random.normal(loc=0.0, scale=0.1, size=100)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "7237f5ed",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# show plot\n",
|
|
"plt.scatter(data, labels, label=\"data\")\n",
|
|
"plt.legend()\n",
|
|
"plt.show()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "0d6e104c",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# define chi2 like cost function\n",
|
|
"def cost(params):\n",
|
|
" W, b = params\n",
|
|
" return np.mean((labels - (W*data*data*data + b))**2)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"id": "8e00e16a",
|
|
"metadata": {},
|
|
"source": [
|
|
"call minimizer\n",
|
|
"provides a collection of optimization algorithms for finding the minimum or maximum of a given function. "
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "433975c3",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"res = minimize(cost, [1., 1.])\n",
|
|
"# returns an OptimizeResult object\n",
|
|
"# x :the solution (minimum) of the optimization problem, represented as an\n",
|
|
"# array.\n",
|
|
"# Results of the minimization\n",
|
|
"W, b = res.x\n",
|
|
"print ('function value at the minimum and fitted parameters',res.fun,' ',W,' ',b)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "1e1f4e81",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"points = np.linspace(0, 1, 100)\n",
|
|
"prediction = W*points*points*points + b"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "d8de971e",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"# plot fit model\n",
|
|
"plt.scatter(data, labels, label=\"data\")\n",
|
|
"plt.plot(points, prediction, label=\"model\", color=\"green\")\n",
|
|
"plt.legend()\n",
|
|
"plt.show()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": null,
|
|
"id": "4a7d62c2",
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": []
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3 (ipykernel)",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.8.16"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 5
|
|
}
|