ImageClassifier class for inference and exporting to tflite.
tflite_model_maker.image_classifier.ImageClassifier(
model_spec,
index_to_label,
shuffle=True,
hparams=hub_lib.get_default_hparams(),
use_augmentation=False,
representative_data=None
)
Args |
model_spec
|
Specification for the model.
|
index_to_label
|
A list that map from index to label class name.
|
shuffle
|
Whether the data should be shuffled.
|
hparams
|
A namedtuple of hyperparameters. This function expects
.dropout_rate: The fraction of the input units to drop, used in dropout
layer.
.do_fine_tuning: If true, the Hub module is trained together with the
classification layer on top.
|
use_augmentation
|
Use data augmentation for preprocessing.
|
representative_data
|
Representative dataset for full integer
quantization. Used when converting the keras model to the TFLite model
with full integer quantization.
|
Methods
create
View source
@classmethod
create(
train_data,
model_spec='efficientnet_lite0',
validation_data=None,
batch_size=None,
epochs=None,
steps_per_epoch=None,
train_whole_model=None,
dropout_rate=None,
learning_rate=None,
momentum=None,
shuffle=False,
use_augmentation=False,
use_hub_library=True,
warmup_steps=None,
model_dir=None,
do_train=True
)
Loads data and retrains the model based on data for image classification.
Args |
train_data
|
Training data.
|
model_spec
|
Specification for the model.
|
validation_data
|
Validation data. If None, skips validation process.
|
batch_size
|
Number of samples per training step. If use_hub_library is
False, it represents the base learning rate when train batch size is 256
and it's linear to the batch size.
|
epochs
|
Number of epochs for training.
|
steps_per_epoch
|
Integer or None. Total number of steps (batches of
samples) before declaring one epoch finished and starting the next
epoch. If steps_per_epoch is None, the epoch will run until the input
dataset is exhausted.
|
train_whole_model
|
If true, the Hub module is trained together with the
classification layer on top. Otherwise, only train the top
classification layer.
|
dropout_rate
|
The rate for dropout.
|
learning_rate
|
Base learning rate when train batch size is 256. Linear to
the batch size.
|
momentum
|
a Python float forwarded to the optimizer. Only used when
use_hub_library is True.
|
shuffle
|
Whether the data should be shuffled.
|
use_augmentation
|
Use data augmentation for preprocessing.
|
use_hub_library
|
Use make_image_classifier_lib from tensorflow hub to
retrain the model.
|
warmup_steps
|
Number of warmup steps for warmup schedule on learning rate.
If None, the default warmup_steps is used which is the total training
steps in two epochs. Only used when use_hub_library is False.
|
model_dir
|
The location of the model checkpoint files. Only used when
use_hub_library is False.
|
do_train
|
Whether to run training.
|
Returns |
An instance based on ImageClassifier.
|
create_model
View source
create_model(
hparams=None, with_loss_and_metrics=False
)
Creates the classifier model for retraining.
create_serving_model
View source
create_serving_model()
Returns the underlining Keras model for serving.
evaluate
View source
evaluate(
data, batch_size=32
)
Evaluates the model.
Args |
data
|
Data to be evaluated.
|
batch_size
|
Number of samples per evaluation step.
|
Returns |
The loss value and accuracy.
|
evaluate_tflite
View source
evaluate_tflite(
tflite_filepath, data, postprocess_fn=None
)
Evaluates the tflite model.
Args |
tflite_filepath
|
File path to the TFLite model.
|
data
|
Data to be evaluated.
|
postprocess_fn
|
Postprocessing function that will be applied to the output
of lite_runner.run before calculating the probabilities.
|
Returns |
The evaluation result of TFLite model - accuracy.
|
export
View source
export(
export_dir,
tflite_filename='model.tflite',
label_filename='labels.txt',
vocab_filename='vocab.txt',
saved_model_filename='saved_model',
tfjs_folder_name='tfjs',
export_format=None,
**kwargs
)
Converts the retrained model based on export_format
.
Args |
export_dir
|
The directory to save exported files.
|
tflite_filename
|
File name to save tflite model. The full export path is
{export_dir}/{tflite_filename}.
|
label_filename
|
File name to save labels. The full export path is
{export_dir}/{label_filename}.
|
vocab_filename
|
File name to save vocabulary. The full export path is
{export_dir}/{vocab_filename}.
|
saved_model_filename
|
Path to SavedModel or H5 file to save the model. The
full export path is
{export_dir}/{saved_model_filename}/{saved_model.pb|assets|variables}.
|
tfjs_folder_name
|
Folder name to save tfjs model. The full export path is
{export_dir}/{tfjs_folder_name}.
|
export_format
|
List of export format that could be saved_model, tflite,
label, vocab.
|
**kwargs
|
Other parameters like quantized_config for TFLITE model.
|
predict_top_k
View source
predict_top_k(
data, k=1, batch_size=32
)
Predicts the top-k predictions.
Args |
data
|
Data to be evaluated. Either an instance of DataLoader or just raw
data entries such TF tensor or numpy array.
|
k
|
Number of top results to be predicted.
|
batch_size
|
Number of samples per evaluation step.
|
Returns |
top k results. Each one is (label, probability).
|
summary
View source
summary()
train
View source
train(
train_data, validation_data=None, hparams=None, steps_per_epoch=None
)
Feeds the training data for training.
Args |
train_data
|
Training data.
|
validation_data
|
Validation data. If None, skips validation process.
|
hparams
|
An instance of hub_lib.HParams or
train_image_classifier_lib.HParams. Anamedtuple of hyperparameters.
|
steps_per_epoch
|
Integer or None. Total number of steps (batches of
samples) before declaring one epoch finished and starting the next
epoch. If 'steps_per_epoch' is None, the epoch will run until the input
dataset is exhausted.
|
Returns |
The tf.keras.callbacks.History object returned by tf.keras.Model.fit*().
|
Class Variables |
ALLOWED_EXPORT_FORMAT
|
(<ExportFormat.TFLITE: 'TFLITE'>,
<ExportFormat.LABEL: 'LABEL'>,
<ExportFormat.SAVED_MODEL: 'SAVED_MODEL'>,
<ExportFormat.TFJS: 'TFJS'>)
|
DEFAULT_EXPORT_FORMAT
|
(<ExportFormat.TFLITE: 'TFLITE'>, <ExportFormat.LABEL: 'LABEL'>)
|