tf.keras.utils.unpack_x_y_sample_weight
Stay organized with collections
Save and categorize content based on your preferences.
Unpacks user-provided data tuple.
tf.keras.utils.unpack_x_y_sample_weight(
data
)
This is a convenience utility to be used when overriding
Model.train_step
, Model.test_step
, or Model.predict_step
.
This utility makes it easy to support data of the form (x,)
,
(x, y)
, or (x, y, sample_weight)
.
Standalone usage:
features_batch = tf.ones((10, 5))
labels_batch = tf.zeros((10, 5))
data = (features_batch, labels_batch)
# `y` and `sample_weight` will default to `None` if not provided.
x, y, sample_weight = tf.keras.utils.unpack_x_y_sample_weight(data)
sample_weight is None
True
Example in overridden Model.train_step
:
class MyModel(tf.keras.Model):
def train_step(self, data):
# If `sample_weight` is not provided, all samples will be weighted
# equally.
x, y, sample_weight = tf.keras.utils.unpack_x_y_sample_weight(data)
with tf.GradientTape() as tape:
y_pred = self(x, training=True)
loss = self.compiled_loss(
y, y_pred, sample_weight, regularization_losses=self.losses)
trainable_variables = self.trainable_variables
gradients = tape.gradient(loss, trainable_variables)
self.optimizer.apply_gradients(zip(gradients, trainable_variables))
self.compiled_metrics.update_state(y, y_pred, sample_weight)
return {m.name: m.result() for m in self.metrics}
Args |
data
|
A tuple of the form (x,) , (x, y) , or (x, y, sample_weight) .
|
Returns |
The unpacked tuple, with None s for y and sample_weight if they are
not provided.
|
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates. Some content is licensed under the numpy license.
Last updated 2023-10-06 UTC.
[[["Easy to understand","easyToUnderstand","thumb-up"],["Solved my problem","solvedMyProblem","thumb-up"],["Other","otherUp","thumb-up"]],[["Missing the information I need","missingTheInformationINeed","thumb-down"],["Too complicated / too many steps","tooComplicatedTooManySteps","thumb-down"],["Out of date","outOfDate","thumb-down"],["Samples / code issue","samplesCodeIssue","thumb-down"],["Other","otherDown","thumb-down"]],["Last updated 2023-10-06 UTC."],[],[],null,["# tf.keras.utils.unpack_x_y_sample_weight\n\n\u003cbr /\u003e\n\n|--------------------------------------------------------------------------------------------------------------------|\n| [View source on GitHub](https://github.com/keras-team/keras/tree/v2.13.1/keras/engine/data_adapter.py#L1810-L1873) |\n\nUnpacks user-provided data tuple. \n\n tf.keras.utils.unpack_x_y_sample_weight(\n data\n )\n\nThis is a convenience utility to be used when overriding\n[`Model.train_step`](../../../tf/keras/Model#train_step), [`Model.test_step`](../../../tf/keras/Model#test_step), or [`Model.predict_step`](../../../tf/keras/Model#predict_step).\nThis utility makes it easy to support data of the form `(x,)`,\n`(x, y)`, or `(x, y, sample_weight)`.\n\n#### Standalone usage:\n\n features_batch = tf.ones((10, 5))\n labels_batch = tf.zeros((10, 5))\n data = (features_batch, labels_batch)\n # `y` and `sample_weight` will default to `None` if not provided.\n x, y, sample_weight = tf.keras.utils.unpack_x_y_sample_weight(data)\n sample_weight is None\n True\n\nExample in overridden [`Model.train_step`](../../../tf/keras/Model#train_step): \n\n class MyModel(tf.keras.Model):\n\n def train_step(self, data):\n # If `sample_weight` is not provided, all samples will be weighted\n # equally.\n x, y, sample_weight = tf.keras.utils.unpack_x_y_sample_weight(data)\n\n with tf.GradientTape() as tape:\n y_pred = self(x, training=True)\n loss = self.compiled_loss(\n y, y_pred, sample_weight, regularization_losses=self.losses)\n trainable_variables = self.trainable_variables\n gradients = tape.gradient(loss, trainable_variables)\n self.optimizer.apply_gradients(zip(gradients, trainable_variables))\n\n self.compiled_metrics.update_state(y, y_pred, sample_weight)\n return {m.name: m.result() for m in self.metrics}\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Args ---- ||\n|--------|-------------------------------------------------------------------|\n| `data` | A tuple of the form `(x,)`, `(x, y)`, or `(x, y, sample_weight)`. |\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Returns ------- ||\n|---|---|\n| The unpacked tuple, with `None`s for `y` and `sample_weight` if they are not provided. ||\n\n\u003cbr /\u003e"]]