View source on GitHub |
Checkpoint sharding callback function, along with a text description.
A callback function wrapper that will be executed to determine how tensors will be split into shards when the saver writes the checkpoint shards to disk.
The callback takes a list of tf.train.experimental.ShardableTensor
s as input
(as well as any kwargs defined by the tf.train.experimental.ShardingCallback
subclass), and organizes the input tensors into different shards. Tensors are
first organized by device task (see tf.DeviceSpec
), then the callback will
be called for each collection of tensors.
There are a few restrictions to keep in mind when creating a custom callback:
- Tensors must not be removed from the checkpoint.
- Tensors must not be reshaped.
- Tensor dtypes must not change.
- Tensors within a shard must belong to the same task. Validation checks will be performed after the callback function is executed to ensure these restrictions aren't violated.
Here's an example of a simple custom callback:
# Place all tensors in a single shard.
class AllInOnePolicy(tf.train.experimental.ShardingCallback):
@property
def description(self):
return "Place all tensors in a single shard."
def __call__(self, shardable_tensors):
tensors = {}
for shardable_tensor in shardable_tensors:
tensor = shardable_tensor.tensor_save_spec.tensor
checkpoint_key = shardable_tensor.checkpoint_key
slice_spec = shardable_tensor.slice_spec
tensors.set_default(checkpoint_key, {})[slice_spec] = tensor
return [tensors]
ckpt.save(
"path",
options=tf.train.CheckpointOptions(
experimental_sharding_callback=AllInOnePolicy()))
The description
attribute is used to identify the callback and to aid
debugging during saving and restoration.
To take in kwargs, simply define the constructor and pass them in:
class ParameterPolicy(tf.train.experimental.ShardingCallback):
def __init__(self, custom_param):
self.custom_param = custom_param
...
ckpt.save(
"path",
options=tf.train.CheckpointOptions(
experimental_sharding_callback=ParameterPolicy(custom_param=...)))
Attributes | |
---|---|
description
|
Methods
__call__
@abc.abstractmethod
__call__( shardable_tensors: Sequence[
tf.train.experimental.ShardableTensor
] ) -> Sequence[TensorSliceDict]
Call self as a function.