ML-Kurs-SS2023/notebooks/04_decision_trees_critical_temp_regression.ipynb

452 lines
54 KiB
Plaintext
Raw Normal View History

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Example: Regression with XGBoost"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Superconductivty Data Set: Predict the critical temperature based on 81 material features."
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"from sklearn.model_selection import train_test_split\n",
"from sklearn.metrics import mean_squared_error"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"filename = \"https://www.physi.uni-heidelberg.de/~reygers/lectures/2021/ml/data/train_critical_temp.csv\"\n",
"df = pd.read_csv(filename, engine='python')"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>number_of_elements</th>\n",
" <th>mean_atomic_mass</th>\n",
" <th>wtd_mean_atomic_mass</th>\n",
" <th>gmean_atomic_mass</th>\n",
" <th>wtd_gmean_atomic_mass</th>\n",
" <th>entropy_atomic_mass</th>\n",
" <th>wtd_entropy_atomic_mass</th>\n",
" <th>range_atomic_mass</th>\n",
" <th>wtd_range_atomic_mass</th>\n",
" <th>std_atomic_mass</th>\n",
" <th>...</th>\n",
" <th>wtd_mean_Valence</th>\n",
" <th>gmean_Valence</th>\n",
" <th>wtd_gmean_Valence</th>\n",
" <th>entropy_Valence</th>\n",
" <th>wtd_entropy_Valence</th>\n",
" <th>range_Valence</th>\n",
" <th>wtd_range_Valence</th>\n",
" <th>std_Valence</th>\n",
" <th>wtd_std_Valence</th>\n",
" <th>critical_temp</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>4</td>\n",
" <td>88.944468</td>\n",
" <td>57.862692</td>\n",
" <td>66.361592</td>\n",
" <td>36.116612</td>\n",
" <td>1.181795</td>\n",
" <td>1.062396</td>\n",
" <td>122.90607</td>\n",
" <td>31.794921</td>\n",
" <td>51.968828</td>\n",
" <td>...</td>\n",
" <td>2.257143</td>\n",
" <td>2.213364</td>\n",
" <td>2.219783</td>\n",
" <td>1.368922</td>\n",
" <td>1.066221</td>\n",
" <td>1</td>\n",
" <td>1.085714</td>\n",
" <td>0.433013</td>\n",
" <td>0.437059</td>\n",
" <td>29.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>5</td>\n",
" <td>92.729214</td>\n",
" <td>58.518416</td>\n",
" <td>73.132787</td>\n",
" <td>36.396602</td>\n",
" <td>1.449309</td>\n",
" <td>1.057755</td>\n",
" <td>122.90607</td>\n",
" <td>36.161939</td>\n",
" <td>47.094633</td>\n",
" <td>...</td>\n",
" <td>2.257143</td>\n",
" <td>1.888175</td>\n",
" <td>2.210679</td>\n",
" <td>1.557113</td>\n",
" <td>1.047221</td>\n",
" <td>2</td>\n",
" <td>1.128571</td>\n",
" <td>0.632456</td>\n",
" <td>0.468606</td>\n",
" <td>26.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>4</td>\n",
" <td>88.944468</td>\n",
" <td>57.885242</td>\n",
" <td>66.361592</td>\n",
" <td>36.122509</td>\n",
" <td>1.181795</td>\n",
" <td>0.975980</td>\n",
" <td>122.90607</td>\n",
" <td>35.741099</td>\n",
" <td>51.968828</td>\n",
" <td>...</td>\n",
" <td>2.271429</td>\n",
" <td>2.213364</td>\n",
" <td>2.232679</td>\n",
" <td>1.368922</td>\n",
" <td>1.029175</td>\n",
" <td>1</td>\n",
" <td>1.114286</td>\n",
" <td>0.433013</td>\n",
" <td>0.444697</td>\n",
" <td>19.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>4</td>\n",
" <td>88.944468</td>\n",
" <td>57.873967</td>\n",
" <td>66.361592</td>\n",
" <td>36.119560</td>\n",
" <td>1.181795</td>\n",
" <td>1.022291</td>\n",
" <td>122.90607</td>\n",
" <td>33.768010</td>\n",
" <td>51.968828</td>\n",
" <td>...</td>\n",
" <td>2.264286</td>\n",
" <td>2.213364</td>\n",
" <td>2.226222</td>\n",
" <td>1.368922</td>\n",
" <td>1.048834</td>\n",
" <td>1</td>\n",
" <td>1.100000</td>\n",
" <td>0.433013</td>\n",
" <td>0.440952</td>\n",
" <td>22.0</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>4</td>\n",
" <td>88.944468</td>\n",
" <td>57.840143</td>\n",
" <td>66.361592</td>\n",
" <td>36.110716</td>\n",
" <td>1.181795</td>\n",
" <td>1.129224</td>\n",
" <td>122.90607</td>\n",
" <td>27.848743</td>\n",
" <td>51.968828</td>\n",
" <td>...</td>\n",
" <td>2.242857</td>\n",
" <td>2.213364</td>\n",
" <td>2.206963</td>\n",
" <td>1.368922</td>\n",
" <td>1.096052</td>\n",
" <td>1</td>\n",
" <td>1.057143</td>\n",
" <td>0.433013</td>\n",
" <td>0.428809</td>\n",
" <td>23.0</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>5 rows × 82 columns</p>\n",
"</div>"
],
"text/plain": [
" number_of_elements mean_atomic_mass wtd_mean_atomic_mass \\\n",
"0 4 88.944468 57.862692 \n",
"1 5 92.729214 58.518416 \n",
"2 4 88.944468 57.885242 \n",
"3 4 88.944468 57.873967 \n",
"4 4 88.944468 57.840143 \n",
"\n",
" gmean_atomic_mass wtd_gmean_atomic_mass entropy_atomic_mass \\\n",
"0 66.361592 36.116612 1.181795 \n",
"1 73.132787 36.396602 1.449309 \n",
"2 66.361592 36.122509 1.181795 \n",
"3 66.361592 36.119560 1.181795 \n",
"4 66.361592 36.110716 1.181795 \n",
"\n",
" wtd_entropy_atomic_mass range_atomic_mass wtd_range_atomic_mass \\\n",
"0 1.062396 122.90607 31.794921 \n",
"1 1.057755 122.90607 36.161939 \n",
"2 0.975980 122.90607 35.741099 \n",
"3 1.022291 122.90607 33.768010 \n",
"4 1.129224 122.90607 27.848743 \n",
"\n",
" std_atomic_mass ... wtd_mean_Valence gmean_Valence wtd_gmean_Valence \\\n",
"0 51.968828 ... 2.257143 2.213364 2.219783 \n",
"1 47.094633 ... 2.257143 1.888175 2.210679 \n",
"2 51.968828 ... 2.271429 2.213364 2.232679 \n",
"3 51.968828 ... 2.264286 2.213364 2.226222 \n",
"4 51.968828 ... 2.242857 2.213364 2.206963 \n",
"\n",
" entropy_Valence wtd_entropy_Valence range_Valence wtd_range_Valence \\\n",
"0 1.368922 1.066221 1 1.085714 \n",
"1 1.557113 1.047221 2 1.128571 \n",
"2 1.368922 1.029175 1 1.114286 \n",
"3 1.368922 1.048834 1 1.100000 \n",
"4 1.368922 1.096052 1 1.057143 \n",
"\n",
" std_Valence wtd_std_Valence critical_temp \n",
"0 0.433013 0.437059 29.0 \n",
"1 0.632456 0.468606 26.0 \n",
"2 0.433013 0.444697 19.0 \n",
"3 0.433013 0.440952 22.0 \n",
"4 0.433013 0.428809 23.0 \n",
"\n",
"[5 rows x 82 columns]"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df.head()"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [],
"source": [
"y = df['critical_temp'].values\n",
"X = df[[col for col in df.columns if col!=\"critical_temp\"]]"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, shuffle=True)"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"16.105452060699463\n"
]
}
],
"source": [
"import xgboost as xgb\n",
"import time\n",
"# XGBreg = xgb.sklearn.XGBRegressor(nthread=-1, seed=1, n_estimators=1000)\n",
"XGBreg = xgb.sklearn.XGBRegressor()\n",
"\n",
"start_time = time.time()\n",
"XGBreg.fit(X_train, y_train)\n",
"run_time = time.time() - start_time\n",
"\n",
"print(run_time)"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [],
"source": [
"y_pred = XGBreg.predict(X_test)"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYkAAAETCAYAAADDIPqYAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy86wFpkAAAACXBIWXMAAAsTAAALEwEAmpwYAAB5AklEQVR4nO29e3wc1Xn//360u7pfbdlYYMkYIgO2IXawY1KIc4PgEpPQJjWENECA0KShoU2+pHZa54LT4pZe4vySNCUGAg00uKElsUMhQFscCDiWwQFfwALbkmzkSLJ1X112V+f3x8wZz45mVyNpVxf7vF+vfUk7Ozvz7Kx0zpzn8nlEKYXBYDAYDH7kTLYBBoPBYJi6mEnCYDAYDCkxk4TBYDAYUmImCYPBYDCkxEwSBoPBYEiJmSQMBoPBkBIzSRgMBoMhJWaSMBgMBkNKwkF3FJE84EygAGhVSrVmzSqDwWAwTAnSriREpEREPi8i24FO4E1gD3BMRBpF5IcisnwiDDUYDAbDxCOpZDlE5EvAXwEHgZ8DvwHeBvqAGcBi4L3AHwAvAX+mlKqfAJsNBoPBMEGkmyS2AHcppfakPYBIPnAzMKiU2px5Ew0Gg8EwWaScJAIfQCRfKdWfIXsMBoPBMIUYKSbx2RFezwe2ZdQig8FgMEwZRspu+q6ItCml/sv7gojkYsUq3pEVy8ZAZWWlOvvssyfbDIPBYJhW7Nq1q00pNcvvtZEmiduBh0XkI0qp/9UbRSQC/AxYCKzMmKXj5Oyzz6aurm6yzTAYDIZphYg0pHot7SShlPqhiFQCj4vIB5RSL9sriP8C3gmsVEodzKy5BoPBYJgqjFhxrZS6G9gM/LeILAZ+ClwMfFAp9WbQE4nI/SLSIiLDsqVE5MsiouwJCbH4joi8KSKvisi7An8ig8FgMGSMQLIcSqkvA08ALwMrsCaI10d5rh8Bq7wbRaQa+DDQ6Nr8+0Ct/bgN+JdRnstgMBgMGSCtu0lEvuN62g8MAa8BnxMR5wWl1BdHOpFSaruInO3z0j8DX8GKcWg+BjykrPzcl0SkXESqlFLNI53HYDAYDJljpMD1hZ7nLwIhz/YxF1qIyMeAo0qp37onHeAsoMn1/Ii9bdgkISK3Ya02qKmpGaspBoPBYPBhpMD1B7J1YhEpBL6K5WoaM0qpe4F7AZYtWza+ykCDwWAwJBFYBTYLnAvMB/QqYi7wsoi8GzgKVLv2nWtvMxgMBsMEkjJwLSJ/LSJFQQ4iIpeKyNWjObFS6jWl1Gyl1NlKqbOxXErvUkodwyrSu8HOcroE6DTxCIPBMJ3Z1dDODfftYFdD+2SbMirSZTedCzSKyL0icrWIVOkXRCRfRN4lIl8Ukd8A/wak/eQi8u9YMY3zROSIiNySZvcnsNRn3wR+CPxpwM9jMBgMk47fhLDpmQNsr29j0zMHJtGy0ZPS3aSU+oyIXIhVdf0QUCoiCogBuYBgpcTeCzyolBpIdyKl1CdHeP1s1+8K+ELAz2AwGAxTCj0hADx0ywoA7rh8QdLP6cJIgevXgD8Rkc8DFwHzsDrTtQG7lVJt2TfRYDAYphd+E8LF8yq44/IFbHrmAHdcvoCL51VMlnmjIlDgWik1BOy2HwaDwWBwsauhPWnwv3hehbOCcONeYUyXCWMys5sMBoPhlMDPveSHe4UR9D2TjZkkDAaDYZwEjTe4VxjTJUYx7s50U4lly5YpIxVuMBgMo0NEdimllvm9ZlYSBoPBkGV0zGLV4iq27GwEEdavXjilYxGaQCqwbkTkDBEZ9fsMBoNhKjIRRW46/nDPU6+z+0gnu5s6pk29RKDBXkQiIvL3ItKNJY9xtr3970TEFLoZDIZpy0QUua1aXEVFYYRrl1WzZG4ZtbOL6eqLTYvq66Argq8DVwN/DLiL5n4D3JRhmwwGg2HCuOPyBaysrcxaAHlXQzv3PPU67dEY+5q7ePz2y6gqy2f3kc5psZoIGpP4JHCzUuo5ERlybd8DTO3QvMFgMKQhVU1Dptj0zAHaozEqCiPDMpqmemYTBJ8kzgT8GmWHR3EMg8FgOO1wTwg6UJ3tiSmTBHU37QVW+mxfA+zKnDkGg+FUYLoqnmYDPSFkM5Mpm9c76CTxTeD/E5G/wupM90ci8gCwFtiQcasMBsO0ZjzB4KkywWTTjo1P7OcdX32CjU/sz8jxshl8D6rdtFVE1mB1khvCCmS/DFytlHom41YZDIZpzXh87kHlKrx6SZlCH7erL8buI50j2jEaHtnRyD1PvU5nNMYQ8IPtB7li0Zxx25/NGMeIKwmd/gq8rpR6n1KqWClVqJS6TCn1y4xbZDAYpj3jcbEEzTYa791zqpWCM0mJZDzrSWc5RcInh95M3P1n06U14kpCKRWzayG+n/GzGwwGg4d0QV336mGsd88jrRT8As2Z4n0LZvH47rf5/cVz+PR7zmbD1r00d/ZzzXefZ/3Vi6ZkBXbQmMRTwAezaYjBYDCMhHv1MNa755FWCmM5btD4xXMHWp2fF8+roLQgQn1Lz5SumQiavvos8LcichFWNlOv+0Wl1H9m2jCDwXBqMp5YQiZ879lYKQSNo9x55fnc89Tr3Hnl+Y4NzZ39HOvsY9XiqpTvm0wCqcB6Cui8KKVUKHMmjR2jAmswTH1uuG8H2+vbWFlbOWVqBdwTFzDqSWw8E99UuB7jVoFVSo1b0E9E7gdWAy1KqcX2tnuw5D4GgbeAzyilOuzX1gG3AAngi0qpp8Zrg8FgmHymWrXxroZ2bn1wJ+3RmLNttM2AxlMcN9Wuh5eJVHP9EbDKs+1pYLFS6iLgALAOQEQWAtcBi+z3fF9EpsRqxWAwjI/xZOKM5PsfS22DVzYj21pOXiai2G48BFpJiMiX0r2ulPqnkY6hlNouImd7trlTaF8CPmH//jHgJ0qpAeCQiLwJvBt4MYi9BoPh1CSV73+kjKV07iC/GMVDt6xwJpyp3oM62wRdSfyZ5/El4B7gLuD2DNlyM/Df9u9nAU2u147Y24YhIreJSJ2I1LW2tmbIFIPBMFGM5u5/YVUp4RxhYVVp0nY9efQOJqgojLCwqjTpmOlqKtx38m5b0r3nkR2NLL3rl2x8Yv+UqA7PJkFjEvO920TkDOAB4IfjNcKW+4gDD4/2vUqpe4F7wQpcj9cWg8GQmmxUOQfNDAJ4tK6J+JDi4R2N7GvucuzQq4Hmzn7aozEeerGBaCxBV3+cx79waVq/v/szaVu6+uOgFEuqy33fo4viNj9/iPiQCmT7dGXMMQml1O+AvwL+fjwGiMhNWAHtT6mTqVZHgWrXbnPtbQaDYRLJhkbQaGIA1y6rJpwjlBWEk+zQE8Wxzj57T3sosYeUdH7/Ddv2sb2+jQ3b9nHH5QtYMreMt1q62X2kk4bjvcP2ByuVtSQvxJyy/JQTyanCeAPXOcAZY32ziKwCvgJ8VCkVdb30c+A6EckTkflALVaDI4PBkIZsi+ONNKDr8z+yozGwHaMJ3O5r7iI+pKgsyR9mx6ZnDtA9YLmb/nr1IlbWVrJmec3IdqiTE4oucOseSBDOEdqjMTZs2zfsGNevqGFpTQVH2vsozQ+f0jGLoIHrP/RuAqqALwC/CniMfwfeD1SKyBEskcB1QB7wtIgAvKSU+pxSaq+IbAH2YbmhvqCUSgQ5j8FwOjMa181YGCnVU5//taOdTkppJu1IVwjnfe36FTVODUI6O9ZfvSipRkL/XLW4iif3NNPVF/M9xlRPXc0UYy2mU0Ar8D/Al5VSzVmwbdSYYjrD6U62lFFHe349wPrZkc7G8Ra1jcWeoMdwvzcTx51KpCumCzRJTBfMJGEwTH3SVRi7XwNGrERONeF4t6c7566GdjZs2wdKsf7qRcDIk5M+XkVhhPZobEpVj4+FcVdci8gNwKN23YJ7ey5wnVLqofGbaTAYTgfSuWn8XkvnznG713R2kjtL6bWjnWy+cXnac2565gC7mzoAuPXBncybUThiHwmvS+pUdjkFdTclgCqlVItn+0wsmY0pUQ1tVhIGw+QxEa4u7zncbh+dlqoD2lpqQz9P5+LasG0fb7V00z2QYEl1OaX
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.scatter(y_test, y_pred, s=2)\n",
"plt.xlabel(\"true critical temperature (K)\", fontsize=14)\n",
"plt.ylabel(\"predicted critical temperature (K)\", fontsize=14)\n",
"plt.savefig(\"critical_temperature.pdf\")"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"root mean square error 9.68\n"
]
}
],
"source": [
"rms = np.sqrt(mean_squared_error(y_test, y_pred))\n",
"print(f\"root mean square error {rms:.2f}\")"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [],
"source": [
"# compare with other regressors\n",
"\n",
"from sklearn.ensemble import RandomForestRegressor\n",
"rfr = RandomForestRegressor()\n",
"\n",
"from sklearn.ensemble import GradientBoostingRegressor\n",
"gbr = GradientBoostingRegressor()\n",
"\n",
"from sklearn.neural_network import MLPRegressor\n",
"mlpr = MLPRegressor(hidden_layer_sizes=(50,50), activation='relu', random_state=1, max_iter=5000)"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"83.4200599193573\n",
"root mean square error 9.54\n",
"\n",
"28.77874779701233\n",
"root mean square error 12.56\n",
"\n",
"9.2709481716156\n",
"root mean square error 18.58\n",
"\n"
]
}
],
"source": [
"regressors = [rfr, gbr, mlpr]\n",
"\n",
"for reg in regressors:\n",
" \n",
" start_time = time.time()\n",
" reg.fit(X_train, y_train)\n",
" run_time = time.time() - start_time\n",
" \n",
" y_pred = reg.predict(X_test)\n",
" rms = np.sqrt(mean_squared_error(y_test, y_pred))\n",
" \n",
" print(run_time)\n",
" print(f\"root mean square error {rms:.2f}\\n\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
}
},
"nbformat": 4,
"nbformat_minor": 4
}