Pruning preserving quantization aware training (PQAT) Keras example

View on TensorFlow.org Run in Google Colab View on GitHub Download notebook

Overview

This is an end to end example showing the usage of the pruning preserving quantization aware training (PQAT) API, part of the TensorFlow Model Optimization Toolkit's collaborative optimization pipeline.

Other pages

For an introduction to the pipeline and other available techniques, see the collaborative optimization overview page.

Contents

In the tutorial, you will:

  1. Train a keras model for the MNIST dataset from scratch.
  2. Fine-tune the model with pruning, using the sparsity API, and see the accuracy.
  3. Apply QAT and observe the loss of sparsity.
  4. Apply PQAT and observe that the sparsity applied earlier has been preserved.
  5. Generate a TFLite model and observe the effects of applying PQAT on it.
  6. Compare the achieved PQAT model accuracy with a model quantized using post-training quantization.

Setup

You can run this Jupyter Notebook in your local virtualenv or colab. For details of setting up dependencies, please refer to the installation guide.

 pip install -q tensorflow-model-optimization
import tensorflow as tf
import tf_keras as keras

import numpy as np
import tempfile
import zipfile
import os
2025-06-21 11:46:00.181320: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
E0000 00:00:1750506360.203691   38386 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1750506360.210714   38386 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1750506360.228609   38386 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1750506360.228635   38386 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1750506360.228638   38386 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1750506360.228640   38386 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.

Train a keras model for MNIST without pruning

# Load MNIST dataset
mnist = keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# Normalize the input image so that each pixel value is between 0 to 1.
train_images = train_images / 255.0
test_images  = test_images / 255.0

model = keras.Sequential([
  keras.layers.InputLayer(input_shape=(28, 28)),
  keras.layers.Reshape(target_shape=(28, 28, 1)),
  keras.layers.Conv2D(filters=12, kernel_size=(3, 3),
                         activation=tf.nn.relu),
  keras.layers.MaxPooling2D(pool_size=(2, 2)),
  keras.layers.Flatten(),
  keras.layers.Dense(10)
])

# Train the digit classification model
model.compile(optimizer='adam',
              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

model.fit(
    train_images,
    train_labels,
    validation_split=0.1,
    epochs=10
)
2025-06-21 11:46:03.997313: E external/local_xla/xla/stream_executor/cuda/cuda_platform.cc:51] failed call to cuInit: INTERNAL: CUDA error: Failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
Epoch 1/10
1688/1688 [==============================] - 8s 5ms/step - loss: 0.2810 - accuracy: 0.9208 - val_loss: 0.1189 - val_accuracy: 0.9678
Epoch 2/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.1100 - accuracy: 0.9684 - val_loss: 0.0838 - val_accuracy: 0.9778
Epoch 3/10
1688/1688 [==============================] - 8s 4ms/step - loss: 0.0805 - accuracy: 0.9764 - val_loss: 0.0779 - val_accuracy: 0.9788
Epoch 4/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0661 - accuracy: 0.9804 - val_loss: 0.0632 - val_accuracy: 0.9827
Epoch 5/10
1688/1688 [==============================] - 7s 4ms/step - loss: 0.0583 - accuracy: 0.9826 - val_loss: 0.0637 - val_accuracy: 0.9823
Epoch 6/10
1688/1688 [==============================] - 8s 4ms/step - loss: 0.0516 - accuracy: 0.9843 - val_loss: 0.0573 - val_accuracy: 0.9843
Epoch 7/10
1688/1688 [==============================] - 8s 4ms/step - loss: 0.0460 - accuracy: 0.9860 - val_loss: 0.0620 - val_accuracy: 0.9835
Epoch 8/10
1688/1688 [==============================] - 8s 4ms/step - loss: 0.0411 - accuracy: 0.9877 - val_loss: 0.0549 - val_accuracy: 0.9853
Epoch 9/10
1688/1688 [==============================] - 8s 4ms/step - loss: 0.0376 - accuracy: 0.9886 - val_loss: 0.0591 - val_accuracy: 0.9850
Epoch 10/10
1688/1688 [==============================] - 8s 4ms/step - loss: 0.0343 - accuracy: 0.9892 - val_loss: 0.0622 - val_accuracy: 0.9837
<tf_keras.src.callbacks.History at 0x7ff16cea98e0>

