452 lines
54 KiB
Plaintext
452 lines
54 KiB
Plaintext
|
{
|
|||
|
"cells": [
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"# Example: Regression with XGBoost"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "markdown",
|
|||
|
"metadata": {},
|
|||
|
"source": [
|
|||
|
"Superconductivty Data Set: Predict the critical temperature based on 81 material features."
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 19,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"import pandas as pd\n",
|
|||
|
"import numpy as np\n",
|
|||
|
"import matplotlib.pyplot as plt\n",
|
|||
|
"from sklearn.model_selection import train_test_split\n",
|
|||
|
"from sklearn.metrics import mean_squared_error"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 20,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"filename = \"https://www.physi.uni-heidelberg.de/~reygers/lectures/2021/ml/data/train_critical_temp.csv\"\n",
|
|||
|
"df = pd.read_csv(filename, engine='python')"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 21,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"text/html": [
|
|||
|
"<div>\n",
|
|||
|
"<style scoped>\n",
|
|||
|
" .dataframe tbody tr th:only-of-type {\n",
|
|||
|
" vertical-align: middle;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe tbody tr th {\n",
|
|||
|
" vertical-align: top;\n",
|
|||
|
" }\n",
|
|||
|
"\n",
|
|||
|
" .dataframe thead th {\n",
|
|||
|
" text-align: right;\n",
|
|||
|
" }\n",
|
|||
|
"</style>\n",
|
|||
|
"<table border=\"1\" class=\"dataframe\">\n",
|
|||
|
" <thead>\n",
|
|||
|
" <tr style=\"text-align: right;\">\n",
|
|||
|
" <th></th>\n",
|
|||
|
" <th>number_of_elements</th>\n",
|
|||
|
" <th>mean_atomic_mass</th>\n",
|
|||
|
" <th>wtd_mean_atomic_mass</th>\n",
|
|||
|
" <th>gmean_atomic_mass</th>\n",
|
|||
|
" <th>wtd_gmean_atomic_mass</th>\n",
|
|||
|
" <th>entropy_atomic_mass</th>\n",
|
|||
|
" <th>wtd_entropy_atomic_mass</th>\n",
|
|||
|
" <th>range_atomic_mass</th>\n",
|
|||
|
" <th>wtd_range_atomic_mass</th>\n",
|
|||
|
" <th>std_atomic_mass</th>\n",
|
|||
|
" <th>...</th>\n",
|
|||
|
" <th>wtd_mean_Valence</th>\n",
|
|||
|
" <th>gmean_Valence</th>\n",
|
|||
|
" <th>wtd_gmean_Valence</th>\n",
|
|||
|
" <th>entropy_Valence</th>\n",
|
|||
|
" <th>wtd_entropy_Valence</th>\n",
|
|||
|
" <th>range_Valence</th>\n",
|
|||
|
" <th>wtd_range_Valence</th>\n",
|
|||
|
" <th>std_Valence</th>\n",
|
|||
|
" <th>wtd_std_Valence</th>\n",
|
|||
|
" <th>critical_temp</th>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </thead>\n",
|
|||
|
" <tbody>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>0</th>\n",
|
|||
|
" <td>4</td>\n",
|
|||
|
" <td>88.944468</td>\n",
|
|||
|
" <td>57.862692</td>\n",
|
|||
|
" <td>66.361592</td>\n",
|
|||
|
" <td>36.116612</td>\n",
|
|||
|
" <td>1.181795</td>\n",
|
|||
|
" <td>1.062396</td>\n",
|
|||
|
" <td>122.90607</td>\n",
|
|||
|
" <td>31.794921</td>\n",
|
|||
|
" <td>51.968828</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>2.257143</td>\n",
|
|||
|
" <td>2.213364</td>\n",
|
|||
|
" <td>2.219783</td>\n",
|
|||
|
" <td>1.368922</td>\n",
|
|||
|
" <td>1.066221</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>1.085714</td>\n",
|
|||
|
" <td>0.433013</td>\n",
|
|||
|
" <td>0.437059</td>\n",
|
|||
|
" <td>29.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>1</th>\n",
|
|||
|
" <td>5</td>\n",
|
|||
|
" <td>92.729214</td>\n",
|
|||
|
" <td>58.518416</td>\n",
|
|||
|
" <td>73.132787</td>\n",
|
|||
|
" <td>36.396602</td>\n",
|
|||
|
" <td>1.449309</td>\n",
|
|||
|
" <td>1.057755</td>\n",
|
|||
|
" <td>122.90607</td>\n",
|
|||
|
" <td>36.161939</td>\n",
|
|||
|
" <td>47.094633</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>2.257143</td>\n",
|
|||
|
" <td>1.888175</td>\n",
|
|||
|
" <td>2.210679</td>\n",
|
|||
|
" <td>1.557113</td>\n",
|
|||
|
" <td>1.047221</td>\n",
|
|||
|
" <td>2</td>\n",
|
|||
|
" <td>1.128571</td>\n",
|
|||
|
" <td>0.632456</td>\n",
|
|||
|
" <td>0.468606</td>\n",
|
|||
|
" <td>26.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>2</th>\n",
|
|||
|
" <td>4</td>\n",
|
|||
|
" <td>88.944468</td>\n",
|
|||
|
" <td>57.885242</td>\n",
|
|||
|
" <td>66.361592</td>\n",
|
|||
|
" <td>36.122509</td>\n",
|
|||
|
" <td>1.181795</td>\n",
|
|||
|
" <td>0.975980</td>\n",
|
|||
|
" <td>122.90607</td>\n",
|
|||
|
" <td>35.741099</td>\n",
|
|||
|
" <td>51.968828</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>2.271429</td>\n",
|
|||
|
" <td>2.213364</td>\n",
|
|||
|
" <td>2.232679</td>\n",
|
|||
|
" <td>1.368922</td>\n",
|
|||
|
" <td>1.029175</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>1.114286</td>\n",
|
|||
|
" <td>0.433013</td>\n",
|
|||
|
" <td>0.444697</td>\n",
|
|||
|
" <td>19.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>3</th>\n",
|
|||
|
" <td>4</td>\n",
|
|||
|
" <td>88.944468</td>\n",
|
|||
|
" <td>57.873967</td>\n",
|
|||
|
" <td>66.361592</td>\n",
|
|||
|
" <td>36.119560</td>\n",
|
|||
|
" <td>1.181795</td>\n",
|
|||
|
" <td>1.022291</td>\n",
|
|||
|
" <td>122.90607</td>\n",
|
|||
|
" <td>33.768010</td>\n",
|
|||
|
" <td>51.968828</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>2.264286</td>\n",
|
|||
|
" <td>2.213364</td>\n",
|
|||
|
" <td>2.226222</td>\n",
|
|||
|
" <td>1.368922</td>\n",
|
|||
|
" <td>1.048834</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>1.100000</td>\n",
|
|||
|
" <td>0.433013</td>\n",
|
|||
|
" <td>0.440952</td>\n",
|
|||
|
" <td>22.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" <tr>\n",
|
|||
|
" <th>4</th>\n",
|
|||
|
" <td>4</td>\n",
|
|||
|
" <td>88.944468</td>\n",
|
|||
|
" <td>57.840143</td>\n",
|
|||
|
" <td>66.361592</td>\n",
|
|||
|
" <td>36.110716</td>\n",
|
|||
|
" <td>1.181795</td>\n",
|
|||
|
" <td>1.129224</td>\n",
|
|||
|
" <td>122.90607</td>\n",
|
|||
|
" <td>27.848743</td>\n",
|
|||
|
" <td>51.968828</td>\n",
|
|||
|
" <td>...</td>\n",
|
|||
|
" <td>2.242857</td>\n",
|
|||
|
" <td>2.213364</td>\n",
|
|||
|
" <td>2.206963</td>\n",
|
|||
|
" <td>1.368922</td>\n",
|
|||
|
" <td>1.096052</td>\n",
|
|||
|
" <td>1</td>\n",
|
|||
|
" <td>1.057143</td>\n",
|
|||
|
" <td>0.433013</td>\n",
|
|||
|
" <td>0.428809</td>\n",
|
|||
|
" <td>23.0</td>\n",
|
|||
|
" </tr>\n",
|
|||
|
" </tbody>\n",
|
|||
|
"</table>\n",
|
|||
|
"<p>5 rows × 82 columns</p>\n",
|
|||
|
"</div>"
|
|||
|
],
|
|||
|
"text/plain": [
|
|||
|
" number_of_elements mean_atomic_mass wtd_mean_atomic_mass \\\n",
|
|||
|
"0 4 88.944468 57.862692 \n",
|
|||
|
"1 5 92.729214 58.518416 \n",
|
|||
|
"2 4 88.944468 57.885242 \n",
|
|||
|
"3 4 88.944468 57.873967 \n",
|
|||
|
"4 4 88.944468 57.840143 \n",
|
|||
|
"\n",
|
|||
|
" gmean_atomic_mass wtd_gmean_atomic_mass entropy_atomic_mass \\\n",
|
|||
|
"0 66.361592 36.116612 1.181795 \n",
|
|||
|
"1 73.132787 36.396602 1.449309 \n",
|
|||
|
"2 66.361592 36.122509 1.181795 \n",
|
|||
|
"3 66.361592 36.119560 1.181795 \n",
|
|||
|
"4 66.361592 36.110716 1.181795 \n",
|
|||
|
"\n",
|
|||
|
" wtd_entropy_atomic_mass range_atomic_mass wtd_range_atomic_mass \\\n",
|
|||
|
"0 1.062396 122.90607 31.794921 \n",
|
|||
|
"1 1.057755 122.90607 36.161939 \n",
|
|||
|
"2 0.975980 122.90607 35.741099 \n",
|
|||
|
"3 1.022291 122.90607 33.768010 \n",
|
|||
|
"4 1.129224 122.90607 27.848743 \n",
|
|||
|
"\n",
|
|||
|
" std_atomic_mass ... wtd_mean_Valence gmean_Valence wtd_gmean_Valence \\\n",
|
|||
|
"0 51.968828 ... 2.257143 2.213364 2.219783 \n",
|
|||
|
"1 47.094633 ... 2.257143 1.888175 2.210679 \n",
|
|||
|
"2 51.968828 ... 2.271429 2.213364 2.232679 \n",
|
|||
|
"3 51.968828 ... 2.264286 2.213364 2.226222 \n",
|
|||
|
"4 51.968828 ... 2.242857 2.213364 2.206963 \n",
|
|||
|
"\n",
|
|||
|
" entropy_Valence wtd_entropy_Valence range_Valence wtd_range_Valence \\\n",
|
|||
|
"0 1.368922 1.066221 1 1.085714 \n",
|
|||
|
"1 1.557113 1.047221 2 1.128571 \n",
|
|||
|
"2 1.368922 1.029175 1 1.114286 \n",
|
|||
|
"3 1.368922 1.048834 1 1.100000 \n",
|
|||
|
"4 1.368922 1.096052 1 1.057143 \n",
|
|||
|
"\n",
|
|||
|
" std_Valence wtd_std_Valence critical_temp \n",
|
|||
|
"0 0.433013 0.437059 29.0 \n",
|
|||
|
"1 0.632456 0.468606 26.0 \n",
|
|||
|
"2 0.433013 0.444697 19.0 \n",
|
|||
|
"3 0.433013 0.440952 22.0 \n",
|
|||
|
"4 0.433013 0.428809 23.0 \n",
|
|||
|
"\n",
|
|||
|
"[5 rows x 82 columns]"
|
|||
|
]
|
|||
|
},
|
|||
|
"execution_count": 21,
|
|||
|
"metadata": {},
|
|||
|
"output_type": "execute_result"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"df.head()"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 22,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"y = df['critical_temp'].values\n",
|
|||
|
"X = df[[col for col in df.columns if col!=\"critical_temp\"]]"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 23,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, shuffle=True)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 24,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"16.105452060699463\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"import xgboost as xgb\n",
|
|||
|
"import time\n",
|
|||
|
"# XGBreg = xgb.sklearn.XGBRegressor(nthread=-1, seed=1, n_estimators=1000)\n",
|
|||
|
"XGBreg = xgb.sklearn.XGBRegressor()\n",
|
|||
|
"\n",
|
|||
|
"start_time = time.time()\n",
|
|||
|
"XGBreg.fit(X_train, y_train)\n",
|
|||
|
"run_time = time.time() - start_time\n",
|
|||
|
"\n",
|
|||
|
"print(run_time)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 31,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"y_pred = XGBreg.predict(X_test)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 32,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"data": {
|
|||
|
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAYkAAAETCAYAAADDIPqYAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy86wFpkAAAACXBIWXMAAAsTAAALEwEAmpwYAAB5AklEQVR4nO29e3wc1Xn//360u7pfbdlYYMkYIgO2IXawY1KIc4PgEpPQJjWENECA0KShoU2+pHZa54LT4pZe4vySNCUGAg00uKElsUMhQFscCDiWwQFfwALbkmzkSLJ1X112V+f3x8wZz45mVyNpVxf7vF+vfUk7Ozvz7Kx0zpzn8nlEKYXBYDAYDH7kTLYBBoPBYJi6mEnCYDAYDCkxk4TBYDAYUmImCYPBYDCkxEwSBoPBYEiJmSQMBoPBkBIzSRgMBoMhJWaSMBgMBkNKwkF3FJE84EygAGhVSrVmzSqDwWAwTAnSriREpEREPi8i24FO4E1gD3BMRBpF5IcisnwiDDUYDAbDxCOpZDlE5EvAXwEHgZ8DvwHeBvqAGcBi4L3AHwAvAX+mlKqfAJsNBoPBMEGkmyS2AHcppfakPYBIPnAzMKiU2px5Ew0Gg8EwWaScJAIfQCRfKdWfIXsMBoPBMIUYKSbx2RFezwe2ZdQig8FgMEwZRspu+q6ItCml/sv7gojkYsUq3pEVy8ZAZWWlOvvssyfbDIPBYJhW7Nq1q00pNcvvtZEmiduBh0XkI0qp/9UbRSQC/AxYCKzMmKXj5Oyzz6aurm6yzTAYDIZphYg0pHot7SShlPqhiFQCj4vIB5RSL9sriP8C3gmsVEodzKy5BoPBYJgqjFhxrZS6G9gM/LeILAZ+ClwMfFAp9WbQE4nI/SLSIiLDsqVE5MsiouwJCbH4joi8KSKvisi7An8ig8FgMGSMQLIcSqkvA08ALwMrsCaI10d5rh8Bq7wbRaQa+DDQ6Nr8+0Ct/bgN+JdRnstgMBgMGSCtu0lEvuN62g8MAa8BnxMR5wWl1BdHOpFSaruInO3z0j8DX8GKcWg+BjykrPzcl0SkXESqlFLNI53HYDAYDJljpMD1hZ7nLwIhz/YxF1qIyMeAo0qp37onHeAsoMn1/Ii9bdgkISK3Ya02qKmpGaspBoPBYPBhpMD1B7J1YhEpBL6K5WoaM0qpe4F7AZYtWza+ykCDwWAwJBFYBTYLnAvMB/QqYi7wsoi8GzgKVLv2nWtvMxgMBsMEkjJwLSJ/LSJFQQ4iIpeKyNWjObFS6jWl1Gyl1NlKqbOxXErvUkodwyrSu8HOcroE6DTxCIPBMJ3Z1dDODfftYFdD+2SbMirSZTedCzSKyL0icrWIVOkXRCRfRN4lIl8Ukd8A/wak/eQi8u9YMY3zROSIiNySZvcnsNRn3wR+CPxpwM9jMBgMk47fhLDpmQNsr29j0zMHJtGy0ZPS3aSU+oyIXIhVdf0QUCoiCogBuYBgpcTeCzyolBpIdyKl1CdHeP1s1+8K+ELAz2AwGAxTCj0hADx0ywoA7rh8QdLP6cJIgevXgD8Rkc8DFwHzsDrTtQG7lVJt2TfRYDAYphd+E8LF8yq44/IFbHrmAHdcvoCL51VMlnmjIlDgWik1BOy2HwaDwWBwsauhPWnwv3hehbOCcONeYUyXCWMys5sMBoPhlMDPveSHe4UR9D2TjZkkDAaDYZwEjTe4VxjTJUYx7s50U4lly5YpIxVuMBgMo0NEdimllvm9ZlYSBoPBkGV0zGLV4iq27GwEEdavXjilYxGaQCqwbkTkDBEZ9fsMBoNhKjIRRW46/nDPU6+z+0gnu5s6pk29RKDBXkQiIvL3ItKNJY9xtr3970TEFLoZDIZpy0QUua1aXEVFYYRrl1WzZG4ZtbOL6eqLTYvq66Argq8DVwN/DLiL5n4D3JRhmwwGg2HCuOPyBaysrcxaAHlXQzv3PPU67dEY+5q7ePz2y6gqy2f3kc5psZoIGpP4JHCzUuo5ERlybd8DTO3QvMFgMKQhVU1Dptj0zAHaozEqCiPDMpqmemYTBJ8kzgT8GmWHR3EMg8FgOO1wTwg6UJ3tiSmTBHU37QVW+mxfA+zKnDkGg+FUYLoqnmYDPSFkM5Mpm9c76CTxTeD/E5G/wupM90ci8gCwFtiQcasMBsO0ZjzB4KkywWTTjo1P7OcdX32CjU/sz8jxshl8D6rdtFVE1mB1khvCCmS/DFytlHom41YZDIZpzXh87kHlKrx6SZlCH7erL8buI50j2jEaHtnRyD1PvU5nNMYQ8IPtB7li0Zxx25/NGMeIKwmd/gq8rpR6n1KqWClVqJS6TCn1y4xbZDAYpj3jcbEEzTYa791zqpWCM0mJZDzrSWc5RcInh95M3P1n06U14kpCKRWzayG+n/GzGwwGg4d0QV336mGsd88jrRT8As2Z4n0LZvH47rf5/cVz+PR7zmbD1r00d/ZzzXefZ/3Vi6ZkBXbQmMRTwAezaYjBYDCMhHv1MNa755FWCmM5btD4xXMHWp2fF8+roLQgQn1Lz5SumQiavvos8LcichFWNlOv+0Wl1H9m2jCDwXBqMp5YQiZ879lYKQSNo9x55fnc89Tr3Hnl+Y4NzZ39HOvsY9XiqpTvm0wCqcB6Cui8KKVUKHMmjR2jAmswTH1uuG8H2+vbWFlbOWVqBdwTFzDqSWw8E99UuB7jVoFVSo1b0E9E7gdWAy1KqcX2tnuw5D4GgbeAzyilOuzX1gG3AAngi0qpp8Zrg8FgmHymWrXxroZ2bn1wJ+3RmLNttM2AxlMcN9Wuh5eJVHP9EbDKs+1pYLFS6iLgALAOQEQWAtcBi+z3fF9EpsRqxWAwjI/xZOKM5PsfS22DVzYj21pOXiai2G48BFpJiMiX0r2ulPqnkY6hlNouImd7trlTaF8CPmH//jHgJ0qpAeCQiLwJvBt4MYi9BoPh1CSV73+kjKV07iC/GMVDt6xwJpyp3oM62wRdSfyZ5/El4B7gLuD2DNlyM/Df9u9nAU2u147Y24YhIreJSJ2I1LW2tmbIFIPBMFGM5u5/YVUp4RxhYVVp0nY9efQOJqgojLCwqjTpmOlqKtx38m5b0r3nkR2NLL3rl2x8Yv+UqA7PJkFjEvO920TkDOAB4IfjNcKW+4gDD4/2vUqpe4F7wQpcj9cWg8GQmmxUOQfNDAJ4tK6J+JDi4R2N7GvucuzQq4Hmzn7aozEeerGBaCxBV3+cx79waVq/v/szaVu6+uOgFEuqy33fo4viNj9/iPiQCmT7dGXMMQml1O+AvwL+fjwGiMhNWAHtT6mTqVZHgWrXbnPtbQaDYRLJhkbQaGIA1y6rJpwjlBWEk+zQE8Wxzj57T3sosYeUdH7/Ddv2sb2+jQ3b9nHH5QtYMreMt1q62X2kk4bjvcP2ByuVtSQvxJyy/JQTyanCeAPXOcAZY32ziKwCvgJ8VCkVdb30c+A6EckTkflALVaDI4PBkIZsi+ONNKDr8z+yozGwHaMJ3O5r7iI+pKgsyR9mx6ZnDtA9YLmb/nr1IlbWVrJmec3IdqiTE4oucOseSBDOEdqjMTZs2zfsGNevqGFpTQVH2vsozQ+f0jGLoIHrP/RuAqqALwC/CniMfwfeD1SKyBEskcB1QB7wtIgAvKSU+pxSaq+IbAH2YbmhvqCUSgQ5j8FwOjMa181YGCnVU5//taOdTkppJu1IVwjnfe36FTVODUI6O9ZfvSipRkL/XLW4iif3NNPVF/M9xlRPXc0UYy2mU0Ar8D/Al5VSzVmwbdSYYjrD6U62lFFHe349wPrZkc7G8Ra1jcWeoMdwvzcTx51KpCumCzRJTBfMJGEwTH3SVRi7XwNGrERONeF4t6c7566GdjZs2wdKsf7qRcDIk5M+XkVhhPZobEpVj4+FcVdci8gNwKN23YJ7ey5wnVLqofGbaTAYTgfSuWn8XkvnznG713R2kjtL6bWjnWy+cXnac2565gC7mzoAuPXBncybUThiHwmvS+pUdjkFdTclgCqlVItn+0wsmY0pUQ1tVhIGw+QxEa4u7zncbh+dlqoD2lpqQz9P5+LasG0fb7V00z2QYEl1OaX
|
|||
|
"text/plain": [
|
|||
|
"<Figure size 432x288 with 1 Axes>"
|
|||
|
]
|
|||
|
},
|
|||
|
"metadata": {
|
|||
|
"needs_background": "light"
|
|||
|
},
|
|||
|
"output_type": "display_data"
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"plt.scatter(y_test, y_pred, s=2)\n",
|
|||
|
"plt.xlabel(\"true critical temperature (K)\", fontsize=14)\n",
|
|||
|
"plt.ylabel(\"predicted critical temperature (K)\", fontsize=14)\n",
|
|||
|
"plt.savefig(\"critical_temperature.pdf\")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 33,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"root mean square error 9.68\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"rms = np.sqrt(mean_squared_error(y_test, y_pred))\n",
|
|||
|
"print(f\"root mean square error {rms:.2f}\")"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 28,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": [
|
|||
|
"# compare with other regressors\n",
|
|||
|
"\n",
|
|||
|
"from sklearn.ensemble import RandomForestRegressor\n",
|
|||
|
"rfr = RandomForestRegressor()\n",
|
|||
|
"\n",
|
|||
|
"from sklearn.ensemble import GradientBoostingRegressor\n",
|
|||
|
"gbr = GradientBoostingRegressor()\n",
|
|||
|
"\n",
|
|||
|
"from sklearn.neural_network import MLPRegressor\n",
|
|||
|
"mlpr = MLPRegressor(hidden_layer_sizes=(50,50), activation='relu', random_state=1, max_iter=5000)"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": 29,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [
|
|||
|
{
|
|||
|
"name": "stdout",
|
|||
|
"output_type": "stream",
|
|||
|
"text": [
|
|||
|
"83.4200599193573\n",
|
|||
|
"root mean square error 9.54\n",
|
|||
|
"\n",
|
|||
|
"28.77874779701233\n",
|
|||
|
"root mean square error 12.56\n",
|
|||
|
"\n",
|
|||
|
"9.2709481716156\n",
|
|||
|
"root mean square error 18.58\n",
|
|||
|
"\n"
|
|||
|
]
|
|||
|
}
|
|||
|
],
|
|||
|
"source": [
|
|||
|
"regressors = [rfr, gbr, mlpr]\n",
|
|||
|
"\n",
|
|||
|
"for reg in regressors:\n",
|
|||
|
" \n",
|
|||
|
" start_time = time.time()\n",
|
|||
|
" reg.fit(X_train, y_train)\n",
|
|||
|
" run_time = time.time() - start_time\n",
|
|||
|
" \n",
|
|||
|
" y_pred = reg.predict(X_test)\n",
|
|||
|
" rms = np.sqrt(mean_squared_error(y_test, y_pred))\n",
|
|||
|
" \n",
|
|||
|
" print(run_time)\n",
|
|||
|
" print(f\"root mean square error {rms:.2f}\\n\")\n"
|
|||
|
]
|
|||
|
},
|
|||
|
{
|
|||
|
"cell_type": "code",
|
|||
|
"execution_count": null,
|
|||
|
"metadata": {},
|
|||
|
"outputs": [],
|
|||
|
"source": []
|
|||
|
}
|
|||
|
],
|
|||
|
"metadata": {
|
|||
|
"kernelspec": {
|
|||
|
"display_name": "Python 3",
|
|||
|
"language": "python",
|
|||
|
"name": "python3"
|
|||
|
},
|
|||
|
"language_info": {
|
|||
|
"codemirror_mode": {
|
|||
|
"name": "ipython",
|
|||
|
"version": 3
|
|||
|
},
|
|||
|
"file_extension": ".py",
|
|||
|
"mimetype": "text/x-python",
|
|||
|
"name": "python",
|
|||
|
"nbconvert_exporter": "python",
|
|||
|
"pygments_lexer": "ipython3",
|
|||
|
"version": "3.8.5"
|
|||
|
}
|
|||
|
},
|
|||
|
"nbformat": 4,
|
|||
|
"nbformat_minor": 4
|
|||
|
}
|