View source on GitHub |
Returns the current worker index, when called within a worker closure.
tf.distribute.coordinator.experimental_get_current_worker_index()
Some parameter server training workloads may require the worker to know its index, for example for data sharding for reduced-variance training.
This method may be used within a tf.function
that is executed on a worker.
That is, either a dataset_fn
that runs via
ClusterCoordinator.create_per_worker_dataset
, or any other function
scheduled via ClusterCoordinator.schedule
.
Example (sharding data by worker):
strategy = tf.distribute.ParameterServerStrategy(
cluster_resolver=...)
coordinator = (
tf.distribute.coordinator.ClusterCoordinator(strategy))
def dataset_fn(context):
dataset = tf.data.Dataset.range(10)
worker_index = (
tf.distribute.coordinator.experimental_get_current_worker_index()
)
dataset = dataset.shard(
num_shards=num_workers,
index=worker_index,
)
return dataset
@tf.function
def per_worker_dataset_fn():
return strategy.distribute_datasets_from_function(dataset_fn)
per_worker_dataset = coordinator.create_per_worker_dataset(
per_worker_dataset_fn)
Raises | |
---|---|
RuntimeError
|
if called from outside a tf.function or outside of a remote
closure execution context (that is, on a non-worker machine).
|