Evaluate the baseline model and save it for later usage

_, baseline_model_accuracy = model.evaluate(
    test_images, test_labels, verbose=0)

print('Baseline test accuracy:', baseline_model_accuracy)

_, keras_file = tempfile.mkstemp('.h5')
print('Saving model to: ', keras_file)
keras.models.save_model(model, keras_file, include_optimizer=False)
Baseline test accuracy: 0.9803000092506409
Saving model to:  /tmpfs/tmp/tmpe6y2pzu6.h5
/tmpfs/tmp/ipykernel_38386/3680774635.py:8: UserWarning: You are saving your model as an HDF5 file via `model.save()`. This file format is considered legacy. We recommend using instead the native TF-Keras format, e.g. `model.save('my_model.keras')`.
  keras.models.save_model(model, keras_file, include_optimizer=False)

Prune and fine-tune the model to 50% sparsity

Apply the prune_low_magnitude() API to prune the whole pre-trained model to demonstrate and observe its effectiveness in reducing the model size when applying zip, while maintaining accuracy. For how best to use the API to achieve the best compression rate while maintaining your target accuracy, refer to the pruning comprehensive guide.

Define the model and apply the sparsity API

The model needs to be pre-trained before using the sparsity API.

import tensorflow_model_optimization as tfmot

prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude

pruning_params = {
      'pruning_schedule': tfmot.sparsity.keras.ConstantSparsity(0.5, begin_step=0, frequency=100)
  }

callbacks = [
  tfmot.sparsity.keras.UpdatePruningStep()
]

pruned_model = prune_low_magnitude(model, **pruning_params)

# Use smaller learning rate for fine-tuning
opt = keras.optimizers.Adam(learning_rate=1e-5)

pruned_model.compile(
  loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
  optimizer=opt,
  metrics=['accuracy'])

pruned_model.summary()
Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 prune_low_magnitude_reshap  (None, 28, 28, 1)         1         
 e (PruneLowMagnitude)                                           
                                                                 
 prune_low_magnitude_conv2d  (None, 26, 26, 12)        230       
  (PruneLowMagnitude)                                            
                                                                 
 prune_low_magnitude_max_po  (None, 13, 13, 12)        1         
 oling2d (PruneLowMagnitude                                      
 )                                                               
                                                                 
 prune_low_magnitude_flatte  (None, 2028)              1         
 n (PruneLowMagnitude)                                           
                                                                 
 prune_low_magnitude_dense   (None, 10)                40572     
 (PruneLowMagnitude)                                             
                                                                 
=================================================================
Total params: 40805 (159.41 KB)
Trainable params: 20410 (79.73 KB)
Non-trainable params: 20395 (79.69 KB)
_________________________________________________________________

Fine-tune the model and evaluate the accuracy against baseline

Fine-tune the model with pruning for 3 epochs.

# Fine-tune model
pruned_model.fit(
  train_images,
  train_labels,
  epochs=3,
  validation_split=0.1,
  callbacks=callbacks)
Epoch 1/3
1688/1688 [==============================] - 10s 5ms/step - loss: 0.0967 - accuracy: 0.9677 - val_loss: 0.0848 - val_accuracy: 0.9747
Epoch 2/3
1688/1688 [==============================] - 8s 5ms/step - loss: 0.0605 - accuracy: 0.9813 - val_loss: 0.0697 - val_accuracy: 0.9788
Epoch 3/3
1688/1688 [==============================] - 8s 5ms/step - loss: 0.0503 - accuracy: 0.9855 - val_loss: 0.0660 - val_accuracy: 0.9808
<tf_keras.src.callbacks.History at 0x7ff0be0233d0>

Define helper functions to calculate and print the sparsity of the model.

