Training a neural network on MNIST with Keras
Stay organized with collections
Save and categorize content based on your preferences.
This simple example demonstrates how to plug TensorFlow Datasets (TFDS) into a Keras model.
import tensorflow as tf
import tensorflow_datasets as tfds
2025-08-06 11:39:27.068270: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
E0000 00:00:1754480367.093425 22003 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1754480367.101768 22003 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1754480367.121188 22003 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1754480367.121213 22003 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1754480367.121216 22003 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1754480367.121218 22003 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
Start by building an efficient input pipeline using advices from:
Load a dataset
Load the MNIST dataset with the following arguments:
shuffle_files=True
: The MNIST data is only stored in a single file, but for larger datasets with multiple files on disk, it's good practice to shuffle them when training.
as_supervised=True
: Returns a tuple (img, label)
instead of a dictionary {'image': img, 'label': label}
.
(ds_train, ds_test), ds_info = tfds.load(
'mnist',
split=['train', 'test'],
shuffle_files=True,
as_supervised=True,
with_info=True,
)
2025-08-06 11:39:31.458052: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
Build a training pipeline
Apply the following transformations:
def normalize_img(image, label):
"""Normalizes images: `uint8` -> `float32`."""
return tf.cast(image, tf.float32) / 255., label
ds_train = ds_train.map(
normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_train = ds_train.cache()
ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)
ds_train = ds_train.batch(128)
ds_train = ds_train.prefetch(tf.data.AUTOTUNE)
Build an evaluation pipeline
Your testing pipeline is similar to the training pipeline with small differences:
- You don't need to call
tf.data.Dataset.shuffle
.
- Caching is done after batching because batches can be the same between epochs.
ds_test = ds_test.map(
normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_test = ds_test.batch(128)
ds_test = ds_test.cache()
ds_test = ds_test.prefetch(tf.data.AUTOTUNE)
Step 2: Create and train the model
Plug the TFDS input pipeline into a simple Keras model, compile the model, and train it.
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(10)
])
model.compile(
optimizer=tf.keras.optimizers.Adam(0.001),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)
model.fit(
ds_train,
epochs=6,
validation_data=ds_test,
)
/tmpfs/src/tf_docs_env/lib/python3.10/site-packages/keras/src/layers/reshaping/flatten.py:37: UserWarning: Do not pass an `input_shape`/`input_dim` argument to a layer. When using Sequential models, prefer using an `Input(shape)` object as the first layer in the model instead.
super().__init__(**kwargs)
Epoch 1/6
469/469 ━━━━━━━━━━━━━━━━━━━━ 4s 3ms/step - loss: 0.3653 - sparse_categorical_accuracy: 0.8988 - val_loss: 0.1945 - val_sparse_categorical_accuracy: 0.9443
Epoch 2/6
469/469 ━━━━━━━━━━━━━━━━━━━━ 1s 3ms/step - loss: 0.1658 - sparse_categorical_accuracy: 0.9530 - val_loss: 0.1354 - val_sparse_categorical_accuracy: 0.9597
Epoch 3/6
469/469 ━━━━━━━━━━━━━━━━━━━━ 1s 2ms/step - loss: 0.1201 - sparse_categorical_accuracy: 0.9658 - val_loss: 0.1136 - val_sparse_categorical_accuracy: 0.9666
Epoch 4/6
469/469 ━━━━━━━━━━━━━━━━━━━━ 1s 3ms/step - loss: 0.0929 - sparse_categorical_accuracy: 0.9738 - val_loss: 0.0996 - val_sparse_categorical_accuracy: 0.9699
Epoch 5/6
469/469 ━━━━━━━━━━━━━━━━━━━━ 1s 2ms/step - loss: 0.0756 - sparse_categorical_accuracy: 0.9785 - val_loss: 0.0899 - val_sparse_categorical_accuracy: 0.9727
Epoch 6/6
469/469 ━━━━━━━━━━━━━━━━━━━━ 1s 3ms/step - loss: 0.0622 - sparse_categorical_accuracy: 0.9818 - val_loss: 0.0833 - val_sparse_categorical_accuracy: 0.9753
<keras.src.callbacks.history.History at 0x7fbcbcf1ddb0>
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.
Last updated 2025-08-07 UTC.
[[["Easy to understand","easyToUnderstand","thumb-up"],["Solved my problem","solvedMyProblem","thumb-up"],["Other","otherUp","thumb-up"]],[["Missing the information I need","missingTheInformationINeed","thumb-down"],["Too complicated / too many steps","tooComplicatedTooManySteps","thumb-down"],["Out of date","outOfDate","thumb-down"],["Samples / code issue","samplesCodeIssue","thumb-down"],["Other","otherDown","thumb-down"]],["Last updated 2025-08-07 UTC."],[],[],null,["# Training a neural network on MNIST with Keras\n\n\u003cbr /\u003e\n\nThis simple example demonstrates how to plug TensorFlow Datasets (TFDS) into a Keras model.\n\n|-----------------------------------------------------------------------------|--------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------|-------------------------------------------------------------------------------------------------------|\n| [View on TensorFlow.org](https://www.tensorflow.org/datasets/keras_example) | [Run in Google Colab](https://colab.research.google.com/github/tensorflow/datasets/blob/master/docs/keras_example.ipynb) | [View source on GitHub](https://github.com/tensorflow/datasets/blob/master/docs/keras_example.ipynb) | [Download notebook](https://storage.googleapis.com/tensorflow_docs/datasets/docs/keras_example.ipynb) |\n\n import tensorflow as tf\n import tensorflow_datasets as tfds\n\n```\n2025-08-06 11:39:27.068270: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\nWARNING: All log messages before absl::InitializeLog() is called are written to STDERR\nE0000 00:00:1754480367.093425 22003 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\nE0000 00:00:1754480367.101768 22003 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\nW0000 00:00:1754480367.121188 22003 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\nW0000 00:00:1754480367.121213 22003 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\nW0000 00:00:1754480367.121216 22003 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\nW0000 00:00:1754480367.121218 22003 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.\n```\n\nStep 1: Create your input pipeline\n----------------------------------\n\nStart by building an efficient input pipeline using advices from:\n\n- The [Performance tips](https://www.tensorflow.org/datasets/performances) guide\n- The [Better performance with the `tf.data` API](https://www.tensorflow.org/guide/data_performance#optimize_performance) guide\n\n### Load a dataset\n\nLoad the MNIST dataset with the following arguments:\n\n- `shuffle_files=True`: The MNIST data is only stored in a single file, but for larger datasets with multiple files on disk, it's good practice to shuffle them when training.\n- `as_supervised=True`: Returns a tuple `(img, label)` instead of a dictionary `{'image': img, 'label': label}`.\n\n (ds_train, ds_test), ds_info = tfds.load(\n 'mnist',\n split=['train', 'test'],\n shuffle_files=True,\n as_supervised=True,\n with_info=True,\n )\n\n```\n2025-08-06 11:39:31.458052: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected\n```\n\n### Build a training pipeline\n\nApply the following transformations:\n\n- [`tf.data.Dataset.map`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#map): TFDS provide images of type [`tf.uint8`](https://www.tensorflow.org/api_docs/python/tf#uint8), while the model expects [`tf.float32`](https://www.tensorflow.org/api_docs/python/tf#float32). Therefore, you need to normalize images.\n- [`tf.data.Dataset.cache`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#cache) As you fit the dataset in memory, cache it before shuffling for a better performance. \n **Note:** Random transformations should be applied after caching.\n- [`tf.data.Dataset.shuffle`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#shuffle): For true randomness, set the shuffle buffer to the full dataset size. \n **Note:** For large datasets that can't fit in memory, use `buffer_size=1000` if your system allows it.\n- [`tf.data.Dataset.batch`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#batch): Batch elements of the dataset after shuffling to get unique batches at each epoch.\n- [`tf.data.Dataset.prefetch`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#prefetch): It is good practice to end the pipeline by prefetching [for performance](https://www.tensorflow.org/guide/data_performance#prefetching).\n\n def normalize_img(image, label):\n \"\"\"Normalizes images: `uint8` -\u003e `float32`.\"\"\"\n return tf.cast(image, tf.float32) / 255., label\n\n ds_train = ds_train.map(\n normalize_img, num_parallel_calls=tf.data.AUTOTUNE)\n ds_train = ds_train.cache()\n ds_train = ds_train.shuffle(ds_info.splits['train'].num_examples)\n ds_train = ds_train.batch(128)\n ds_train = ds_train.prefetch(tf.data.AUTOTUNE)\n\n### Build an evaluation pipeline\n\nYour testing pipeline is similar to the training pipeline with small differences:\n\n- You don't need to call [`tf.data.Dataset.shuffle`](https://www.tensorflow.org/api_docs/python/tf/data/Dataset#shuffle).\n- Caching is done after batching because batches can be the same between epochs.\n\n ds_test = ds_test.map(\n normalize_img, num_parallel_calls=tf.data.AUTOTUNE)\n ds_test = ds_test.batch(128)\n ds_test = ds_test.cache()\n ds_test = ds_test.prefetch(tf.data.AUTOTUNE)\n\nStep 2: Create and train the model\n----------------------------------\n\nPlug the TFDS input pipeline into a simple Keras model, compile the model, and train it. \n\n model = tf.keras.models.Sequential([\n tf.keras.layers.Flatten(input_shape=(28, 28)),\n tf.keras.layers.Dense(128, activation='relu'),\n tf.keras.layers.Dense(10)\n ])\n model.compile(\n optimizer=tf.keras.optimizers.Adam(0.001),\n loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],\n )\n\n model.fit(\n ds_train,\n epochs=6,\n validation_data=ds_test,\n )\n\n```\n/tmpfs/src/tf_docs_env/lib/python3.10/site-packages/keras/src/layers/reshaping/flatten.py:37: UserWarning: Do not pass an `input_shape`/`input_dim` argument to a layer. When using Sequential models, prefer using an `Input(shape)` object as the first layer in the model instead.\n super().__init__(**kwargs)\nEpoch 1/6\n469/469 ━━━━━━━━━━━━━━━━━━━━ 4s 3ms/step - loss: 0.3653 - sparse_categorical_accuracy: 0.8988 - val_loss: 0.1945 - val_sparse_categorical_accuracy: 0.9443\nEpoch 2/6\n469/469 ━━━━━━━━━━━━━━━━━━━━ 1s 3ms/step - loss: 0.1658 - sparse_categorical_accuracy: 0.9530 - val_loss: 0.1354 - val_sparse_categorical_accuracy: 0.9597\nEpoch 3/6\n469/469 ━━━━━━━━━━━━━━━━━━━━ 1s 2ms/step - loss: 0.1201 - sparse_categorical_accuracy: 0.9658 - val_loss: 0.1136 - val_sparse_categorical_accuracy: 0.9666\nEpoch 4/6\n469/469 ━━━━━━━━━━━━━━━━━━━━ 1s 3ms/step - loss: 0.0929 - sparse_categorical_accuracy: 0.9738 - val_loss: 0.0996 - val_sparse_categorical_accuracy: 0.9699\nEpoch 5/6\n469/469 ━━━━━━━━━━━━━━━━━━━━ 1s 2ms/step - loss: 0.0756 - sparse_categorical_accuracy: 0.9785 - val_loss: 0.0899 - val_sparse_categorical_accuracy: 0.9727\nEpoch 6/6\n469/469 ━━━━━━━━━━━━━━━━━━━━ 1s 3ms/step - loss: 0.0622 - sparse_categorical_accuracy: 0.9818 - val_loss: 0.0833 - val_sparse_categorical_accuracy: 0.9753\n\u003ckeras.src.callbacks.history.History at 0x7fbcbcf1ddb0\u003e\n```"]]