Pruning for on-device inference w/ XNNPACK

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

Welcome to the guide on Keras weights pruning for improving latency of on-device inference via XNNPACK.

This guide presents the usage of the newly introduced tfmot.sparsity.keras.PruningPolicy API and demonstrates how it could be used for accelerating mostly convolutional models on modern CPUs using XNNPACK Sparse inference.

The guide covers the following steps of the model creation process:

  • Build and train the dense baseline
  • Fine-tune model with pruning
  • Convert to TFLite
  • On-device benchmark

The guide doesn't cover the best practices for the fine-tuning with pruning. For more detailed information on this topic, please check out our comprehensive guide.

Setup

 pip install -q tensorflow
 pip install -q tensorflow-model-optimization
import tempfile

import tensorflow as tf
import numpy as np

from tensorflow import keras
import tensorflow_datasets as tfds
import tensorflow_model_optimization as tfmot
import tf_keras as keras

%load_ext tensorboard
2025-05-01 11:39:59.484850: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
E0000 00:00:1746099599.508625   29110 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1746099599.515906   29110 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1746099599.534198   29110 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1746099599.534224   29110 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1746099599.534227   29110 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1746099599.534230   29110 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.

Build and train the dense model

We build and train a simple baseline CNN for classification task on CIFAR10 dataset.

# Load CIFAR10 dataset.
(ds_train, ds_val, ds_test), ds_info = tfds.load(
    'cifar10',
    split=['train[:90%]', 'train[90%:]', 'test'],
    as_supervised=True,
    with_info=True,
)

# Normalize the input image so that each pixel value is between 0 and 1.
def normalize_img(image, label):
  """Normalizes images: `uint8` -> `float32`."""
  return tf.image.convert_image_dtype(image, tf.float32), label

# Load the data in batches of 128 images.
batch_size = 128
def prepare_dataset(ds, buffer_size=None):
  ds = ds.map(normalize_img, num_parallel_calls=tf.data.experimental.AUTOTUNE)
  ds = ds.cache()
  if buffer_size:
    ds = ds.shuffle(buffer_size)
  ds = ds.batch(batch_size)
  ds = ds.prefetch(tf.data.experimental.AUTOTUNE)
  return ds

ds_train = prepare_dataset(ds_train,
                           buffer_size=ds_info.splits['train'].num_examples)
ds_val = prepare_dataset(ds_val)
ds_test = prepare_dataset(ds_test)

# Build the dense baseline model.
dense_model = keras.Sequential([
    keras.layers.InputLayer(input_shape=(32, 32, 3)),
    keras.layers.ZeroPadding2D(padding=1),
    keras.layers.Conv2D(
        filters=8,
        kernel_size=(3, 3),
        strides=(2, 2),
        padding='valid'),
    keras.layers.BatchNormalization(),
    keras.layers.ReLU(),
    keras.layers.DepthwiseConv2D(kernel_size=(3, 3), padding='same'),
    keras.layers.BatchNormalization(),
    keras.layers.ReLU(),
    keras.layers.Conv2D(filters=16, kernel_size=(1, 1)),
    keras.layers.BatchNormalization(),
    keras.layers.ReLU(),
    keras.layers.ZeroPadding2D(padding=1),
    keras.layers.DepthwiseConv2D(
        kernel_size=(3, 3), strides=(2, 2), padding='valid'),
    keras.layers.BatchNormalization(),
    keras.layers.ReLU(),
    keras.layers.Conv2D(filters=32, kernel_size=(1, 1)),
    keras.layers.BatchNormalization(),
    keras.layers.ReLU(),
    keras.layers.GlobalAveragePooling2D(),
    keras.layers.Flatten(),
    keras.layers.Dense(10)
])

# Compile and train the dense model for 10 epochs.
dense_model.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer='adam',
    metrics=['accuracy'])

dense_model.fit(
  ds_train,
  epochs=10,
  validation_data=ds_val)

