tfr.keras.pipeline.MultiTaskPipeline
Stay organized with collections
Save and categorize content based on your preferences.
Pipeline for multi-task training.
Inherits From: ModelFitPipeline
, AbstractPipeline
tfr.keras.pipeline.MultiTaskPipeline(
model_builder: tfr.keras.model.AbstractModelBuilder
,
dataset_builder: tfr.keras.pipeline.AbstractDatasetBuilder
,
hparams: tfr.keras.pipeline.PipelineHparams
)
This handles a set of losses and labels. It is intended to mainly work with
MultiLabelDatasetBuilder
.
Use subclassing to customize the losses and metrics.
Example usage:
context_feature_spec = {}
example_feature_spec = {
"example_feature_1": tf.io.FixedLenFeature(
shape=(1,), dtype=tf.float32, default_value=0.0)
}
mask_feature_name = "list_mask"
label_spec_tuple = ("utility",
tf.io.FixedLenFeature(
shape=(1,),
dtype=tf.float32,
default_value=_PADDING_LABEL))
label_spec = {"task1": label_spec_tuple, "task2": label_spec_tuple}
weight_spec = ("weight",
tf.io.FixedLenFeature(
shape=(1,), dtype=tf.float32, default_value=1.))
dataset_hparams = DatasetHparams(
train_input_pattern="train.dat",
valid_input_pattern="valid.dat",
train_batch_size=128,
valid_batch_size=128)
pipeline_hparams = PipelineHparams(
model_dir="model/",
num_epochs=2,
steps_per_epoch=5,
validation_steps=2,
learning_rate=0.01,
loss={
"task1": "softmax_loss",
"task2": "pairwise_logistic_loss"
},
loss_weights={
"task1": 1.0,
"task2": 2.0
},
export_best_model=True)
model_builder = MultiTaskModelBuilder(...)
dataset_builder = MultiLabelDatasetBuilder(
context_feature_spec,
example_feature_spec,
mask_feature_name,
label_spec,
dataset_hparams,
sample_weight_spec=weight_spec)
pipeline = MultiTaskPipeline(model_builder, dataset_builder, pipeline_hparams)
pipeline.train_and_validate(verbose=1)
Args |
model_builder
|
A ModelBuilder instance for model fit.
|
dataset_builder
|
An AbstractDatasetBuilder instance to load train and
validate datasets and create signatures for SavedModel.
|
hparams
|
A dict containing model hyperparameters.
|
Methods
build_callbacks
View source
build_callbacks() -> List[tf.keras.callbacks.Callback]
Sets up Callbacks.
Example usage:
model_builder = ModelBuilder(...)
dataset_builder = DatasetBuilder(...)
hparams = PipelineHparams(...)
pipeline = BasicModelFitPipeline(model_builder, dataset_builder, hparams)
callbacks = pipeline.build_callbacks()
build_loss
View source
build_loss() -> Dict[str, tf.keras.losses.Loss]
See AbstractPipeline
.
build_metrics
View source
build_metrics() -> Dict[str, List[tf.keras.metrics.Metric]]
See AbstractPipeline
.
build_weighted_metrics
View source
build_weighted_metrics() -> Dict[str, List[tf.keras.metrics.Metric]]
See AbstractPipeline
.
export_saved_model
View source
export_saved_model(
model: tf.keras.Model,
export_to: str,
checkpoint: Optional[tf.train.Checkpoint] = None
)
Exports the trained model with signatures.
Example usage:
model_builder = ModelBuilder(...)
dataset_builder = DatasetBuilder(...)
hparams = PipelineHparams(...)
pipeline = BasicModelFitPipeline(model_builder, dataset_builder, hparams)
pipeline.export_saved_model(model_builder.build(), 'saved_model/')
Args |
model
|
Model to be saved.
|
export_to
|
Specifies the directory the model is be exported to.
|
checkpoint
|
If given, export the model with weights from this checkpoint.
|
train_and_validate
View source
train_and_validate(
verbose=0
)
Main function to train the model with TPU strategy.
Example usage:
context_feature_spec = {}
example_feature_spec = {
"example_feature_1": tf.io.FixedLenFeature(
shape=(1,), dtype=tf.float32, default_value=0.0)
}
mask_feature_name = "list_mask"
label_spec = {
"utility": tf.io.FixedLenFeature(
shape=(1,), dtype=tf.float32, default_value=0.0)
}
dataset_hparams = DatasetHparams(
train_input_pattern="train.dat",
valid_input_pattern="valid.dat",
train_batch_size=128,
valid_batch_size=128)
pipeline_hparams = pipeline.PipelineHparams(
model_dir="model/",
num_epochs=2,
steps_per_epoch=5,
validation_steps=2,
learning_rate=0.01,
loss="softmax_loss")
model_builder = SimpleModelBuilder(
context_feature_spec, example_feature_spec, mask_feature_name)
dataset_builder = SimpleDatasetBuilder(
context_feature_spec,
example_feature_spec,
mask_feature_name,
label_spec,
dataset_hparams)
pipeline = BasicModelFitPipeline(
model_builder, dataset_builder, pipeline_hparams)
pipeline.train_and_validate(verbose=1)
Args |
verbose
|
An int for the verbosity level.
|
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.
Last updated 2023-08-18 UTC.
[[["Easy to understand","easyToUnderstand","thumb-up"],["Solved my problem","solvedMyProblem","thumb-up"],["Other","otherUp","thumb-up"]],[["Missing the information I need","missingTheInformationINeed","thumb-down"],["Too complicated / too many steps","tooComplicatedTooManySteps","thumb-down"],["Out of date","outOfDate","thumb-down"],["Samples / code issue","samplesCodeIssue","thumb-down"],["Other","otherDown","thumb-down"]],["Last updated 2023-08-18 UTC."],[],[],null,["# tfr.keras.pipeline.MultiTaskPipeline\n\n\u003cbr /\u003e\n\n|----------------------------------------------------------------------------------------------------------------------------------|\n| [View source on GitHub](https://github.com/tensorflow/ranking/blob/v0.5.3/tensorflow_ranking/python/keras/pipeline.py#L733-L824) |\n\nPipeline for multi-task training.\n\nInherits From: [`ModelFitPipeline`](../../../tfr/keras/pipeline/ModelFitPipeline), [`AbstractPipeline`](../../../tfr/keras/pipeline/AbstractPipeline) \n\n tfr.keras.pipeline.MultiTaskPipeline(\n model_builder: ../../../tfr/keras/model/AbstractModelBuilder,\n dataset_builder: ../../../tfr/keras/pipeline/AbstractDatasetBuilder,\n hparams: ../../../tfr/keras/pipeline/PipelineHparams\n )\n\nThis handles a set of losses and labels. It is intended to mainly work with\n`MultiLabelDatasetBuilder`.\n\nUse subclassing to customize the losses and metrics.\n\n#### Example usage:\n\n context_feature_spec = {}\n example_feature_spec = {\n \"example_feature_1\": tf.io.FixedLenFeature(\n shape=(1,), dtype=tf.float32, default_value=0.0)\n }\n mask_feature_name = \"list_mask\"\n label_spec_tuple = (\"utility\",\n tf.io.FixedLenFeature(\n shape=(1,),\n dtype=tf.float32,\n default_value=_PADDING_LABEL))\n label_spec = {\"task1\": label_spec_tuple, \"task2\": label_spec_tuple}\n weight_spec = (\"weight\",\n tf.io.FixedLenFeature(\n shape=(1,), dtype=tf.float32, default_value=1.))\n dataset_hparams = DatasetHparams(\n train_input_pattern=\"train.dat\",\n valid_input_pattern=\"valid.dat\",\n train_batch_size=128,\n valid_batch_size=128)\n pipeline_hparams = PipelineHparams(\n model_dir=\"model/\",\n num_epochs=2,\n steps_per_epoch=5,\n validation_steps=2,\n learning_rate=0.01,\n loss={\n \"task1\": \"softmax_loss\",\n \"task2\": \"pairwise_logistic_loss\"\n },\n loss_weights={\n \"task1\": 1.0,\n \"task2\": 2.0\n },\n export_best_model=True)\n model_builder = MultiTaskModelBuilder(...)\n dataset_builder = MultiLabelDatasetBuilder(\n context_feature_spec,\n example_feature_spec,\n mask_feature_name,\n label_spec,\n dataset_hparams,\n sample_weight_spec=weight_spec)\n pipeline = MultiTaskPipeline(model_builder, dataset_builder, pipeline_hparams)\n pipeline.train_and_validate(verbose=1)\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Args ---- ||\n|-------------------|----------------------------------------------------------------------------------------------------------------|\n| `model_builder` | A `ModelBuilder` instance for model fit. |\n| `dataset_builder` | An `AbstractDatasetBuilder` instance to load train and validate datasets and create signatures for SavedModel. |\n| `hparams` | A dict containing model hyperparameters. |\n\n\u003cbr /\u003e\n\nMethods\n-------\n\n### `build_callbacks`\n\n[View source](https://github.com/tensorflow/ranking/blob/v0.5.3/tensorflow_ranking/python/keras/pipeline.py#L472-L532) \n\n build_callbacks() -\u003e List[tf.keras.callbacks.Callback]\n\nSets up Callbacks.\n\n#### Example usage:\n\n model_builder = ModelBuilder(...)\n dataset_builder = DatasetBuilder(...)\n hparams = PipelineHparams(...)\n pipeline = BasicModelFitPipeline(model_builder, dataset_builder, hparams)\n callbacks = pipeline.build_callbacks()\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Returns ||\n|---|---|\n| A list of [`tf.keras.callbacks.Callback`](https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/Callback) or a [`tf.keras.callbacks.CallbackList`](https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/CallbackList) for tensorboard and checkpoint. ||\n\n\u003cbr /\u003e\n\n### `build_loss`\n\n[View source](https://github.com/tensorflow/ranking/blob/v0.5.3/tensorflow_ranking/python/keras/pipeline.py#L792-L801) \n\n build_loss() -\u003e Dict[str, tf.keras.losses.Loss]\n\nSee `AbstractPipeline`.\n\n### `build_metrics`\n\n[View source](https://github.com/tensorflow/ranking/blob/v0.5.3/tensorflow_ranking/python/keras/pipeline.py#L803-L812) \n\n build_metrics() -\u003e Dict[str, List[tf.keras.metrics.Metric]]\n\nSee `AbstractPipeline`.\n\n### `build_weighted_metrics`\n\n[View source](https://github.com/tensorflow/ranking/blob/v0.5.3/tensorflow_ranking/python/keras/pipeline.py#L814-L824) \n\n build_weighted_metrics() -\u003e Dict[str, List[tf.keras.metrics.Metric]]\n\nSee `AbstractPipeline`.\n\n### `export_saved_model`\n\n[View source](https://github.com/tensorflow/ranking/blob/v0.5.3/tensorflow_ranking/python/keras/pipeline.py#L534-L559) \n\n export_saved_model(\n model: tf.keras.Model,\n export_to: str,\n checkpoint: Optional[tf.train.Checkpoint] = None\n )\n\nExports the trained model with signatures.\n\n#### Example usage:\n\n model_builder = ModelBuilder(...)\n dataset_builder = DatasetBuilder(...)\n hparams = PipelineHparams(...)\n pipeline = BasicModelFitPipeline(model_builder, dataset_builder, hparams)\n pipeline.export_saved_model(model_builder.build(), 'saved_model/')\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Args ||\n|--------------|---------------------------------------------------------------|\n| `model` | Model to be saved. |\n| `export_to` | Specifies the directory the model is be exported to. |\n| `checkpoint` | If given, export the model with weights from this checkpoint. |\n\n\u003cbr /\u003e\n\n### `train_and_validate`\n\n[View source](https://github.com/tensorflow/ranking/blob/v0.5.3/tensorflow_ranking/python/keras/pipeline.py#L561-L650) \n\n train_and_validate(\n verbose=0\n )\n\nMain function to train the model with TPU strategy.\n\n#### Example usage:\n\n context_feature_spec = {}\n example_feature_spec = {\n \"example_feature_1\": tf.io.FixedLenFeature(\n shape=(1,), dtype=tf.float32, default_value=0.0)\n }\n mask_feature_name = \"list_mask\"\n label_spec = {\n \"utility\": tf.io.FixedLenFeature(\n shape=(1,), dtype=tf.float32, default_value=0.0)\n }\n dataset_hparams = DatasetHparams(\n train_input_pattern=\"train.dat\",\n valid_input_pattern=\"valid.dat\",\n train_batch_size=128,\n valid_batch_size=128)\n pipeline_hparams = pipeline.PipelineHparams(\n model_dir=\"model/\",\n num_epochs=2,\n steps_per_epoch=5,\n validation_steps=2,\n learning_rate=0.01,\n loss=\"softmax_loss\")\n model_builder = SimpleModelBuilder(\n context_feature_spec, example_feature_spec, mask_feature_name)\n dataset_builder = SimpleDatasetBuilder(\n context_feature_spec,\n example_feature_spec,\n mask_feature_name,\n label_spec,\n dataset_hparams)\n pipeline = BasicModelFitPipeline(\n model_builder, dataset_builder, pipeline_hparams)\n pipeline.train_and_validate(verbose=1)\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Args ||\n|-----------|---------------------------------|\n| `verbose` | An int for the verbosity level. |\n\n\u003cbr /\u003e"]]