View on TensorFlow.org
   | 
  
    
     
    Run in Google Colab
   | 
  
    
     
    View source on GitHub
   | 
  
     Download notebook
   | 
This notebook demonstrates how you can set up model training with early stopping, first, in TensorFlow 1 with tf.estimator.Estimator and an early stopping hook, and then, in TensorFlow 2 with Keras APIs or a custom training loop. Early stopping is a regularization technique that stops training if, for example, the validation loss reaches a certain threshold.
In TensorFlow 2, there are three ways to implement early stopping:
- Use a built-in Keras callback—
tf.keras.callbacks.EarlyStopping—and pass it toModel.fit. - Define a custom callback and pass it to Keras 
Model.fit. - Write a custom early stopping rule in a custom training loop (with 
tf.GradientTape). 
Setup
import time
import numpy as np
import tensorflow as tf
import tensorflow.compat.v1 as tf1
import tensorflow_datasets as tfds
2023-10-04 01:37:29.125012: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered 2023-10-04 01:37:29.125061: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered 2023-10-04 01:37:29.125095: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
TensorFlow 1: Early stopping with an early stopping hook and tf.estimator
Start by defining functions for MNIST dataset loading and preprocessing, and model definition to be used with tf.estimator.Estimator:
def normalize_img(image, label):
  return tf.cast(image, tf.float32) / 255., label
def _input_fn():
  ds_train = tfds.load(
    name='mnist',
    split='train',
    shuffle_files=True,
    as_supervised=True)
  ds_train = ds_train.map(
      normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
  ds_train = ds_train.batch(128)
  ds_train = ds_train.repeat(100)
  return ds_train
def _eval_input_fn():
  ds_test = tfds.load(
    name='mnist',
    split='test',
    shuffle_files=True,
    as_supervised=True)
  ds_test = ds_test.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
  ds_test = ds_test.batch(128)
  return ds_test
def _model_fn(features, labels, mode):
  flatten = tf1.layers.Flatten()(features)
  features = tf1.layers.Dense(128, 'relu')(flatten)
  logits = tf1.layers.Dense(10)(features)
  loss = tf1.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
  optimizer = tf1.train.AdagradOptimizer(0.005)
  train_op = optimizer.minimize(loss, global_step=tf1.train.get_global_step())
  return tf1.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)
In TensorFlow 1, early stopping works by setting up an early stopping hook with tf.estimator.experimental.make_early_stopping_hook. You pass the hook to the make_early_stopping_hook method as a parameter for should_stop_fn, which can accept a function without any arguments. The training stops once should_stop_fn returns True.
The following example demonstrates how to implement an early stopping technique that limits the training time to a maximum of 20 seconds:
estimator = tf1.estimator.Estimator(model_fn=_model_fn)
start_time = time.time()
max_train_seconds = 20
def should_stop_fn():
  return time.time() - start_time > max_train_seconds
early_stopping_hook = tf1.estimator.experimental.make_early_stopping_hook(
    estimator=estimator,
    should_stop_fn=should_stop_fn,
    run_every_secs=1,
    run_every_steps=None)
train_spec = tf1.estimator.TrainSpec(
    input_fn=_input_fn,
    hooks=[early_stopping_hook])
