tf.keras.estimator.model_to_estimator
Stay organized with collections
Save and categorize content based on your preferences.
Constructs an Estimator
instance from given keras model.
tf.keras.estimator.model_to_estimator(
keras_model=None, keras_model_path=None, custom_objects=None, model_dir=None,
config=None, checkpoint_format='checkpoint'
)
If you use infrastructure or other tooling that relies on Estimators, you can
still build a Keras model and use model_to_estimator to convert the Keras
model to an Estimator for use with downstream systems.
For usage example, please see:
Creating estimators from Keras
Models.
Sample Weights:
Estimators returned by model_to_estimator
are configured so that they can
handle sample weights (similar to keras_model.fit(x, y, sample_weights)
).
To pass sample weights when training or evaluating the Estimator, the first
item returned by the input function should be a dictionary with keys
features
and sample_weights
. Example below:
keras_model = tf.keras.Model(...)
keras_model.compile(...)
estimator = tf.keras.estimator.model_to_estimator(keras_model)
def input_fn():
return dataset_ops.Dataset.from_tensors(
({'features': features, 'sample_weights': sample_weights},
targets))
estimator.train(input_fn, steps=1)
Args |
keras_model
|
A compiled Keras model object. This argument is mutually
exclusive with keras_model_path . Estimator's model_fn uses the
structure of the model to clone the model. Defaults to None .
|
keras_model_path
|
Path to a compiled Keras model saved on disk, in HDF5
format, which can be generated with the save() method of a Keras model.
This argument is mutually exclusive with keras_model .
Defaults to None .
|
custom_objects
|
Dictionary for cloning customized objects. This is
used with classes that is not part of this pip package. For example, if
user maintains a relu6 class that inherits from tf.keras.layers.Layer ,
then pass custom_objects={'relu6': relu6} . Defaults to None .
|
model_dir
|
Directory to save Estimator model parameters, graph, summary
files for TensorBoard, etc. If unset a directory will be created with
tempfile.mkdtemp
|
config
|
RunConfig to config Estimator . Allows setting up things in
model_fn based on configuration such as num_ps_replicas , or
model_dir . Defaults to None . If both config.model_dir and the
model_dir argument (above) are specified the model_dir argument
takes precedence.
|
checkpoint_format
|
Sets the format of the checkpoint saved by the estimator
when training. May be saver or checkpoint , depending on whether to
save checkpoints from tf.compat.v1.train.Saver or tf.train.Checkpoint .
The default is checkpoint . Estimators use name-based tf.train.Saver
checkpoints, while Keras models use object-based checkpoints from
tf.train.Checkpoint . Currently, saving object-based checkpoints from
model_to_estimator is only supported by Functional and Sequential
models. Defaults to 'checkpoint'.
|
Returns |
An Estimator from given keras model.
|
Raises |
ValueError
|
If neither keras_model nor keras_model_path was given.
|
ValueError
|
If both keras_model and keras_model_path was given.
|
ValueError
|
If the keras_model_path is a GCS URI.
|
ValueError
|
If keras_model has not been compiled.
|
ValueError
|
If an invalid checkpoint_format was given.
|
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.
Last updated 2020-10-01 UTC.
[[["Easy to understand","easyToUnderstand","thumb-up"],["Solved my problem","solvedMyProblem","thumb-up"],["Other","otherUp","thumb-up"]],[["Missing the information I need","missingTheInformationINeed","thumb-down"],["Too complicated / too many steps","tooComplicatedTooManySteps","thumb-down"],["Out of date","outOfDate","thumb-down"],["Samples / code issue","samplesCodeIssue","thumb-down"],["Other","otherDown","thumb-down"]],["Last updated 2020-10-01 UTC."],[],[],null,["# tf.keras.estimator.model_to_estimator\n\n\u003cbr /\u003e\n\n|-----------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------|\n| [TensorFlow 1 version](/versions/r1.15/api_docs/python/tf/keras/estimator/model_to_estimator) | [View source on GitHub](https://github.com/tensorflow/tensorflow/blob/v2.2.0/tensorflow/python/keras/estimator/__init__.py#L132-L226) |\n\nConstructs an `Estimator` instance from given keras model. \n\n tf.keras.estimator.model_to_estimator(\n keras_model=None, keras_model_path=None, custom_objects=None, model_dir=None,\n config=None, checkpoint_format='checkpoint'\n )\n\nIf you use infrastructure or other tooling that relies on Estimators, you can\nstill build a Keras model and use model_to_estimator to convert the Keras\nmodel to an Estimator for use with downstream systems.\n\nFor usage example, please see:\n[Creating estimators from Keras\nModels](https://www.tensorflow.org/guide/estimators#creating_estimators_from_keras_models).\n\n#### Sample Weights:\n\nEstimators returned by `model_to_estimator` are configured so that they can\nhandle sample weights (similar to `keras_model.fit(x, y, sample_weights)`).\n\nTo pass sample weights when training or evaluating the Estimator, the first\nitem returned by the input function should be a dictionary with keys\n`features` and `sample_weights`. Example below: \n\n keras_model = tf.keras.Model(...)\n keras_model.compile(...)\n\n estimator = tf.keras.estimator.model_to_estimator(keras_model)\n\n def input_fn():\n return dataset_ops.Dataset.from_tensors(\n ({'features': features, 'sample_weights': sample_weights},\n targets))\n\n estimator.train(input_fn, steps=1)\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Args ---- ||\n|---------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n| `keras_model` | A compiled Keras model object. This argument is mutually exclusive with `keras_model_path`. Estimator's `model_fn` uses the structure of the model to clone the model. Defaults to `None`. |\n| `keras_model_path` | Path to a compiled Keras model saved on disk, in HDF5 format, which can be generated with the `save()` method of a Keras model. This argument is mutually exclusive with `keras_model`. Defaults to `None`. |\n| `custom_objects` | Dictionary for cloning customized objects. This is used with classes that is not part of this pip package. For example, if user maintains a `relu6` class that inherits from [`tf.keras.layers.Layer`](../../../tf/keras/layers/Layer), then pass `custom_objects={'relu6': relu6}`. Defaults to `None`. |\n| `model_dir` | Directory to save `Estimator` model parameters, graph, summary files for TensorBoard, etc. If unset a directory will be created with `tempfile.mkdtemp` |\n| `config` | `RunConfig` to config `Estimator`. Allows setting up things in `model_fn` based on configuration such as `num_ps_replicas`, or `model_dir`. Defaults to `None`. If both `config.model_dir` and the `model_dir` argument (above) are specified the `model_dir` **argument** takes precedence. |\n| `checkpoint_format` | Sets the format of the checkpoint saved by the estimator when training. May be `saver` or `checkpoint`, depending on whether to save checkpoints from [`tf.compat.v1.train.Saver`](../../../tf/compat/v1/train/Saver) or [`tf.train.Checkpoint`](../../../tf/train/Checkpoint). The default is `checkpoint`. Estimators use name-based `tf.train.Saver` checkpoints, while Keras models use object-based checkpoints from [`tf.train.Checkpoint`](../../../tf/train/Checkpoint). Currently, saving object-based checkpoints from `model_to_estimator` is only supported by Functional and Sequential models. Defaults to 'checkpoint'. |\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Returns ------- ||\n|---|---|\n| An Estimator from given keras model. ||\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Raises ------ ||\n|--------------|--------------------------------------------------------|\n| `ValueError` | If neither keras_model nor keras_model_path was given. |\n| `ValueError` | If both keras_model and keras_model_path was given. |\n| `ValueError` | If the keras_model_path is a GCS URI. |\n| `ValueError` | If keras_model has not been compiled. |\n| `ValueError` | If an invalid checkpoint_format was given. |\n\n\u003cbr /\u003e"]]