text.RandomItemSelector
Stay organized with collections
Save and categorize content based on your preferences.
An ItemSelector
implementation that randomly selects items in a batch.
text.RandomItemSelector(
max_selections_per_batch,
selection_rate,
unselectable_ids=None,
shuffle_fn=None
)
Used in the notebooks
RandomItemSelector
randomly selects items in a batch subject to
restrictions given (max_selections_per_batch, selection_rate and
unselectable_ids).
Example:
vocab = ["[UNK]", "[MASK]", "[RANDOM]", "[CLS]", "[SEP]",
"abc", "def", "ghi"]
# Note that commonly in masked language model work, there are
# special tokens we don't want to mask, like CLS, SEP, and probably
# any OOV (out-of-vocab) tokens here called UNK.
# Note that if e.g. there are bucketed OOV tokens in the code,
# that might be a use case for overriding `get_selectable()` to
# exclude a range of IDs rather than enumerating them.
tf.random.set_seed(1234)
selector = tf_text.RandomItemSelector(
max_selections_per_batch=2,
selection_rate=0.2,
unselectable_ids=[0, 3, 4]) # indices of UNK, CLS, SEP
selection = selector.get_selection_mask(
tf.ragged.constant([[3, 5, 7, 7], [4, 6, 7, 5]]), axis=1)
print(selection)
<tf.RaggedTensor [[False, False, False, True], [False, False, True, False]]>
The selection has skipped the first elements (the CLS and SEP token codings)
and picked random elements from the other elements of the segments -- if
run with a different random seed the selections might be different.
Args |
max_selections_per_batch
|
An int of the max number of items to mask out.
|
selection_rate
|
The rate at which items are randomly selected.
|
unselectable_ids
|
(optional) A list of python ints or 1D Tensor of ints
which are ids that will be not be masked.
|
shuffle_fn
|
(optional) A function that shuffles a 1D Tensor . Default
uses tf.random.shuffle .
|
Attributes |
max_selections_per_batch
|
|
selection_rate
|
|
shuffle_fn
|
|
unselectable_ids
|
|
Methods
get_selectable
View source
get_selectable(
input_ids, axis
)
Return a boolean mask of items that can be chosen for selection.
The default implementation marks all items whose IDs are not in the
unselectable_ids
list. This can be overridden if there is a need for
a more complex or algorithmic approach for selectability.
Args |
input_ids
|
a RaggedTensor .
|
axis
|
axis to apply selection on.
|
Returns |
a RaggedTensor with dtype of bool and with shape
input_ids.shape[:axis] . Its values are True if the
corresponding item (or broadcasted subitems) should be considered for
masking. In the default implementation, all input_ids items that are not
listed in unselectable_ids (from the class arg) are considered
selectable.
|
get_selection_mask
View source
get_selection_mask(
input_ids, axis
)
Returns a mask of items that have been selected.
The default implementation simply returns all items not excluded by
get_selectable
.
Args |
input_ids
|
A RaggedTensor .
|
axis
|
(optional) An int detailing the dimension to apply selection on.
Default is the 1st dimension.
|
Returns |
a RaggedTensor with shape input_ids.shape[:axis] . Its values are True
if the corresponding item (or broadcasted subitems) should be selected.
|
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.
Last updated 2025-04-11 UTC.
[[["Easy to understand","easyToUnderstand","thumb-up"],["Solved my problem","solvedMyProblem","thumb-up"],["Other","otherUp","thumb-up"]],[["Missing the information I need","missingTheInformationINeed","thumb-down"],["Too complicated / too many steps","tooComplicatedTooManySteps","thumb-down"],["Out of date","outOfDate","thumb-down"],["Samples / code issue","samplesCodeIssue","thumb-down"],["Other","otherDown","thumb-down"]],["Last updated 2025-04-11 UTC."],[],[],null,["# text.RandomItemSelector\n\n\u003cbr /\u003e\n\n|------------------------------------------------------------------------------------------------------------------------------------|\n| [View source on GitHub](https://github.com/tensorflow/text/blob/v2.19.0/tensorflow_text/python/ops/item_selector_ops.py#L154-L281) |\n\nAn `ItemSelector` implementation that randomly selects items in a batch. \n\n text.RandomItemSelector(\n max_selections_per_batch,\n selection_rate,\n unselectable_ids=None,\n shuffle_fn=None\n )\n\n### Used in the notebooks\n\n| Used in the guide |\n|-----------------------------------------------------------------------------------------------------|\n| - [BERT Preprocessing with TF Text](https://www.tensorflow.org/text/guide/bert_preprocessing_guide) |\n\n`RandomItemSelector` randomly selects items in a batch subject to\nrestrictions given (max_selections_per_batch, selection_rate and\nunselectable_ids).\n\n#### Example:\n\n vocab = [\"[UNK]\", \"[MASK]\", \"[RANDOM]\", \"[CLS]\", \"[SEP]\",\n \"abc\", \"def\", \"ghi\"]\n # Note that commonly in masked language model work, there are\n # special tokens we don't want to mask, like CLS, SEP, and probably\n # any OOV (out-of-vocab) tokens here called UNK.\n # Note that if e.g. there are bucketed OOV tokens in the code,\n # that might be a use case for overriding `get_selectable()` to\n # exclude a range of IDs rather than enumerating them.\n tf.random.set_seed(1234)\n selector = tf_text.RandomItemSelector(\n max_selections_per_batch=2,\n selection_rate=0.2,\n unselectable_ids=[0, 3, 4]) # indices of UNK, CLS, SEP\n selection = selector.get_selection_mask(\n tf.ragged.constant([[3, 5, 7, 7], [4, 6, 7, 5]]), axis=1)\n print(selection)\n \u003ctf.RaggedTensor [[False, False, False, True], [False, False, True, False]]\u003e\n\nThe selection has skipped the first elements (the CLS and SEP token codings)\nand picked random elements from the other elements of the segments -- if\nrun with a different random seed the selections might be different.\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Args ---- ||\n|----------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------|\n| `max_selections_per_batch` | An int of the max number of items to mask out. |\n| `selection_rate` | The rate at which items are randomly selected. |\n| `unselectable_ids` | (optional) A list of python ints or 1D `Tensor` of ints which are ids that will be not be masked. |\n| `shuffle_fn` | (optional) A function that shuffles a 1D `Tensor`. Default uses [`tf.random.shuffle`](https://www.tensorflow.org/api_docs/python/tf/random/shuffle). |\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Attributes ---------- ||\n|----------------------------|---------------|\n| `max_selections_per_batch` | \u003cbr /\u003e \u003cbr /\u003e |\n| `selection_rate` | \u003cbr /\u003e \u003cbr /\u003e |\n| `shuffle_fn` | \u003cbr /\u003e \u003cbr /\u003e |\n| `unselectable_ids` | \u003cbr /\u003e \u003cbr /\u003e |\n\n\u003cbr /\u003e\n\nMethods\n-------\n\n### `get_selectable`\n\n[View source](https://github.com/tensorflow/text/blob/v2.19.0/tensorflow_text/python/ops/item_selector_ops.py#L91-L134) \n\n get_selectable(\n input_ids, axis\n )\n\nReturn a boolean mask of items that can be chosen for selection.\n\nThe default implementation marks all items whose IDs are not in the\n`unselectable_ids` list. This can be overridden if there is a need for\na more complex or algorithmic approach for selectability.\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Args ||\n|-------------|-----------------------------|\n| `input_ids` | a `RaggedTensor`. |\n| `axis` | axis to apply selection on. |\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Returns ||\n|---|---|\n| a `RaggedTensor` with dtype of bool and with shape `input_ids.shape[:axis]`. Its values are True if the corresponding item (or broadcasted subitems) should be considered for masking. In the default implementation, all `input_ids` items that are not listed in `unselectable_ids` (from the class arg) are considered selectable. ||\n\n\u003cbr /\u003e\n\n### `get_selection_mask`\n\n[View source](https://github.com/tensorflow/text/blob/v2.19.0/tensorflow_text/python/ops/item_selector_ops.py#L227-L281) \n\n get_selection_mask(\n input_ids, axis\n )\n\nReturns a mask of items that have been selected.\n\nThe default implementation simply returns all items not excluded by\n`get_selectable`.\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Args ||\n|-------------|------------------------------------------------------------------------------------------------|\n| `input_ids` | A `RaggedTensor`. |\n| `axis` | (optional) An int detailing the dimension to apply selection on. Default is the 1st dimension. |\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Returns ||\n|---|---|\n| a `RaggedTensor` with shape `input_ids.shape[:axis]`. Its values are True if the corresponding item (or broadcasted subitems) should be selected. ||\n\n\u003cbr /\u003e"]]