eval_spec = tf1.estimator.EvalSpec(input_fn=_eval_input_fn)
tf1.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_29043/1011025907.py:1: Estimator.__init__ (from tensorflow_estimator.python.estimator.estimator) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1844: RunConfig.__init__ (from tensorflow_estimator.python.estimator.run_config) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
INFO:tensorflow:Using default config.
WARNING:tensorflow:Using temporary folder as model directory: /tmpfs/tmp/tmpp8x_ipb8
INFO:tensorflow:Using config: {'_model_dir': '/tmpfs/tmp/tmpp8x_ipb8', '_tf_random_seed': None, '_save_summary_steps': 100, '_save_checkpoints_steps': None, '_save_checkpoints_secs': 600, '_session_config': allow_soft_placement: true
graph_options {
  rewrite_options {
    meta_optimizer_iterations: ONE
  }
}
, '_keep_checkpoint_max': 5, '_keep_checkpoint_every_n_hours': 10000, '_log_step_count_steps': 100, '_train_distribute': None, '_device_fn': None, '_protocol': None, '_eval_distribute': None, '_experimental_distribute': None, '_experimental_max_worker_delay_secs': None, '_session_creation_timeout_secs': 7200, '_checkpoint_save_graph_def': True, '_service': None, '_cluster_spec': ClusterSpec({}), '_task_type': 'worker', '_task_id': 0, '_global_id_in_cluster': 0, '_master': '', '_evaluation_master': '', '_is_chief': True, '_num_ps_replicas': 0, '_num_worker_replicas': 1}
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_29043/1011025907.py:9: make_early_stopping_hook (from tensorflow_estimator.python.estimator.early_stopping) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/early_stopping.py:474: SecondOrStepTimer.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_29043/1011025907.py:15: TrainSpec.__new__ (from tensorflow_estimator.python.estimator.training) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_29043/1011025907.py:19: EvalSpec.__new__ (from tensorflow_estimator.python.estimator.training) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_29043/1011025907.py:21: train_and_evaluate (from tensorflow_estimator.python.estimator.training) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
INFO:tensorflow:Not using Distribute Coordinator.
INFO:tensorflow:Running training and evaluation locally (non-distributed).
INFO:tensorflow:Start train and evaluate loop. The evaluate will happen after every checkpoint. Checkpoint frequency is determined based on RunConfig arguments: save_checkpoints_steps None or save_checkpoints_secs 600.
2023-10-04 01:37:32.147816: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2211] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/adagrad.py:138: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/adagrad.py:138: calling Constant.__init__ (from tensorflow.python.ops.init_ops) with dtype is deprecated and will be removed in a future version.
Instructions for updating:
Call initializer instance with the dtype argument instead of passing it to the constructor
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_29043/1468818800.py:37: EstimatorSpec.__new__ (from tensorflow_estimator.python.estimator.model_fn) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_29043/1468818800.py:37: EstimatorSpec.__new__ (from tensorflow_estimator.python.estimator.model_fn) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1416: NanTensorHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1416: NanTensorHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1419: LoggingTensorHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1419: LoggingTensorHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1456: CheckpointSaverHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow_estimator/python/estimator/estimator.py:1456: CheckpointSaverHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
INFO:tensorflow:Create CheckpointSaverHook.
INFO:tensorflow:Create CheckpointSaverHook.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:579: StepCounterHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:579: StepCounterHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:586: SummarySaverHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:586: SummarySaverHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Running local_init_op.
2023-10-04 01:37:34.401540: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2211] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 0...
INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmpp8x_ipb8/model.ckpt.
INFO:tensorflow:Saving checkpoints for 0 into /tmpfs/tmp/tmpp8x_ipb8/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 0...
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1455: SessionRunArgs.__new__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1455: SessionRunArgs.__new__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1454: SessionRunContext.__init__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1454: SessionRunContext.__init__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1474: SessionRunValues.__new__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/monitored_session.py:1474: SessionRunValues.__new__ (from tensorflow.python.training.session_run_hook) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
INFO:tensorflow:loss = 2.2961025, step = 0
INFO:tensorflow:loss = 2.2961025, step = 0
INFO:tensorflow:global_step/sec: 335.823
INFO:tensorflow:global_step/sec: 335.823
INFO:tensorflow:loss = 1.318093, step = 100 (0.300 sec)
INFO:tensorflow:loss = 1.318093, step = 100 (0.300 sec)
INFO:tensorflow:global_step/sec: 423.864
INFO:tensorflow:global_step/sec: 423.864
INFO:tensorflow:loss = 0.8593272, step = 200 (0.236 sec)
INFO:tensorflow:loss = 0.8593272, step = 200 (0.236 sec)
INFO:tensorflow:global_step/sec: 436.472
INFO:tensorflow:global_step/sec: 436.472
INFO:tensorflow:loss = 0.74179834, step = 300 (0.229 sec)
INFO:tensorflow:loss = 0.74179834, step = 300 (0.229 sec)
INFO:tensorflow:global_step/sec: 432.337
INFO:tensorflow:global_step/sec: 432.337
INFO:tensorflow:loss = 0.64878374, step = 400 (0.232 sec)
INFO:tensorflow:loss = 0.64878374, step = 400 (0.232 sec)
INFO:tensorflow:global_step/sec: 425.537
INFO:tensorflow:global_step/sec: 425.537
INFO:tensorflow:loss = 0.49600804, step = 500 (0.234 sec)
INFO:tensorflow:loss = 0.49600804, step = 500 (0.234 sec)
INFO:tensorflow:global_step/sec: 505.423
INFO:tensorflow:global_step/sec: 505.423
INFO:tensorflow:loss = 0.44587734, step = 600 (0.198 sec)
INFO:tensorflow:loss = 0.44587734, step = 600 (0.198 sec)
INFO:tensorflow:global_step/sec: 518.525
INFO:tensorflow:global_step/sec: 518.525
INFO:tensorflow:loss = 0.38416976, step = 700 (0.193 sec)
INFO:tensorflow:loss = 0.38416976, step = 700 (0.193 sec)
INFO:tensorflow:global_step/sec: 509.36
INFO:tensorflow:global_step/sec: 509.36
INFO:tensorflow:loss = 0.5128687, step = 800 (0.197 sec)
INFO:tensorflow:loss = 0.5128687, step = 800 (0.197 sec)
INFO:tensorflow:global_step/sec: 531.176
INFO:tensorflow:global_step/sec: 531.176
INFO:tensorflow:loss = 0.38511667, step = 900 (0.188 sec)
INFO:tensorflow:loss = 0.38511667, step = 900 (0.188 sec)
INFO:tensorflow:global_step/sec: 465.699
INFO:tensorflow:global_step/sec: 465.699
INFO:tensorflow:loss = 0.43401116, step = 1000 (0.215 sec)
INFO:tensorflow:loss = 0.43401116, step = 1000 (0.215 sec)
INFO:tensorflow:global_step/sec: 495.799
INFO:tensorflow:global_step/sec: 495.799
INFO:tensorflow:loss = 0.44691753, step = 1100 (0.203 sec)
INFO:tensorflow:loss = 0.44691753, step = 1100 (0.203 sec)
INFO:tensorflow:global_step/sec: 504.403
INFO:tensorflow:global_step/sec: 504.403
INFO:tensorflow:loss = 0.40272003, step = 1200 (0.198 sec)
INFO:tensorflow:loss = 0.40272003, step = 1200 (0.198 sec)
INFO:tensorflow:global_step/sec: 482.232
INFO:tensorflow:global_step/sec: 482.232
INFO:tensorflow:loss = 0.47271937, step = 1300 (0.207 sec)
INFO:tensorflow:loss = 0.47271937, step = 1300 (0.207 sec)
INFO:tensorflow:global_step/sec: 501.553
INFO:tensorflow:global_step/sec: 501.553
INFO:tensorflow:loss = 0.29635084, step = 1400 (0.200 sec)
INFO:tensorflow:loss = 0.29635084, step = 1400 (0.200 sec)
INFO:tensorflow:global_step/sec: 459.337
INFO:tensorflow:global_step/sec: 459.337
INFO:tensorflow:loss = 0.293486, step = 1500 (0.218 sec)
INFO:tensorflow:loss = 0.293486, step = 1500 (0.218 sec)
INFO:tensorflow:global_step/sec: 511.454
INFO:tensorflow:global_step/sec: 511.454
INFO:tensorflow:loss = 0.40195698, step = 1600 (0.196 sec)
INFO:tensorflow:loss = 0.40195698, step = 1600 (0.196 sec)
INFO:tensorflow:global_step/sec: 487.452
INFO:tensorflow:global_step/sec: 487.452
INFO:tensorflow:loss = 0.38753498, step = 1700 (0.206 sec)
INFO:tensorflow:loss = 0.38753498, step = 1700 (0.206 sec)
INFO:tensorflow:global_step/sec: 500.287
INFO:tensorflow:global_step/sec: 500.287
INFO:tensorflow:loss = 0.3344679, step = 1800 (0.199 sec)
INFO:tensorflow:loss = 0.3344679, step = 1800 (0.199 sec)
INFO:tensorflow:global_step/sec: 476.868
INFO:tensorflow:global_step/sec: 476.868
INFO:tensorflow:loss = 0.49753922, step = 1900 (0.210 sec)
INFO:tensorflow:loss = 0.49753922, step = 1900 (0.210 sec)
INFO:tensorflow:global_step/sec: 484.689
INFO:tensorflow:global_step/sec: 484.689
INFO:tensorflow:loss = 0.21684857, step = 2000 (0.207 sec)
INFO:tensorflow:loss = 0.21684857, step = 2000 (0.207 sec)
INFO:tensorflow:global_step/sec: 496.691
INFO:tensorflow:global_step/sec: 496.691
INFO:tensorflow:loss = 0.28068116, step = 2100 (0.202 sec)
INFO:tensorflow:loss = 0.28068116, step = 2100 (0.202 sec)
INFO:tensorflow:global_step/sec: 499.682
INFO:tensorflow:global_step/sec: 499.682
INFO:tensorflow:loss = 0.3000077, step = 2200 (0.200 sec)
INFO:tensorflow:loss = 0.3000077, step = 2200 (0.200 sec)
INFO:tensorflow:global_step/sec: 507.391
INFO:tensorflow:global_step/sec: 507.391
INFO:tensorflow:loss = 0.34870982, step = 2300 (0.197 sec)
INFO:tensorflow:loss = 0.34870982, step = 2300 (0.197 sec)
INFO:tensorflow:global_step/sec: 475.414
INFO:tensorflow:global_step/sec: 475.414
INFO:tensorflow:loss = 0.24876948, step = 2400 (0.211 sec)
INFO:tensorflow:loss = 0.24876948, step = 2400 (0.211 sec)
INFO:tensorflow:global_step/sec: 499.39
INFO:tensorflow:global_step/sec: 499.39
INFO:tensorflow:loss = 0.21644332, step = 2500 (0.200 sec)
INFO:tensorflow:loss = 0.21644332, step = 2500 (0.200 sec)
INFO:tensorflow:global_step/sec: 500.659
INFO:tensorflow:global_step/sec: 500.659
INFO:tensorflow:loss = 0.152693, step = 2600 (0.200 sec)
INFO:tensorflow:loss = 0.152693, step = 2600 (0.200 sec)
INFO:tensorflow:global_step/sec: 500.201
INFO:tensorflow:global_step/sec: 500.201
INFO:tensorflow:loss = 0.33327985, step = 2700 (0.200 sec)
INFO:tensorflow:loss = 0.33327985, step = 2700 (0.200 sec)
INFO:tensorflow:global_step/sec: 517.488
INFO:tensorflow:global_step/sec: 517.488
INFO:tensorflow:loss = 0.47266263, step = 2800 (0.193 sec)
INFO:tensorflow:loss = 0.47266263, step = 2800 (0.193 sec)
INFO:tensorflow:global_step/sec: 466.55
INFO:tensorflow:global_step/sec: 466.55
INFO:tensorflow:loss = 0.24876104, step = 2900 (0.215 sec)
INFO:tensorflow:loss = 0.24876104, step = 2900 (0.215 sec)
INFO:tensorflow:global_step/sec: 499.818
INFO:tensorflow:global_step/sec: 499.818
INFO:tensorflow:loss = 0.33199376, step = 3000 (0.200 sec)
INFO:tensorflow:loss = 0.33199376, step = 3000 (0.200 sec)
INFO:tensorflow:global_step/sec: 477.253
INFO:tensorflow:global_step/sec: 477.253
INFO:tensorflow:loss = 0.19820198, step = 3100 (0.210 sec)
INFO:tensorflow:loss = 0.19820198, step = 3100 (0.210 sec)
INFO:tensorflow:global_step/sec: 512.079
INFO:tensorflow:global_step/sec: 512.079
INFO:tensorflow:loss = 0.4163157, step = 3200 (0.195 sec)
INFO:tensorflow:loss = 0.4163157, step = 3200 (0.195 sec)
INFO:tensorflow:global_step/sec: 461.596
INFO:tensorflow:global_step/sec: 461.596
INFO:tensorflow:loss = 0.3364423, step = 3300 (0.216 sec)
INFO:tensorflow:loss = 0.3364423, step = 3300 (0.216 sec)
INFO:tensorflow:global_step/sec: 488.092
INFO:tensorflow:global_step/sec: 488.092
INFO:tensorflow:loss = 0.25606278, step = 3400 (0.206 sec)
INFO:tensorflow:loss = 0.25606278, step = 3400 (0.206 sec)
INFO:tensorflow:global_step/sec: 490.26
INFO:tensorflow:global_step/sec: 490.26
INFO:tensorflow:loss = 0.20572862, step = 3500 (0.204 sec)
INFO:tensorflow:loss = 0.20572862, step = 3500 (0.204 sec)
INFO:tensorflow:global_step/sec: 487.323
INFO:tensorflow:global_step/sec: 487.323
INFO:tensorflow:loss = 0.2212799, step = 3600 (0.206 sec)
INFO:tensorflow:loss = 0.2212799, step = 3600 (0.206 sec)
INFO:tensorflow:global_step/sec: 492.262
INFO:tensorflow:global_step/sec: 492.262
INFO:tensorflow:loss = 0.282781, step = 3700 (0.203 sec)
INFO:tensorflow:loss = 0.282781, step = 3700 (0.203 sec)
INFO:tensorflow:global_step/sec: 449.227
INFO:tensorflow:global_step/sec: 449.227
INFO:tensorflow:loss = 0.365446, step = 3800 (0.223 sec)
INFO:tensorflow:loss = 0.365446, step = 3800 (0.223 sec)
INFO:tensorflow:global_step/sec: 521.924
INFO:tensorflow:global_step/sec: 521.924
INFO:tensorflow:loss = 0.22579709, step = 3900 (0.191 sec)
INFO:tensorflow:loss = 0.22579709, step = 3900 (0.191 sec)
INFO:tensorflow:global_step/sec: 535.844
INFO:tensorflow:global_step/sec: 535.844
INFO:tensorflow:loss = 0.30844557, step = 4000 (0.187 sec)
INFO:tensorflow:loss = 0.30844557, step = 4000 (0.187 sec)
INFO:tensorflow:global_step/sec: 496.551
INFO:tensorflow:global_step/sec: 496.551
INFO:tensorflow:loss = 0.21947613, step = 4100 (0.202 sec)
INFO:tensorflow:loss = 0.21947613, step = 4100 (0.202 sec)
INFO:tensorflow:global_step/sec: 508.086
INFO:tensorflow:global_step/sec: 508.086
INFO:tensorflow:loss = 0.26513258, step = 4200 (0.197 sec)
INFO:tensorflow:loss = 0.26513258, step = 4200 (0.197 sec)
INFO:tensorflow:global_step/sec: 462.149
INFO:tensorflow:global_step/sec: 462.149
INFO:tensorflow:loss = 0.29323363, step = 4300 (0.217 sec)
INFO:tensorflow:loss = 0.29323363, step = 4300 (0.217 sec)
INFO:tensorflow:global_step/sec: 503.226
INFO:tensorflow:global_step/sec: 503.226
INFO:tensorflow:loss = 0.31204918, step = 4400 (0.198 sec)
INFO:tensorflow:loss = 0.31204918, step = 4400 (0.198 sec)
INFO:tensorflow:global_step/sec: 510.874
INFO:tensorflow:global_step/sec: 510.874
INFO:tensorflow:loss = 0.26014802, step = 4500 (0.196 sec)
INFO:tensorflow:loss = 0.26014802, step = 4500 (0.196 sec)
INFO:tensorflow:global_step/sec: 511.332
INFO:tensorflow:global_step/sec: 511.332
INFO:tensorflow:loss = 0.33227044, step = 4600 (0.196 sec)
INFO:tensorflow:loss = 0.33227044, step = 4600 (0.196 sec)
INFO:tensorflow:global_step/sec: 472.525
INFO:tensorflow:global_step/sec: 472.525
INFO:tensorflow:loss = 0.13610935, step = 4700 (0.211 sec)
INFO:tensorflow:loss = 0.13610935, step = 4700 (0.211 sec)
INFO:tensorflow:global_step/sec: 505.262
INFO:tensorflow:global_step/sec: 505.262
INFO:tensorflow:loss = 0.28594193, step = 4800 (0.198 sec)
INFO:tensorflow:loss = 0.28594193, step = 4800 (0.198 sec)
INFO:tensorflow:global_step/sec: 515.453
INFO:tensorflow:global_step/sec: 515.453
INFO:tensorflow:loss = 0.38484165, step = 4900 (0.195 sec)
INFO:tensorflow:loss = 0.38484165, step = 4900 (0.195 sec)
INFO:tensorflow:global_step/sec: 479.259
INFO:tensorflow:global_step/sec: 479.259
INFO:tensorflow:loss = 0.27081215, step = 5000 (0.208 sec)
INFO:tensorflow:loss = 0.27081215, step = 5000 (0.208 sec)
INFO:tensorflow:global_step/sec: 491.361
INFO:tensorflow:global_step/sec: 491.361
INFO:tensorflow:loss = 0.33877313, step = 5100 (0.204 sec)
INFO:tensorflow:loss = 0.33877313, step = 5100 (0.204 sec)
INFO:tensorflow:global_step/sec: 471.256
INFO:tensorflow:global_step/sec: 471.256
INFO:tensorflow:loss = 0.2074028, step = 5200 (0.212 sec)
INFO:tensorflow:loss = 0.2074028, step = 5200 (0.212 sec)
INFO:tensorflow:global_step/sec: 506.672
INFO:tensorflow:global_step/sec: 506.672
INFO:tensorflow:loss = 0.24718614, step = 5300 (0.197 sec)
INFO:tensorflow:loss = 0.24718614, step = 5300 (0.197 sec)
INFO:tensorflow:global_step/sec: 491.333
INFO:tensorflow:global_step/sec: 491.333
INFO:tensorflow:loss = 0.16439602, step = 5400 (0.203 sec)
INFO:tensorflow:loss = 0.16439602, step = 5400 (0.203 sec)
INFO:tensorflow:global_step/sec: 500.638
INFO:tensorflow:global_step/sec: 500.638
INFO:tensorflow:loss = 0.22073755, step = 5500 (0.200 sec)
INFO:tensorflow:loss = 0.22073755, step = 5500 (0.200 sec)
INFO:tensorflow:global_step/sec: 508.206
INFO:tensorflow:global_step/sec: 508.206
INFO:tensorflow:loss = 0.18545151, step = 5600 (0.196 sec)
INFO:tensorflow:loss = 0.18545151, step = 5600 (0.196 sec)
INFO:tensorflow:global_step/sec: 450.634
INFO:tensorflow:global_step/sec: 450.634
INFO:tensorflow:loss = 0.17126478, step = 5700 (0.222 sec)
INFO:tensorflow:loss = 0.17126478, step = 5700 (0.222 sec)
INFO:tensorflow:global_step/sec: 525.95
INFO:tensorflow:global_step/sec: 525.95
INFO:tensorflow:loss = 0.2689212, step = 5800 (0.190 sec)
INFO:tensorflow:loss = 0.2689212, step = 5800 (0.190 sec)
INFO:tensorflow:global_step/sec: 483.04
INFO:tensorflow:global_step/sec: 483.04
INFO:tensorflow:loss = 0.21353054, step = 5900 (0.207 sec)
INFO:tensorflow:loss = 0.21353054, step = 5900 (0.207 sec)
INFO:tensorflow:global_step/sec: 480.064
INFO:tensorflow:global_step/sec: 480.064
INFO:tensorflow:loss = 0.25207376, step = 6000 (0.209 sec)
INFO:tensorflow:loss = 0.25207376, step = 6000 (0.209 sec)
INFO:tensorflow:global_step/sec: 451.147
INFO:tensorflow:global_step/sec: 451.147
INFO:tensorflow:loss = 0.18487616, step = 6100 (0.221 sec)
INFO:tensorflow:loss = 0.18487616, step = 6100 (0.221 sec)
INFO:tensorflow:global_step/sec: 490.621
INFO:tensorflow:global_step/sec: 490.621
INFO:tensorflow:loss = 0.267612, step = 6200 (0.204 sec)
INFO:tensorflow:loss = 0.267612, step = 6200 (0.204 sec)
INFO:tensorflow:global_step/sec: 499.641
INFO:tensorflow:global_step/sec: 499.641
INFO:tensorflow:loss = 0.24492177, step = 6300 (0.200 sec)
INFO:tensorflow:loss = 0.24492177, step = 6300 (0.200 sec)
INFO:tensorflow:global_step/sec: 498.738
INFO:tensorflow:global_step/sec: 498.738
INFO:tensorflow:loss = 0.28542638, step = 6400 (0.200 sec)
INFO:tensorflow:loss = 0.28542638, step = 6400 (0.200 sec)
INFO:tensorflow:global_step/sec: 523.196
INFO:tensorflow:global_step/sec: 523.196
INFO:tensorflow:loss = 0.26353425, step = 6500 (0.191 sec)
INFO:tensorflow:loss = 0.26353425, step = 6500 (0.191 sec)
INFO:tensorflow:global_step/sec: 455.054
INFO:tensorflow:global_step/sec: 455.054
INFO:tensorflow:loss = 0.19190696, step = 6600 (0.220 sec)
INFO:tensorflow:loss = 0.19190696, step = 6600 (0.220 sec)
INFO:tensorflow:global_step/sec: 522.506
INFO:tensorflow:global_step/sec: 522.506
INFO:tensorflow:loss = 0.25657248, step = 6700 (0.191 sec)
INFO:tensorflow:loss = 0.25657248, step = 6700 (0.191 sec)
INFO:tensorflow:global_step/sec: 506.766
INFO:tensorflow:global_step/sec: 506.766
INFO:tensorflow:loss = 0.39108855, step = 6800 (0.197 sec)
INFO:tensorflow:loss = 0.39108855, step = 6800 (0.197 sec)
INFO:tensorflow:global_step/sec: 481.896
INFO:tensorflow:global_step/sec: 481.896
INFO:tensorflow:loss = 0.15236983, step = 6900 (0.208 sec)
INFO:tensorflow:loss = 0.15236983, step = 6900 (0.208 sec)
INFO:tensorflow:global_step/sec: 506.833
INFO:tensorflow:global_step/sec: 506.833
INFO:tensorflow:loss = 0.3356074, step = 7000 (0.197 sec)
INFO:tensorflow:loss = 0.3356074, step = 7000 (0.197 sec)
INFO:tensorflow:global_step/sec: 467.501
INFO:tensorflow:global_step/sec: 467.501
INFO:tensorflow:loss = 0.1956916, step = 7100 (0.214 sec)
INFO:tensorflow:loss = 0.1956916, step = 7100 (0.214 sec)
INFO:tensorflow:global_step/sec: 494.573
INFO:tensorflow:global_step/sec: 494.573
INFO:tensorflow:loss = 0.24700633, step = 7200 (0.203 sec)
INFO:tensorflow:loss = 0.24700633, step = 7200 (0.203 sec)
INFO:tensorflow:global_step/sec: 522.627
INFO:tensorflow:global_step/sec: 522.627
INFO:tensorflow:loss = 0.17747968, step = 7300 (0.191 sec)
INFO:tensorflow:loss = 0.17747968, step = 7300 (0.191 sec)
INFO:tensorflow:global_step/sec: 486.106
INFO:tensorflow:global_step/sec: 486.106
INFO:tensorflow:loss = 0.28143722, step = 7400 (0.206 sec)
INFO:tensorflow:loss = 0.28143722, step = 7400 (0.206 sec)
INFO:tensorflow:global_step/sec: 479.817
INFO:tensorflow:global_step/sec: 479.817
INFO:tensorflow:loss = 0.14399035, step = 7500 (0.209 sec)
INFO:tensorflow:loss = 0.14399035, step = 7500 (0.209 sec)
INFO:tensorflow:global_step/sec: 487.355
INFO:tensorflow:global_step/sec: 487.355
INFO:tensorflow:loss = 0.2522769, step = 7600 (0.205 sec)
INFO:tensorflow:loss = 0.2522769, step = 7600 (0.205 sec)
INFO:tensorflow:global_step/sec: 526.322
INFO:tensorflow:global_step/sec: 526.322
INFO:tensorflow:loss = 0.25871283, step = 7700 (0.190 sec)
INFO:tensorflow:loss = 0.25871283, step = 7700 (0.190 sec)
INFO:tensorflow:global_step/sec: 501.783
INFO:tensorflow:global_step/sec: 501.783
INFO:tensorflow:loss = 0.16747415, step = 7800 (0.199 sec)
INFO:tensorflow:loss = 0.16747415, step = 7800 (0.199 sec)
INFO:tensorflow:global_step/sec: 485.409
INFO:tensorflow:global_step/sec: 485.409
INFO:tensorflow:loss = 0.15161222, step = 7900 (0.205 sec)
INFO:tensorflow:loss = 0.15161222, step = 7900 (0.205 sec)
INFO:tensorflow:global_step/sec: 462.017
INFO:tensorflow:global_step/sec: 462.017
INFO:tensorflow:loss = 0.18137535, step = 8000 (0.216 sec)
INFO:tensorflow:loss = 0.18137535, step = 8000 (0.216 sec)
INFO:tensorflow:global_step/sec: 488.439
INFO:tensorflow:global_step/sec: 488.439
INFO:tensorflow:loss = 0.19478491, step = 8100 (0.205 sec)
INFO:tensorflow:loss = 0.19478491, step = 8100 (0.205 sec)
INFO:tensorflow:global_step/sec: 485.238
INFO:tensorflow:global_step/sec: 485.238
INFO:tensorflow:loss = 0.27511528, step = 8200 (0.206 sec)
INFO:tensorflow:loss = 0.27511528, step = 8200 (0.206 sec)
INFO:tensorflow:Requesting early stopping at global step 8286
INFO:tensorflow:Requesting early stopping at global step 8286
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 8287...
INFO:tensorflow:Calling checkpoint listeners before saving checkpoint 8287...
INFO:tensorflow:Saving checkpoints for 8287 into /tmpfs/tmp/tmpp8x_ipb8/model.ckpt.
INFO:tensorflow:Saving checkpoints for 8287 into /tmpfs/tmp/tmpp8x_ipb8/model.ckpt.
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 8287...
INFO:tensorflow:Calling checkpoint listeners after saving checkpoint 8287...
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Done calling model_fn.
INFO:tensorflow:Starting evaluation at 2023-10-04T01:37:52
INFO:tensorflow:Starting evaluation at 2023-10-04T01:37:52
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/evaluation.py:260: FinalOpsHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/python/training/evaluation.py:260: FinalOpsHook.__init__ (from tensorflow.python.training.basic_session_run_hooks) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.keras instead.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Graph was finalized.
INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmpp8x_ipb8/model.ckpt-8287
2023-10-04 01:37:52.957202: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2211] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...
INFO:tensorflow:Restoring parameters from /tmpfs/tmp/tmpp8x_ipb8/model.ckpt-8287
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Done running local_init_op.
INFO:tensorflow:Evaluation [10/100]
INFO:tensorflow:Evaluation [10/100]
INFO:tensorflow:Evaluation [20/100]
INFO:tensorflow:Evaluation [20/100]
INFO:tensorflow:Evaluation [30/100]
INFO:tensorflow:Evaluation [30/100]
INFO:tensorflow:Evaluation [40/100]
INFO:tensorflow:Evaluation [40/100]
INFO:tensorflow:Evaluation [50/100]
INFO:tensorflow:Evaluation [50/100]
INFO:tensorflow:Evaluation [60/100]
INFO:tensorflow:Evaluation [60/100]
INFO:tensorflow:Evaluation [70/100]
INFO:tensorflow:Evaluation [70/100]
INFO:tensorflow:Inference Time : 0.51915s
INFO:tensorflow:Inference Time : 0.51915s
INFO:tensorflow:Finished evaluation at 2023-10-04-01:37:53
INFO:tensorflow:Finished evaluation at 2023-10-04-01:37:53
INFO:tensorflow:Saving dict for global step 8287: global_step = 8287, loss = 0.21179305
INFO:tensorflow:Saving dict for global step 8287: global_step = 8287, loss = 0.21179305
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 8287: /tmpfs/tmp/tmpp8x_ipb8/model.ckpt-8287
INFO:tensorflow:Saving 'checkpoint_path' summary for global step 8287: /tmpfs/tmp/tmpp8x_ipb8/model.ckpt-8287
INFO:tensorflow:Loss for final step: 0.28788015.
INFO:tensorflow:Loss for final step: 0.28788015.
({'loss': 0.21179305, 'global_step': 8287}, [])
TensorFlow 2: Early stopping with a built-in callback and Model.fit
Prepare the MNIST dataset and a simple Keras model:
(ds_train, ds_test), ds_info = tfds.load(
    'mnist',
    split=['train', 'test'],
    shuffle_files=True,
    as_supervised=True,
    with_info=True,
)
ds_train = ds_train.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_train = ds_train.batch(128)
ds_test = ds_test.map(
    normalize_img, num_parallel_calls=tf.data.AUTOTUNE)
