Machine Learning Kurs im Rahmen der Studierendentage im SS 2023
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

159 lines
3.4 KiB

  1. {
  2. "cells": [
  3. {
  4. "attachments": {},
  5. "cell_type": "markdown",
  6. "metadata": {},
  7. "source": [
  8. "# Hyperparameter optimization"
  9. ]
  10. },
  11. {
  12. "cell_type": "markdown",
  13. "metadata": {},
  14. "source": [
  15. "Superconductivty Data Set: Predict the critical temperature based on 81 material features."
  16. ]
  17. },
  18. {
  19. "cell_type": "code",
  20. "execution_count": null,
  21. "metadata": {},
  22. "outputs": [],
  23. "source": [
  24. "import pandas as pd\n",
  25. "import numpy as np\n",
  26. "import matplotlib.pyplot as plt\n",
  27. "from sklearn.model_selection import train_test_split\n",
  28. "from sklearn.metrics import mean_squared_error"
  29. ]
  30. },
  31. {
  32. "cell_type": "code",
  33. "execution_count": null,
  34. "metadata": {},
  35. "outputs": [],
  36. "source": [
  37. "filename = \"https://www.physi.uni-heidelberg.de/~reygers/lectures/2021/ml/data/train_critical_temp.csv\"\n",
  38. "df = pd.read_csv(filename, engine='python')"
  39. ]
  40. },
  41. {
  42. "cell_type": "code",
  43. "execution_count": null,
  44. "metadata": {},
  45. "outputs": [],
  46. "source": [
  47. "df.head()"
  48. ]
  49. },
  50. {
  51. "cell_type": "code",
  52. "execution_count": null,
  53. "metadata": {},
  54. "outputs": [],
  55. "source": [
  56. "y = df['critical_temp'].values\n",
  57. "X = df[[col for col in df.columns if col!=\"critical_temp\"]]"
  58. ]
  59. },
  60. {
  61. "cell_type": "code",
  62. "execution_count": null,
  63. "metadata": {},
  64. "outputs": [],
  65. "source": [
  66. "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, shuffle=True)"
  67. ]
  68. },
  69. {
  70. "cell_type": "code",
  71. "execution_count": null,
  72. "metadata": {},
  73. "outputs": [],
  74. "source": [
  75. "from sklearn.neural_network import MLPRegressor\n",
  76. "import time\n",
  77. "\n",
  78. "mlpr = MLPRegressor(hidden_layer_sizes=(50,50), activation='relu', random_state=1, max_iter=5000)\n",
  79. "\n",
  80. "start_time = time.time()\n",
  81. "mlpr.fit(X_train, y_train)\n",
  82. "run_time = time.time() - start_time\n",
  83. "\n",
  84. "print(run_time)"
  85. ]
  86. },
  87. {
  88. "cell_type": "code",
  89. "execution_count": null,
  90. "metadata": {},
  91. "outputs": [],
  92. "source": [
  93. "y_pred = mlpr.predict(X_test)"
  94. ]
  95. },
  96. {
  97. "cell_type": "code",
  98. "execution_count": null,
  99. "metadata": {},
  100. "outputs": [],
  101. "source": [
  102. "plt.scatter(y_test, y_pred, s=2)\n",
  103. "plt.xlabel(\"true critical temperature (K)\", fontsize=14)\n",
  104. "plt.ylabel(\"predicted critical temperature (K)\", fontsize=14)\n",
  105. "plt.savefig(\"critical_temperature.pdf\")"
  106. ]
  107. },
  108. {
  109. "cell_type": "code",
  110. "execution_count": null,
  111. "metadata": {},
  112. "outputs": [],
  113. "source": [
  114. "rms = np.sqrt(mean_squared_error(y_test, y_pred))\n",
  115. "print(f\"root mean square error {rms:.2f}\")"
  116. ]
  117. },
  118. {
  119. "attachments": {},
  120. "cell_type": "markdown",
  121. "metadata": {},
  122. "source": [
  123. "## Now try to optimize the hyperparameters"
  124. ]
  125. },
  126. {
  127. "cell_type": "code",
  128. "execution_count": null,
  129. "metadata": {},
  130. "outputs": [],
  131. "source": [
  132. "from sklearn.model_selection import GridSearchCV\n",
  133. "\n",
  134. "### Your code here\n"
  135. ]
  136. }
  137. ],
  138. "metadata": {
  139. "kernelspec": {
  140. "display_name": "Python 3",
  141. "language": "python",
  142. "name": "python3"
  143. },
  144. "language_info": {
  145. "codemirror_mode": {
  146. "name": "ipython",
  147. "version": 3
  148. },
  149. "file_extension": ".py",
  150. "mimetype": "text/x-python",
  151. "name": "python",
  152. "nbconvert_exporter": "python",
  153. "pygments_lexer": "ipython3",
  154. "version": "3.10.11"
  155. }
  156. },
  157. "nbformat": 4,
  158. "nbformat_minor": 4
  159. }