# Evaluate the dense model.
_, dense_model_accuracy = dense_model.evaluate(ds_test, verbose=0)
2025-05-01 11:40:04.315448: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
Epoch 1/10
352/352 [==============================] - 12s 19ms/step - loss: 1.9401 - accuracy: 0.2834 - val_loss: 2.3729 - val_accuracy: 0.1454
Epoch 2/10
352/352 [==============================] - 5s 14ms/step - loss: 1.6801 - accuracy: 0.3720 - val_loss: 1.6550 - val_accuracy: 0.3958
Epoch 3/10
352/352 [==============================] - 5s 14ms/step - loss: 1.6059 - accuracy: 0.4076 - val_loss: 1.5743 - val_accuracy: 0.4222
Epoch 4/10
352/352 [==============================] - 5s 14ms/step - loss: 1.5476 - accuracy: 0.4376 - val_loss: 1.5349 - val_accuracy: 0.4386
Epoch 5/10
352/352 [==============================] - 5s 14ms/step - loss: 1.5056 - accuracy: 0.4560 - val_loss: 1.4769 - val_accuracy: 0.4576
Epoch 6/10
352/352 [==============================] - 5s 14ms/step - loss: 1.4735 - accuracy: 0.4655 - val_loss: 1.4941 - val_accuracy: 0.4610
Epoch 7/10
352/352 [==============================] - 5s 14ms/step - loss: 1.4505 - accuracy: 0.4768 - val_loss: 1.4739 - val_accuracy: 0.4548
Epoch 8/10
352/352 [==============================] - 5s 14ms/step - loss: 1.4332 - accuracy: 0.4781 - val_loss: 1.4993 - val_accuracy: 0.4546
Epoch 9/10
352/352 [==============================] - 5s 14ms/step - loss: 1.4181 - accuracy: 0.4856 - val_loss: 1.4890 - val_accuracy: 0.4470
Epoch 10/10
352/352 [==============================] - 5s 14ms/step - loss: 1.4068 - accuracy: 0.4896 - val_loss: 1.4430 - val_accuracy: 0.4746

Build the sparse model

Using the instructions from the comprehensive guide, we apply tfmot.sparsity.keras.prune_low_magnitude function with parameters that target on-device acceleration via pruning i.e. tfmot.sparsity.keras.PruneForLatencyOnXNNPack policy.

prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude

# Compute end step to finish pruning after after 5 epochs.
end_epoch = 5

num_iterations_per_epoch = len(ds_train)
end_step =  num_iterations_per_epoch * end_epoch

# Define parameters for pruning.
pruning_params = {
      'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=0.25,
                                                               final_sparsity=0.75,
                                                               begin_step=0,
                                                               end_step=end_step),
      'pruning_policy': tfmot.sparsity.keras.PruneForLatencyOnXNNPack()
}

# Try to apply pruning wrapper with pruning policy parameter.
try:
  model_for_pruning = prune_low_magnitude(dense_model, **pruning_params)
except ValueError as e:
  print(e)

The call prune_low_magnitude results in ValueError with the message Could not find a GlobalAveragePooling2D layer with keepdims = True in all output branches. The message indicates that the model isn't supported for pruning with policy tfmot.sparsity.keras.PruneForLatencyOnXNNPack and specifically the layer GlobalAveragePooling2D requires the parameter keepdims = True. Let's fix that and reapply prune_low_magnitude function.

fixed_dense_model = keras.Sequential([
    keras.layers.InputLayer(input_shape=(32, 32, 3)),
    keras.layers.ZeroPadding2D(padding=1),
    keras.layers.Conv2D(
        filters=8,
        kernel_size=(3, 3),
        strides=(2, 2),
        padding='valid'),
    keras.layers.BatchNormalization(),
    keras.layers.ReLU(),
    keras.layers.DepthwiseConv2D(kernel_size=(3, 3), padding='same'),
    keras.layers.BatchNormalization(),
    keras.layers.ReLU(),
    keras.layers.Conv2D(filters=16, kernel_size=(1, 1)),
    keras.layers.BatchNormalization(),
    keras.layers.ReLU(),
    keras.layers.ZeroPadding2D(padding=1),
    keras.layers.DepthwiseConv2D(
        kernel_size=(3, 3), strides=(2, 2), padding='valid'),
    keras.layers.BatchNormalization(),
    keras.layers.ReLU(),
    keras.layers.Conv2D(filters=32, kernel_size=(1, 1)),
    keras.layers.BatchNormalization(),
    keras.layers.ReLU(),
    keras.layers.GlobalAveragePooling2D(keepdims=True),
    keras.layers.Flatten(),
    keras.layers.Dense(10)
])

