View source on GitHub |
Keeps track of the object created by invoking trackable_factory_callable
.
tft.make_and_track_object(
trackable_factory_callable: Callable[[], base.Trackable],
name: Optional[str] = None
) -> base.Trackable
This API is only for use when Transform APIs are run with TF2 behaviors
enabled and tft_beam.Context.force_tf_compat_v1
is set to False.
Use this API to track TF Trackable objects created in the preprocessing_fn
such as tf.hub modules, tf.data.Dataset etc. This ensures they are serialized
correctly when exporting to SavedModel.
Example:
def preprocessing_fn(inputs):
dataset = tft.make_and_track_object(
lambda: tf.data.Dataset.from_tensor_slices([1, 2, 3]))
with tf.init_scope():
dataset_list = list(dataset.as_numpy_iterator())
return {'x_0': dataset_list[0] + inputs['x']}
raw_data = [dict(x=1), dict(x=2), dict(x=3)]
feature_spec = dict(x=tf.io.FixedLenFeature([], tf.int64))
raw_data_metadata = tft.DatasetMetadata.from_feature_spec(feature_spec)
with tft_beam.Context(temp_dir=tempfile.mkdtemp(),
force_tf_compat_v1=False):
transformed_dataset, transform_fn = (
(raw_data, raw_data_metadata)
| tft_beam.AnalyzeAndTransformDataset(preprocessing_fn))
transformed_data, transformed_metadata = transformed_dataset
transformed_data
[{'x_0': 2}, {'x_0': 3}, {'x_0': 4}]
Returns | |
---|---|
The object returned when trackable_factory_callable is invoked. The object
creation is lifted out to the eager context using tf.init_scope .
|