tfr.keras.pipeline.SimplePipeline
Stay organized with collections
Save and categorize content based on your preferences.
Pipleine for single-task training.
Inherits From: ModelFitPipeline
, AbstractPipeline
tfr.keras.pipeline.SimplePipeline(
model_builder: tfr.keras.model.AbstractModelBuilder
,
dataset_builder: tfr.keras.pipeline.AbstractDatasetBuilder
,
hparams: tfr.keras.pipeline.PipelineHparams
)
This handles a single loss and works with SimpleDatasetBuilder
. This can
also work with MultiLabelDatasetBuilder
. In this case, the same loss, as
well as all metrics, will be applied to all labels and their predictions
uniformly.
Use subclassing to customize the loss and metrics.
Example usage:
context_feature_spec = {}
example_feature_spec = {
"example_feature_1": tf.io.FixedLenFeature(
shape=(1,), dtype=tf.float32, default_value=0.0)
}
mask_feature_name = "list_mask"
label_spec = {
"utility": tf.io.FixedLenFeature(
shape=(1,), dtype=tf.float32, default_value=0.0)
}
dataset_hparams = DatasetHparams(
train_input_pattern="train.dat",
valid_input_pattern="valid.dat",
train_batch_size=128,
valid_batch_size=128)
pipeline_hparams = pipeline.PipelineHparams(
model_dir="model/",
num_epochs=2,
steps_per_epoch=5,
validation_steps=2,
learning_rate=0.01,
loss="softmax_loss")
model_builder = SimpleModelBuilder(
context_feature_spec, example_feature_spec, mask_feature_name)
dataset_builder = SimpleDatasetBuilder(
context_feature_spec,
example_feature_spec,
mask_feature_name,
label_spec,
dataset_hparams)
pipeline = SimplePipeline(model_builder, dataset_builder, pipeline_hparams)
pipeline.train_and_validate(verbose=1)
Args |
model_builder
|
A ModelBuilder instance for model fit.
|
dataset_builder
|
An AbstractDatasetBuilder instance to load train and
validate datasets and create signatures for SavedModel.
|
hparams
|
A dict containing model hyperparameters.
|
Methods
build_callbacks
View source
build_callbacks() -> List[tf.keras.callbacks.Callback]
Sets up Callbacks.
Example usage:
model_builder = ModelBuilder(...)
dataset_builder = DatasetBuilder(...)
hparams = PipelineHparams(...)
pipeline = BasicModelFitPipeline(model_builder, dataset_builder, hparams)
callbacks = pipeline.build_callbacks()
build_loss
View source
build_loss() -> tf.keras.losses.Loss
See AbstractPipeline
.
build_metrics
View source
build_metrics() -> List[tf.keras.metrics.Metric]
See AbstractPipeline
.
build_weighted_metrics
View source
build_weighted_metrics() -> List[tf.keras.metrics.Metric]
See AbstractPipeline
.
export_saved_model
View source
export_saved_model(
model: tf.keras.Model,
export_to: str,
checkpoint: Optional[tf.train.Checkpoint] = None
)
Exports the trained model with signatures.
Example usage:
model_builder = ModelBuilder(...)
dataset_builder = DatasetBuilder(...)
hparams = PipelineHparams(...)
pipeline = BasicModelFitPipeline(model_builder, dataset_builder, hparams)
pipeline.export_saved_model(model_builder.build(), 'saved_model/')
Args |
model
|
Model to be saved.
|
export_to
|
Specifies the directory the model is be exported to.
|
checkpoint
|
If given, export the model with weights from this checkpoint.
|
train_and_validate
View source
train_and_validate(
verbose=0
)
Main function to train the model with TPU strategy.
Example usage:
context_feature_spec = {}
example_feature_spec = {
"example_feature_1": tf.io.FixedLenFeature(
shape=(1,), dtype=tf.float32, default_value=0.0)
}
mask_feature_name = "list_mask"
label_spec = {
"utility": tf.io.FixedLenFeature(
shape=(1,), dtype=tf.float32, default_value=0.0)
}
dataset_hparams = DatasetHparams(
train_input_pattern="train.dat",
valid_input_pattern="valid.dat",
train_batch_size=128,
valid_batch_size=128)
pipeline_hparams = pipeline.PipelineHparams(
model_dir="model/",
num_epochs=2,
steps_per_epoch=5,
validation_steps=2,
learning_rate=0.01,
loss="softmax_loss")
model_builder = SimpleModelBuilder(
context_feature_spec, example_feature_spec, mask_feature_name)
dataset_builder = SimpleDatasetBuilder(
context_feature_spec,
example_feature_spec,
mask_feature_name,
label_spec,
dataset_hparams)
pipeline = BasicModelFitPipeline(
model_builder, dataset_builder, pipeline_hparams)
pipeline.train_and_validate(verbose=1)
Args |
verbose
|
An int for the verbosity level.
|
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.
Last updated 2023-08-18 UTC.
[[["Easy to understand","easyToUnderstand","thumb-up"],["Solved my problem","solvedMyProblem","thumb-up"],["Other","otherUp","thumb-up"]],[["Missing the information I need","missingTheInformationINeed","thumb-down"],["Too complicated / too many steps","tooComplicatedTooManySteps","thumb-down"],["Out of date","outOfDate","thumb-down"],["Samples / code issue","samplesCodeIssue","thumb-down"],["Other","otherDown","thumb-down"]],["Last updated 2023-08-18 UTC."],[],[],null,["# tfr.keras.pipeline.SimplePipeline\n\n\u003cbr /\u003e\n\n|----------------------------------------------------------------------------------------------------------------------------------|\n| [View source on GitHub](https://github.com/tensorflow/ranking/blob/v0.5.3/tensorflow_ranking/python/keras/pipeline.py#L659-L730) |\n\nPipleine for single-task training.\n\nInherits From: [`ModelFitPipeline`](../../../tfr/keras/pipeline/ModelFitPipeline), [`AbstractPipeline`](../../../tfr/keras/pipeline/AbstractPipeline) \n\n tfr.keras.pipeline.SimplePipeline(\n model_builder: ../../../tfr/keras/model/AbstractModelBuilder,\n dataset_builder: ../../../tfr/keras/pipeline/AbstractDatasetBuilder,\n hparams: ../../../tfr/keras/pipeline/PipelineHparams\n )\n\nThis handles a single loss and works with `SimpleDatasetBuilder`. This can\nalso work with `MultiLabelDatasetBuilder`. In this case, the same loss, as\nwell as all metrics, will be applied to all labels and their predictions\nuniformly.\n\nUse subclassing to customize the loss and metrics.\n\n#### Example usage:\n\n context_feature_spec = {}\n example_feature_spec = {\n \"example_feature_1\": tf.io.FixedLenFeature(\n shape=(1,), dtype=tf.float32, default_value=0.0)\n }\n mask_feature_name = \"list_mask\"\n label_spec = {\n \"utility\": tf.io.FixedLenFeature(\n shape=(1,), dtype=tf.float32, default_value=0.0)\n }\n dataset_hparams = DatasetHparams(\n train_input_pattern=\"train.dat\",\n valid_input_pattern=\"valid.dat\",\n train_batch_size=128,\n valid_batch_size=128)\n pipeline_hparams = pipeline.PipelineHparams(\n model_dir=\"model/\",\n num_epochs=2,\n steps_per_epoch=5,\n validation_steps=2,\n learning_rate=0.01,\n loss=\"softmax_loss\")\n model_builder = SimpleModelBuilder(\n context_feature_spec, example_feature_spec, mask_feature_name)\n dataset_builder = SimpleDatasetBuilder(\n context_feature_spec,\n example_feature_spec,\n mask_feature_name,\n label_spec,\n dataset_hparams)\n pipeline = SimplePipeline(model_builder, dataset_builder, pipeline_hparams)\n pipeline.train_and_validate(verbose=1)\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Args ---- ||\n|-------------------|----------------------------------------------------------------------------------------------------------------|\n| `model_builder` | A `ModelBuilder` instance for model fit. |\n| `dataset_builder` | An `AbstractDatasetBuilder` instance to load train and validate datasets and create signatures for SavedModel. |\n| `hparams` | A dict containing model hyperparameters. |\n\n\u003cbr /\u003e\n\nMethods\n-------\n\n### `build_callbacks`\n\n[View source](https://github.com/tensorflow/ranking/blob/v0.5.3/tensorflow_ranking/python/keras/pipeline.py#L472-L532) \n\n build_callbacks() -\u003e List[tf.keras.callbacks.Callback]\n\nSets up Callbacks.\n\n#### Example usage:\n\n model_builder = ModelBuilder(...)\n dataset_builder = DatasetBuilder(...)\n hparams = PipelineHparams(...)\n pipeline = BasicModelFitPipeline(model_builder, dataset_builder, hparams)\n callbacks = pipeline.build_callbacks()\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Returns ||\n|---|---|\n| A list of [`tf.keras.callbacks.Callback`](https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/Callback) or a [`tf.keras.callbacks.CallbackList`](https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/CallbackList) for tensorboard and checkpoint. ||\n\n\u003cbr /\u003e\n\n### `build_loss`\n\n[View source](https://github.com/tensorflow/ranking/blob/v0.5.3/tensorflow_ranking/python/keras/pipeline.py#L707-L713) \n\n build_loss() -\u003e tf.keras.losses.Loss\n\nSee `AbstractPipeline`.\n\n### `build_metrics`\n\n[View source](https://github.com/tensorflow/ranking/blob/v0.5.3/tensorflow_ranking/python/keras/pipeline.py#L715-L721) \n\n build_metrics() -\u003e List[tf.keras.metrics.Metric]\n\nSee `AbstractPipeline`.\n\n### `build_weighted_metrics`\n\n[View source](https://github.com/tensorflow/ranking/blob/v0.5.3/tensorflow_ranking/python/keras/pipeline.py#L723-L730) \n\n build_weighted_metrics() -\u003e List[tf.keras.metrics.Metric]\n\nSee `AbstractPipeline`.\n\n### `export_saved_model`\n\n[View source](https://github.com/tensorflow/ranking/blob/v0.5.3/tensorflow_ranking/python/keras/pipeline.py#L534-L559) \n\n export_saved_model(\n model: tf.keras.Model,\n export_to: str,\n checkpoint: Optional[tf.train.Checkpoint] = None\n )\n\nExports the trained model with signatures.\n\n#### Example usage:\n\n model_builder = ModelBuilder(...)\n dataset_builder = DatasetBuilder(...)\n hparams = PipelineHparams(...)\n pipeline = BasicModelFitPipeline(model_builder, dataset_builder, hparams)\n pipeline.export_saved_model(model_builder.build(), 'saved_model/')\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Args ||\n|--------------|---------------------------------------------------------------|\n| `model` | Model to be saved. |\n| `export_to` | Specifies the directory the model is be exported to. |\n| `checkpoint` | If given, export the model with weights from this checkpoint. |\n\n\u003cbr /\u003e\n\n### `train_and_validate`\n\n[View source](https://github.com/tensorflow/ranking/blob/v0.5.3/tensorflow_ranking/python/keras/pipeline.py#L561-L650) \n\n train_and_validate(\n verbose=0\n )\n\nMain function to train the model with TPU strategy.\n\n#### Example usage:\n\n context_feature_spec = {}\n example_feature_spec = {\n \"example_feature_1\": tf.io.FixedLenFeature(\n shape=(1,), dtype=tf.float32, default_value=0.0)\n }\n mask_feature_name = \"list_mask\"\n label_spec = {\n \"utility\": tf.io.FixedLenFeature(\n shape=(1,), dtype=tf.float32, default_value=0.0)\n }\n dataset_hparams = DatasetHparams(\n train_input_pattern=\"train.dat\",\n valid_input_pattern=\"valid.dat\",\n train_batch_size=128,\n valid_batch_size=128)\n pipeline_hparams = pipeline.PipelineHparams(\n model_dir=\"model/\",\n num_epochs=2,\n steps_per_epoch=5,\n validation_steps=2,\n learning_rate=0.01,\n loss=\"softmax_loss\")\n model_builder = SimpleModelBuilder(\n context_feature_spec, example_feature_spec, mask_feature_name)\n dataset_builder = SimpleDatasetBuilder(\n context_feature_spec,\n example_feature_spec,\n mask_feature_name,\n label_spec,\n dataset_hparams)\n pipeline = BasicModelFitPipeline(\n model_builder, dataset_builder, pipeline_hparams)\n pipeline.train_and_validate(verbose=1)\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Args ||\n|-----------|---------------------------------|\n| `verbose` | An int for the verbosity level. |\n\n\u003cbr /\u003e"]]