tfr.keras.pipeline.AbstractPipeline
Stay organized with collections
Save and categorize content based on your preferences.
Interface for ranking pipeline to train a tf.keras.Model
.
The AbstractPipeline
class is an abstract class to train and validate a
ranking model in tfr.keras.
To be implemented by subclasses:
build_loss()
: Contains the logic to build a tf.keras.losses.Loss
or a
dict or list of tf.keras.losses.Loss
s to be optimized in training.
build_metrics()
: Contains the logic to build a list or dict of
tf.keras.metrics.Metric
s to monitor and evaluate the training.
build_weighted_metrics()
: Contains the logic to build a list or dict of
tf.keras.metrics.Metric
s which will take the weights.
train_and_validate()
: Contrains the main training pipeline for training
and validation.
Example subclass implementation:
class BasicPipeline(AbstractPipeline):
def __init__(self, model, train_data, valid_data, name=None):
self._model = model
self._train_data = train_data
self._valid_data = valid_data
self._name = name
def build_loss(self):
return tfr.keras.losses.get('softmax_loss')
def build_metrics(self):
return [
tfr.keras.metrics.get(
'ndcg', topn=topn, name='ndcg_{}'.format(topn)
) for topn in [1, 5, 10]
]
def build_weighted_metrics(self):
return [
tfr.keras.metrics.get(
'ndcg', topn=topn, name='weighted_ndcg_{}'.format(topn)
) for topn in [1, 5, 10]
]
def train_and_validate(self, *arg, **kwargs):
self._model.compile(
optimizer=tf.keras.optimizers.SGD(learning_rate=0.001),
loss=self.build_loss(),
metrics=self.build_metrics(),
weighted_metrics=self.build_weighted_metrics())
self._model.fit(
x=self._train_data,
epochs=100,
validation_data=self._valid_data)
Methods
build_loss
View source
@abc.abstractmethod
build_loss() -> Any
Returns the loss for model.compile.
Example usage:
pipeline = BasicPipeline(model, train_data, valid_data)
loss = pipeline.build_loss()
build_metrics
View source
@abc.abstractmethod
build_metrics() -> Any
Returns a list of ranking metrics for model.compile()
.
Example usage:
pipeline = BasicPipeline(model, train_data, valid_data)
metrics = pipeline.build_metrics()
build_weighted_metrics
View source
@abc.abstractmethod
build_weighted_metrics() -> Any
Returns a list of weighted ranking metrics for model.compile.
Example usage:
pipeline = BasicPipeline(model, train_data, valid_data)
weighted_metrics = pipeline.build_weighted_metrics()
train_and_validate
View source
@abc.abstractmethod
train_and_validate(
*arg, **kwargs
) -> Any
Constructs and runs the training pipeline.
Example usage:
pipeline = BasicPipeline(model, train_data, valid_data)
pipeline.train_and_validate()
Args |
*arg
|
arguments that might be used in the training pipeline.
|
**kwargs
|
keyword arguments that might be used in the training pipeline.
|
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.
Last updated 2023-08-18 UTC.
[[["Easy to understand","easyToUnderstand","thumb-up"],["Solved my problem","solvedMyProblem","thumb-up"],["Other","otherUp","thumb-up"]],[["Missing the information I need","missingTheInformationINeed","thumb-down"],["Too complicated / too many steps","tooComplicatedTooManySteps","thumb-down"],["Out of date","outOfDate","thumb-down"],["Samples / code issue","samplesCodeIssue","thumb-down"],["Other","otherDown","thumb-down"]],["Last updated 2023-08-18 UTC."],[],[],null,["# tfr.keras.pipeline.AbstractPipeline\n\n\u003cbr /\u003e\n\n|---------------------------------------------------------------------------------------------------------------------------------|\n| [View source on GitHub](https://github.com/tensorflow/ranking/blob/v0.5.3/tensorflow_ranking/python/keras/pipeline.py#L32-L156) |\n\nInterface for ranking pipeline to train a [`tf.keras.Model`](https://www.tensorflow.org/api_docs/python/tf/keras/Model).\n\nThe `AbstractPipeline` class is an abstract class to train and validate a\nranking model in tfr.keras.\n\nTo be implemented by subclasses:\n\n- `build_loss()`: Contains the logic to build a [`tf.keras.losses.Loss`](https://www.tensorflow.org/api_docs/python/tf/keras/losses/Loss) or a dict or list of [`tf.keras.losses.Loss`](https://www.tensorflow.org/api_docs/python/tf/keras/losses/Loss)s to be optimized in training.\n- `build_metrics()`: Contains the logic to build a list or dict of [`tf.keras.metrics.Metric`](https://www.tensorflow.org/api_docs/python/tf/keras/metrics/Metric)s to monitor and evaluate the training.\n- `build_weighted_metrics()`: Contains the logic to build a list or dict of [`tf.keras.metrics.Metric`](https://www.tensorflow.org/api_docs/python/tf/keras/metrics/Metric)s which will take the weights.\n- `train_and_validate()`: Contrains the main training pipeline for training and validation.\n\nExample subclass implementation: \n\n class BasicPipeline(AbstractPipeline):\n\n def __init__(self, model, train_data, valid_data, name=None):\n self._model = model\n self._train_data = train_data\n self._valid_data = valid_data\n self._name = name\n\n def build_loss(self):\n return tfr.keras.losses.get('softmax_loss')\n\n def build_metrics(self):\n return [\n tfr.keras.metrics.get(\n 'ndcg', topn=topn, name='ndcg_{}'.format(topn)\n ) for topn in [1, 5, 10]\n ]\n\n def build_weighted_metrics(self):\n return [\n tfr.keras.metrics.get(\n 'ndcg', topn=topn, name='weighted_ndcg_{}'.format(topn)\n ) for topn in [1, 5, 10]\n ]\n\n def train_and_validate(self, *arg, **kwargs):\n self._model.compile(\n optimizer=tf.keras.optimizers.SGD(learning_rate=0.001),\n loss=self.build_loss(),\n metrics=self.build_metrics(),\n weighted_metrics=self.build_weighted_metrics())\n self._model.fit(\n x=self._train_data,\n epochs=100,\n validation_data=self._valid_data)\n\nMethods\n-------\n\n### `build_loss`\n\n[View source](https://github.com/tensorflow/ranking/blob/v0.5.3/tensorflow_ranking/python/keras/pipeline.py#L90-L104) \n\n @abc.abstractmethod\n build_loss() -\u003e Any\n\nReturns the loss for model.compile.\n\n#### Example usage:\n\n pipeline = BasicPipeline(model, train_data, valid_data)\n loss = pipeline.build_loss()\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Returns ||\n|---|---|\n| A [`tf.keras.losses.Loss`](https://www.tensorflow.org/api_docs/python/tf/keras/losses/Loss) or a dict or list of [`tf.keras.losses.Loss`](https://www.tensorflow.org/api_docs/python/tf/keras/losses/Loss). ||\n\n\u003cbr /\u003e\n\n### `build_metrics`\n\n[View source](https://github.com/tensorflow/ranking/blob/v0.5.3/tensorflow_ranking/python/keras/pipeline.py#L106-L120) \n\n @abc.abstractmethod\n build_metrics() -\u003e Any\n\nReturns a list of ranking metrics for `model.compile()`.\n\n#### Example usage:\n\n pipeline = BasicPipeline(model, train_data, valid_data)\n metrics = pipeline.build_metrics()\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Returns ||\n|---|---|\n| A list or a dict of [`tf.keras.metrics.Metric`](https://www.tensorflow.org/api_docs/python/tf/keras/metrics/Metric)s. ||\n\n\u003cbr /\u003e\n\n### `build_weighted_metrics`\n\n[View source](https://github.com/tensorflow/ranking/blob/v0.5.3/tensorflow_ranking/python/keras/pipeline.py#L122-L136) \n\n @abc.abstractmethod\n build_weighted_metrics() -\u003e Any\n\nReturns a list of weighted ranking metrics for model.compile.\n\n#### Example usage:\n\n pipeline = BasicPipeline(model, train_data, valid_data)\n weighted_metrics = pipeline.build_weighted_metrics()\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Returns ||\n|---|---|\n| A list or a dict of [`tf.keras.metrics.Metric`](https://www.tensorflow.org/api_docs/python/tf/keras/metrics/Metric)s. ||\n\n\u003cbr /\u003e\n\n### `train_and_validate`\n\n[View source](https://github.com/tensorflow/ranking/blob/v0.5.3/tensorflow_ranking/python/keras/pipeline.py#L138-L156) \n\n @abc.abstractmethod\n train_and_validate(\n *arg, **kwargs\n ) -\u003e Any\n\nConstructs and runs the training pipeline.\n\n#### Example usage:\n\n pipeline = BasicPipeline(model, train_data, valid_data)\n pipeline.train_and_validate()\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Args ||\n|------------|----------------------------------------------------------------|\n| `*arg` | arguments that might be used in the training pipeline. |\n| `**kwargs` | keyword arguments that might be used in the training pipeline. |\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Returns ||\n|---|---|\n| None or a trained [`tf.keras.Model`](https://www.tensorflow.org/api_docs/python/tf/keras/Model) or a path to a saved [`tf.keras.Model`](https://www.tensorflow.org/api_docs/python/tf/keras/Model). ||\n\n\u003cbr /\u003e"]]