def print_model_weights_sparsity(model):

    for layer in model.layers:
        if isinstance(layer, keras.layers.Wrapper):
            weights = layer.trainable_weights
        else:
            weights = layer.weights
        for weight in weights:
            # ignore auxiliary quantization weights
            if "quantize_layer" in weight.name:
                continue
            weight_size = weight.numpy().size
            zero_num = np.count_nonzero(weight == 0)
            print(
                f"{weight.name}: {zero_num/weight_size:.2%} sparsity ",
                f"({zero_num}/{weight_size})",
            )

Check that the model was correctly pruned. We need to strip the pruning wrapper first.

stripped_pruned_model = tfmot.sparsity.keras.strip_pruning(pruned_model)

print_model_weights_sparsity(stripped_pruned_model)
conv2d/kernel:0: 50.00% sparsity  (54/108)
conv2d/bias:0: 0.00% sparsity  (0/12)
dense/kernel:0: 50.00% sparsity  (10140/20280)
dense/bias:0: 0.00% sparsity  (0/10)

For this example, there is minimal loss in test accuracy after pruning, compared to the baseline.

_, pruned_model_accuracy = pruned_model.evaluate(
  test_images, test_labels, verbose=0)

print('Baseline test accuracy:', baseline_model_accuracy)
print('Pruned test accuracy:', pruned_model_accuracy)
Baseline test accuracy: 0.9803000092506409
Pruned test accuracy: 0.977400004863739

Apply QAT and PQAT and check effect on model sparsity in both cases

Next, we apply both QAT and pruning-preserving QAT (PQAT) on the pruned model and observe that PQAT preserves sparsity on your pruned model. Note that we stripped pruning wrappers from your pruned model with tfmot.sparsity.keras.strip_pruning before applying PQAT API.

# QAT
qat_model = tfmot.quantization.keras.quantize_model(stripped_pruned_model)

qat_model.compile(optimizer='adam',
              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])
print('Train qat model:')
qat_model.fit(train_images, train_labels, batch_size=128, epochs=1, validation_split=0.1)

# PQAT
quant_aware_annotate_model = tfmot.quantization.keras.quantize_annotate_model(
              stripped_pruned_model)
pqat_model = tfmot.quantization.keras.quantize_apply(
              quant_aware_annotate_model,
              tfmot.experimental.combine.Default8BitPrunePreserveQuantizeScheme())

pqat_model.compile(optimizer='adam',
              loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])
print('Train pqat model:')
pqat_model.fit(train_images, train_labels, batch_size=128, epochs=1, validation_split=0.1)
Train qat model:
422/422 [==============================] - 4s 7ms/step - loss: 0.0382 - accuracy: 0.9886 - val_loss: 0.0549 - val_accuracy: 0.9852
Train pqat model:
422/422 [==============================] - 4s 7ms/step - loss: 0.0385 - accuracy: 0.9891 - val_loss: 0.0566 - val_accuracy: 0.9837
<tf_keras.src.callbacks.History at 0x7ff0a35992e0>
print("QAT Model sparsity:")
print_model_weights_sparsity(qat_model)
print("PQAT Model sparsity:")
print_model_weights_sparsity(pqat_model)
QAT Model sparsity:
conv2d/kernel:0: 24.07% sparsity  (26/108)
conv2d/bias:0: 0.00% sparsity  (0/12)
dense/kernel:0: 14.67% sparsity  (2975/20280)
dense/bias:0: 0.00% sparsity  (0/10)
PQAT Model sparsity:
conv2d/kernel:0: 50.00% sparsity  (54/108)
conv2d/bias:0: 0.00% sparsity  (0/12)
dense/kernel:0: 50.00% sparsity  (10140/20280)
dense/bias:0: 0.00% sparsity  (0/10)

See compression benefits of PQAT model

Define helper function to get zipped model file.

def get_gzipped_model_size(file):
  # It returns the size of the gzipped model in kilobytes.

  _, zipped_file = tempfile.mkstemp('.zip')
  with zipfile.ZipFile(zipped_file, 'w', compression=zipfile.ZIP_DEFLATED) as f:
    f.write(file)

  return os.path.getsize(zipped_file)/1000