# Use the pretrained model for pruning instead of training from scratch.
fixed_dense_model.set_weights(dense_model.get_weights())

# Try to reapply pruning wrapper.
model_for_pruning = prune_low_magnitude(fixed_dense_model, **pruning_params)

Invocation of prune_low_magnitude has finished without any errors meaning that the model is fully supported for the tfmot.sparsity.keras.PruneForLatencyOnXNNPack policy and can be accelerated using XNNPACK Sparse inference.

Fine-tune the sparse model

Following the pruning example, we fine-tune the sparse model using the weights of the dense model. We start fine-tuning of the model with 25% sparsity (25% of the weights are set to zero) and end with 75% sparsity.

logdir = tempfile.mkdtemp()

callbacks = [
  tfmot.sparsity.keras.UpdatePruningStep(),
  tfmot.sparsity.keras.PruningSummaries(log_dir=logdir),
]

model_for_pruning.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer='adam',
    metrics=['accuracy'])

model_for_pruning.fit(
  ds_train,
  epochs=15,
  validation_data=ds_val,
  callbacks=callbacks)

# Evaluate the dense model.
_, pruned_model_accuracy = model_for_pruning.evaluate(ds_test, verbose=0)

print('Dense model test accuracy:', dense_model_accuracy)
print('Pruned model test accuracy:', pruned_model_accuracy)
Epoch 1/15
352/352 [==============================] - 8s 15ms/step - loss: 1.4156 - accuracy: 0.4864 - val_loss: 1.5387 - val_accuracy: 0.4482
Epoch 2/15
352/352 [==============================] - 5s 14ms/step - loss: 1.4474 - accuracy: 0.4761 - val_loss: 1.9037 - val_accuracy: 0.3348
Epoch 3/15
352/352 [==============================] - 5s 14ms/step - loss: 1.4695 - accuracy: 0.4673 - val_loss: 1.6044 - val_accuracy: 0.3986
Epoch 4/15
352/352 [==============================] - 5s 14ms/step - loss: 1.4859 - accuracy: 0.4590 - val_loss: 4.3158 - val_accuracy: 0.1766
Epoch 5/15
352/352 [==============================] - 5s 14ms/step - loss: 1.4680 - accuracy: 0.4659 - val_loss: 5.7806 - val_accuracy: 0.1308
Epoch 6/15
352/352 [==============================] - 5s 14ms/step - loss: 1.4515 - accuracy: 0.4743 - val_loss: 1.4982 - val_accuracy: 0.4474
Epoch 7/15
352/352 [==============================] - 5s 14ms/step - loss: 1.4402 - accuracy: 0.4759 - val_loss: 1.4497 - val_accuracy: 0.4676
Epoch 8/15
352/352 [==============================] - 5s 14ms/step - loss: 1.4295 - accuracy: 0.4813 - val_loss: 1.4707 - val_accuracy: 0.4604
Epoch 9/15
352/352 [==============================] - 5s 14ms/step - loss: 1.4228 - accuracy: 0.4839 - val_loss: 1.4692 - val_accuracy: 0.4590
Epoch 10/15
352/352 [==============================] - 5s 14ms/step - loss: 1.4156 - accuracy: 0.4875 - val_loss: 1.5493 - val_accuracy: 0.4440
Epoch 11/15
352/352 [==============================] - 5s 14ms/step - loss: 1.4088 - accuracy: 0.4884 - val_loss: 6.1619 - val_accuracy: 0.2098
Epoch 12/15
352/352 [==============================] - 5s 14ms/step - loss: 1.4069 - accuracy: 0.4879 - val_loss: 8.7068 - val_accuracy: 0.1784
Epoch 13/15
352/352 [==============================] - 5s 14ms/step - loss: 1.4003 - accuracy: 0.4938 - val_loss: 7.8199 - val_accuracy: 0.2046
Epoch 14/15
352/352 [==============================] - 5s 14ms/step - loss: 1.3972 - accuracy: 0.4936 - val_loss: 8.5771 - val_accuracy: 0.1998
Epoch 15/15
352/352 [==============================] - 5s 14ms/step - loss: 1.3907 - accuracy: 0.4941 - val_loss: 2.0665 - val_accuracy: 0.3506
Dense model test accuracy: 0.4632999897003174
Pruned model test accuracy: 0.3368000090122223

