View source on GitHub
  
 | 
Object that returns a tf.data.Dataset upon invoking.
tf.keras.utils.experimental.DatasetCreator(
    dataset_fn
)
tf.keras.utils.experimental.DatasetCreator is designated as a supported type
for x, or the input, in tf.keras.Model.fit. Pass an instance of this class
to fit when using a callable (with a input_context argument) that returns
a tf.data.Dataset.
model = tf.keras.Sequential([tf.keras.layers.Dense(10)])
model.compile(tf.keras.optimizers.SGD(), loss="mse")
def dataset_fn(input_context):
  global_batch_size = 64
  batch_size = input_context.get_per_replica_batch_size(global_batch_size)
  dataset = tf.data.Dataset.from_tensors(([1.], [1.])).repeat()
  dataset = dataset.shard(
      input_context.num_input_pipelines, input_context.input_pipeline_id)
  dataset = dataset.batch(batch_size)
  dataset = dataset.prefetch(2)
  return dataset
model.fit(DatasetCreator(dataset_fn), epochs=10, steps_per_epoch=10)
Model.fit usage with DatasetCreator is intended to work across all
tf.distribute.Strategys, as long as Strategy.scope is used at model
creation:
strategy = tf.distribute.experimental.ParameterServerStrategy(
    cluster_resolver)
with strategy.scope():
  model = tf.keras.Sequential([tf.keras.layers.Dense(10)])
model.compile(tf.keras.optimizers.SGD(), loss="mse")
...
Args | |
|---|---|
dataset_fn
 | 
A callable that takes a single argument of type
tf.distribute.InputContext, which is used for batch size calculation and
cross-worker input pipeline sharding (if neither is needed, the
InputContext parameter can be ignored in the dataset_fn), and returns
a tf.data.Dataset.
 | 
Methods
__call__
__call__(
    *args, **kwargs
)
Call self as a function.
    View source on GitHub