Since this is a small model, the difference between the two models isn't very noticeable. Applying pruning and PQAT to a bigger production model would yield a more significant compression.

# QAT model
converter = tf.lite.TFLiteConverter.from_keras_model(qat_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
qat_tflite_model = converter.convert()
qat_model_file = 'qat_model.tflite'
# Save the model.
with open(qat_model_file, 'wb') as f:
    f.write(qat_tflite_model)

# PQAT model
converter = tf.lite.TFLiteConverter.from_keras_model(pqat_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
pqat_tflite_model = converter.convert()
pqat_model_file = 'pqat_model.tflite'
# Save the model.
with open(pqat_model_file, 'wb') as f:
    f.write(pqat_tflite_model)

print("QAT model size: ", get_gzipped_model_size(qat_model_file), ' KB')
print("PQAT model size: ", get_gzipped_model_size(pqat_model_file), ' KB')
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpby5lp_3e/assets
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpby5lp_3e/assets
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/lite/python/convert.py:854: UserWarning: Statistics for quantized inputs were expected, but not specified; continuing anyway.
  warnings.warn(
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
W0000 00:00:1750506479.328212   38386 tf_tfl_flatbuffer_helpers.cc:365] Ignored output_format.
W0000 00:00:1750506479.328251   38386 tf_tfl_flatbuffer_helpers.cc:368] Ignored drop_control_dependency.
I0000 00:00:1750506479.342159   38386 mlir_graph_optimization_pass.cc:425] MLIR V1 optimization pass is not enabled
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpy9_zurxz/assets
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpy9_zurxz/assets
QAT model size:  16.134  KB
PQAT model size:  14.061  KB
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/lite/python/convert.py:854: UserWarning: Statistics for quantized inputs were expected, but not specified; continuing anyway.
  warnings.warn(
W0000 00:00:1750506480.897337   38386 tf_tfl_flatbuffer_helpers.cc:365] Ignored output_format.
W0000 00:00:1750506480.897365   38386 tf_tfl_flatbuffer_helpers.cc:368] Ignored drop_control_dependency.

See the persistence of accuracy from TF to TFLite

Define a helper function to evaluate the TFLite model on the test dataset.

def eval_model(interpreter):
  input_index = interpreter.get_input_details()[0]["index"]
  output_index = interpreter.get_output_details()[0]["index"]

  # Run predictions on every image in the "test" dataset.
  prediction_digits = []
  for i, test_image in enumerate(test_images):
    if i % 1000 == 0:
      print(f"Evaluated on {i} results so far.")
    # Pre-processing: add batch dimension and convert to float32 to match with
    # the model's input data format.
    test_image = np.expand_dims(test_image, axis=0).astype(np.float32)
    interpreter.set_tensor(input_index, test_image)

    # Run inference.
    interpreter.invoke()

    # Post-processing: remove batch dimension and find the digit with highest
    # probability.
    output = interpreter.tensor(output_index)
    digit = np.argmax(output()[0])
    prediction_digits.append(digit)

  print('\n')
  # Compare prediction results with ground truth labels to calculate accuracy.
  prediction_digits = np.array(prediction_digits)
  accuracy = (prediction_digits == test_labels).mean()
  return accuracy

You evaluate the model, which has been pruned and quantized, and then see the accuracy from TensorFlow persists in the TFLite backend.

interpreter = tf.lite.Interpreter(pqat_model_file)
interpreter.allocate_tensors()

pqat_test_accuracy = eval_model(interpreter)

