Machine Learning Kurs im Rahmen der Studierendentage im SS 2023
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 

166 lines
3.7 KiB

{
"cells": [
{
"cell_type": "markdown",
"id": "042acd49",
"metadata": {},
"source": [
"# Test a minimizer"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "cb51a492",
"metadata": {},
"outputs": [],
"source": [
"import tensorflow as tf\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"from scipy.optimize import minimize\n",
"plt.style.use(\"ggplot\")\n",
"from matplotlib import colors, cm"
]
},
{
"cell_type": "markdown",
"id": "2ac3651a",
"metadata": {},
"source": [
"plt.rcParams controls the appearance of your plots globally,\n",
"affecting all subsequent plots created in your session."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "97ef9933",
"metadata": {},
"outputs": [],
"source": [
"plt.rcParams[\"axes.grid\"] = False\n",
"plt.rcParams.update({'font.size': 20})\n",
"plt.rcParams.update({'figure.figsize': (12,9)})\n",
"plt.rcParams['lines.markersize'] = 8"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f15200f9",
"metadata": {},
"outputs": [],
"source": [
"# Generate data points with gaussian smearing\n",
"data = np.random.uniform(size=100)\n",
"labels = 5.*data*data*data + 1 + np.random.normal(loc=0.0, scale=0.1, size=100)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7237f5ed",
"metadata": {},
"outputs": [],
"source": [
"# show plot\n",
"plt.scatter(data, labels, label=\"data\")\n",
"plt.legend()\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0d6e104c",
"metadata": {},
"outputs": [],
"source": [
"# define chi2 like cost function\n",
"def cost(params):\n",
" W, b = params\n",
" return np.mean((labels - (W*data*data*data + b))**2)"
]
},
{
"cell_type": "markdown",
"id": "8e00e16a",
"metadata": {},
"source": [
"call minimizer\n",
"provides a collection of optimization algorithms for finding the minimum or maximum of a given function. "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "433975c3",
"metadata": {},
"outputs": [],
"source": [
"res = minimize(cost, [1., 1.])\n",
"# returns an OptimizeResult object\n",
"# x :the solution (minimum) of the optimization problem, represented as an\n",
"# array.\n",
"# Results of the minimization\n",
"W, b = res.x\n",
"print ('function value at the minimum and fitted parameters',res.fun,' ',W,' ',b)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1e1f4e81",
"metadata": {},
"outputs": [],
"source": [
"points = np.linspace(0, 1, 100)\n",
"prediction = W*points*points*points + b"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d8de971e",
"metadata": {},
"outputs": [],
"source": [
"# plot fit model\n",
"plt.scatter(data, labels, label=\"data\")\n",
"plt.plot(points, prediction, label=\"model\", color=\"green\")\n",
"plt.legend()\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4a7d62c2",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.16"
}
},
"nbformat": 4,
"nbformat_minor": 5
}