ML-Kurs-SS2023/notebooks/03_ml_basics_display_HorseOrHuman.ipynb

194 lines
4.9 KiB
Plaintext
Raw Normal View History

2023-04-03 13:08:49 +02:00
{
"cells": [
{
"cell_type": "markdown",
"id": "2eaba66b",
"metadata": {},
"source": [
"Read and Display Horse or Human machine learning dataset"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f1e48ac0",
"metadata": {},
"outputs": [],
"source": [
"import tensorflow as tf\n",
"import numpy as np\n",
"import tensorflow_datasets as tfds\n",
"from tensorflow.keras import regularizers\n",
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "feda024e",
"metadata": {},
"outputs": [],
"source": [
"# Load the horse or human dataset\n",
"#(300, 300, 3) unint8\n",
"dataset, label = tfds.load('horses_or_humans', with_info=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "35991dec",
"metadata": {},
"outputs": [],
"source": [
"# Extract the horse/human class\n",
"horse_ds = dataset['train'].filter(lambda x: x['label'] == 0)\n",
"human_ds = dataset['train'].filter(lambda x: x['label'] == 1)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fab03aa8",
"metadata": {},
"outputs": [],
"source": [
"# Take a few examples < 16\n",
"n_examples = 5\n",
"horse_examples = horse_ds.take(n_examples)\n",
"human_examples = human_ds.take(n_examples)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c33f1acd",
"metadata": {},
"outputs": [],
"source": [
"# Display the examples\n",
"fig, axes = plt.subplots(1, n_examples, figsize=(15, 15))\n",
"for i, example in enumerate(human_examples):\n",
" image = example['image']\n",
" axes[i].imshow(image)\n",
" axes[i].set_title(f\"humans {i+1}\")\n",
"plt.show()\n",
"\n",
"fig, axes = plt.subplots(1, n_examples, figsize=(15, 15))\n",
"for i, example in enumerate(horse_examples):\n",
" image = example['image']\n",
" axes[i].imshow(image)\n",
" axes[i].set_title(f\"horses {i+1}\")\n",
"plt.show()\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "25f3eeb3",
"metadata": {},
"outputs": [],
"source": [
"# Split the dataset into training and validation sets\n",
"# as_supervised: Specifies whether to return the dataset as a tuple\n",
"# of (input, label) pairs.\n",
"train_dataset, valid_dataset = tfds.load('horses_or_humans', split=['train','test'], as_supervised=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "29dc0e62",
"metadata": {},
"outputs": [],
"source": [
"# Get the number of elements in the training and validation dataset\n",
"train_size = tf.data.experimental.cardinality(train_dataset).numpy()\n",
2023-04-05 18:25:09 +02:00
"valid_size = tf.data.experimental.cardinality(valid_dataset).numpy()\n",
"print(\"Training dataset size:\", train_size)\n",
"print(\"Validation dataset size:\", valid_size)"
2023-04-03 13:08:49 +02:00
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "db8aaf91",
"metadata": {},
"outputs": [],
"source": [
"IMG_SIZE = 300\n",
"NUM_CLASSES = 2\n",
"\n",
"def preprocess(image, label):\n",
" image = tf.cast(image, tf.float32)\n",
"# # Resize the images to a fixed size\n",
" image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))\n",
"# # Rescale the pixel values to be between 0 and 1\n",
" image = image / 255.0\n",
" label = tf.one_hot(label, NUM_CLASSES)\n",
" return image, label"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d59661c3",
"metadata": {},
"outputs": [],
"source": [
"# Apply the preprocessing function to the datasets\n",
"train_dataset = train_dataset.map(preprocess)\n",
"valid_dataset = valid_dataset.map(preprocess)\n",
"\n",
"# Batch and shuffle the datasets\n",
2023-04-05 18:25:09 +02:00
"train_dataset = train_dataset.shuffle(1000).batch(80)\n",
2023-04-03 13:08:49 +02:00
"valid_dataset = valid_dataset.batch(20)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "13af7d53",
"metadata": {},
"outputs": [],
"source": [
"# Store images and labels of the validation data for predictions\n",
"for images, labels in valid_dataset:\n",
" x_val = images\n",
" y_val = labels\n",
" \n",
"print(x_val.shape, y_val.shape)"
]
2023-04-05 18:25:09 +02:00
},
{
"cell_type": "code",
"execution_count": null,
"id": "67e152ff-0713-4629-8471-1afbb1bf22a6",
"metadata": {},
"outputs": [],
"source": []
2023-04-03 13:08:49 +02:00
}
],
"metadata": {
"kernelspec": {
2023-04-05 18:25:09 +02:00
"display_name": "Python [conda env:ML]",
2023-04-03 13:08:49 +02:00
"language": "python",
2023-04-05 18:25:09 +02:00
"name": "conda-env-ML-py"
2023-04-03 13:08:49 +02:00
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
2023-04-05 18:25:09 +02:00
"version": "3.10.9"
2023-04-03 13:08:49 +02:00
}
},
"nbformat": 4,
"nbformat_minor": 5
}