Computes the max of segments in a tensor.
tf.keras.ops.segment_max(
data, segment_ids, num_segments=None, sorted=False
)
Args |
data
|
Input tensor.
|
segment_ids
|
A 1-D tensor containing segment indices for each
element in data .
|
num_segments
|
An integer representing the total number of
segments. If not specified, it is inferred from the maximum
value in segment_ids .
|
sorted
|
A boolean indicating whether segment_ids is sorted.
Defaults toFalse .
|
Returns |
A tensor containing the max of segments, where each element
represents the max of the corresponding segment in data .
|
Example:
data = keras.ops.convert_to_tensor([1, 2, 10, 20, 100, 200])
segment_ids = keras.ops.convert_to_tensor([0, 0, 1, 1, 2, 2])
num_segments = 3
keras.ops.segment_max(data, segment_ids, num_segments)
array([2, 20, 200], dtype=int32)