tf.math.top_k
Stay organized with collections
Save and categorize content based on your preferences.
Finds values and indices of the k
largest entries for the last dimension.
tf.math.top_k(
input,
k=1,
sorted=True,
index_type=tf.dtypes.int32
,
name=None
)
If the input is a vector (rank=1), finds the k
largest entries in the vector
and outputs their values and indices as vectors. Thus values[j]
is the
j
-th largest entry in input
, and its index is indices[j]
.
result = tf.math.top_k([1, 2, 98, 1, 1, 99, 3, 1, 3, 96, 4, 1],
k=3)
result.values.numpy()
array([99, 98, 96], dtype=int32)
result.indices.numpy()
array([5, 2, 9], dtype=int32)
For matrices (resp. higher rank input), computes the top k
entries in each
row (resp. vector along the last dimension). Thus,
input = tf.random.normal(shape=(3,4,5,6))
k = 2
values, indices = tf.math.top_k(input, k=k)
values.shape.as_list()
[3, 4, 5, 2]
values.shape == indices.shape == input.shape[:-1] + [k]
True
The indices can be used to gather
from a tensor who's shape matches input
.
gathered_values = tf.gather(input, indices, batch_dims=-1)
assert tf.reduce_all(gathered_values == values)
If two elements are equal, the lower-index element appears first.
result = tf.math.top_k([1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0],
k=3)
result.indices.numpy()
array([0, 1, 3], dtype=int32)
By default, indices are returned as type int32
, however, this can be changed
by specifying the index_type
.
result = tf.math.top_k([1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0],
k=3, index_type=tf.int16)
result.indices.numpy()
array([0, 1, 3], dtype=int16)
Args |
input
|
1-D or higher Tensor with last dimension at least k .
|
k
|
0-D Tensor of type int16 , int32 or int64 . Number of top element
to look for along the last dimension (along each row for matrices).
|
sorted
|
If true the resulting k elements will be sorted by the values in
descending order.
|
index_type
|
Optional dtype for output indices.
|
name
|
Optional name for the operation.
|
Returns |
A tuple with two named fields:
|
values
|
The k largest elements along each last dimensional slice.
|
indices
|
The indices of values within the last dimension of input .
|
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates. Some content is licensed under the numpy license.
Last updated 2023-10-06 UTC.
[[["Easy to understand","easyToUnderstand","thumb-up"],["Solved my problem","solvedMyProblem","thumb-up"],["Other","otherUp","thumb-up"]],[["Missing the information I need","missingTheInformationINeed","thumb-down"],["Too complicated / too many steps","tooComplicatedTooManySteps","thumb-down"],["Out of date","outOfDate","thumb-down"],["Samples / code issue","samplesCodeIssue","thumb-down"],["Other","otherDown","thumb-down"]],["Last updated 2023-10-06 UTC."],[],[],null,["# tf.math.top_k\n\n\u003cbr /\u003e\n\n|----------------------------------------------------------------------------------------------------------------------------|\n| [View source on GitHub](https://github.com/tensorflow/tensorflow/blob/v2.13.1/tensorflow/python/ops/nn_ops.py#L5810-L5874) |\n\nFinds values and indices of the `k` largest entries for the last dimension.\n\n#### View aliases\n\n\n**Main aliases**\n\n[`tf.nn.top_k`](https://www.tensorflow.org/api_docs/python/tf/math/top_k)\n**Compat aliases for migration**\n\nSee\n[Migration guide](https://www.tensorflow.org/guide/migrate) for\nmore details.\n\n[`tf.compat.v1.math.top_k`](https://www.tensorflow.org/api_docs/python/tf/math/top_k), [`tf.compat.v1.nn.top_k`](https://www.tensorflow.org/api_docs/python/tf/math/top_k)\n\n\u003cbr /\u003e\n\n tf.math.top_k(\n input,\n k=1,\n sorted=True,\n index_type=../../tf/dtypes#int32,\n name=None\n )\n\nIf the input is a vector (rank=1), finds the `k` largest entries in the vector\nand outputs their values and indices as vectors. Thus `values[j]` is the\n`j`-th largest entry in `input`, and its index is `indices[j]`. \n\n result = tf.math.top_k([1, 2, 98, 1, 1, 99, 3, 1, 3, 96, 4, 1],\n k=3)\n result.values.numpy()\n array([99, 98, 96], dtype=int32)\n result.indices.numpy()\n array([5, 2, 9], dtype=int32)\n\nFor matrices (resp. higher rank input), computes the top `k` entries in each\nrow (resp. vector along the last dimension). Thus, \n\n input = tf.random.normal(shape=(3,4,5,6))\n k = 2\n values, indices = tf.math.top_k(input, k=k)\n values.shape.as_list()\n [3, 4, 5, 2]\n\n values.shape == indices.shape == input.shape[:-1] + [k]\n True\n\nThe indices can be used to `gather` from a tensor who's shape matches `input`. \n\n gathered_values = tf.gather(input, indices, batch_dims=-1)\n assert tf.reduce_all(gathered_values == values)\n\nIf two elements are equal, the lower-index element appears first. \n\n result = tf.math.top_k([1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0],\n k=3)\n result.indices.numpy()\n array([0, 1, 3], dtype=int32)\n\nBy default, indices are returned as type `int32`, however, this can be changed\nby specifying the `index_type`. \n\n result = tf.math.top_k([1, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0],\n k=3, index_type=tf.int16)\n result.indices.numpy()\n array([0, 1, 3], dtype=int16)\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Args ---- ||\n|--------------|---------------------------------------------------------------------------------------------------------------------------------------------|\n| `input` | 1-D or higher `Tensor` with last dimension at least `k`. |\n| `k` | 0-D `Tensor` of type `int16`, `int32` or `int64`. Number of top element to look for along the last dimension (along each row for matrices). |\n| `sorted` | If true the resulting `k` elements will be sorted by the values in descending order. |\n| `index_type` | Optional dtype for output indices. |\n| `name` | Optional name for the operation. |\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Returns ------- ||\n|-----------|---------------------------------------------------------------|\n| A tuple with two named fields: ||\n| `values` | The `k` largest elements along each last dimensional slice. |\n| `indices` | The indices of `values` within the last dimension of `input`. |\n\n\u003cbr /\u003e"]]