View on TensorFlow.org | Run in Google Colab | View source on GitHub | Download notebook |
In TensorFlow 1, to customize the behavior of training, you use tf.estimator.SessionRunHook
with tf.estimator.Estimator
. This guide demonstrates how to migrate from SessionRunHook
to TensorFlow 2's custom callbacks with the tf.keras.callbacks.Callback
API, which works with Keras Model.fit
for training (as well as Model.evaluate
and Model.predict
). You will learn how to do this by implementing a SessionRunHook
and a Callback
task that measures examples per second during training.
Examples of callbacks are checkpoint saving (tf.keras.callbacks.ModelCheckpoint
) and TensorBoard summary writing. Keras callbacks are objects that are called at different points during training/evaluation/prediction in the built-in Keras Model.fit
/Model.evaluate
/Model.predict
APIs. You can learn more about callbacks in the tf.keras.callbacks.Callback
API docs, as well as the Writing your own callbacks and Training and evaluation with the built-in methods (the Using callbacks section) guides.
Setup
Start with imports and a simple dataset for demonstration purposes:
import tensorflow as tf
import tensorflow.compat.v1 as tf1
import time
from datetime import datetime
from absl import flags
features = [[1., 1.5], [2., 2.5], [3., 3.5]]
labels = [[0.3], [0.5], [0.7]]
eval_features = [[4., 4.5], [5., 5.5], [6., 6.5]]
eval_labels = [[0.8], [0.9], [1.]]
TensorFlow 1: Create a custom SessionRunHook with tf.estimator APIs
The following TensorFlow 1 examples show how to set up a custom SessionRunHook
that measures examples per second during training. After creating the hook (LoggerHook
), pass it to the hooks
parameter of tf.estimator.Estimator.train
.
def _input_fn():
return tf1.data.Dataset.from_tensor_slices(
(features, labels)).batch(1).repeat(100)
def _model_fn(features, labels, mode):
logits = tf1.layers.Dense(1)(features)
loss = tf1.losses.mean_squared_error(labels=labels, predictions=logits)
optimizer = tf1.train.AdagradOptimizer(0.05)
train_op = optimizer.minimize(loss, global_step=tf1.train.get_global_step())
return tf1.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)
class LoggerHook(tf1.train.SessionRunHook):
"""Logs loss and runtime."""
def begin(self):
self._step = -1
self._start_time = time.time()
self.log_frequency = 10
def before_run(self, run_context):
self._step += 1
def after_run(self, run_context, run_values):
if self._step % self.log_frequency == 0:
current_time = time.time()
duration = current_time - self._start_time
self._start_time = current_time
examples_per_sec = self.log_frequency / duration
print('Time:', datetime.now(), ', Step #:', self._step,
', Examples per second:', examples_per_sec)
estimator = tf1.estimator.Estimator(model_fn=_model_fn)
# Begin training.
estimator.train(_input_fn, hooks=[LoggerHook()])
TensorFlow 2: Create a custom Keras callback for Model.fit
In TensorFlow 2, when you use the built-in Keras Model.fit
(or Model.evaluate
) for training/evaluation, you can configure a custom tf.keras.callbacks.Callback
, which you then pass to the callbacks
parameter of Model.fit
(or Model.evaluate
). (Learn more in the Writing your own callbacks guide.)
In the example below, you will write a custom tf.keras.callbacks.Callback
that logs various metrics—it will measure examples per second, which should be comparable to the metrics in the previous SessionRunHook
example.
class CustomCallback(tf.keras.callbacks.Callback):
def on_train_begin(self, logs = None):
self._step = -1
self._start_time = time.time()
self.log_frequency = 10
def on_train_batch_begin(self, batch, logs = None):
self._step += 1
def on_train_batch_end(self, batch, logs = None):
if self._step % self.log_frequency == 0:
current_time = time.time()
duration = current_time - self._start_time
self._start_time = current_time
examples_per_sec = self.log_frequency / duration
print('Time:', datetime.now(), ', Step #:', self._step,
', Examples per second:', examples_per_sec)
callback = CustomCallback()
dataset = tf.data.Dataset.from_tensor_slices(
(features, labels)).batch(1).repeat(100)
model = tf.keras.models.Sequential([tf.keras.layers.Dense(1)])
optimizer = tf.keras.optimizers.Adagrad(learning_rate=0.05)
model.compile(optimizer, "mse")
# Begin training.
result = model.fit(dataset, callbacks=[callback], verbose = 0)
# Provide the results of training metrics.
result.history
Next steps
Learn more about callbacks in:
- API docs:
tf.keras.callbacks.Callback
- Guide: Writing your own callbacks
- Guide: Training and evaluation with the built-in methods (the Using callbacks section)
You may also find the following migration-related resources useful:
- The Early stopping migration guide:
tf.keras.callbacks.EarlyStopping
is a built-in early stopping callback - The TensorBoard migration guide: TensorBoard enables tracking and displaying metrics
- The LoggingTensorHook and StopAtStepHook to Keras callbacks migration guide