The function enables you to use a tf.data.Dataset in a stateless
"tensor-in tensor-out" expression, without creating an iterator.
This facilitates the ease of data transformation on tensors using the
optimized tf.data.Dataset abstraction on top of them.
For example, lets consider a preprocessing_fn which would take as an
input the raw features and returns the processed feature along with
it's label.
defpreprocessing_fn(raw_feature):# ... the raw_feature is preprocessed as per the use-casereturnfeatureraw_features=...# input batch of BATCH_SIZE elements.dataset=(tf.data.Dataset.from_tensor_slices(raw_features).map(preprocessing_fn,num_parallel_calls=BATCH_SIZE).batch(BATCH_SIZE))processed_features=tf.data.experimental.get_single_element(dataset)
In the above example, the raw_features tensor of length=BATCH_SIZE
was converted to a tf.data.Dataset. Next, each of the raw_feature was
mapped using the preprocessing_fn and the processed features were
grouped into a single batch. The final dataset contains only one element
which is a batch of all the processed features.
Now, instead of creating an iterator for the dataset and retrieving the
batch of features, the tf.data.experimental.get_single_element() function
is used to skip the iterator creation process and directly output the batch
of features.
This can be particularly useful when your tensor transformations are
expressed as tf.data.Dataset operations, and you want to use those
transformations while serving your model.
Keras
model=...# A pre-built or custom modelclassPreprocessingModel(tf.keras.Model):def__init__(self,model):super().__init__(self)self.model=model@tf.function(input_signature=[...])defserving_fn(self,data):ds=tf.data.Dataset.from_tensor_slices(data)ds=ds.map(preprocessing_fn,num_parallel_calls=BATCH_SIZE)ds=ds.batch(batch_size=BATCH_SIZE)returntf.argmax(self.model(tf.data.experimental.get_single_element(ds)),axis=-1)preprocessing_model=PreprocessingModel(model)your_exported_model_dir=...# save the model to this path.tf.saved_model.save(preprocessing_model,your_exported_model_dir,signatures={'serving_default':preprocessing_model.serving_fn})
[[["Easy to understand","easyToUnderstand","thumb-up"],["Solved my problem","solvedMyProblem","thumb-up"],["Other","otherUp","thumb-up"]],[["Missing the information I need","missingTheInformationINeed","thumb-down"],["Too complicated / too many steps","tooComplicatedTooManySteps","thumb-down"],["Out of date","outOfDate","thumb-down"],["Samples / code issue","samplesCodeIssue","thumb-down"],["Other","otherDown","thumb-down"]],["Last updated 2024-04-26 UTC."],[],[],null,["# tf.data.experimental.get_single_element\n\n\u003cbr /\u003e\n\n|-------------------------------------------------------------------------------------------------------------------------------------------------------|\n| [View source on GitHub](https://github.com/tensorflow/tensorflow/blob/v2.16.1/tensorflow/python/data/experimental/ops/get_single_element.py#L21-L109) |\n\nReturns the single element of the `dataset` as a nested structure of tensors. (deprecated)\n\n#### View aliases\n\n\n**Compat aliases for migration**\n\nSee\n[Migration guide](https://www.tensorflow.org/guide/migrate) for\nmore details.\n\n[`tf.compat.v1.data.experimental.get_single_element`](https://www.tensorflow.org/api_docs/python/tf/data/experimental/get_single_element)\n\n\u003cbr /\u003e\n\n tf.data.experimental.get_single_element(\n dataset\n )\n\n| **Deprecated:** THIS FUNCTION IS DEPRECATED. It will be removed in a future version. Instructions for updating: Use [`tf.data.Dataset.get_single_element()`](../../../tf/data/Dataset#get_single_element).\n\nThe function enables you to use a [`tf.data.Dataset`](../../../tf/data/Dataset) in a stateless\n\"tensor-in tensor-out\" expression, without creating an iterator.\nThis facilitates the ease of data transformation on tensors using the\noptimized [`tf.data.Dataset`](../../../tf/data/Dataset) abstraction on top of them.\n\nFor example, lets consider a `preprocessing_fn` which would take as an\ninput the raw features and returns the processed feature along with\nit's label. \n\n def preprocessing_fn(raw_feature):\n # ... the raw_feature is preprocessed as per the use-case\n return feature\n\n raw_features = ... # input batch of BATCH_SIZE elements.\n dataset = (tf.data.Dataset.from_tensor_slices(raw_features)\n .map(preprocessing_fn, num_parallel_calls=BATCH_SIZE)\n .batch(BATCH_SIZE))\n\n processed_features = tf.data.experimental.get_single_element(dataset)\n\nIn the above example, the `raw_features` tensor of length=BATCH_SIZE\nwas converted to a [`tf.data.Dataset`](../../../tf/data/Dataset). Next, each of the `raw_feature` was\nmapped using the `preprocessing_fn` and the processed features were\ngrouped into a single batch. The final `dataset` contains only one element\nwhich is a batch of all the processed features.\n| **Note:** The `dataset` should contain only one element.\n\nNow, instead of creating an iterator for the `dataset` and retrieving the\nbatch of features, the [`tf.data.experimental.get_single_element()`](../../../tf/data/experimental/get_single_element) function\nis used to skip the iterator creation process and directly output the batch\nof features.\n\nThis can be particularly useful when your tensor transformations are\nexpressed as [`tf.data.Dataset`](../../../tf/data/Dataset) operations, and you want to use those\ntransformations while serving your model.\n\nKeras\n=====\n\n\n model = ... # A pre-built or custom model\n\n class PreprocessingModel(tf.keras.Model):\n def __init__(self, model):\n super().__init__(self)\n self.model = model\n\n @tf.function(input_signature=[...])\n def serving_fn(self, data):\n ds = tf.data.Dataset.from_tensor_slices(data)\n ds = ds.map(preprocessing_fn, num_parallel_calls=BATCH_SIZE)\n ds = ds.batch(batch_size=BATCH_SIZE)\n return tf.argmax(\n self.model(tf.data.experimental.get_single_element(ds)),\n axis=-1\n )\n\n preprocessing_model = PreprocessingModel(model)\n your_exported_model_dir = ... # save the model to this path.\n tf.saved_model.save(preprocessing_model, your_exported_model_dir,\n signatures={'serving_default': preprocessing_model.serving_fn})\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Args ---- ||\n|-----------|-------------------------------------------------------------------------------------|\n| `dataset` | A [`tf.data.Dataset`](../../../tf/data/Dataset) object containing a single element. |\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Returns ------- ||\n|---|---|\n| A nested structure of [`tf.Tensor`](../../../tf/Tensor) objects, corresponding to the single element of `dataset`. ||\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Raises ------ ||\n|------------------------|-----------------------------------------------------------------------------|\n| `TypeError` | if `dataset` is not a [`tf.data.Dataset`](../../../tf/data/Dataset) object. |\n| `InvalidArgumentError` | (at runtime) if `dataset` does not contain exactly one element. |\n\n\u003cbr /\u003e"]]