The logs show the progression of sparsity on a per-layer basis.

#docs_infra: no_execute
%tensorboard --logdir={logdir}

After the fine-tuning with pruning, test accuracy demonstrates a modest improvement (43% to 44%) compared to the dense model. Let's compare on-device latency using TFLite benchmark.

Model conversion and benchmarking

To convert the pruned model into TFLite, we need replace the PruneLowMagnitude wrappers with original layers via the strip_pruning function. Also, since the weights of the pruned model (model_for_pruning) are mostly zeros, we may apply an optimization tf.lite.Optimize.EXPERIMENTAL_SPARSITY to efficiently store the resulted TFLite model. This optimization flag is not required for the dense model.

converter = tf.lite.TFLiteConverter.from_keras_model(dense_model)
dense_tflite_model = converter.convert()

_, dense_tflite_file = tempfile.mkstemp('.tflite')
with open(dense_tflite_file, 'wb') as f:
  f.write(dense_tflite_model)

model_for_export = tfmot.sparsity.keras.strip_pruning(model_for_pruning)

converter = tf.lite.TFLiteConverter.from_keras_model(model_for_export)
converter.optimizations = [tf.lite.Optimize.EXPERIMENTAL_SPARSITY]
pruned_tflite_model = converter.convert()

_, pruned_tflite_file = tempfile.mkstemp('.tflite')
with open(pruned_tflite_file, 'wb') as f:
  f.write(pruned_tflite_model)
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmp7jl2_bo6/assets
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmp7jl2_bo6/assets
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
W0000 00:00:1746099744.188034   29110 tf_tfl_flatbuffer_helpers.cc:365] Ignored output_format.
W0000 00:00:1746099744.188072   29110 tf_tfl_flatbuffer_helpers.cc:368] Ignored drop_control_dependency.
I0000 00:00:1746099744.224342   29110 mlir_graph_optimization_pass.cc:425] MLIR V1 optimization pass is not enabled
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmp7tz7yd7n/assets
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmp7tz7yd7n/assets
W0000 00:00:1746099746.592382   29110 tf_tfl_flatbuffer_helpers.cc:365] Ignored output_format.
W0000 00:00:1746099746.592408   29110 tf_tfl_flatbuffer_helpers.cc:368] Ignored drop_control_dependency.

Following the instructions of TFLite Model Benchmarking Tool, we build the tool, upload it to the Android device together with dense and pruned TFLite models, and benchmark both models on the device.

! adb shell /data/local/tmp/benchmark_model \
    --graph=/data/local/tmp/dense_model.tflite \
    --use_xnnpack=true \
    --num_runs=100 \
    --num_threads=1
/bin/bash: adb: command not found
! adb shell /data/local/tmp/benchmark_model \
    --graph=/data/local/tmp/pruned_model.tflite \
    --use_xnnpack=true \
    --num_runs=100 \
    --num_threads=1
/bin/bash: adb: command not found

Benchmarks on Pixel 4 resulted in average inference time of 17us for the dense model and 12us for the pruned model. The on-device benchmarks demonstrate a clear 5us or 30% improvements in latency even for such small models. In our experience, larger models based on MobileNetV3 or EfficientNet-lite show similar performance improvements. The speed-up varies based on the relative contribution of 1x1 convolutions to the overall model.

Conclusion

In this tutorial, we show how one may create sparse models for faster on-device performance using the new functionality introduced by the TF MOT API and XNNPack. These sparse models are smaller and faster than their dense counterparts while retaining or even surpassing their quality.

We encourage you to try this new capability which can be particularly important for deploying your models on device.