ds_test = ds_test.batch(128)
model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dense(10)
])
model.compile(
    optimizer=tf.keras.optimizers.Adam(0.005),
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)
In TensorFlow 2, when you use the built-in Keras Model.fit (or Model.evaluate), you can configure early stopping by passing a built-in callback—tf.keras.callbacks.EarlyStopping—to the callbacks parameter of Model.fit.
The EarlyStopping callback monitors a user-specified metric and ends training when it stops improving. (Check the Training and evaluation with the built-in methods or the API docs for more information.)
Below is an example of an early stopping callback that monitors the loss and stops training after the number of epochs that show no improvements is set to 3 (patience): 
callback = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=3)
# Only around 25 epochs are run during training, instead of 100.
history = model.fit(
    ds_train,
    epochs=100,
    validation_data=ds_test,
    callbacks=[callback]
)
len(history.history['loss'])
Epoch 1/100 469/469 [==============================] - 3s 4ms/step - loss: 0.2326 - sparse_categorical_accuracy: 0.9310 - val_loss: 0.1214 - val_sparse_categorical_accuracy: 0.9630 Epoch 2/100 469/469 [==============================] - 1s 3ms/step - loss: 0.1004 - sparse_categorical_accuracy: 0.9700 - val_loss: 0.1009 - val_sparse_categorical_accuracy: 0.9684 Epoch 3/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0689 - sparse_categorical_accuracy: 0.9792 - val_loss: 0.1061 - val_sparse_categorical_accuracy: 0.9674 Epoch 4/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0528 - sparse_categorical_accuracy: 0.9835 - val_loss: 0.1244 - val_sparse_categorical_accuracy: 0.9640 Epoch 5/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0424 - sparse_categorical_accuracy: 0.9866 - val_loss: 0.1001 - val_sparse_categorical_accuracy: 0.9730 Epoch 6/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0355 - sparse_categorical_accuracy: 0.9883 - val_loss: 0.1034 - val_sparse_categorical_accuracy: 0.9728 Epoch 7/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0315 - sparse_categorical_accuracy: 0.9894 - val_loss: 0.1029 - val_sparse_categorical_accuracy: 0.9748 Epoch 8/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0319 - sparse_categorical_accuracy: 0.9888 - val_loss: 0.1267 - val_sparse_categorical_accuracy: 0.9707 Epoch 9/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0280 - sparse_categorical_accuracy: 0.9905 - val_loss: 0.1128 - val_sparse_categorical_accuracy: 0.9733 Epoch 10/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0281 - sparse_categorical_accuracy: 0.9907 - val_loss: 0.1246 - val_sparse_categorical_accuracy: 0.9718 Epoch 11/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0270 - sparse_categorical_accuracy: 0.9905 - val_loss: 0.1184 - val_sparse_categorical_accuracy: 0.9786 Epoch 12/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0217 - sparse_categorical_accuracy: 0.9927 - val_loss: 0.1231 - val_sparse_categorical_accuracy: 0.9756 Epoch 13/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0191 - sparse_categorical_accuracy: 0.9935 - val_loss: 0.1466 - val_sparse_categorical_accuracy: 0.9716 Epoch 14/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0220 - sparse_categorical_accuracy: 0.9927 - val_loss: 0.1475 - val_sparse_categorical_accuracy: 0.9754 Epoch 15/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0200 - sparse_categorical_accuracy: 0.9934 - val_loss: 0.1472 - val_sparse_categorical_accuracy: 0.9746 Epoch 16/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0195 - sparse_categorical_accuracy: 0.9937 - val_loss: 0.1454 - val_sparse_categorical_accuracy: 0.9755 16
TensorFlow 2: Early stopping with a custom callback and Model.fit
You can also implement a custom early stopping callback, which can also be passed to the callbacks parameter of Model.fit (or Model.evaluate).
In this example, the training process is stopped once self.model.stop_training is set to be True:
class LimitTrainingTime(tf.keras.callbacks.Callback):
  def __init__(self, max_time_s):
    super().__init__()
    self.max_time_s = max_time_s
    self.start_time = None
  def on_train_begin(self, logs):
    self.start_time = time.time()
  def on_train_batch_end(self, batch, logs):
    now = time.time()
    if now - self.start_time >  self.max_time_s:
      self.model.stop_training = True
