View on TensorFlow.org
|
Run in Google Colab
|
View source on GitHub
|
Download notebook
|
This guide demonstrates how to migrate the single-worker multiple-GPU workflows from TensorFlow 1 to TensorFlow 2.
To perform synchronous training across multiple GPUs on one machine:
- In TensorFlow 1, you use the
tf.estimator.EstimatorAPIs withtf.distribute.MirroredStrategy. - In TensorFlow 2, you can use Keras Model.fit or a custom training loop with
tf.distribute.MirroredStrategy. Learn more in the Distributed training with TensorFlow guide.
Setup
Start with imports and a simple dataset for demonstration purposes:
import tensorflow as tf
import tensorflow.compat.v1 as tf1
features = [[1., 1.5], [2., 2.5], [3., 3.5]]
labels = [[0.3], [0.5], [0.7]]
eval_features = [[4., 4.5], [5., 5.5], [6., 6.5]]
eval_labels = [[0.8], [0.9], [1.]]
TensorFlow 1: Single-worker distributed training with tf.estimator.Estimator
This example demonstrates the TensorFlow 1 canonical workflow of single-worker multiple-GPU training. You need to set the distribution strategy (tf.distribute.MirroredStrategy) through the config parameter of the tf.estimator.Estimator:
def _input_fn():
return tf1.data.Dataset.from_tensor_slices((features, labels)).batch(1)
def _eval_input_fn():
return tf1.data.Dataset.from_tensor_slices(
(eval_features, eval_labels)).batch(1)
def _model_fn(features, labels, mode):
logits = tf1.layers.Dense(1)(features)
loss = tf1.losses.mean_squared_error(labels=labels, predictions=logits)
optimizer = tf1.train.AdagradOptimizer(0.05)
train_op = optimizer.minimize(loss, global_step=tf1.train.get_global_step())
return tf1.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)
strategy = tf1.distribute.MirroredStrategy()
config = tf1.estimator.RunConfig(
train_distribute=strategy, eval_distribute=strategy)
estimator = tf1.estimator.Estimator(model_fn=_model_fn, config=config)
train_spec = tf1.estimator.TrainSpec(input_fn=_input_fn)
eval_spec = tf1.estimator.EvalSpec(input_fn=_eval_input_fn)
tf1.estimator.train_and_evaluate(estimator, train_spec, eval_spec)
TensorFlow 2: Single-worker training with Keras
When migrating to TensorFlow 2, you can use the Keras APIs with tf.distribute.MirroredStrategy.
If you use the tf.keras APIs for model building and Keras Model.fit for training, the main difference is instantiating the Keras model, an optimizer, and metrics in the context of Strategy.scope, instead of defining a config for tf.estimator.Estimator.
If you need to use a custom training loop, check out the Using tf.distribute.Strategy with custom training loops guide.
dataset = tf.data.Dataset.from_tensor_slices((features, labels)).batch(1)
eval_dataset = tf.data.Dataset.from_tensor_slices(
(eval_features, eval_labels)).batch(1)
strategy = tf.distribute.MirroredStrategy()
with strategy.scope():
model = tf.keras.models.Sequential([tf.keras.layers.Dense(1)])
optimizer = tf.keras.optimizers.Adagrad(learning_rate=0.05)
model.compile(optimizer=optimizer, loss='mse')
model.fit(dataset)
model.evaluate(eval_dataset, return_dict=True)
Next steps
To learn more about distributed training with tf.distribute.MirroredStrategy in TensorFlow 2, check out the following documentation:
- The Distributed training on one machine with Keras tutorial
- The Distributed training on one machine with a custom training loop tutorial
- The Distributed training with TensorFlow guide
- The Using multiple GPUs guide
- The Optimize the performance on the multi-GPU single host (with the TensorFlow Profiler) guide
View on TensorFlow.org
Run in Google Colab
View source on GitHub
Download notebook