160 lines
3.4 KiB
Plaintext
160 lines
3.4 KiB
Plaintext
|
{
|
||
|
"cells": [
|
||
|
{
|
||
|
"attachments": {},
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"# Hyperparameter optimization"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"Superconductivty Data Set: Predict the critical temperature based on 81 material features."
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"import pandas as pd\n",
|
||
|
"import numpy as np\n",
|
||
|
"import matplotlib.pyplot as plt\n",
|
||
|
"from sklearn.model_selection import train_test_split\n",
|
||
|
"from sklearn.metrics import mean_squared_error"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"filename = \"https://www.physi.uni-heidelberg.de/~reygers/lectures/2021/ml/data/train_critical_temp.csv\"\n",
|
||
|
"df = pd.read_csv(filename, engine='python')"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"df.head()"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"y = df['critical_temp'].values\n",
|
||
|
"X = df[[col for col in df.columns if col!=\"critical_temp\"]]"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, shuffle=True)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"from sklearn.neural_network import MLPRegressor\n",
|
||
|
"import time\n",
|
||
|
"\n",
|
||
|
"mlpr = MLPRegressor(hidden_layer_sizes=(50,50), activation='relu', random_state=1, max_iter=5000)\n",
|
||
|
"\n",
|
||
|
"start_time = time.time()\n",
|
||
|
"mlpr.fit(X_train, y_train)\n",
|
||
|
"run_time = time.time() - start_time\n",
|
||
|
"\n",
|
||
|
"print(run_time)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"y_pred = mlpr.predict(X_test)"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"plt.scatter(y_test, y_pred, s=2)\n",
|
||
|
"plt.xlabel(\"true critical temperature (K)\", fontsize=14)\n",
|
||
|
"plt.ylabel(\"predicted critical temperature (K)\", fontsize=14)\n",
|
||
|
"plt.savefig(\"critical_temperature.pdf\")"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"rms = np.sqrt(mean_squared_error(y_test, y_pred))\n",
|
||
|
"print(f\"root mean square error {rms:.2f}\")"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"attachments": {},
|
||
|
"cell_type": "markdown",
|
||
|
"metadata": {},
|
||
|
"source": [
|
||
|
"## Now try to optimize the hyperparameters"
|
||
|
]
|
||
|
},
|
||
|
{
|
||
|
"cell_type": "code",
|
||
|
"execution_count": null,
|
||
|
"metadata": {},
|
||
|
"outputs": [],
|
||
|
"source": [
|
||
|
"from sklearn.model_selection import GridSearchCV\n",
|
||
|
"\n",
|
||
|
"### Your code here\n"
|
||
|
]
|
||
|
}
|
||
|
],
|
||
|
"metadata": {
|
||
|
"kernelspec": {
|
||
|
"display_name": "Python 3",
|
||
|
"language": "python",
|
||
|
"name": "python3"
|
||
|
},
|
||
|
"language_info": {
|
||
|
"codemirror_mode": {
|
||
|
"name": "ipython",
|
||
|
"version": 3
|
||
|
},
|
||
|
"file_extension": ".py",
|
||
|
"mimetype": "text/x-python",
|
||
|
"name": "python",
|
||
|
"nbconvert_exporter": "python",
|
||
|
"pygments_lexer": "ipython3",
|
||
|
"version": "3.10.11"
|
||
|
}
|
||
|
},
|
||
|
"nbformat": 4,
|
||
|
"nbformat_minor": 4
|
||
|
}
|