View source on GitHub |
Update inputs via updates at scattered (sparse) indices.
tf.keras.ops.scatter_update(
inputs, indices, updates
)
At a high level, this operation does inputs[indices] = updates
.
Assume inputs
is a tensor of shape (D0, D1, ..., Dn)
, there are 2 main
usages of scatter_update
.
indices
is a 2D tensor of shape(num_updates, n)
, wherenum_updates
is the number of updates to perform, andupdates
is a 1D tensor of shape(num_updates,)
. For example, ifinputs
iszeros((4, 4, 4))
, and we want to updateinputs[1, 2, 3]
andinputs[0, 1, 3]
as 1, then we can use:
inputs = np.zeros((4, 4, 4))
indices = [[1, 2, 3], [0, 1, 3]]
updates = np.array([1., 1.])
inputs = keras.ops.scatter_update(inputs, indices, updates)
2 indices
is a 2D tensor of shape (num_updates, k)
, where num_updates
is the number of updates to perform, and k
(k < n
) is the size of
each index in indices
. updates
is a n - k
-D tensor of shape
(num_updates, inputs.shape[k:])
. For example, if
inputs = np.zeros((4, 4, 4))
, and we want to update inputs[1, 2, :]
and inputs[2, 3, :]
as [1, 1, 1, 1]
, then indices
would have shape
(num_updates, 2)
(k = 2
), and updates
would have shape
(num_updates, 4)
(inputs.shape[2:] = 4
). See the code below:
inputs = np.zeros((4, 4, 4))
indices = [[1, 2], [2, 3]]
updates = np.array([[1., 1., 1, 1,], [1., 1., 1, 1,])
inputs = keras.ops.scatter_update(inputs, indices, updates)
Returns | |
---|---|
A tensor, has the same shape and dtype as inputs .
|