print('Pruned and quantized TFLite test_accuracy:', pqat_test_accuracy)
print('Pruned TF test accuracy:', pruned_model_accuracy)
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/lite/python/interpreter.py:457: UserWarning:     Warning: tf.lite.Interpreter is deprecated and is scheduled for deletion in
    TF 2.20. Please use the LiteRT interpreter from the ai_edge_litert package.
    See the [migration guide](https://ai.google.dev/edge/litert/migration)
    for details.
    
  warnings.warn(_INTERPRETER_DELETION_WARNING)
INFO: Created TensorFlow Lite XNNPACK delegate for CPU.
Evaluated on 0 results so far.
Evaluated on 1000 results so far.
Evaluated on 2000 results so far.
Evaluated on 3000 results so far.
Evaluated on 4000 results so far.
Evaluated on 5000 results so far.
Evaluated on 6000 results so far.
Evaluated on 7000 results so far.
Evaluated on 8000 results so far.
Evaluated on 9000 results so far.


Pruned and quantized TFLite test_accuracy: 0.9816
Pruned TF test accuracy: 0.977400004863739

Apply post-training quantization and compare to PQAT model

Next, we use normal post-training quantization (no fine-tuning) on the pruned model and check its accuracy against the PQAT model. This demonstrates why you would need to use PQAT to improve the quantized model's accuracy.

First, define a generator for the callibration dataset from the first 1000 training images.

def mnist_representative_data_gen():
  for image in train_images[:1000]:  
    image = np.expand_dims(image, axis=0).astype(np.float32)
    yield [image]

Quantize the model and compare accuracy to the previously acquired PQAT model. Note that the model quantized with fine-tuning achieves higher accuracy.

converter = tf.lite.TFLiteConverter.from_keras_model(stripped_pruned_model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = mnist_representative_data_gen
post_training_tflite_model = converter.convert()
post_training_model_file = 'post_training_model.tflite'
# Save the model.
with open(post_training_model_file, 'wb') as f:
    f.write(post_training_tflite_model)

# Compare accuracy
interpreter = tf.lite.Interpreter(post_training_model_file)
interpreter.allocate_tensors()

post_training_test_accuracy = eval_model(interpreter)

print('PQAT TFLite test_accuracy:', pqat_test_accuracy)
print('Post-training (no fine-tuning) TF test accuracy:', post_training_test_accuracy)
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpw21bzd6b/assets
INFO:tensorflow:Assets written to: /tmpfs/tmp/tmpw21bzd6b/assets
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/lite/python/convert.py:854: UserWarning: Statistics for quantized inputs were expected, but not specified; continuing anyway.
  warnings.warn(
W0000 00:00:1750506482.189327   38386 tf_tfl_flatbuffer_helpers.cc:365] Ignored output_format.
W0000 00:00:1750506482.189352   38386 tf_tfl_flatbuffer_helpers.cc:368] Ignored drop_control_dependency.
fully_quantize: 0, inference_type: 6, input_inference_type: FLOAT32, output_inference_type: FLOAT32
/tmpfs/src/tf_docs_env/lib/python3.9/site-packages/tensorflow/lite/python/interpreter.py:457: UserWarning:     Warning: tf.lite.Interpreter is deprecated and is scheduled for deletion in
    TF 2.20. Please use the LiteRT interpreter from the ai_edge_litert package.
    See the [migration guide](https://ai.google.dev/edge/litert/migration)
    for details.
    
  warnings.warn(_INTERPRETER_DELETION_WARNING)
Evaluated on 0 results so far.
Evaluated on 1000 results so far.
Evaluated on 2000 results so far.
Evaluated on 3000 results so far.
Evaluated on 4000 results so far.
Evaluated on 5000 results so far.
Evaluated on 6000 results so far.
Evaluated on 7000 results so far.
Evaluated on 8000 results so far.
Evaluated on 9000 results so far.


PQAT TFLite test_accuracy: 0.9816
Post-training (no fine-tuning) TF test accuracy: 0.9778

Conclusion

In this tutorial, you learned how to create a model, prune it using the sparsity API, and apply the sparsity-preserving quantization aware training (PQAT) to preserve sparsity while using QAT. The final PQAT model was compared to the QAT one to show that the sparsity is preserved in the former and lost in the latter. Next, the models were converted to TFLite to show the compression benefits of chaining pruning and PQAT model optimization techniques and the TFLite model was evaluated to ensure that the accuracy persists in the TFLite backend. Finally, the PQAT model was compared to a quantized pruned model achieved using the post-training quantization API to demonstrate the advantage of PQAT in recovering the accuracy loss from normal quantization.