# Limit the training time to 30 seconds.
callback = LimitTrainingTime(30)
history = model.fit(
    ds_train,
    epochs=100,
    validation_data=ds_test,
    callbacks=[callback]
)
len(history.history['loss'])
Epoch 1/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0180 - sparse_categorical_accuracy: 0.9939 - val_loss: 0.1440 - val_sparse_categorical_accuracy: 0.9785 Epoch 2/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0139 - sparse_categorical_accuracy: 0.9958 - val_loss: 0.1728 - val_sparse_categorical_accuracy: 0.9755 Epoch 3/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0213 - sparse_categorical_accuracy: 0.9937 - val_loss: 0.1567 - val_sparse_categorical_accuracy: 0.9787 Epoch 4/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0198 - sparse_categorical_accuracy: 0.9941 - val_loss: 0.1826 - val_sparse_categorical_accuracy: 0.9740 Epoch 5/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0146 - sparse_categorical_accuracy: 0.9952 - val_loss: 0.1658 - val_sparse_categorical_accuracy: 0.9756 Epoch 6/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0126 - sparse_categorical_accuracy: 0.9959 - val_loss: 0.2150 - val_sparse_categorical_accuracy: 0.9726 Epoch 7/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0174 - sparse_categorical_accuracy: 0.9952 - val_loss: 0.1890 - val_sparse_categorical_accuracy: 0.9769 Epoch 8/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0140 - sparse_categorical_accuracy: 0.9956 - val_loss: 0.1852 - val_sparse_categorical_accuracy: 0.9772 Epoch 9/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0141 - sparse_categorical_accuracy: 0.9958 - val_loss: 0.2051 - val_sparse_categorical_accuracy: 0.9761 Epoch 10/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0156 - sparse_categorical_accuracy: 0.9961 - val_loss: 0.2260 - val_sparse_categorical_accuracy: 0.9727 Epoch 11/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0167 - sparse_categorical_accuracy: 0.9952 - val_loss: 0.2008 - val_sparse_categorical_accuracy: 0.9757 Epoch 12/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0133 - sparse_categorical_accuracy: 0.9963 - val_loss: 0.2283 - val_sparse_categorical_accuracy: 0.9755 Epoch 13/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0106 - sparse_categorical_accuracy: 0.9969 - val_loss: 0.2270 - val_sparse_categorical_accuracy: 0.9739 Epoch 14/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0139 - sparse_categorical_accuracy: 0.9962 - val_loss: 0.2169 - val_sparse_categorical_accuracy: 0.9775 Epoch 15/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0121 - sparse_categorical_accuracy: 0.9964 - val_loss: 0.2282 - val_sparse_categorical_accuracy: 0.9773 Epoch 16/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0128 - sparse_categorical_accuracy: 0.9966 - val_loss: 0.2723 - val_sparse_categorical_accuracy: 0.9738 Epoch 17/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0118 - sparse_categorical_accuracy: 0.9967 - val_loss: 0.2223 - val_sparse_categorical_accuracy: 0.9784 Epoch 18/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0168 - sparse_categorical_accuracy: 0.9959 - val_loss: 0.2489 - val_sparse_categorical_accuracy: 0.9770 Epoch 19/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0129 - sparse_categorical_accuracy: 0.9967 - val_loss: 0.2607 - val_sparse_categorical_accuracy: 0.9753 Epoch 20/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0112 - sparse_categorical_accuracy: 0.9975 - val_loss: 0.2267 - val_sparse_categorical_accuracy: 0.9776 Epoch 21/100 469/469 [==============================] - 1s 3ms/step - loss: 0.0105 - sparse_categorical_accuracy: 0.9971 - val_loss: 0.2758 - val_sparse_categorical_accuracy: 0.9733 21
TensorFlow 2: Early stopping with a custom training loop
In TensorFlow 2, you can implement early stopping in a custom training loop if you're not training and evaluating with the built-in Keras methods.
Start by using Keras APIs to define another simple model, an optimizer, a loss function, and metrics:
model = tf.keras.models.Sequential([
  tf.keras.layers.Flatten(input_shape=(28, 28)),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dense(10)
])
optimizer = tf.keras.optimizers.Adam(0.005)
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
train_acc_metric = tf.keras.metrics.SparseCategoricalAccuracy()
train_loss_metric = tf.keras.metrics.SparseCategoricalCrossentropy()
val_acc_metric = tf.keras.metrics.SparseCategoricalAccuracy()
val_loss_metric = tf.keras.metrics.SparseCategoricalCrossentropy()
Define the parameter update functions with tf.GradientTape and the @tf.function decorator for a speedup:
@tf.function
def train_step(x, y):
  with tf.GradientTape() as tape:
      logits = model(x, training=True)
      loss_value = loss_fn(y, logits)
  grads = tape.gradient(loss_value, model.trainable_weights)
  optimizer.apply_gradients(zip(grads, model.trainable_weights))
  train_acc_metric.update_state(y, logits)
  train_loss_metric.update_state(y, logits)
  return loss_value
