tf.distribute.DistributedValues
Stay organized with collections
Save and categorize content based on your preferences.
Base class for representing distributed values.
A subclass instance of tf.distribute.DistributedValues
is created when
creating variables within a distribution strategy, iterating a
tf.distribute.DistributedDataset
or through tf.distribute.Strategy.run
.
This base class should never be instantiated directly.
tf.distribute.DistributedValues
contains a value per replica. Depending on
the subclass, the values could either be synced on update, synced on demand,
or never synced.
Two representative types of tf.distribute.DistributedValues
are
tf.types.experimental.PerReplica
and tf.types.experimental.Mirrored
values.
PerReplica
values exist on the worker devices, with a different value for
each replica. They are produced by iterating through a distributed dataset
returned by tf.distribute.Strategy.experimental_distribute_dataset
(Example
1, below) and tf.distribute.Strategy.distribute_datasets_from_function
. They
are also the typical result returned by tf.distribute.Strategy.run
(Example
2).
Mirrored
values are like PerReplica
values, except we know that the value
on all replicas are the same. Mirrored
values are kept synchronized by the
distribution strategy in use, while PerReplica
values are left
unsynchronized. Mirrored
values typically represent model weights. We can
safely read a Mirrored
value in a cross-replica context by using the value
on any replica, while PerReplica values should not be read or manipulated in
a cross-replica context."
tf.distribute.DistributedValues
can be reduced via strategy.reduce
to
obtain a single value across replicas (Example 4), used as input into
tf.distribute.Strategy.run
(Example 3), or collected to inspect the
per-replica values using tf.distribute.Strategy.experimental_local_results
(Example 5).
Example usages:
- Created from a
tf.distribute.DistributedDataset
:
strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
dataset = tf.data.Dataset.from_tensor_slices([5., 6., 7., 8.]).batch(2)
dataset_iterator = iter(strategy.experimental_distribute_dataset(dataset))
distributed_values = next(dataset_iterator)
distributed_values
PerReplica:{
0: <tf.Tensor: shape=(1,), dtype=float32, numpy=array([5.], dtype=float32)>,
1: <tf.Tensor: shape=(1,), dtype=float32, numpy=array([6.], dtype=float32)>
}
- Returned by
run
:
strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
@tf.function
def run():
ctx = tf.distribute.get_replica_context()
return ctx.replica_id_in_sync_group
distributed_values = strategy.run(run)
distributed_values
PerReplica:{
0: <tf.Tensor: shape=(), dtype=int32, numpy=0>,
1: <tf.Tensor: shape=(), dtype=int32, numpy=1>
}
- As input into
run
:
strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
dataset = tf.data.Dataset.from_tensor_slices([5., 6., 7., 8.]).batch(2)
dataset_iterator = iter(strategy.experimental_distribute_dataset(dataset))
distributed_values = next(dataset_iterator)
@tf.function
def run(input):
return input + 1.0
updated_value = strategy.run(run, args=(distributed_values,))
updated_value
PerReplica:{
0: <tf.Tensor: shape=(1,), dtype=float32, numpy=array([6.], dtype=float32)>,
1: <tf.Tensor: shape=(1,), dtype=float32, numpy=array([7.], dtype=float32)>
}
- As input into
reduce
:
strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
dataset = tf.data.Dataset.from_tensor_slices([5., 6., 7., 8.]).batch(2)
dataset_iterator = iter(strategy.experimental_distribute_dataset(dataset))
distributed_values = next(dataset_iterator)
reduced_value = strategy.reduce(tf.distribute.ReduceOp.SUM,
distributed_values,
axis = 0)
reduced_value
<tf.Tensor: shape=(), dtype=float32, numpy=11.0>
- How to inspect per-replica values locally:
strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
dataset = tf.data.Dataset.from_tensor_slices([5., 6., 7., 8.]).batch(2)
dataset_iterator = iter(strategy.experimental_distribute_dataset(dataset))
per_replica_values = strategy.experimental_local_results(
distributed_values)
per_replica_values
(<tf.Tensor: shape=(1,), dtype=float32, numpy=array([5.], dtype=float32)>,
<tf.Tensor: shape=(1,), dtype=float32, numpy=array([6.], dtype=float32)>)
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates. Some content is licensed under the numpy license.
Last updated 2024-04-26 UTC.
[[["Easy to understand","easyToUnderstand","thumb-up"],["Solved my problem","solvedMyProblem","thumb-up"],["Other","otherUp","thumb-up"]],[["Missing the information I need","missingTheInformationINeed","thumb-down"],["Too complicated / too many steps","tooComplicatedTooManySteps","thumb-down"],["Out of date","outOfDate","thumb-down"],["Samples / code issue","samplesCodeIssue","thumb-down"],["Other","otherDown","thumb-down"]],["Last updated 2024-04-26 UTC."],[],[],null,["# tf.distribute.DistributedValues\n\n\u003cbr /\u003e\n\n|-------------------------------------------------------------------------------------------------------------------------------|\n| [View source on GitHub](https://github.com/tensorflow/tensorflow/blob/v2.16.1/tensorflow/python/types/distribute.py#L62-L166) |\n\nBase class for representing distributed values.\n\nA subclass instance of [`tf.distribute.DistributedValues`](../../tf/distribute/DistributedValues) is created when\ncreating variables within a distribution strategy, iterating a\n[`tf.distribute.DistributedDataset`](../../tf/distribute/DistributedDataset) or through [`tf.distribute.Strategy.run`](../../tf/distribute/Strategy#run).\nThis base class should never be instantiated directly.\n[`tf.distribute.DistributedValues`](../../tf/distribute/DistributedValues) contains a value per replica. Depending on\nthe subclass, the values could either be synced on update, synced on demand,\nor never synced.\n\nTwo representative types of [`tf.distribute.DistributedValues`](../../tf/distribute/DistributedValues) are\n`tf.types.experimental.PerReplica` and `tf.types.experimental.Mirrored`\nvalues.\n\n`PerReplica` values exist on the worker devices, with a different value for\neach replica. They are produced by iterating through a distributed dataset\nreturned by [`tf.distribute.Strategy.experimental_distribute_dataset`](../../tf/distribute/Strategy#experimental_distribute_dataset) (Example\n1, below) and [`tf.distribute.Strategy.distribute_datasets_from_function`](../../tf/distribute/Strategy#distribute_datasets_from_function). They\nare also the typical result returned by [`tf.distribute.Strategy.run`](../../tf/distribute/Strategy#run) (Example\n2).\n\n`Mirrored` values are like `PerReplica` values, except we know that the value\non all replicas are the same. `Mirrored` values are kept synchronized by the\ndistribution strategy in use, while `PerReplica` values are left\nunsynchronized. `Mirrored` values typically represent model weights. We can\nsafely read a `Mirrored` value in a cross-replica context by using the value\non any replica, while PerReplica values should not be read or manipulated in\na cross-replica context.\"\n\n[`tf.distribute.DistributedValues`](../../tf/distribute/DistributedValues) can be reduced via `strategy.reduce` to\nobtain a single value across replicas (Example 4), used as input into\n[`tf.distribute.Strategy.run`](../../tf/distribute/Strategy#run) (Example 3), or collected to inspect the\nper-replica values using [`tf.distribute.Strategy.experimental_local_results`](../../tf/distribute/Strategy#experimental_local_results)\n(Example 5).\n\n#### Example usages:\n\n1. Created from a [`tf.distribute.DistributedDataset`](../../tf/distribute/DistributedDataset):\n\n strategy = tf.distribute.MirroredStrategy([\"GPU:0\", \"GPU:1\"])\n dataset = tf.data.Dataset.from_tensor_slices([5., 6., 7., 8.]).batch(2)\n dataset_iterator = iter(strategy.experimental_distribute_dataset(dataset))\n distributed_values = next(dataset_iterator)\n distributed_values\n PerReplica:{\n 0: \u003ctf.Tensor: shape=(1,), dtype=float32, numpy=array([5.], dtype=float32)\u003e,\n 1: \u003ctf.Tensor: shape=(1,), dtype=float32, numpy=array([6.], dtype=float32)\u003e\n }\n\n1. Returned by `run`:\n\n strategy = tf.distribute.MirroredStrategy([\"GPU:0\", \"GPU:1\"])\n @tf.function\n def run():\n ctx = tf.distribute.get_replica_context()\n return ctx.replica_id_in_sync_group\n distributed_values = strategy.run(run)\n distributed_values\n PerReplica:{\n 0: \u003ctf.Tensor: shape=(), dtype=int32, numpy=0\u003e,\n 1: \u003ctf.Tensor: shape=(), dtype=int32, numpy=1\u003e\n }\n\n1. As input into `run`:\n\n strategy = tf.distribute.MirroredStrategy([\"GPU:0\", \"GPU:1\"])\n dataset = tf.data.Dataset.from_tensor_slices([5., 6., 7., 8.]).batch(2)\n dataset_iterator = iter(strategy.experimental_distribute_dataset(dataset))\n distributed_values = next(dataset_iterator)\n @tf.function\n def run(input):\n return input + 1.0\n updated_value = strategy.run(run, args=(distributed_values,))\n updated_value\n PerReplica:{\n 0: \u003ctf.Tensor: shape=(1,), dtype=float32, numpy=array([6.], dtype=float32)\u003e,\n 1: \u003ctf.Tensor: shape=(1,), dtype=float32, numpy=array([7.], dtype=float32)\u003e\n }\n\n1. As input into `reduce`:\n\n strategy = tf.distribute.MirroredStrategy([\"GPU:0\", \"GPU:1\"])\n dataset = tf.data.Dataset.from_tensor_slices([5., 6., 7., 8.]).batch(2)\n dataset_iterator = iter(strategy.experimental_distribute_dataset(dataset))\n distributed_values = next(dataset_iterator)\n reduced_value = strategy.reduce(tf.distribute.ReduceOp.SUM,\n distributed_values,\n axis = 0)\n reduced_value\n \u003ctf.Tensor: shape=(), dtype=float32, numpy=11.0\u003e\n\n1. How to inspect per-replica values locally:\n\n strategy = tf.distribute.MirroredStrategy([\"GPU:0\", \"GPU:1\"])\n dataset = tf.data.Dataset.from_tensor_slices([5., 6., 7., 8.]).batch(2)\n dataset_iterator = iter(strategy.experimental_distribute_dataset(dataset))\n per_replica_values = strategy.experimental_local_results(\n distributed_values)\n per_replica_values\n (\u003ctf.Tensor: shape=(1,), dtype=float32, numpy=array([5.], dtype=float32)\u003e,\n \u003ctf.Tensor: shape=(1,), dtype=float32, numpy=array([6.], dtype=float32)\u003e)"]]