Machine Learning Kurs im Rahmen der Studierendentage im SS 2023
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

227 lines
5.9 KiB

2 years ago
  1. {
  2. "cells": [
  3. {
  4. "cell_type": "markdown",
  5. "id": "abe24003",
  6. "metadata": {},
  7. "source": [
  8. "Use a Convolutional Neural Net to classify the MNIST data of digits"
  9. ]
  10. },
  11. {
  12. "cell_type": "code",
  13. "execution_count": null,
  14. "id": "6c95fefb",
  15. "metadata": {},
  16. "outputs": [],
  17. "source": [
  18. "import tensorflow as tf\n",
  19. "import numpy as np\n",
  20. "import matplotlib.pyplot as plt"
  21. ]
  22. },
  23. {
  24. "cell_type": "code",
  25. "execution_count": null,
  26. "id": "280c5099",
  27. "metadata": {},
  28. "outputs": [],
  29. "source": [
  30. "# load the data\n",
  31. "(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()"
  32. ]
  33. },
  34. {
  35. "cell_type": "code",
  36. "execution_count": null,
  37. "id": "8c3fc1b2",
  38. "metadata": {},
  39. "outputs": [],
  40. "source": [
  41. "# Normalize the pixel values to be between 0 and 1\n",
  42. "x_train = x_train / 255\n",
  43. "x_test = x_test / 255"
  44. ]
  45. },
  46. {
  47. "cell_type": "code",
  48. "execution_count": null,
  49. "id": "3a9686ff",
  50. "metadata": {},
  51. "outputs": [],
  52. "source": [
  53. "# Convert the labels into one-hot encoded arrays\n",
  54. "y_train = tf.keras.utils.to_categorical(y_train, num_classes=10)\n",
  55. "y_test = tf.keras.utils.to_categorical(y_test, num_classes=10)"
  56. ]
  57. },
  58. {
  59. "cell_type": "code",
  60. "execution_count": null,
  61. "id": "e80e582c",
  62. "metadata": {},
  63. "outputs": [],
  64. "source": [
  65. "# Define the model\n",
  66. "model = tf.keras.models.Sequential()\n",
  67. "model.add(tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)))\n",
  68. "model.add(tf.keras.layers.MaxPooling2D((2, 2)))\n",
  69. "model.add(tf.keras.layers.Conv2D(64, (3, 3), activation='relu'))\n",
  70. "model.add(tf.keras.layers.MaxPooling2D((2, 2)))\n",
  71. "model.add(tf.keras.layers.Flatten())\n",
  72. "model.add(tf.keras.layers.Dense(64, activation='relu'))\n",
  73. "model.add(tf.keras.layers.Dense(10, activation='softmax'))"
  74. ]
  75. },
  76. {
  77. "cell_type": "code",
  78. "execution_count": null,
  79. "id": "aff6c38a",
  80. "metadata": {},
  81. "outputs": [],
  82. "source": [
  83. "# Compile the model\n",
  84. "model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])"
  85. ]
  86. },
  87. {
  88. "cell_type": "code",
  89. "execution_count": null,
  90. "id": "049f9d49",
  91. "metadata": {},
  92. "outputs": [],
  93. "source": [
  94. "# Train the model and record the history, the data is split in batches\n",
  95. "history = model.fit(x_train, y_train, epochs=10, batch_size=64, validation_data=(x_test, y_test))"
  96. ]
  97. },
  98. {
  99. "cell_type": "code",
  100. "execution_count": null,
  101. "id": "f7a8baea",
  102. "metadata": {},
  103. "outputs": [],
  104. "source": [
  105. "# Get the weights of the Dense layer\n",
  106. "# plot the weights as a heatmap or image, where the weights are represented\n",
  107. "# as pixel values.\n",
  108. "last_layer_weights = model.layers[-1].get_weights()[0]\n",
  109. "# Plot the weights as a heatmap\n",
  110. "plt.imshow(last_layer_weights, cmap='coolwarm')\n",
  111. "plt.colorbar()\n",
  112. "plt.title('weights in the output layer')\n",
  113. "plt.show()"
  114. ]
  115. },
  116. {
  117. "cell_type": "code",
  118. "execution_count": null,
  119. "id": "6cc4d04b",
  120. "metadata": {},
  121. "outputs": [],
  122. "source": [
  123. "# Plot loss and accuracy\n",
  124. "plt.figure(figsize=(12, 4))\n",
  125. "\n",
  126. "# Plot the loss during training\n",
  127. "plt.subplot(1, 2, 1)\n",
  128. "plt.plot(history.history['loss'], label='training loss')\n",
  129. "plt.plot(history.history['val_loss'], label='validation loss')\n",
  130. "plt.xlabel('Epoch')\n",
  131. "plt.ylabel('Loss')\n",
  132. "plt.legend()\n",
  133. "\n",
  134. "plt.subplot(1, 2, 2)\n",
  135. "plt.plot(history.history['accuracy'])\n",
  136. "plt.plot(history.history['val_accuracy'])\n",
  137. "plt.xlabel('Epoch')\n",
  138. "plt.ylabel('Accuracy')\n",
  139. "plt.legend()\n",
  140. "\n",
  141. "plt.show()"
  142. ]
  143. },
  144. {
  145. "cell_type": "code",
  146. "execution_count": null,
  147. "id": "dcdef199",
  148. "metadata": {},
  149. "outputs": [],
  150. "source": [
  151. "# Plot a confusion matrix of the test set predictions\n",
  152. "test_preds = np.argmax(model.predict(x_test), axis=1)\n",
  153. "conf_mat = tf.math.confusion_matrix(y_test.argmax(axis=1), test_preds)\n",
  154. "\n",
  155. "plt.imshow(conf_mat, cmap=\"Blues\")\n",
  156. "plt.xlabel(\"Predicted labels\")\n",
  157. "plt.ylabel(\"True labels\")\n",
  158. "plt.xticks(np.arange(10))\n",
  159. "plt.yticks(np.arange(10))\n",
  160. "plt.colorbar()\n",
  161. "plt.show()"
  162. ]
  163. },
  164. {
  165. "cell_type": "code",
  166. "execution_count": null,
  167. "id": "8f28f1ea",
  168. "metadata": {},
  169. "outputs": [],
  170. "source": [
  171. "# Evaluate the model on the test set\n",
  172. "test_loss, test_acc = model.evaluate(x_test, y_test)\n",
  173. "print('Test accuracy:', test_acc)"
  174. ]
  175. },
  176. {
  177. "cell_type": "code",
  178. "execution_count": null,
  179. "id": "50c0f27a",
  180. "metadata": {},
  181. "outputs": [],
  182. "source": [
  183. "# Make predictions on the test set\n",
  184. "y_pred = model.predict(x_test)\n",
  185. "y_pred = np.argmax(y_pred, axis=1)"
  186. ]
  187. },
  188. {
  189. "cell_type": "code",
  190. "execution_count": null,
  191. "id": "f781bef8",
  192. "metadata": {},
  193. "outputs": [],
  194. "source": [
  195. "# Plot some examples from the test set and their predictions\n",
  196. "fig, axes = plt.subplots(4, 4, figsize=(10, 10))\n",
  197. "for i, ax in enumerate(axes.ravel()):\n",
  198. " ax.imshow(x_test[i].reshape(28, 28), cmap='gray')\n",
  199. " ax.set_title(\"True: %d\\nPred: %d\" % (np.argmax(y_test[i]), y_pred[i]))\n",
  200. " ax.axis('off')\n",
  201. "plt.suptitle(\"Examples of test set images and their predictions\")\n",
  202. "plt.show()\n"
  203. ]
  204. }
  205. ],
  206. "metadata": {
  207. "kernelspec": {
  208. "display_name": "Python 3 (ipykernel)",
  209. "language": "python",
  210. "name": "python3"
  211. },
  212. "language_info": {
  213. "codemirror_mode": {
  214. "name": "ipython",
  215. "version": 3
  216. },
  217. "file_extension": ".py",
  218. "mimetype": "text/x-python",
  219. "name": "python",
  220. "nbconvert_exporter": "python",
  221. "pygments_lexer": "ipython3",
  222. "version": "3.8.16"
  223. }
  224. },
  225. "nbformat": 4,
  226. "nbformat_minor": 5
  227. }