@tf.function
def test_step(x, y):
  logits = model(x, training=False)
  val_acc_metric.update_state(y, logits)
  val_loss_metric.update_state(y, logits)
Next, write a custom training loop, where you can implement your early stopping rule manually.
The example below shows how to stop training when the validation loss doesn't improve over a certain number of epochs:
epochs = 100
patience = 5
wait = 0
best = float('inf')
for epoch in range(epochs):
    print("\nStart of epoch %d" % (epoch,))
    start_time = time.time()
    for step, (x_batch_train, y_batch_train) in enumerate(ds_train):
      loss_value = train_step(x_batch_train, y_batch_train)
      if step % 200 == 0:
        print("Training loss at step %d: %.4f" % (step, loss_value.numpy()))
        print("Seen so far: %s samples" % ((step + 1) * 128))        
    train_acc = train_acc_metric.result()
    train_loss = train_loss_metric.result()
    train_acc_metric.reset_states()
    train_loss_metric.reset_states()
    print("Training acc over epoch: %.4f" % (train_acc.numpy()))
    for x_batch_val, y_batch_val in ds_test:
      test_step(x_batch_val, y_batch_val)
    val_acc = val_acc_metric.result()
    val_loss = val_loss_metric.result()
    val_acc_metric.reset_states()
    val_loss_metric.reset_states()
    print("Validation acc: %.4f" % (float(val_acc),))
    print("Time taken: %.2fs" % (time.time() - start_time))
    # The early stopping strategy: stop the training if `val_loss` does not
    # decrease over a certain number of epochs.
    wait += 1
    if val_loss < best:
      best = val_loss
      wait = 0
    if wait >= patience:
      break
