tf.contrib.estimator.multi_label_head
Stay organized with collections
Save and categorize content based on your preferences.
Creates a _Head
for multi-label classification.
tf.contrib.estimator.multi_label_head(
n_classes, weight_column=None, thresholds=None, label_vocabulary=None,
loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE, loss_fn=None,
classes_for_class_based_metrics=None, name=None
)
Multi-label classification handles the case where each example may have zero
or more associated labels, from a discrete set. This is distinct from
multi_class_head
which has exactly one label per example.
Uses sigmoid_cross_entropy
loss average over classes and weighted sum over
the batch. Namely, if the input logits have shape [batch_size, n_classes]
,
the loss is the average over n_classes
and the weighted sum over
batch_size
.
The head expects logits
with shape [D0, D1, ... DN, n_classes]
. In many
applications, the shape is [batch_size, n_classes]
.
Labels can be:
- A multi-hot tensor of shape
[D0, D1, ... DN, n_classes]
- An integer
SparseTensor
of class indices. The dense_shape
must be
[D0, D1, ... DN, ?]
and the values within [0, n_classes)
.
- If
label_vocabulary
is given, a string SparseTensor
. The dense_shape
must be [D0, D1, ... DN, ?]
and the values within label_vocabulary
or a
multi-hot tensor of shape [D0, D1, ... DN, n_classes]
.
If weight_column
is specified, weights must be of shape
[D0, D1, ... DN]
, or [D0, D1, ... DN, 1]
.
Also supports custom loss_fn
. loss_fn
takes (labels, logits)
or
(labels, logits, features)
as arguments and returns unreduced loss with
shape [D0, D1, ... DN, 1]
. loss_fn
must support indicator labels
with
shape [D0, D1, ... DN, n_classes]
. Namely, the head applies
label_vocabulary
to the input labels before passing them to loss_fn
.
The head can be used with a canned estimator. Example:
my_head = tf.contrib.estimator.multi_label_head(n_classes=3)
my_estimator = tf.estimator.DNNEstimator(
head=my_head,
hidden_units=...,
feature_columns=...)
It can also be used with a custom model_fn
. Example:
def _my_model_fn(features, labels, mode):
my_head = tf.contrib.estimator.multi_label_head(n_classes=3)
logits = tf.keras.Model(...)(features)
return my_head.create_estimator_spec(
features=features,
mode=mode,
labels=labels,
optimizer=tf.AdagradOptimizer(learning_rate=0.1),
logits=logits)
my_estimator = tf.estimator.Estimator(model_fn=_my_model_fn)
Args |
n_classes
|
Number of classes, must be greater than 1 (for 1 class, use
binary_classification_head ).
|
weight_column
|
A string or a _NumericColumn created by
tf.feature_column.numeric_column defining feature column representing
weights. It is used to down weight or boost examples during training. It
will be multiplied by the loss of the example. Per-class weighting is
not supported.
|
thresholds
|
Iterable of floats in the range (0, 1) . Accuracy, precision
and recall metrics are evaluated for each threshold value. The threshold
is applied to the predicted probabilities, i.e. above the threshold is
true , below is false .
|
label_vocabulary
|
A list of strings represents possible label values. If it
is not given, that means labels are already encoded as integer within
[0, n_classes) or multi-hot Tensor. If given, labels must be SparseTensor
string type and have any value in label_vocabulary . Also there will be
errors if vocabulary is not provided and labels are string.
|
loss_reduction
|
One of tf.losses.Reduction except NONE . Describes how to
reduce training loss over batch. Defaults to SUM_OVER_BATCH_SIZE , namely
weighted sum of losses divided by batch size. See tf.losses.Reduction .
|
loss_fn
|
Optional loss function.
|
classes_for_class_based_metrics
|
List of integer class IDs or string class
names for which per-class metrics are evaluated. If integers, all must be
in the range [0, n_classes - 1] . If strings, all must be in
label_vocabulary .
|
name
|
name of the head. If provided, summary and metrics keys will be
suffixed by "/" + name . Also used as name_scope when creating ops.
|
Returns |
An instance of _Head for multi-label classification.
|
Raises |
ValueError
|
if n_classes , thresholds , loss_reduction , loss_fn or
metric_class_ids is invalid.
|
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.
Last updated 2020-10-01 UTC.
[[["Easy to understand","easyToUnderstand","thumb-up"],["Solved my problem","solvedMyProblem","thumb-up"],["Other","otherUp","thumb-up"]],[["Missing the information I need","missingTheInformationINeed","thumb-down"],["Too complicated / too many steps","tooComplicatedTooManySteps","thumb-down"],["Out of date","outOfDate","thumb-down"],["Samples / code issue","samplesCodeIssue","thumb-down"],["Other","otherDown","thumb-down"]],["Last updated 2020-10-01 UTC."],[],[],null,["# tf.contrib.estimator.multi_label_head\n\n\u003cbr /\u003e\n\n|----------------------------------------------------------------------------------------------------------------------------------------------|\n| [View source on GitHub](https://github.com/tensorflow/estimator/tree/master/tensorflow_estimator/contrib/estimator/python/estimator/head.py) |\n\nCreates a `_Head` for multi-label classification. \n\n tf.contrib.estimator.multi_label_head(\n n_classes, weight_column=None, thresholds=None, label_vocabulary=None,\n loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE, loss_fn=None,\n classes_for_class_based_metrics=None, name=None\n )\n\nMulti-label classification handles the case where each example may have zero\nor more associated labels, from a discrete set. This is distinct from\n`multi_class_head` which has exactly one label per example.\n\nUses `sigmoid_cross_entropy` loss average over classes and weighted sum over\nthe batch. Namely, if the input logits have shape `[batch_size, n_classes]`,\nthe loss is the average over `n_classes` and the weighted sum over\n`batch_size`.\n\nThe head expects `logits` with shape `[D0, D1, ... DN, n_classes]`. In many\napplications, the shape is `[batch_size, n_classes]`.\n\n#### Labels can be:\n\n- A multi-hot tensor of shape `[D0, D1, ... DN, n_classes]`\n- An integer `SparseTensor` of class indices. The `dense_shape` must be `[D0, D1, ... DN, ?]` and the values within `[0, n_classes)`.\n- If `label_vocabulary` is given, a string `SparseTensor`. The `dense_shape` must be `[D0, D1, ... DN, ?]` and the values within `label_vocabulary` or a multi-hot tensor of shape `[D0, D1, ... DN, n_classes]`.\n\nIf `weight_column` is specified, weights must be of shape\n`[D0, D1, ... DN]`, or `[D0, D1, ... DN, 1]`.\n\nAlso supports custom `loss_fn`. `loss_fn` takes `(labels, logits)` or\n`(labels, logits, features)` as arguments and returns unreduced loss with\nshape `[D0, D1, ... DN, 1]`. `loss_fn` must support indicator `labels` with\nshape `[D0, D1, ... DN, n_classes]`. Namely, the head applies\n`label_vocabulary` to the input labels before passing them to `loss_fn`.\n\nThe head can be used with a canned estimator. Example: \n\n my_head = tf.contrib.estimator.multi_label_head(n_classes=3)\n my_estimator = tf.estimator.DNNEstimator(\n head=my_head,\n hidden_units=...,\n feature_columns=...)\n\nIt can also be used with a custom `model_fn`. Example: \n\n def _my_model_fn(features, labels, mode):\n my_head = tf.contrib.estimator.multi_label_head(n_classes=3)\n logits = tf.keras.Model(...)(features)\n\n return my_head.create_estimator_spec(\n features=features,\n mode=mode,\n labels=labels,\n optimizer=tf.AdagradOptimizer(learning_rate=0.1),\n logits=logits)\n\n my_estimator = tf.estimator.Estimator(model_fn=_my_model_fn)\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Args ---- ||\n|-----------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n| `n_classes` | Number of classes, must be greater than 1 (for 1 class, use `binary_classification_head`). |\n| `weight_column` | A string or a `_NumericColumn` created by [`tf.feature_column.numeric_column`](../../../tf/feature_column/numeric_column) defining feature column representing weights. It is used to down weight or boost examples during training. It will be multiplied by the loss of the example. Per-class weighting is not supported. |\n| `thresholds` | Iterable of floats in the range `(0, 1)`. Accuracy, precision and recall metrics are evaluated for each threshold value. The threshold is applied to the predicted probabilities, i.e. above the threshold is `true`, below is `false`. |\n| `label_vocabulary` | A list of strings represents possible label values. If it is not given, that means labels are already encoded as integer within \\[0, n_classes) or multi-hot Tensor. If given, labels must be SparseTensor string type and have any value in `label_vocabulary`. Also there will be errors if vocabulary is not provided and labels are string. |\n| `loss_reduction` | One of [`tf.losses.Reduction`](../../../tf/losses/Reduction) except `NONE`. Describes how to reduce training loss over batch. Defaults to `SUM_OVER_BATCH_SIZE`, namely weighted sum of losses divided by batch size. See [`tf.losses.Reduction`](../../../tf/losses/Reduction). |\n| `loss_fn` | Optional loss function. |\n| `classes_for_class_based_metrics` | List of integer class IDs or string class names for which per-class metrics are evaluated. If integers, all must be in the range `[0, n_classes - 1]`. If strings, all must be in `label_vocabulary`. |\n| `name` | name of the head. If provided, summary and metrics keys will be suffixed by `\"/\" + name`. Also used as `name_scope` when creating ops. |\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Returns ------- ||\n|---|---|\n| An instance of `_Head` for multi-label classification. ||\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Raises ------ ||\n|--------------|---------------------------------------------------------------------------------------------|\n| `ValueError` | if `n_classes`, `thresholds`, `loss_reduction`, `loss_fn` or `metric_class_ids` is invalid. |\n\n\u003cbr /\u003e"]]