View source on GitHub |
A cluster of computation devices for distributed computation.
tf.keras.distribution.DeviceMesh(
shape, axis_names, devices=None
)
This API is aligned with jax.sharding.Mesh
and tf.dtensor.Mesh
, which
represents the computation devices in the global context.
See more details in jax.sharding.Mesh and tf.dtensor.Mesh.
Args | |
---|---|
shape
|
tuple of list of integers. The shape of the overall
DeviceMesh , e.g. (8,) for a data parallel only distribution,
or (4, 2) for a model+data parallel distribution.
|
axis_names
|
List of string. The logical name of the each axis for
the DeviceMesh . The length of the axis_names should match to
the rank of the shape . The axis_names will be used to
match/create the TensorLayout when distribute the data and
variables.
|
devices
|
Optional list of devices. Defaults to all the available
devices locally from keras.distribution.list_devices() .
|
Attributes | |
---|---|
axis_names
|
|
devices
|
|
shape
|