View source on GitHub |
Policy that splits tensors into shards with a max shard size.
Inherits From: ShardingCallback
tf.train.experimental.MaxShardSizePolicy(
max_shard_size: int
)
Shards may exceed the max shard size if they contain 1. a single scalar/string tensor that could not be sliced and exceeds the max shard size or 2. the checkpoint object graph, whose size cannot be calculated when saving.
Attributes | |
---|---|
description
|
Methods
__call__
__call__(
shardable_tensors: Sequence[tf.train.experimental.ShardableTensor
]
) -> Sequence[sharding_util.TensorSliceDict]
Callback to split tensors into shards with a max shard size.
Args | |
---|---|
shardable_tensors
|
A list of ShardableTensors. |
Returns | |
---|---|
List of shard dicts containing tensors. [ {checkpoint key: {slice_spec: tensor} } ] |