View source on GitHub |
Parallel map of function
on axis 0 of tensor(s) elements
.
tf.keras.ops.vectorized_map(
function, elements
)
Schematically, vectorized_map
implements the following,
in the case of a single tensor input elements
:
def vectorized_map(function, elements)
outputs = []
for e in elements:
outputs.append(function(e))
return stack(outputs)
In the case of an iterable of tensors elements
,
it implements the following:
def vectorized_map(function, elements)
batch_size = elements[0].shape[0]
outputs = []
for index in range(batch_size):
outputs.append(function([e[index] for e in elements]))
return np.stack(outputs)
In this case, function
is expected to take as input
a single list of tensor arguments.