tf.train.experimental.ShardingCallback
Stay organized with collections
Save and categorize content based on your preferences.
Checkpoint sharding callback function, along with a text description.
A callback function wrapper that will be executed to determine how tensors
will be split into shards when the saver writes the checkpoint shards to disk.
The callback takes a list of tf.train.experimental.ShardableTensor
s as input
(as well as any kwargs defined by the tf.train.experimental.ShardingCallback
subclass), and organizes the input tensors into different shards. Tensors are
first organized by device task (see tf.DeviceSpec
), then the callback will
be called for each collection of tensors.
There are a few restrictions to keep in mind when creating a custom callback:
- Tensors must not be removed from the checkpoint.
- Tensors must not be reshaped.
- Tensor dtypes must not change.
- Tensors within a shard must belong to the same task.
Validation checks will be performed after the callback function is executed to
ensure these restrictions aren't violated.
Here's an example of a simple custom callback:
# Place all tensors in a single shard.
class AllInOnePolicy(tf.train.experimental.ShardingCallback):
@property
def description(self):
return "Place all tensors in a single shard."
def __call__(self, shardable_tensors):
tensors = {}
for shardable_tensor in shardable_tensors:
tensor = shardable_tensor.tensor_save_spec.tensor
checkpoint_key = shardable_tensor.checkpoint_key
slice_spec = shardable_tensor.slice_spec
tensors.set_default(checkpoint_key, {})[slice_spec] = tensor
return [tensors]
ckpt.save(
"path",
options=tf.train.CheckpointOptions(
experimental_sharding_callback=AllInOnePolicy()))
The description
attribute is used to identify the callback and to aid
debugging during saving and restoration.
To take in kwargs, simply define the constructor and pass them in:
class ParameterPolicy(tf.train.experimental.ShardingCallback):
def __init__(self, custom_param):
self.custom_param = custom_param
...
ckpt.save(
"path",
options=tf.train.CheckpointOptions(
experimental_sharding_callback=ParameterPolicy(custom_param=...)))
Methods
__call__
View source
@abc.abstractmethod
__call__(
shardable_tensors: Sequence[tf.train.experimental.ShardableTensor
]
) -> Sequence[TensorSliceDict]
Call self as a function.
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates. Some content is licensed under the numpy license.
Last updated 2024-04-26 UTC.
[[["Easy to understand","easyToUnderstand","thumb-up"],["Solved my problem","solvedMyProblem","thumb-up"],["Other","otherUp","thumb-up"]],[["Missing the information I need","missingTheInformationINeed","thumb-down"],["Too complicated / too many steps","tooComplicatedTooManySteps","thumb-down"],["Out of date","outOfDate","thumb-down"],["Samples / code issue","samplesCodeIssue","thumb-down"],["Other","otherDown","thumb-down"]],["Last updated 2024-04-26 UTC."],[],[],null,["# tf.train.experimental.ShardingCallback\n\n\u003cbr /\u003e\n\n|------------------------------------------------------------------------------------------------------------------------------------------------|\n| [View source on GitHub](https://github.com/tensorflow/tensorflow/blob/v2.16.1/tensorflow/python/checkpoint/sharding/sharding_util.py#L73-L158) |\n\nCheckpoint sharding callback function, along with a text description.\n\n#### View aliases\n\n\n**Compat aliases for migration**\n\nSee\n[Migration guide](https://www.tensorflow.org/guide/migrate) for\nmore details.\n\n[`tf.compat.v1.train.experimental.ShardingCallback`](https://www.tensorflow.org/api_docs/python/tf/train/experimental/ShardingCallback)\n\n\u003cbr /\u003e\n\nA callback function wrapper that will be executed to determine how tensors\nwill be split into shards when the saver writes the checkpoint shards to disk.\n\nThe callback takes a list of [`tf.train.experimental.ShardableTensor`](../../../tf/train/experimental/ShardableTensor)s as input\n(as well as any kwargs defined by the [`tf.train.experimental.ShardingCallback`](../../../tf/train/experimental/ShardingCallback)\nsubclass), and organizes the input tensors into different shards. Tensors are\nfirst organized by device task (see [`tf.DeviceSpec`](../../../tf/DeviceSpec)), then the callback will\nbe called for each collection of tensors.\n\nThere are a few restrictions to keep in mind when creating a custom callback:\n\n- Tensors must not be removed from the checkpoint.\n- Tensors must not be reshaped.\n- Tensor dtypes must not change.\n- Tensors within a shard must belong to the same task. Validation checks will be performed after the callback function is executed to ensure these restrictions aren't violated.\n\nHere's an example of a simple custom callback: \n\n # Place all tensors in a single shard.\n class AllInOnePolicy(tf.train.experimental.ShardingCallback):\n @property\n def description(self):\n return \"Place all tensors in a single shard.\"\n\n def __call__(self, shardable_tensors):\n tensors = {}\n for shardable_tensor in shardable_tensors:\n tensor = shardable_tensor.tensor_save_spec.tensor\n checkpoint_key = shardable_tensor.checkpoint_key\n slice_spec = shardable_tensor.slice_spec\n\n tensors.set_default(checkpoint_key, {})[slice_spec] = tensor\n return [tensors]\n\n ckpt.save(\n \"path\",\n options=tf.train.CheckpointOptions(\n experimental_sharding_callback=AllInOnePolicy()))\n\nThe `description` attribute is used to identify the callback and to aid\ndebugging during saving and restoration.\n\nTo take in kwargs, simply define the constructor and pass them in: \n\n class ParameterPolicy(tf.train.experimental.ShardingCallback):\n def __init__(self, custom_param):\n self.custom_param = custom_param\n ...\n\n ckpt.save(\n \"path\",\n options=tf.train.CheckpointOptions(\n experimental_sharding_callback=ParameterPolicy(custom_param=...)))\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Attributes ---------- ||\n|---------------|---------------|\n| `description` | \u003cbr /\u003e \u003cbr /\u003e |\n\n\u003cbr /\u003e\n\nMethods\n-------\n\n### `__call__`\n\n[View source](https://github.com/tensorflow/tensorflow/blob/v2.16.1/tensorflow/python/checkpoint/sharding/sharding_util.py#L144-L148) \n\n @abc.abstractmethod\n __call__(\n shardable_tensors: Sequence[../../../tf/train/experimental/ShardableTensor]\n ) -\u003e Sequence[TensorSliceDict]\n\nCall self as a function."]]