View source on GitHub |
Concatenates the given value
across (GPU/TPU) cores, along axis
.
tfm.utils.cross_replica_concat(
value, axis, name='cross_replica_concat'
)
In general, each core ("replica") will pass a
replica-specific value as value
(corresponding to some element of a
data-parallel computation taking place across replicas).
The resulting concatenated Tensor
will have the same shape as value
for
all dimensions except axis
, where it will be larger by a factor of the
number of replicas. It will also have the same dtype
as value
.
The position of a given replica's value
within the resulting concatenation
is determined by that replica's replica ID. For
example:
With value
for replica 0 given as
0 0 0
0 0 0
and value
for replica 1 given as
1 1 1
1 1 1
the resulting concatenation along axis 0 will be
0 0 0
0 0 0
1 1 1
1 1 1
and this result will be identical across all replicas.
Note that this API only works in TF2 with tf.distribute
.
Returns | |
---|---|
The result of concatenating value along axis across replicas.
|
Raises | |
---|---|
RuntimeError
|
when the batch (0-th) dimension is None. |