View on TensorFlow.org | Run in Google Colab | View source on GitHub | Download notebook |
This notebook demonstrates how you can set up model training with early stopping, first, in TensorFlow 1 with tf.estimator.Estimator
and an early stopping hook, and then, in TensorFlow 2 with Keras APIs or a custom training loop. Early stopping is a regularization technique that stops training if, for example, the validation loss reaches a certain threshold.
In TensorFlow 2, there are three ways to implement early stopping:
- Use a built-in Keras callback—
tf.keras.callbacks.EarlyStopping
—and pass it toModel.fit
. - Define a custom callback and pass it to Keras
Model.fit
. - Write a custom early stopping rule in a custom training loop (with
tf.GradientTape
).
Setup
import time
import numpy as np
import tensorflow as tf
import tensorflow.compat.v1 as tf1
import tensorflow_datasets as tfds
2023-10-04 01:37:29.125012: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered 2023-10-04 01:37:29.125061: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered 2023-10-04 01:37:29.125095: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
TensorFlow 1: Early stopping with an early stopping hook and tf.estimator
Start by defining functions for MNIST dataset loading and preprocessing, and model definition to be used with tf.estimator.Estimator
:
def normalize_img(image, label):
return tf.cast(image, tf.float32) / 255., label
def _input_fn():
ds_train = tfds.load(
name='mnist',
split='train',
shuffle_files=True,
as_supervised=True)
ds_train = ds_train.map(
normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_train = ds_train.batch(128)
ds_train = ds_train.repeat(100)
return ds_train
def _eval_input_fn():
ds_test = tfds.load(
name='mnist',
split='test',
shuffle_files=True,
as_supervised=True)
ds_test = ds_test.map(
normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_test = ds_test.batch(128)
return ds_test
def _model_fn(features, labels, mode):
flatten = tf1.layers.Flatten()(features)
features = tf1.layers.Dense(128, 'relu')(flatten)
logits = tf1.layers.Dense(10)(features)
loss = tf1.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
optimizer = tf1.train.AdagradOptimizer(0.005)
train_op = optimizer.minimize(loss, global_step=tf1.train.get_global_step())
return tf1.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)
In TensorFlow 1, early stopping works by setting up an early stopping hook with tf.estimator.experimental.make_early_stopping_hook
. You pass the hook to the make_early_stopping_hook
method as a parameter for should_stop_fn
, which can accept a function without any arguments. The training stops once should_stop_fn
returns True
.
The following example demonstrates how to implement an early stopping technique that limits the training time to a maximum of 20 seconds:
estimator = tf1.estimator.Estimator(model_fn=_model_fn)
start_time = time.time()
max_train_seconds = 20
def should_stop_fn():
return time.time() - start_time > max_train_seconds
early_stopping_hook = tf1.estimator.experimental.make_early_stopping_hook(
estimator=estimator,
should_stop_fn=should_stop_fn,
run_every_secs=1,
run_every_steps=None)
train_spec = tf1.estimator.TrainSpec(
input_fn=_input_fn,
hooks=[early_stopping_hook])
eval_spec = tf1.estimator.EvalSpec(input_fn=_eval_input_fn)
tf1.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_29043/1011025907.py:1: Estimator.__init__ (from tensorflow_estimator.python.estimator.estimator) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1844: RunConfig.__init__ (from tensorflow_estimator.python.estimator.run_config) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. INFO:tensorflow:Using default config. WARNING:tensorflow:Using temporary folder as model directory: /tmpfs/tmp/tmpp8x_ipb8 INFO:tensorflow:Using config: {'_model_dir': '/tmpfs/tmp/tmpp8x_ipb8', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true graph_options { rewrite_options { meta_optimizer_iterations: ONE } } , '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1} WARNING:tensorflow:From /tmpfs/tmp/ipykernel_29043/1011025907.py:9: make_early_stopping_hook (from tensorflow_estimator.python.estimator.early_stopping) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/early_stopping.py:474: SecondOrStepTimer.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/tmp/ipykernel_29043/1011025907.py:15: TrainSpec.__new__ (from tensorflow_estimator.python.estimator.training) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/tmp/ipykernel_29043/1011025907.py:19: EvalSpec.__new__ (from tensorflow_estimator.python.estimator.training) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/tmp/ipykernel_29043/1011025907.py:21: train_and_evaluate (from tensorflow_estimator.python.estimator.training) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. INFO:tensorflow:Not using Distribute Coordinator. INFO:tensorflow:Running training and evaluation locally (non-distributed). INFO:tensorflow:Start train and evaluate loop. The evaluate will happen after every checkpoint. Checkpoint frequency is determined based on RunConfig arguments: save_checkpoints_steps None or save_checkpoints_secs 600. 2023-10-04 01:37:32.147816: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2211] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform. Skipping registering GPU devices... INFO:tensorflow:Calling model_fn. INFO:tensorflow:Calling model_fn. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/adagrad.py:138: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version. Instructions for updating: Call initializer instance with the dtype argument instead of passing it to the constructor WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/adagrad.py:138: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version. Instructions for updating: Call initializer instance with the dtype argument instead of passing it to the constructor WARNING:tensorflow:From /tmpfs/tmp/ipykernel_29043/1468818800.py:37: EstimatorSpec.__new__ (from tensorflow_estimator.python.estimator.model_fn) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/tmp/ipykernel_29043/1468818800.py:37: EstimatorSpec.__new__ (from tensorflow_estimator.python.estimator.model_fn) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Done calling model_fn. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1416: NanTensorHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1416: NanTensorHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1419: LoggingTensorHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1419: LoggingTensorHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1456: CheckpointSaverHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1456: CheckpointSaverHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. INFO:tensorflow:Create CheckpointSaverHook. INFO:tensorflow:Create CheckpointSaverHook. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:579: StepCounterHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:579: StepCounterHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:586: SummarySaverHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:586: SummarySaverHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Running local_init_op. 2023-10-04 01:37:34.401540: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2211] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform. Skipping registering GPU devices... INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0... INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmpp8x_ipb8/model.ckpt. INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmpp8x_ipb8/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0... WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1455: SessionRunArgs.__new__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1455: SessionRunArgs.__new__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1454: SessionRunContext.__init__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1454: SessionRunContext.__init__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1474: SessionRunValues.__new__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1474: SessionRunValues.__new__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. INFO:tensorflow:loss = 2.2961025, step = 0 INFO:tensorflow:loss = 2.2961025, step = 0 INFO:tensorflow:global_step/sec: 335.823 INFO:tensorflow:global_step/sec: 335.823 INFO:tensorflow:loss = 1.318093, step = 100 (0.300 sec) INFO:tensorflow:loss = 1.318093, step = 100 (0.300 sec) INFO:tensorflow:global_step/sec: 423.864 INFO:tensorflow:global_step/sec: 423.864 INFO:tensorflow:loss = 0.8593272, step = 200 (0.236 sec) INFO:tensorflow:loss = 0.8593272, step = 200 (0.236 sec) INFO:tensorflow:global_step/sec: 436.472 INFO:tensorflow:global_step/sec: 436.472 INFO:tensorflow:loss = 0.74179834, step = 300 (0.229 sec) INFO:tensorflow:loss = 0.74179834, step = 300 (0.229 sec) INFO:tensorflow:global_step/sec: 432.337 INFO:tensorflow:global_step/sec: 432.337 INFO:tensorflow:loss = 0.64878374, step = 400 (0.232 sec) INFO:tensorflow:loss = 0.64878374, step = 400 (0.232 sec) INFO:tensorflow:global_step/sec: 425.537 INFO:tensorflow:global_step/sec: 425.537 INFO:tensorflow:loss = 0.49600804, step = 500 (0.234 sec) INFO:tensorflow:loss = 0.49600804, step = 500 (0.234 sec) INFO:tensorflow:global_step/sec: 505.423 INFO:tensorflow:global_step/sec: 505.423 INFO:tensorflow:loss = 0.44587734, step = 600 (0.198 sec) INFO:tensorflow:loss = 0.44587734, step = 600 (0.198 sec) INFO:tensorflow:global_step/sec: 518.525 INFO:tensorflow:global_step/sec: 518.525 INFO:tensorflow:loss = 0.38416976, step = 700 (0.193 sec) INFO:tensorflow:loss = 0.38416976, step = 700 (0.193 sec) INFO:tensorflow:global_step/sec: 509.36 INFO:tensorflow:global_step/sec: 509.36 INFO:tensorflow:loss = 0.5128687, step = 800 (0.197 sec) INFO:tensorflow:loss = 0.5128687, step = 800 (0.197 sec) INFO:tensorflow:global_step/sec: 531.176 INFO:tensorflow:global_step/sec: 531.176 INFO:tensorflow:loss = 0.38511667, step = 900 (0.188 sec) INFO:tensorflow:loss = 0.38511667, step = 900 (0.188 sec) INFO:tensorflow:global_step/sec: 465.699 INFO:tensorflow:global_step/sec: 465.699 INFO:tensorflow:loss = 0.43401116, step = 1000 (0.215 sec) INFO:tensorflow:loss = 0.43401116, step = 1000 (0.215 sec) INFO:tensorflow:global_step/sec: 495.799 INFO:tensorflow:global_step/sec: 495.799 INFO:tensorflow:loss = 0.44691753, step = 1100 (0.203 sec) INFO:tensorflow:loss = 0.44691753, step = 1100 (0.203 sec) INFO:tensorflow:global_step/sec: 504.403 INFO:tensorflow:global_step/sec: 504.403 INFO:tensorflow:loss = 0.40272003, step = 1200 (0.198 sec) INFO:tensorflow:loss = 0.40272003, step = 1200 (0.198 sec) INFO:tensorflow:global_step/sec: 482.232 INFO:tensorflow:global_step/sec: 482.232 INFO:tensorflow:loss = 0.47271937, step = 1300 (0.207 sec) INFO:tensorflow:loss = 0.47271937, step = 1300 (0.207 sec) INFO:tensorflow:global_step/sec: 501.553 INFO:tensorflow:global_step/sec: 501.553 INFO:tensorflow:loss = 0.29635084, step = 1400 (0.200 sec) INFO:tensorflow:loss = 0.29635084, step = 1400 (0.200 sec) INFO:tensorflow:global_step/sec: 459.337 INFO:tensorflow:global_step/sec: 459.337 INFO:tensorflow:loss = 0.293486, step = 1500 (0.218 sec) INFO:tensorflow:loss = 0.293486, step = 1500 (0.218 sec) INFO:tensorflow:global_step/sec: 511.454 INFO:tensorflow:global_step/sec: 511.454 INFO:tensorflow:loss = 0.40195698, step = 1600 (0.196 sec) INFO:tensorflow:loss = 0.40195698, step = 1600 (0.196 sec) INFO:tensorflow:global_step/sec: 487.452 INFO:tensorflow:global_step/sec: 487.452 INFO:tensorflow:loss = 0.38753498, step = 1700 (0.206 sec) INFO:tensorflow:loss = 0.38753498, step = 1700 (0.206 sec) INFO:tensorflow:global_step/sec: 500.287 INFO:tensorflow:global_step/sec: 500.287 INFO:tensorflow:loss = 0.3344679, step = 1800 (0.199 sec) INFO:tensorflow:loss = 0.3344679, step = 1800 (0.199 sec) INFO:tensorflow:global_step/sec: 476.868 INFO:tensorflow:global_step/sec: 476.868 INFO:tensorflow:loss = 0.49753922, step = 1900 (0.210 sec) INFO:tensorflow:loss = 0.49753922, step = 1900 (0.210 sec) INFO:tensorflow:global_step/sec: 484.689 INFO:tensorflow:global_step/sec: 484.689 INFO:tensorflow:loss = 0.21684857, step = 2000 (0.207 sec) INFO:tensorflow:loss = 0.21684857, step = 2000 (0.207 sec) INFO:tensorflow:global_step/sec: 496.691 INFO:tensorflow:global_step/sec: 496.691 INFO:tensorflow:loss = 0.28068116, step = 2100 (0.202 sec) INFO:tensorflow:loss = 0.28068116, step = 2100 (0.202 sec) INFO:tensorflow:global_step/sec: 499.682 INFO:tensorflow:global_step/sec: 499.682 INFO:tensorflow:loss = 0.3000077, step = 2200 (0.200 sec) INFO:tensorflow:loss = 0.3000077, step = 2200 (0.200 sec) INFO:tensorflow:global_step/sec: 507.391 INFO:tensorflow:global_step/sec: 507.391 INFO:tensorflow:loss = 0.34870982, step = 2300 (0.197 sec) INFO:tensorflow:loss = 0.34870982, step = 2300 (0.197 sec) INFO:tensorflow:global_step/sec: 475.414 INFO:tensorflow:global_step/sec: 475.414 INFO:tensorflow:loss = 0.24876948, step = 2400 (0.211 sec) INFO:tensorflow:loss = 0.24876948, step = 2400 (0.211 sec) INFO:tensorflow:global_step/sec: 499.39 INFO:tensorflow:global_step/sec: 499.39 INFO:tensorflow:loss = 0.21644332, step = 2500 (0.200 sec) INFO:tensorflow:loss = 0.21644332, step = 2500 (0.200 sec) INFO:tensorflow:global_step/sec: 500.659 INFO:tensorflow:global_step/sec: 500.659 INFO:tensorflow:loss = 0.152693, step = 2600 (0.200 sec) INFO:tensorflow:loss = 0.152693, step = 2600 (0.200 sec) INFO:tensorflow:global_step/sec: 500.201 INFO:tensorflow:global_step/sec: 500.201 INFO:tensorflow:loss = 0.33327985, step = 2700 (0.200 sec) INFO:tensorflow:loss = 0.33327985, step = 2700 (0.200 sec) INFO:tensorflow:global_step/sec: 517.488 INFO:tensorflow:global_step/sec: 517.488 INFO:tensorflow:loss = 0.47266263, step = 2800 (0.193 sec) INFO:tensorflow:loss = 0.47266263, step = 2800 (0.193 sec) INFO:tensorflow:global_step/sec: 466.55 INFO:tensorflow:global_step/sec: 466.55 INFO:tensorflow:loss = 0.24876104, step = 2900 (0.215 sec) INFO:tensorflow:loss = 0.24876104, step = 2900 (0.215 sec) INFO:tensorflow:global_step/sec: 499.818 INFO:tensorflow:global_step/sec: 499.818 INFO:tensorflow:loss = 0.33199376, step = 3000 (0.200 sec) INFO:tensorflow:loss = 0.33199376, step = 3000 (0.200 sec) INFO:tensorflow:global_step/sec: 477.253 INFO:tensorflow:global_step/sec: 477.253 INFO:tensorflow:loss = 0.19820198, step = 3100 (0.210 sec) INFO:tensorflow:loss = 0.19820198, step = 3100 (0.210 sec) INFO:tensorflow:global_step/sec: 512.079 INFO:tensorflow:global_step/sec: 512.079 INFO:tensorflow:loss = 0.4163157, step = 3200 (0.195 sec) INFO:tensorflow:loss = 0.4163157, step = 3200 (0.195 sec) INFO:tensorflow:global_step/sec: 461.596 INFO:tensorflow:global_step/sec: 461.596 INFO:tensorflow:loss = 0.3364423, step = 3300 (0.216 sec) INFO:tensorflow:loss = 0.3364423, step = 3300 (0.216 sec) INFO:tensorflow:global_step/sec: 488.092 INFO:tensorflow:global_step/sec: 488.092 INFO:tensorflow:loss = 0.25606278, step = 3400 (0.206 sec) INFO:tensorflow:loss = 0.25606278, step = 3400 (0.206 sec) INFO:tensorflow:global_step/sec: 490.26 INFO:tensorflow:global_step/sec: 490.26 INFO:tensorflow:loss = 0.20572862, step = 3500 (0.204 sec) INFO:tensorflow:loss = 0.20572862, step = 3500 (0.204 sec) INFO:tensorflow:global_step/sec: 487.323 INFO:tensorflow:global_step/sec: 487.323 INFO:tensorflow:loss = 0.2212799, step = 3600 (0.206 sec) INFO:tensorflow:loss = 0.2212799, step = 3600 (0.206 sec) INFO:tensorflow:global_step/sec: 492.262 INFO:tensorflow:global_step/sec: 492.262 INFO:tensorflow:loss = 0.282781, step = 3700 (0.203 sec) INFO:tensorflow:loss = 0.282781, step = 3700 (0.203 sec) INFO:tensorflow:global_step/sec: 449.227 INFO:tensorflow:global_step/sec: 449.227 INFO:tensorflow:loss = 0.365446, step = 3800 (0.223 sec) INFO:tensorflow:loss = 0.365446, step = 3800 (0.223 sec) INFO:tensorflow:global_step/sec: 521.924 INFO:tensorflow:global_step/sec: 521.924 INFO:tensorflow:loss = 0.22579709, step = 3900 (0.191 sec) INFO:tensorflow:loss = 0.22579709, step = 3900 (0.191 sec) INFO:tensorflow:global_step/sec: 535.844 INFO:tensorflow:global_step/sec: 535.844 INFO:tensorflow:loss = 0.30844557, step = 4000 (0.187 sec) INFO:tensorflow:loss = 0.30844557, step = 4000 (0.187 sec) INFO:tensorflow:global_step/sec: 496.551 INFO:tensorflow:global_step/sec: 496.551 INFO:tensorflow:loss = 0.21947613, step = 4100 (0.202 sec) INFO:tensorflow:loss = 0.21947613, step = 4100 (0.202 sec) INFO:tensorflow:global_step/sec: 508.086 INFO:tensorflow:global_step/sec: 508.086 INFO:tensorflow:loss = 0.26513258, step = 4200 (0.197 sec) INFO:tensorflow:loss = 0.26513258, step = 4200 (0.197 sec) INFO:tensorflow:global_step/sec: 462.149 INFO:tensorflow:global_step/sec: 462.149 INFO:tensorflow:loss = 0.29323363, step = 4300 (0.217 sec) INFO:tensorflow:loss = 0.29323363, step = 4300 (0.217 sec) INFO:tensorflow:global_step/sec: 503.226 INFO:tensorflow:global_step/sec: 503.226 INFO:tensorflow:loss = 0.31204918, step = 4400 (0.198 sec) INFO:tensorflow:loss = 0.31204918, step = 4400 (0.198 sec) INFO:tensorflow:global_step/sec: 510.874 INFO:tensorflow:global_step/sec: 510.874 INFO:tensorflow:loss = 0.26014802, step = 4500 (0.196 sec) INFO:tensorflow:loss = 0.26014802, step = 4500 (0.196 sec) INFO:tensorflow:global_step/sec: 511.332 INFO:tensorflow:global_step/sec: 511.332 INFO:tensorflow:loss = 0.33227044, step = 4600 (0.196 sec) INFO:tensorflow:loss = 0.33227044, step = 4600 (0.196 sec) INFO:tensorflow:global_step/sec: 472.525 INFO:tensorflow:global_step/sec: 472.525 INFO:tensorflow:loss = 0.13610935, step = 4700 (0.211 sec) INFO:tensorflow:loss = 0.13610935, step = 4700 (0.211 sec) INFO:tensorflow:global_step/sec: 505.262 INFO:tensorflow:global_step/sec: 505.262 INFO:tensorflow:loss = 0.28594193, step = 4800 (0.198 sec) INFO:tensorflow:loss = 0.28594193, step = 4800 (0.198 sec) INFO:tensorflow:global_step/sec: 515.453 INFO:tensorflow:global_step/sec: 515.453 INFO:tensorflow:loss = 0.38484165, step = 4900 (0.195 sec) INFO:tensorflow:loss = 0.38484165, step = 4900 (0.195 sec) INFO:tensorflow:global_step/sec: 479.259 INFO:tensorflow:global_step/sec: 479.259 INFO:tensorflow:loss = 0.27081215, step = 5000 (0.208 sec) INFO:tensorflow:loss = 0.27081215, step = 5000 (0.208 sec) INFO:tensorflow:global_step/sec: 491.361 INFO:tensorflow:global_step/sec: 491.361 INFO:tensorflow:loss = 0.33877313, step = 5100 (0.204 sec) INFO:tensorflow:loss = 0.33877313, step = 5100 (0.204 sec) INFO:tensorflow:global_step/sec: 471.256 INFO:tensorflow:global_step/sec: 471.256 INFO:tensorflow:loss = 0.2074028, step = 5200 (0.212 sec) INFO:tensorflow:loss = 0.2074028, step = 5200 (0.212 sec) INFO:tensorflow:global_step/sec: 506.672 INFO:tensorflow:global_step/sec: 506.672 INFO:tensorflow:loss = 0.24718614, step = 5300 (0.197 sec) INFO:tensorflow:loss = 0.24718614, step = 5300 (0.197 sec) INFO:tensorflow:global_step/sec: 491.333 INFO:tensorflow:global_step/sec: 491.333 INFO:tensorflow:loss = 0.16439602, step = 5400 (0.203 sec) INFO:tensorflow:loss = 0.16439602, step = 5400 (0.203 sec) INFO:tensorflow:global_step/sec: 500.638 INFO:tensorflow:global_step/sec: 500.638 INFO:tensorflow:loss = 0.22073755, step = 5500 (0.200 sec) INFO:tensorflow:loss = 0.22073755, step = 5500 (0.200 sec) INFO:tensorflow:global_step/sec: 508.206 INFO:tensorflow:global_step/sec: 508.206 INFO:tensorflow:loss = 0.18545151, step = 5600 (0.196 sec) INFO:tensorflow:loss = 0.18545151, step = 5600 (0.196 sec) INFO:tensorflow:global_step/sec: 450.634 INFO:tensorflow:global_step/sec: 450.634 INFO:tensorflow:loss = 0.17126478, step = 5700 (0.222 sec) INFO:tensorflow:loss = 0.17126478, step = 5700 (0.222 sec) INFO:tensorflow:global_step/sec: 525.95 INFO:tensorflow:global_step/sec: 525.95 INFO:tensorflow:loss = 0.2689212, step = 5800 (0.190 sec) INFO:tensorflow:loss = 0.2689212, step = 5800 (0.190 sec) INFO:tensorflow:global_step/sec: 483.04 INFO:tensorflow:global_step/sec: 483.04 INFO:tensorflow:loss = 0.21353054, step = 5900 (0.207 sec) INFO:tensorflow:loss = 0.21353054, step = 5900 (0.207 sec) INFO:tensorflow:global_step/sec: 480.064 INFO:tensorflow:global_step/sec: 480.064 INFO:tensorflow:loss = 0.25207376, step = 6000 (0.209 sec) INFO:tensorflow:loss = 0.25207376, step = 6000 (0.209 sec) INFO:tensorflow:global_step/sec: 451.147 INFO:tensorflow:global_step/sec: 451.147 INFO:tensorflow:loss = 0.18487616, step = 6100 (0.221 sec) INFO:tensorflow:loss = 0.18487616, step = 6100 (0.221 sec) INFO:tensorflow:global_step/sec: 490.621 INFO:tensorflow:global_step/sec: 490.621 INFO:tensorflow:loss = 0.267612, step = 6200 (0.204 sec) INFO:tensorflow:loss = 0.267612, step = 6200 (0.204 sec) INFO:tensorflow:global_step/sec: 499.641 INFO:tensorflow:global_step/sec: 499.641 INFO:tensorflow:loss = 0.24492177, step = 6300 (0.200 sec) INFO:tensorflow:loss = 0.24492177, step = 6300 (0.200 sec) INFO:tensorflow:global_step/sec: 498.738 INFO:tensorflow:global_step/sec: 498.738 INFO:tensorflow:loss = 0.28542638, step = 6400 (0.200 sec) INFO:tensorflow:loss = 0.28542638, step = 6400 (0.200 sec) INFO:tensorflow:global_step/sec: 523.196 INFO:tensorflow:global_step/sec: 523.196 INFO:tensorflow:loss = 0.26353425, step = 6500 (0.191 sec) INFO:tensorflow:loss = 0.26353425, step = 6500 (0.191 sec) INFO:tensorflow:global_step/sec: 455.054 INFO:tensorflow:global_step/sec: 455.054 INFO:tensorflow:loss = 0.19190696, step = 6600 (0.220 sec) INFO:tensorflow:loss = 0.19190696, step = 6600 (0.220 sec) INFO:tensorflow:global_step/sec: 522.506 INFO:tensorflow:global_step/sec: 522.506 INFO:tensorflow:loss = 0.25657248, step = 6700 (0.191 sec) INFO:tensorflow:loss = 0.25657248, step = 6700 (0.191 sec) INFO:tensorflow:global_step/sec: 506.766 INFO:tensorflow:global_step/sec: 506.766 INFO:tensorflow:loss = 0.39108855, step = 6800 (0.197 sec) INFO:tensorflow:loss = 0.39108855, step = 6800 (0.197 sec) INFO:tensorflow:global_step/sec: 481.896 INFO:tensorflow:global_step/sec: 481.896 INFO:tensorflow:loss = 0.15236983, step = 6900 (0.208 sec) INFO:tensorflow:loss = 0.15236983, step = 6900 (0.208 sec) INFO:tensorflow:global_step/sec: 506.833 INFO:tensorflow:global_step/sec: 506.833 INFO:tensorflow:loss = 0.3356074, step = 7000 (0.197 sec) INFO:tensorflow:loss = 0.3356074, step = 7000 (0.197 sec) INFO:tensorflow:global_step/sec: 467.501 INFO:tensorflow:global_step/sec: 467.501 INFO:tensorflow:loss = 0.1956916, step = 7100 (0.214 sec) INFO:tensorflow:loss = 0.1956916, step = 7100 (0.214 sec) INFO:tensorflow:global_step/sec: 494.573 INFO:tensorflow:global_step/sec: 494.573 INFO:tensorflow:loss = 0.24700633, step = 7200 (0.203 sec) INFO:tensorflow:loss = 0.24700633, step = 7200 (0.203 sec) INFO:tensorflow:global_step/sec: 522.627 INFO:tensorflow:global_step/sec: 522.627 INFO:tensorflow:loss = 0.17747968, step = 7300 (0.191 sec) INFO:tensorflow:loss = 0.17747968, step = 7300 (0.191 sec) INFO:tensorflow:global_step/sec: 486.106 INFO:tensorflow:global_step/sec: 486.106 INFO:tensorflow:loss = 0.28143722, step = 7400 (0.206 sec) INFO:tensorflow:loss = 0.28143722, step = 7400 (0.206 sec) INFO:tensorflow:global_step/sec: 479.817 INFO:tensorflow:global_step/sec: 479.817 INFO:tensorflow:loss = 0.14399035, step = 7500 (0.209 sec) INFO:tensorflow:loss = 0.14399035, step = 7500 (0.209 sec) INFO:tensorflow:global_step/sec: 487.355 INFO:tensorflow:global_step/sec: 487.355 INFO:tensorflow:loss = 0.2522769, step = 7600 (0.205 sec) INFO:tensorflow:loss = 0.2522769, step = 7600 (0.205 sec) INFO:tensorflow:global_step/sec: 526.322 INFO:tensorflow:global_step/sec: 526.322 INFO:tensorflow:loss = 0.25871283, step = 7700 (0.190 sec) INFO:tensorflow:loss = 0.25871283, step = 7700 (0.190 sec) INFO:tensorflow:global_step/sec: 501.783 INFO:tensorflow:global_step/sec: 501.783 INFO:tensorflow:loss = 0.16747415, step = 7800 (0.199 sec) INFO:tensorflow:loss = 0.16747415, step = 7800 (0.199 sec) INFO:tensorflow:global_step/sec: 485.409 INFO:tensorflow:global_step/sec: 485.409 INFO:tensorflow:loss = 0.15161222, step = 7900 (0.205 sec) INFO:tensorflow:loss = 0.15161222, step = 7900 (0.205 sec) INFO:tensorflow:global_step/sec: 462.017 INFO:tensorflow:global_step/sec: 462.017 INFO:tensorflow:loss = 0.18137535, step = 8000 (0.216 sec) INFO:tensorflow:loss = 0.18137535, step = 8000 (0.216 sec) INFO:tensorflow:global_step/sec: 488.439 INFO:tensorflow:global_step/sec: 488.439 INFO:tensorflow:loss = 0.19478491, step = 8100 (0.205 sec) INFO:tensorflow:loss = 0.19478491, step = 8100 (0.205 sec) INFO:tensorflow:global_step/sec: 485.238 INFO:tensorflow:global_step/sec: 485.238 INFO:tensorflow:loss = 0.27511528, step = 8200 (0.206 sec) INFO:tensorflow:loss = 0.27511528, step = 8200 (0.206 sec) INFO:tensorflow:Requesting early stopping at global step 8286 INFO:tensorflow:Requesting early stopping at global step 8286 INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 8287... INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 8287... INFO:tensorflow:Saving checkpoints for 8287 into /tmpfs/tmp/tmpp8x_ipb8/model.ckpt. INFO:tensorflow:Saving checkpoints for 8287 into /tmpfs/tmp/tmpp8x_ipb8/model.ckpt. INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 8287... INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 8287... INFO:tensorflow:Calling model_fn. INFO:tensorflow:Calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Done calling model_fn. INFO:tensorflow:Starting evaluation at 2023-10-04T01:37:52 INFO:tensorflow:Starting evaluation at 2023-10-04T01:37:52 WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/evaluation.py:260: FinalOpsHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/evaluation.py:260: FinalOpsHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version. Instructions for updating: Use tf.keras instead. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Graph was finalized. INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmpp8x_ipb8/model.ckpt-8287 2023-10-04 01:37:52.957202: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2211] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform. Skipping registering GPU devices... INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmpp8x_ipb8/model.ckpt-8287 INFO:tensorflow:Running local_init_op. INFO:tensorflow:Running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Done running local_init_op. INFO:tensorflow:Evaluation [10/100] INFO:tensorflow:Evaluation [10/100] INFO:tensorflow:Evaluation [20/100] INFO:tensorflow:Evaluation [20/100] INFO:tensorflow:Evaluation [30/100] INFO:tensorflow:Evaluation [30/100] INFO:tensorflow:Evaluation [40/100] INFO:tensorflow:Evaluation [40/100] INFO:tensorflow:Evaluation [50/100] INFO:tensorflow:Evaluation [50/100] INFO:tensorflow:Evaluation [60/100] INFO:tensorflow:Evaluation [60/100] INFO:tensorflow:Evaluation [70/100] INFO:tensorflow:Evaluation [70/100] INFO:tensorflow:Inference Time : 0.51915s INFO:tensorflow:Inference Time : 0.51915s INFO:tensorflow:Finished evaluation at 2023-10-04-01:37:53 INFO:tensorflow:Finished evaluation at 2023-10-04-01:37:53 INFO:tensorflow:Saving dict for global step 8287: global_step = 8287, loss = 0.21179305 INFO:tensorflow:Saving dict for global step 8287: global_step = 8287, loss = 0.21179305 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 8287: /tmpfs/tmp/tmpp8x_ipb8/model.ckpt-8287 INFO:tensorflow:Saving 'checkpoint_path' summary for global step 8287: /tmpfs/tmp/tmpp8x_ipb8/model.ckpt-8287 INFO:tensorflow:Loss for final step: 0.28788015. INFO:tensorflow:Loss for final step: 0.28788015. ({'loss': 0.21179305, 'global_step': 8287}, [])
TensorFlow 2: Early stopping with a built-in callback and Model.fit
Prepare the MNIST dataset and a simple Keras model:
(ds_train, ds_test), ds_info = tfds.load(
'mnist',
split=['train', 'test'],
shuffle_files=True,
as_supervised=True,
with_info=True,
)
ds_train = ds_train.map(
normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_train = ds_train.batch(128)
ds_test = ds_test.map(
normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_test = ds_test.batch(128)
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(10)
])
model.compile(
optimizer=tf.keras.optimizers.Adam(0.005),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)
In TensorFlow 2, when you use the built-in Keras Model.fit
(or Model.evaluate
), you can configure early stopping by passing a built-in callback—tf.keras.callbacks.EarlyStopping
—to the callbacks
parameter of Model.fit
.
The EarlyStopping
callback monitors a user-specified metric and ends training when it stops improving. (Check the Training and evaluation with the built-in methods or the API docs for more information.)
Below is an example of an early stopping callback that monitors the loss and stops training after the number of epochs that show no improvements is set to 3
(patience
):
callback = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=3)
# Only around 25 epochs are run during training, instead of 100.
history = model.fit(
ds_train,
epochs=100,
validation_data=ds_test,
callbacks=[callback]
)
len(history.history['loss'])
Epoch 1/100 469/469 [==============================] - 3s 4ms/step - loss: 0.2326 - sparse_categorical_accuracy: 0.9310 - val_loss: 0.1214 - val_sparse_categorical_accuracy: 0.9630 Epoch 2/100 469/469 [==============================] - 1s 3ms/step - loss: 0.1004 - sparse_categorical_accuracy: 0.9700 - val_loss: 0.1009 - val_sparse_categorical_accuracy: 0.9684 Epoch 3/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0689 - sparse_categorical_accuracy: 0.9792 - val_loss: 0.1061 - val_sparse_categorical_accuracy: 0.9674 Epoch 4/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0528 - sparse_categorical_accuracy: 0.9835 - val_loss: 0.1244 - val_sparse_categorical_accuracy: 0.9640 Epoch 5/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0424 - sparse_categorical_accuracy: 0.9866 - val_loss: 0.1001 - val_sparse_categorical_accuracy: 0.9730 Epoch 6/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0355 - sparse_categorical_accuracy: 0.9883 - val_loss: 0.1034 - val_sparse_categorical_accuracy: 0.9728 Epoch 7/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0315 - sparse_categorical_accuracy: 0.9894 - val_loss: 0.1029 - val_sparse_categorical_accuracy: 0.9748 Epoch 8/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0319 - sparse_categorical_accuracy: 0.9888 - val_loss: 0.1267 - val_sparse_categorical_accuracy: 0.9707 Epoch 9/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0280 - sparse_categorical_accuracy: 0.9905 - val_loss: 0.1128 - val_sparse_categorical_accuracy: 0.9733 Epoch 10/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0281 - sparse_categorical_accuracy: 0.9907 - val_loss: 0.1246 - val_sparse_categorical_accuracy: 0.9718 Epoch 11/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0270 - sparse_categorical_accuracy: 0.9905 - val_loss: 0.1184 - val_sparse_categorical_accuracy: 0.9786 Epoch 12/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0217 - sparse_categorical_accuracy: 0.9927 - val_loss: 0.1231 - val_sparse_categorical_accuracy: 0.9756 Epoch 13/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0191 - sparse_categorical_accuracy: 0.9935 - val_loss: 0.1466 - val_sparse_categorical_accuracy: 0.9716 Epoch 14/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0220 - sparse_categorical_accuracy: 0.9927 - val_loss: 0.1475 - val_sparse_categorical_accuracy: 0.9754 Epoch 15/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0200 - sparse_categorical_accuracy: 0.9934 - val_loss: 0.1472 - val_sparse_categorical_accuracy: 0.9746 Epoch 16/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0195 - sparse_categorical_accuracy: 0.9937 - val_loss: 0.1454 - val_sparse_categorical_accuracy: 0.9755 16
TensorFlow 2: Early stopping with a custom callback and Model.fit
You can also implement a custom early stopping callback, which can also be passed to the callbacks
parameter of Model.fit
(or Model.evaluate
).
In this example, the training process is stopped once self.model.stop_training
is set to be True
:
class LimitTrainingTime(tf.keras.callbacks.Callback):
def __init__(self, max_time_s):
super().__init__()
self.max_time_s = max_time_s
self.start_time = None
def on_train_begin(self, logs):
self.start_time = time.time()
def on_train_batch_end(self, batch, logs):
now = time.time()
if now - self.start_time > self.max_time_s:
self.model.stop_training = True
# Limit the training time to 30 seconds.
callback = LimitTrainingTime(30)
history = model.fit(
ds_train,
epochs=100,
validation_data=ds_test,
callbacks=[callback]
)
len(history.history['loss'])
Epoch 1/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0180 - sparse_categorical_accuracy: 0.9939 - val_loss: 0.1440 - val_sparse_categorical_accuracy: 0.9785 Epoch 2/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0139 - sparse_categorical_accuracy: 0.9958 - val_loss: 0.1728 - val_sparse_categorical_accuracy: 0.9755 Epoch 3/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0213 - sparse_categorical_accuracy: 0.9937 - val_loss: 0.1567 - val_sparse_categorical_accuracy: 0.9787 Epoch 4/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0198 - sparse_categorical_accuracy: 0.9941 - val_loss: 0.1826 - val_sparse_categorical_accuracy: 0.9740 Epoch 5/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0146 - sparse_categorical_accuracy: 0.9952 - val_loss: 0.1658 - val_sparse_categorical_accuracy: 0.9756 Epoch 6/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0126 - sparse_categorical_accuracy: 0.9959 - val_loss: 0.2150 - val_sparse_categorical_accuracy: 0.9726 Epoch 7/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0174 - sparse_categorical_accuracy: 0.9952 - val_loss: 0.1890 - val_sparse_categorical_accuracy: 0.9769 Epoch 8/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0140 - sparse_categorical_accuracy: 0.9956 - val_loss: 0.1852 - val_sparse_categorical_accuracy: 0.9772 Epoch 9/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0141 - sparse_categorical_accuracy: 0.9958 - val_loss: 0.2051 - val_sparse_categorical_accuracy: 0.9761 Epoch 10/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0156 - sparse_categorical_accuracy: 0.9961 - val_loss: 0.2260 - val_sparse_categorical_accuracy: 0.9727 Epoch 11/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0167 - sparse_categorical_accuracy: 0.9952 - val_loss: 0.2008 - val_sparse_categorical_accuracy: 0.9757 Epoch 12/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0133 - sparse_categorical_accuracy: 0.9963 - val_loss: 0.2283 - val_sparse_categorical_accuracy: 0.9755 Epoch 13/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0106 - sparse_categorical_accuracy: 0.9969 - val_loss: 0.2270 - val_sparse_categorical_accuracy: 0.9739 Epoch 14/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0139 - sparse_categorical_accuracy: 0.9962 - val_loss: 0.2169 - val_sparse_categorical_accuracy: 0.9775 Epoch 15/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0121 - sparse_categorical_accuracy: 0.9964 - val_loss: 0.2282 - val_sparse_categorical_accuracy: 0.9773 Epoch 16/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0128 - sparse_categorical_accuracy: 0.9966 - val_loss: 0.2723 - val_sparse_categorical_accuracy: 0.9738 Epoch 17/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0118 - sparse_categorical_accuracy: 0.9967 - val_loss: 0.2223 - val_sparse_categorical_accuracy: 0.9784 Epoch 18/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0168 - sparse_categorical_accuracy: 0.9959 - val_loss: 0.2489 - val_sparse_categorical_accuracy: 0.9770 Epoch 19/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0129 - sparse_categorical_accuracy: 0.9967 - val_loss: 0.2607 - val_sparse_categorical_accuracy: 0.9753 Epoch 20/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0112 - sparse_categorical_accuracy: 0.9975 - val_loss: 0.2267 - val_sparse_categorical_accuracy: 0.9776 Epoch 21/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0105 - sparse_categorical_accuracy: 0.9971 - val_loss: 0.2758 - val_sparse_categorical_accuracy: 0.9733 21
TensorFlow 2: Early stopping with a custom training loop
In TensorFlow 2, you can implement early stopping in a custom training loop if you're not training and evaluating with the built-in Keras methods.
Start by using Keras APIs to define another simple model, an optimizer, a loss function, and metrics:
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(10)
])
optimizer = tf.keras.optimizers.Adam(0.005)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
train_acc_metric = tf.keras.metrics.SparseCategoricalAccuracy()
train_loss_metric = tf.keras.metrics.SparseCategoricalCrossentropy()
val_acc_metric = tf.keras.metrics.SparseCategoricalAccuracy()
val_loss_metric = tf.keras.metrics.SparseCategoricalCrossentropy()
Define the parameter update functions with tf.GradientTape and the @tf.function
decorator for a speedup:
@tf.function
def train_step(x, y):
with tf.GradientTape() as tape:
logits = model(x, training=True)
loss_value = loss_fn(y, logits)
grads = tape.gradient(loss_value, model.trainable_weights)
optimizer.apply_gradients(zip(grads, model.trainable_weights))
train_acc_metric.update_state(y, logits)
train_loss_metric.update_state(y, logits)
return loss_value
@tf.function
def test_step(x, y):
logits = model(x, training=False)
val_acc_metric.update_state(y, logits)
val_loss_metric.update_state(y, logits)
Next, write a custom training loop, where you can implement your early stopping rule manually.
The example below shows how to stop training when the validation loss doesn't improve over a certain number of epochs:
epochs = 100
patience = 5
wait = 0
best = float('inf')
for epoch in range(epochs):
print("\nStart of epoch %d" % (epoch,))
start_time = time.time()
for step, (x_batch_train, y_batch_train) in enumerate(ds_train):
loss_value = train_step(x_batch_train, y_batch_train)
if step % 200 == 0:
print("Training loss at step %d: %.4f" % (step, loss_value.numpy()))
print("Seen so far: %s samples" % ((step + 1) * 128))
train_acc = train_acc_metric.result()
train_loss = train_loss_metric.result()
train_acc_metric.reset_states()
train_loss_metric.reset_states()
print("Training acc over epoch: %.4f" % (train_acc.numpy()))
for x_batch_val, y_batch_val in ds_test:
test_step(x_batch_val, y_batch_val)
val_acc = val_acc_metric.result()
val_loss = val_loss_metric.result()
val_acc_metric.reset_states()
val_loss_metric.reset_states()
print("Validation acc: %.4f" % (float(val_acc),))
print("Time taken: %.2fs" % (time.time() - start_time))
# The early stopping strategy: stop the training if `val_loss` does not
# decrease over a certain number of epochs.
wait += 1
if val_loss < best:
best = val_loss
wait = 0
if wait >= patience:
break
Start of epoch 0 Training loss at step 0: 2.4644 Seen so far: 128 samples Training loss at step 200: 0.2622 Seen so far: 25728 samples Training loss at step 400: 0.2129 Seen so far: 51328 samples Training acc over epoch: 0.9297 Validation acc: 0.9610 Time taken: 2.07s Start of epoch 1 Training loss at step 0: 0.0756 Seen so far: 128 samples Training loss at step 200: 0.1215 Seen so far: 25728 samples Training loss at step 400: 0.1359 Seen so far: 51328 samples Training acc over epoch: 0.9685 Validation acc: 0.9664 Time taken: 1.36s Start of epoch 2 Training loss at step 0: 0.0390 Seen so far: 128 samples Training loss at step 200: 0.0995 Seen so far: 25728 samples Training loss at step 400: 0.1270 Seen so far: 51328 samples Training acc over epoch: 0.9778 Validation acc: 0.9690 Time taken: 1.37s Start of epoch 3 Training loss at step 0: 0.0282 Seen so far: 128 samples Training loss at step 200: 0.0673 Seen so far: 25728 samples Training loss at step 400: 0.0700 Seen so far: 51328 samples Training acc over epoch: 0.9826 Validation acc: 0.9695 Time taken: 1.35s Start of epoch 4 Training loss at step 0: 0.0249 Seen so far: 128 samples Training loss at step 200: 0.0403 Seen so far: 25728 samples Training loss at step 400: 0.0324 Seen so far: 51328 samples Training acc over epoch: 0.9862 Validation acc: 0.9693 Time taken: 1.37s Start of epoch 5 Training loss at step 0: 0.0279 Seen so far: 128 samples Training loss at step 200: 0.0529 Seen so far: 25728 samples Training loss at step 400: 0.0592 Seen so far: 51328 samples Training acc over epoch: 0.9875 Validation acc: 0.9660 Time taken: 1.32s Start of epoch 6 Training loss at step 0: 0.0200 Seen so far: 128 samples Training loss at step 200: 0.0344 Seen so far: 25728 samples Training loss at step 400: 0.0142 Seen so far: 51328 samples Training acc over epoch: 0.9884 Validation acc: 0.9720 Time taken: 1.36s Start of epoch 7 Training loss at step 0: 0.0325 Seen so far: 128 samples Training loss at step 200: 0.0499 Seen so far: 25728 samples Training loss at step 400: 0.0167 Seen so far: 51328 samples Training acc over epoch: 0.9898 Validation acc: 0.9711 Time taken: 1.39s Start of epoch 8 Training loss at step 0: 0.0155 Seen so far: 128 samples Training loss at step 200: 0.0219 Seen so far: 25728 samples Training loss at step 400: 0.0213 Seen so far: 51328 samples Training acc over epoch: 0.9914 Validation acc: 0.9715 Time taken: 1.34s Start of epoch 9 Training loss at step 0: 0.0068 Seen so far: 128 samples Training loss at step 200: 0.0288 Seen so far: 25728 samples Training loss at step 400: 0.0320 Seen so far: 51328 samples Training acc over epoch: 0.9921 Validation acc: 0.9743 Time taken: 1.34s Start of epoch 10 Training loss at step 0: 0.0069 Seen so far: 128 samples Training loss at step 200: 0.0419 Seen so far: 25728 samples Training loss at step 400: 0.0458 Seen so far: 51328 samples Training acc over epoch: 0.9919 Validation acc: 0.9720 Time taken: 1.34s Start of epoch 11 Training loss at step 0: 0.0015 Seen so far: 128 samples Training loss at step 200: 0.0113 Seen so far: 25728 samples Training loss at step 400: 0.0285 Seen so far: 51328 samples Training acc over epoch: 0.9930 Validation acc: 0.9740 Time taken: 1.33s Start of epoch 12 Training loss at step 0: 0.0010 Seen so far: 128 samples Training loss at step 200: 0.0450 Seen so far: 25728 samples Training loss at step 400: 0.0561 Seen so far: 51328 samples Training acc over epoch: 0.9926 Validation acc: 0.9744 Time taken: 1.34s Start of epoch 13 Training loss at step 0: 0.0038 Seen so far: 128 samples Training loss at step 200: 0.0309 Seen so far: 25728 samples Training loss at step 400: 0.0124 Seen so far: 51328 samples Training acc over epoch: 0.9937 Validation acc: 0.9740 Time taken: 1.36s Start of epoch 14 Training loss at step 0: 0.0011 Seen so far: 128 samples Training loss at step 200: 0.0068 Seen so far: 25728 samples Training loss at step 400: 0.0110 Seen so far: 51328 samples Training acc over epoch: 0.9930 Validation acc: 0.9736 Time taken: 1.35s Start of epoch 15 Training loss at step 0: 0.0043 Seen so far: 128 samples Training loss at step 200: 0.0085 Seen so far: 25728 samples Training loss at step 400: 0.0042 Seen so far: 51328 samples Training acc over epoch: 0.9942 Validation acc: 0.9752 Time taken: 1.34s Start of epoch 16 Training loss at step 0: 0.0208 Seen so far: 128 samples Training loss at step 200: 0.0042 Seen so far: 25728 samples Training loss at step 400: 0.1063 Seen so far: 51328 samples Training acc over epoch: 0.9947 Validation acc: 0.9740 Time taken: 1.34s Start of epoch 17 Training loss at step 0: 0.0067 Seen so far: 128 samples Training loss at step 200: 0.0277 Seen so far: 25728 samples Training loss at step 400: 0.0787 Seen so far: 51328 samples Training acc over epoch: 0.9951 Validation acc: 0.9729 Time taken: 1.33s Start of epoch 18 Training loss at step 0: 0.0017 Seen so far: 128 samples Training loss at step 200: 0.0131 Seen so far: 25728 samples Training loss at step 400: 0.0431 Seen so far: 51328 samples Training acc over epoch: 0.9943 Validation acc: 0.9739 Time taken: 1.40s Start of epoch 19 Training loss at step 0: 0.0004 Seen so far: 128 samples Training loss at step 200: 0.0220 Seen so far: 25728 samples Training loss at step 400: 0.0662 Seen so far: 51328 samples Training acc over epoch: 0.9952 Validation acc: 0.9738 Time taken: 1.34s Start of epoch 20 Training loss at step 0: 0.0003 Seen so far: 128 samples Training loss at step 200: 0.0306 Seen so far: 25728 samples Training loss at step 400: 0.0083 Seen so far: 51328 samples Training acc over epoch: 0.9955 Validation acc: 0.9753 Time taken: 1.37s Start of epoch 21 Training loss at step 0: 0.0016 Seen so far: 128 samples Training loss at step 200: 0.0003 Seen so far: 25728 samples Training loss at step 400: 0.0069 Seen so far: 51328 samples Training acc over epoch: 0.9946 Validation acc: 0.9729 Time taken: 1.36s Start of epoch 22 Training loss at step 0: 0.0626 Seen so far: 128 samples Training loss at step 200: 0.0013 Seen so far: 25728 samples Training loss at step 400: 0.0278 Seen so far: 51328 samples Training acc over epoch: 0.9946 Validation acc: 0.9740 Time taken: 1.34s Start of epoch 23 Training loss at step 0: 0.0318 Seen so far: 128 samples Training loss at step 200: 0.0514 Seen so far: 25728 samples Training loss at step 400: 0.0001 Seen so far: 51328 samples Training acc over epoch: 0.9952 Validation acc: 0.9758 Time taken: 1.36s Start of epoch 24 Training loss at step 0: 0.0004 Seen so far: 128 samples Training loss at step 200: 0.0043 Seen so far: 25728 samples Training loss at step 400: 0.0339 Seen so far: 51328 samples Training acc over epoch: 0.9956 Validation acc: 0.9752 Time taken: 1.37s Start of epoch 25 Training loss at step 0: 0.0000 Seen so far: 128 samples Training loss at step 200: 0.0057 Seen so far: 25728 samples Training loss at step 400: 0.1485 Seen so far: 51328 samples Training acc over epoch: 0.9961 Validation acc: 0.9733 Time taken: 1.35s Start of epoch 26 Training loss at step 0: 0.0005 Seen so far: 128 samples Training loss at step 200: 0.0992 Seen so far: 25728 samples Training loss at step 400: 0.0033 Seen so far: 51328 samples Training acc over epoch: 0.9972 Validation acc: 0.9776 Time taken: 1.32s Start of epoch 27 Training loss at step 0: 0.0004 Seen so far: 128 samples Training loss at step 200: 0.0402 Seen so far: 25728 samples Training loss at step 400: 0.0002 Seen so far: 51328 samples Training acc over epoch: 0.9966 Validation acc: 0.9784 Time taken: 1.34s Start of epoch 28 Training loss at step 0: 0.0005 Seen so far: 128 samples Training loss at step 200: 0.0307 Seen so far: 25728 samples Training loss at step 400: 0.1466 Seen so far: 51328 samples Training acc over epoch: 0.9948 Validation acc: 0.9742 Time taken: 1.37s Start of epoch 29 Training loss at step 0: 0.0092 Seen so far: 128 samples Training loss at step 200: 0.0039 Seen so far: 25728 samples Training loss at step 400: 0.0358 Seen so far: 51328 samples Training acc over epoch: 0.9959 Validation acc: 0.9791 Time taken: 1.33s
Next steps
- Learn more about the Keras built-in early stopping callback API in the API docs.
- Learn to write custom Keras callbacks, including early stopping at a minimum loss.
- Learn about Training and evaluation with the Keras built-in methods.
- Explore common regularization techniques in the Overfit and underfit tutorial that uses the
EarlyStopping
callback.