tf.experimental.dispatch_for_unary_elementwise_apis
Stay organized with collections
Save and categorize content based on your preferences.
Decorator to override default implementation for unary elementwise APIs.
tf.experimental.dispatch_for_unary_elementwise_apis(
x_type
)
Used in the notebooks
The decorated function (known as the "elementwise api handler") overrides
the default implementation for any unary elementwise API whenever the value
for the first argument (typically named x
) matches the type annotation
x_type
. The elementwise api handler is called with two arguments:
elementwise_api_handler(api_func, x)
Where api_func
is a function that takes a single parameter and performs the
elementwise operation (e.g., tf.abs
), and x
is the first argument to the
elementwise api.
The following example shows how this decorator can be used to update all
unary elementwise operations to handle a MaskedTensor
type:
class MaskedTensor(tf.experimental.ExtensionType):
values: tf.Tensor
mask: tf.Tensor
@dispatch_for_unary_elementwise_apis(MaskedTensor)
def unary_elementwise_api_handler(api_func, x):
return MaskedTensor(api_func(x.values), x.mask)
mt = MaskedTensor([1, -2, -3], [True, False, True])
abs_mt = tf.abs(mt)
print(f"values={abs_mt.values.numpy()}, mask={abs_mt.mask.numpy()}")
values=[1 2 3], mask=[ True False True]
For unary elementwise operations that take extra arguments beyond x
, those
arguments are not passed to the elementwise api handler, but are
automatically added when api_func
is called. E.g., in the following
example, the dtype
parameter is not passed to
unary_elementwise_api_handler
, but is added by api_func
.
ones_mt = tf.ones_like(mt, dtype=tf.float32)
print(f"values={ones_mt.values.numpy()}, mask={ones_mt.mask.numpy()}")
values=[1.0 1.0 1.0], mask=[ True False True]
Args |
x_type
|
A type annotation indicating when the api handler should be called.
See dispatch_for_api for a list of supported annotation types.
|
Registered APIs
The unary elementwise APIs are:
<>
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates. Some content is licensed under the numpy license.
Last updated 2024-04-26 UTC.
[[["Easy to understand","easyToUnderstand","thumb-up"],["Solved my problem","solvedMyProblem","thumb-up"],["Other","otherUp","thumb-up"]],[["Missing the information I need","missingTheInformationINeed","thumb-down"],["Too complicated / too many steps","tooComplicatedTooManySteps","thumb-down"],["Out of date","outOfDate","thumb-down"],["Samples / code issue","samplesCodeIssue","thumb-down"],["Other","otherDown","thumb-down"]],["Last updated 2024-04-26 UTC."],[],[],null,["# tf.experimental.dispatch_for_unary_elementwise_apis\n\n\u003cbr /\u003e\n\n|-----------------------------------------------------------------------------------------------------------------------------|\n| [View source on GitHub](https://github.com/tensorflow/tensorflow/blob/v2.16.1/tensorflow/python/util/dispatch.py#L811-L875) |\n\nDecorator to override default implementation for unary elementwise APIs.\n\n#### View aliases\n\n\n**Compat aliases for migration**\n\nSee\n[Migration guide](https://www.tensorflow.org/guide/migrate) for\nmore details.\n\n[`tf.compat.v1.experimental.dispatch_for_unary_elementwise_apis`](https://www.tensorflow.org/api_docs/python/tf/experimental/dispatch_for_unary_elementwise_apis)\n\n\u003cbr /\u003e\n\n tf.experimental.dispatch_for_unary_elementwise_apis(\n x_type\n )\n\n### Used in the notebooks\n\n| Used in the guide |\n|----------------------------------------------------------------------|\n| - [Extension types](https://www.tensorflow.org/guide/extension_type) |\n\nThe decorated function (known as the \"elementwise api handler\") overrides\nthe default implementation for any unary elementwise API whenever the value\nfor the first argument (typically named `x`) matches the type annotation\n`x_type`. The elementwise api handler is called with two arguments:\n\n`elementwise_api_handler(api_func, x)`\n\nWhere `api_func` is a function that takes a single parameter and performs the\nelementwise operation (e.g., [`tf.abs`](../../tf/math/abs)), and `x` is the first argument to the\nelementwise api.\n\nThe following example shows how this decorator can be used to update all\nunary elementwise operations to handle a `MaskedTensor` type: \n\n class MaskedTensor(tf.experimental.ExtensionType):\n values: tf.Tensor\n mask: tf.Tensor\n @dispatch_for_unary_elementwise_apis(MaskedTensor)\n def unary_elementwise_api_handler(api_func, x):\n return MaskedTensor(api_func(x.values), x.mask)\n mt = MaskedTensor([1, -2, -3], [True, False, True])\n abs_mt = tf.abs(mt)\n print(f\"values={abs_mt.values.numpy()}, mask={abs_mt.mask.numpy()}\")\n values=[1 2 3], mask=[ True False True]\n\nFor unary elementwise operations that take extra arguments beyond `x`, those\narguments are *not* passed to the elementwise api handler, but are\nautomatically added when `api_func` is called. E.g., in the following\nexample, the `dtype` parameter is not passed to\n`unary_elementwise_api_handler`, but is added by `api_func`. \n\n ones_mt = tf.ones_like(mt, dtype=tf.float32)\n print(f\"values={ones_mt.values.numpy()}, mask={ones_mt.mask.numpy()}\")\n values=[1.0 1.0 1.0], mask=[ True False True]\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Args ---- ||\n|----------|--------------------------------------------------------------------------------------------------------------------------------------|\n| `x_type` | A type annotation indicating when the api handler should be called. See `dispatch_for_api` for a list of supported annotation types. |\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Returns ------- ||\n|---|---|\n| A decorator. ||\n\n\u003cbr /\u003e\n\n#### Registered APIs\n\nThe unary elementwise APIs are:\n\n\\\u003c\\\u003e"]]