{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Example: Regression with XGBoost" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Superconductivty Data Set: Predict the critical temperature based on 81 material features." ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.metrics import mean_squared_error" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "filename = \"https://www.physi.uni-heidelberg.de/~reygers/lectures/2021/ml/data/train_critical_temp.csv\"\n", "df = pd.read_csv(filename, engine='python')" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | number_of_elements | \n", "mean_atomic_mass | \n", "wtd_mean_atomic_mass | \n", "gmean_atomic_mass | \n", "wtd_gmean_atomic_mass | \n", "entropy_atomic_mass | \n", "wtd_entropy_atomic_mass | \n", "range_atomic_mass | \n", "wtd_range_atomic_mass | \n", "std_atomic_mass | \n", "... | \n", "wtd_mean_Valence | \n", "gmean_Valence | \n", "wtd_gmean_Valence | \n", "entropy_Valence | \n", "wtd_entropy_Valence | \n", "range_Valence | \n", "wtd_range_Valence | \n", "std_Valence | \n", "wtd_std_Valence | \n", "critical_temp | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "4 | \n", "88.944468 | \n", "57.862692 | \n", "66.361592 | \n", "36.116612 | \n", "1.181795 | \n", "1.062396 | \n", "122.90607 | \n", "31.794921 | \n", "51.968828 | \n", "... | \n", "2.257143 | \n", "2.213364 | \n", "2.219783 | \n", "1.368922 | \n", "1.066221 | \n", "1 | \n", "1.085714 | \n", "0.433013 | \n", "0.437059 | \n", "29.0 | \n", "
1 | \n", "5 | \n", "92.729214 | \n", "58.518416 | \n", "73.132787 | \n", "36.396602 | \n", "1.449309 | \n", "1.057755 | \n", "122.90607 | \n", "36.161939 | \n", "47.094633 | \n", "... | \n", "2.257143 | \n", "1.888175 | \n", "2.210679 | \n", "1.557113 | \n", "1.047221 | \n", "2 | \n", "1.128571 | \n", "0.632456 | \n", "0.468606 | \n", "26.0 | \n", "
2 | \n", "4 | \n", "88.944468 | \n", "57.885242 | \n", "66.361592 | \n", "36.122509 | \n", "1.181795 | \n", "0.975980 | \n", "122.90607 | \n", "35.741099 | \n", "51.968828 | \n", "... | \n", "2.271429 | \n", "2.213364 | \n", "2.232679 | \n", "1.368922 | \n", "1.029175 | \n", "1 | \n", "1.114286 | \n", "0.433013 | \n", "0.444697 | \n", "19.0 | \n", "
3 | \n", "4 | \n", "88.944468 | \n", "57.873967 | \n", "66.361592 | \n", "36.119560 | \n", "1.181795 | \n", "1.022291 | \n", "122.90607 | \n", "33.768010 | \n", "51.968828 | \n", "... | \n", "2.264286 | \n", "2.213364 | \n", "2.226222 | \n", "1.368922 | \n", "1.048834 | \n", "1 | \n", "1.100000 | \n", "0.433013 | \n", "0.440952 | \n", "22.0 | \n", "
4 | \n", "4 | \n", "88.944468 | \n", "57.840143 | \n", "66.361592 | \n", "36.110716 | \n", "1.181795 | \n", "1.129224 | \n", "122.90607 | \n", "27.848743 | \n", "51.968828 | \n", "... | \n", "2.242857 | \n", "2.213364 | \n", "2.206963 | \n", "1.368922 | \n", "1.096052 | \n", "1 | \n", "1.057143 | \n", "0.433013 | \n", "0.428809 | \n", "23.0 | \n", "
5 rows × 82 columns
\n", "