View source on GitHub |
Distribution for data parallelism.
tf.keras.distribution.DataParallel(
device_mesh=None, devices=None
)
You can choose to create this instance by either specifying
the device_mesh
or devices
arguments (but not both).
The device_mesh
argument is expected to be a DeviceMesh
instance,
and is expected to be 1D only. In case that the mesh has multiple axes,
then the first axis will be treated as the data parallel dimension
(and a warning will be raised).
When a list of devices
are provided, they will be used to construct a
1D mesh.
When both mesh
and devices
are absent, then list_devices()
will be used to detect any available devices and create a 1D mesh from
them.
Args | |
---|---|
device_mesh
|
Optional DeviceMesh instance.
|
devices
|
Optional list of devices. |
Attributes | |
---|---|
device_mesh
|
Methods
distribute_dataset
distribute_dataset(
dataset
)
Create a distributed dataset instance from the original user dataset.
Args | |
---|---|
dataset
|
the original global dataset instance. Only
tf.data.Dataset is supported at the moment.
|
Returns | |
---|---|
a sharded tf.data.Dataset instance, which will produce data for
the current local worker/process.
|
get_data_layout
get_data_layout(
data_shape
)
Retrieve the TensorLayout
for the input data.
Args | |
---|---|
data_shape
|
shape for the input data in list or tuple format. |
Returns | |
---|---|
The TensorLayout for the data, which can be used by
backend.distribute_value() to redistribute a input data.
|
get_tensor_layout
get_tensor_layout(
path
)
Retrieve the TensorLayout
for the intermediate tensor.
Args | |
---|---|
path
|
a string path for the corresponding tensor. |
return:
The TensorLayout
for the intermediate tensor, which can be used
by backend.relayout()
to reshard the tensor. Could also return
None.
get_variable_layout
get_variable_layout(
variable
)
Retrieve the TensorLayout
for the variable.
Args | |
---|---|
variable
|
A KerasVariable instance.
|
return:
The TensorLayout
for the variable, which can be used by
backend.distribute_value()
to redistribute a variable.
scope
@contextlib.contextmanager
scope()
Context manager to make the Distribution
current.