Start of epoch 0 Training loss at step 0: 2.4644 Seen so far: 128 samples Training loss at step 200: 0.2622 Seen so far: 25728 samples Training loss at step 400: 0.2129 Seen so far: 51328 samples Training acc over epoch: 0.9297 Validation acc: 0.9610 Time taken: 2.07s Start of epoch 1 Training loss at step 0: 0.0756 Seen so far: 128 samples Training loss at step 200: 0.1215 Seen so far: 25728 samples Training loss at step 400: 0.1359 Seen so far: 51328 samples Training acc over epoch: 0.9685 Validation acc: 0.9664 Time taken: 1.36s Start of epoch 2 Training loss at step 0: 0.0390 Seen so far: 128 samples Training loss at step 200: 0.0995 Seen so far: 25728 samples Training loss at step 400: 0.1270 Seen so far: 51328 samples Training acc over epoch: 0.9778 Validation acc: 0.9690 Time taken: 1.37s Start of epoch 3 Training loss at step 0: 0.0282 Seen so far: 128 samples Training loss at step 200: 0.0673 Seen so far: 25728 samples Training loss at step 400: 0.0700 Seen so far: 51328 samples Training acc over epoch: 0.9826 Validation acc: 0.9695 Time taken: 1.35s Start of epoch 4 Training loss at step 0: 0.0249 Seen so far: 128 samples Training loss at step 200: 0.0403 Seen so far: 25728 samples Training loss at step 400: 0.0324 Seen so far: 51328 samples Training acc over epoch: 0.9862 Validation acc: 0.9693 Time taken: 1.37s Start of epoch 5 Training loss at step 0: 0.0279 Seen so far: 128 samples Training loss at step 200: 0.0529 Seen so far: 25728 samples Training loss at step 400: 0.0592 Seen so far: 51328 samples Training acc over epoch: 0.9875 Validation acc: 0.9660 Time taken: 1.32s Start of epoch 6 Training loss at step 0: 0.0200 Seen so far: 128 samples Training loss at step 200: 0.0344 Seen so far: 25728 samples Training loss at step 400: 0.0142 Seen so far: 51328 samples Training acc over epoch: 0.9884 Validation acc: 0.9720 Time taken: 1.36s Start of epoch 7 Training loss at step 0: 0.0325 Seen so far: 128 samples Training loss at step 200: 0.0499 Seen so far: 25728 samples Training loss at step 400: 0.0167 Seen so far: 51328 samples Training acc over epoch: 0.9898 Validation acc: 0.9711 Time taken: 1.39s Start of epoch 8 Training loss at step 0: 0.0155 Seen so far: 128 samples Training loss at step 200: 0.0219 Seen so far: 25728 samples Training loss at step 400: 0.0213 Seen so far: 51328 samples Training acc over epoch: 0.9914 Validation acc: 0.9715 Time taken: 1.34s Start of epoch 9 Training loss at step 0: 0.0068 Seen so far: 128 samples Training loss at step 200: 0.0288 Seen so far: 25728 samples Training loss at step 400: 0.0320 Seen so far: 51328 samples Training acc over epoch: 0.9921 Validation acc: 0.9743 Time taken: 1.34s Start of epoch 10 Training loss at step 0: 0.0069 Seen so far: 128 samples Training loss at step 200: 0.0419 Seen so far: 25728 samples Training loss at step 400: 0.0458 Seen so far: 51328 samples Training acc over epoch: 0.9919 Validation acc: 0.9720 Time taken: 1.34s Start of epoch 11 Training loss at step 0: 0.0015 Seen so far: 128 samples Training loss at step 200: 0.0113 Seen so far: 25728 samples Training loss at step 400: 0.0285 Seen so far: 51328 samples Training acc over epoch: 0.9930 Validation acc: 0.9740 Time taken: 1.33s Start of epoch 12 Training loss at step 0: 0.0010 Seen so far: 128 samples Training loss at step 200: 0.0450 Seen so far: 25728 samples Training loss at step 400: 0.0561 Seen so far: 51328 samples Training acc over epoch: 0.9926 Validation acc: 0.9744 Time taken: 1.34s Start of epoch 13 Training loss at step 0: 0.0038 Seen so far: 128 samples Training loss at step 200: 0.0309 Seen so far: 25728 samples Training loss at step 400: 0.0124 Seen so far: 51328 samples Training acc over epoch: 0.9937 Validation acc: 0.9740 Time taken: 1.36s Start of epoch 14 Training loss at step 0: 0.0011 Seen so far: 128 samples Training loss at step 200: 0.0068 Seen so far: 25728 samples Training loss at step 400: 0.0110 Seen so far: 51328 samples Training acc over epoch: 0.9930 Validation acc: 0.9736 Time taken: 1.35s Start of epoch 15 Training loss at step 0: 0.0043 Seen so far: 128 samples Training loss at step 200: 0.0085 Seen so far: 25728 samples Training loss at step 400: 0.0042 Seen so far: 51328 samples Training acc over epoch: 0.9942 Validation acc: 0.9752 Time taken: 1.34s Start of epoch 16 Training loss at step 0: 0.0208 Seen so far: 128 samples Training loss at step 200: 0.0042 Seen so far: 25728 samples Training loss at step 400: 0.1063 Seen so far: 51328 samples Training acc over epoch: 0.9947 Validation acc: 0.9740 Time taken: 1.34s Start of epoch 17 Training loss at step 0: 0.0067 Seen so far: 128 samples Training loss at step 200: 0.0277 Seen so far: 25728 samples Training loss at step 400: 0.0787 Seen so far: 51328 samples Training acc over epoch: 0.9951 Validation acc: 0.9729 Time taken: 1.33s Start of epoch 18 Training loss at step 0: 0.0017 Seen so far: 128 samples Training loss at step 200: 0.0131 Seen so far: 25728 samples Training loss at step 400: 0.0431 Seen so far: 51328 samples Training acc over epoch: 0.9943 Validation acc: 0.9739 Time taken: 1.40s Start of epoch 19 Training loss at step 0: 0.0004 Seen so far: 128 samples Training loss at step 200: 0.0220 Seen so far: 25728 samples Training loss at step 400: 0.0662 Seen so far: 51328 samples Training acc over epoch: 0.9952 Validation acc: 0.9738 Time taken: 1.34s Start of epoch 20 Training loss at step 0: 0.0003 Seen so far: 128 samples Training loss at step 200: 0.0306 Seen so far: 25728 samples Training loss at step 400: 0.0083 Seen so far: 51328 samples Training acc over epoch: 0.9955 Validation acc: 0.9753 Time taken: 1.37s Start of epoch 21 Training loss at step 0: 0.0016 Seen so far: 128 samples Training loss at step 200: 0.0003 Seen so far: 25728 samples Training loss at step 400: 0.0069 Seen so far: 51328 samples Training acc over epoch: 0.9946 Validation acc: 0.9729 Time taken: 1.36s Start of epoch 22 Training loss at step 0: 0.0626 Seen so far: 128 samples Training loss at step 200: 0.0013 Seen so far: 25728 samples Training loss at step 400: 0.0278 Seen so far: 51328 samples Training acc over epoch: 0.9946 Validation acc: 0.9740 Time taken: 1.34s Start of epoch 23 Training loss at step 0: 0.0318 Seen so far: 128 samples Training loss at step 200: 0.0514 Seen so far: 25728 samples Training loss at step 400: 0.0001 Seen so far: 51328 samples Training acc over epoch: 0.9952 Validation acc: 0.9758 Time taken: 1.36s Start of epoch 24 Training loss at step 0: 0.0004 Seen so far: 128 samples Training loss at step 200: 0.0043 Seen so far: 25728 samples Training loss at step 400: 0.0339 Seen so far: 51328 samples Training acc over epoch: 0.9956 Validation acc: 0.9752 Time taken: 1.37s Start of epoch 25 Training loss at step 0: 0.0000 Seen so far: 128 samples Training loss at step 200: 0.0057 Seen so far: 25728 samples Training loss at step 400: 0.1485 Seen so far: 51328 samples Training acc over epoch: 0.9961 Validation acc: 0.9733 Time taken: 1.35s Start of epoch 26 Training loss at step 0: 0.0005 Seen so far: 128 samples Training loss at step 200: 0.0992 Seen so far: 25728 samples Training loss at step 400: 0.0033 Seen so far: 51328 samples Training acc over epoch: 0.9972 Validation acc: 0.9776 Time taken: 1.32s Start of epoch 27 Training loss at step 0: 0.0004 Seen so far: 128 samples Training loss at step 200: 0.0402 Seen so far: 25728 samples Training loss at step 400: 0.0002 Seen so far: 51328 samples Training acc over epoch: 0.9966 Validation acc: 0.9784 Time taken: 1.34s Start of epoch 28 Training loss at step 0: 0.0005 Seen so far: 128 samples Training loss at step 200: 0.0307 Seen so far: 25728 samples Training loss at step 400: 0.1466 Seen so far: 51328 samples Training acc over epoch: 0.9948 Validation acc: 0.9742 Time taken: 1.37s Start of epoch 29 Training loss at step 0: 0.0092 Seen so far: 128 samples Training loss at step 200: 0.0039 Seen so far: 25728 samples Training loss at step 400: 0.0358 Seen so far: 51328 samples Training acc over epoch: 0.9959 Validation acc: 0.9791 Time taken: 1.33s
Next steps
- Learn more about the Keras built-in early stopping callback API in the API docs.
 - Learn to write custom Keras callbacks, including early stopping at a minimum loss.
 - Learn about Training and evaluation with the Keras built-in methods.
 - Explore common regularization techniques in the Overfit and underfit tutorial that uses the 
EarlyStoppingcallback. 
    View on TensorFlow.org
    Run in Google Colab
    View source on GitHub
Download notebook