Returns the indices of a tensor that give its sorted order along an axis.
tfp.experimental.distributions.marginal_fns.ps.argsort(
values, axis=-1, direction='ASCENDING', stable=False, name=None
)
values = [1, 10, 26.9, 2.8, 166.32, 62.3]sort_order = tf.argsort(values)sort_order.numpy()array([0, 3, 1, 2, 5, 4], dtype=int32)
For a 1D tensor:
sorted = tf.gather(values, sort_order)assert tf.reduce_all(sorted == tf.sort(values))
For higher dimensions, the output has the same shape as
values, but along the given axis, values represent the index of the sorted
element in that slice of the tensor at the given position.
mat = [[30,20,10],[20,10,30],[10,30,20]]indices = tf.argsort(mat)indices.numpy()array([[2, 1, 0],[1, 0, 2],[0, 2, 1]], dtype=int32)
If axis=-1 these indices can be used to apply a sort using tf.gather:
tf.gather(mat, indices, batch_dims=-1).numpy()array([[10, 20, 30],[10, 20, 30],[10, 20, 30]], dtype=int32)
See also | |
|---|---|
|
Returns | |
|---|---|
An int32 Tensor with the same shape as values. The indices that would
sort each slice of the given values along the given axis.
|
Raises | |
|---|---|
ValueError
|
If axis is not a constant scalar, or the direction is invalid. |
tf.errors.InvalidArgumentError
|
If the values.dtype is not a float or
int type.
|