{ "cells": [ { "cell_type": "markdown", "id": "abe24003", "metadata": {}, "source": [ "Use a Convolutional Neural Net to classify the MNIST data of digits" ] }, { "cell_type": "code", "execution_count": null, "id": "6c95fefb", "metadata": {}, "outputs": [], "source": [ "import tensorflow as tf\n", "import numpy as np\n", "import matplotlib.pyplot as plt" ] }, { "cell_type": "code", "execution_count": null, "id": "280c5099", "metadata": {}, "outputs": [], "source": [ "# load the data\n", "(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()" ] }, { "cell_type": "code", "execution_count": null, "id": "8c3fc1b2", "metadata": {}, "outputs": [], "source": [ "# Normalize the pixel values to be between 0 and 1\n", "x_train = x_train / 255\n", "x_test = x_test / 255" ] }, { "cell_type": "code", "execution_count": null, "id": "3a9686ff", "metadata": {}, "outputs": [], "source": [ "# Convert the labels into one-hot encoded arrays\n", "y_train = tf.keras.utils.to_categorical(y_train, num_classes=10)\n", "y_test = tf.keras.utils.to_categorical(y_test, num_classes=10)" ] }, { "cell_type": "code", "execution_count": null, "id": "e80e582c", "metadata": {}, "outputs": [], "source": [ "# Define the model\n", "model = tf.keras.models.Sequential()\n", "model.add(tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)))\n", "model.add(tf.keras.layers.MaxPooling2D((2, 2)))\n", "model.add(tf.keras.layers.Conv2D(64, (3, 3), activation='relu'))\n", "model.add(tf.keras.layers.MaxPooling2D((2, 2)))\n", "model.add(tf.keras.layers.Flatten())\n", "model.add(tf.keras.layers.Dense(64, activation='relu'))\n", "model.add(tf.keras.layers.Dense(10, activation='softmax'))" ] }, { "cell_type": "code", "execution_count": null, "id": "aff6c38a", "metadata": {}, "outputs": [], "source": [ "# Compile the model\n", "model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])" ] }, { "cell_type": "code", "execution_count": null, "id": "049f9d49", "metadata": {}, "outputs": [], "source": [ "# Train the model and record the history, the data is split in batches\n", "history = model.fit(x_train, y_train, epochs=10, batch_size=64, validation_data=(x_test, y_test))" ] }, { "cell_type": "code", "execution_count": null, "id": "f7a8baea", "metadata": {}, "outputs": [], "source": [ "# Get the weights of the Dense layer\n", "# plot the weights as a heatmap or image, where the weights are represented\n", "# as pixel values.\n", "last_layer_weights = model.layers[-1].get_weights()[0]\n", "# Plot the weights as a heatmap\n", "plt.imshow(last_layer_weights, cmap='coolwarm')\n", "plt.colorbar()\n", "plt.title('weights in the output layer')\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "6cc4d04b", "metadata": {}, "outputs": [], "source": [ "# Plot loss and accuracy\n", "plt.figure(figsize=(12, 4))\n", "\n", "# Plot the loss during training\n", "plt.subplot(1, 2, 1)\n", "plt.plot(history.history['loss'], label='training loss')\n", "plt.plot(history.history['val_loss'], label='validation loss')\n", "plt.xlabel('Epoch')\n", "plt.ylabel('Loss')\n", "plt.legend()\n", "\n", "plt.subplot(1, 2, 2)\n", "plt.plot(history.history['accuracy'])\n", "plt.plot(history.history['val_accuracy'])\n", "plt.xlabel('Epoch')\n", "plt.ylabel('Accuracy')\n", "plt.legend()\n", "\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "dcdef199", "metadata": {}, "outputs": [], "source": [ "# Plot a confusion matrix of the test set predictions\n", "test_preds = np.argmax(model.predict(x_test), axis=1)\n", "conf_mat = tf.math.confusion_matrix(y_test.argmax(axis=1), test_preds)\n", "\n", "plt.imshow(conf_mat, cmap=\"Blues\")\n", "plt.xlabel(\"Predicted labels\")\n", "plt.ylabel(\"True labels\")\n", "plt.xticks(np.arange(10))\n", "plt.yticks(np.arange(10))\n", "plt.colorbar()\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "8f28f1ea", "metadata": {}, "outputs": [], "source": [ "# Evaluate the model on the test set\n", "test_loss, test_acc = model.evaluate(x_test, y_test)\n", "print('Test accuracy:', test_acc)" ] }, { "cell_type": "code", "execution_count": null, "id": "50c0f27a", "metadata": {}, "outputs": [], "source": [ "# Make predictions on the test set\n", "y_pred = model.predict(x_test)\n", "y_pred = np.argmax(y_pred, axis=1)" ] }, { "cell_type": "code", "execution_count": null, "id": "f781bef8", "metadata": {}, "outputs": [], "source": [ "# Plot some examples from the test set and their predictions\n", "fig, axes = plt.subplots(4, 4, figsize=(10, 10))\n", "for i, ax in enumerate(axes.ravel()):\n", " ax.imshow(x_test[i].reshape(28, 28), cmap='gray')\n", " ax.set_title(\"True: %d\\nPred: %d\" % (np.argmax(y_test[i]), y_pred[i]))\n", " ax.axis('off')\n", "plt.suptitle(\"Examples of test set images and their predictions\")\n", "plt.show()\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.16" } }, "nbformat": 4, "nbformat_minor": 5 }