View on TensorFlow.org
|
Run in Google Colab
|
View source on GitHub
|
Download notebook
|
Continually saving the "best" model or model weights/parameters has many benefits. These include being able to track the training progress and load saved models from different saved states.
In TensorFlow 1, to configure checkpoint saving during training/validation with the tf.estimator.Estimator APIs, you specify a schedule in tf.estimator.RunConfig or use tf.estimator.CheckpointSaverHook. This guide demonstrates how to migrate from this workflow to TensorFlow 2 Keras APIs.
In TensorFlow 2, you can configure tf.keras.callbacks.ModelCheckpoint in a number of ways:
- Save the "best" version according to a metric monitored using the
save_best_only=Trueparameter, wheremonitorcan be, for example,'loss','val_loss','accuracy', or'val_accuracy'`. - Save continually at a certain frequency (using the
save_freqargument). - Save the weights/parameters only instead of the whole model by setting
save_weights_onlytoTrue.
For more details, refer to the tf.keras.callbacks.ModelCheckpoint API docs and the Save checkpoints during training section in the Save and load models tutorial. Learn more about the Checkpoint format in the TF Checkpoint format section in the Save and load Keras models guide. In addition, to add fault tolerance, you can use tf.keras.callbacks.BackupAndRestore or tf.train.Checkpoint for manual checkpointing. Learn more in the Fault tolerance migration guide.
Keras callbacks are objects that are called at different points during training/evaluation/prediction in the built-in Keras Model.fit/Model.evaluate/Model.predict APIs. Learn more in the Next steps section at the end of the guide.
Setup
Start with imports and a simple dataset for demonstration purposes:
import tensorflow.compat.v1 as tf1
import tensorflow as tf
import numpy as np
import tempfile
mnist = tf.keras.datasets.mnist
(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
TensorFlow 1: Save checkpoints with tf.estimator APIs
This TensorFlow 1 example shows how to configure tf.estimator.RunConfig to save checkpoints at every step during training/evaluation with the tf.estimator.Estimator APIs:
feature_columns = [tf1.feature_column.numeric_column("x", shape=[28, 28])]
config = tf1.estimator.RunConfig(save_summary_steps=1,
save_checkpoints_steps=1)
path = tempfile.mkdtemp()
classifier = tf1.estimator.DNNClassifier(
feature_columns=feature_columns,
hidden_units=[256, 32],
optimizer=tf1.train.AdamOptimizer(0.001),
n_classes=10,
dropout=0.2,
model_dir=path,
config = config
)
train_input_fn = tf1.estimator.inputs.numpy_input_fn(
x={"x": x_train},
y=y_train.astype(np.int32),
num_epochs=10,
batch_size=50,
shuffle=True,
)
test_input_fn = tf1.estimator.inputs.numpy_input_fn(
x={"x": x_test},
y=y_test.astype(np.int32),
num_epochs=10,
shuffle=False
)
train_spec = tf1.estimator.TrainSpec(input_fn=train_input_fn, max_steps=10)
eval_spec = tf1.estimator.EvalSpec(input_fn=test_input_fn,
steps=10,
throttle_secs=0)
tf1.estimator.train_and_evaluate(estimator=classifier,
train_spec=train_spec,
eval_spec=eval_spec)
%ls {classifier.model_dir}
TensorFlow 2: Save checkpoints with a Keras callback for Model.fit
In TensorFlow 2, when you use the built-in Keras Model.fit (or Model.evaluate) for training/evaluation, you can configure tf.keras.callbacks.ModelCheckpoint and then pass it to the callbacks parameter of Model.fit (or Model.evaluate). (Learn more in the API docs and the Using callbacks section in the Training and evaluation with the built-in methods guide.)
In the example below, you will use a tf.keras.callbacks.ModelCheckpoint callback to store checkpoints in a temporary directory:
def create_model():
return tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(512, activation='relu'),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10, activation='softmax')
])
model = create_model()
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'],
steps_per_execution=10)
log_dir = tempfile.mkdtemp()
model_checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=log_dir)
model.fit(x=x_train,
y=y_train,
epochs=10,
validation_data=(x_test, y_test),
callbacks=[model_checkpoint_callback])
%ls {model_checkpoint_callback.filepath}
Next steps
Learn more about checkpointing in:
- API docs:
tf.keras.callbacks.ModelCheckpoint - Tutorial: Save and load models (the Save checkpoints during training section)
- Guide: Save and load Keras models (the TF Checkpoint format section)
Learn more about callbacks in:
- API docs:
tf.keras.callbacks.Callback - Guide: Writing your own callbacks
- Guide: Training and evaluation with the built-in methods (the Using callbacks section)
You may also find the following migration-related resources useful:
- The Fault tolerance migration guide:
tf.keras.callbacks.BackupAndRestoreforModel.fit, ortf.train.Checkpointandtf.train.CheckpointManagerAPIs for a custom training loop - The Early stopping migration guide:
tf.keras.callbacks.EarlyStoppingis a built-in early stopping callback - The TensorBoard migration guide: TensorBoard enables tracking and displaying metrics
- The LoggingTensorHook and StopAtStepHook to Keras callbacks migration guide
- The SessionRunHook to Keras callbacks guide
View on TensorFlow.org
Run in Google Colab
View source on GitHub
Download notebook