View source on GitHub |
Modify a tf.keras layer or model to be pruned during training.
tfmot.sparsity.keras.prune_low_magnitude(
to_prune,
pruning_schedule=pruning_sched.ConstantSparsity(0.5, 0)
,
block_size=(1, 1),
block_pooling_type='AVG',
pruning_policy=None,
sparsity_m_by_n=None,
**kwargs
)
Used in the notebooks
Used in the guide |
---|
This function wraps a tf.keras model or layer with pruning functionality which sparsifies the layer's weights during training. For example, using this with 50% sparsity will ensure that 50% of the layer's weights are zero.
The function accepts either a single keras layer
(subclass of tf.keras.layers.Layer
), list of keras layers or a Sequential
or Functional tf.keras model and handles them appropriately.
If it encounters a layer it does not know how to handle, it will throw an error. While pruning an entire model, even a single unknown layer would lead to an error.
Prune a model:
pruning_params = {
'pruning_schedule': ConstantSparsity(0.5, 0),
'block_size': (1, 1),
'block_pooling_type': 'AVG'
}
model = prune_low_magnitude(
keras.Sequential([
layers.Dense(10, activation='relu', input_shape=(100,)),
layers.Dense(2, activation='sigmoid')
]), **pruning_params)
Prune a layer:
pruning_params = {
'pruning_schedule': PolynomialDecay(initial_sparsity=0.2,
final_sparsity=0.8, begin_step=1000, end_step=2000),
'block_size': (2, 3),
'block_pooling_type': 'MAX'
}
model = keras.Sequential([
layers.Dense(10, activation='relu', input_shape=(100,)),
prune_low_magnitude(layers.Dense(2, activation='tanh'), **pruning_params)
])
Pretrained models: you must first load the weights and then apply the prune API:
model.load_weights(...)
model = prune_low_magnitude(model)
Optimizer: this function removes the optimizer. The user is expected to compile the model again. It's easiest to rely on the default (step starts at 0) and then use that to determine the desired begin_step for the pruning_schedules.
Checkpointing: checkpointing should include the optimizer, not just the weights. Pruning supports checkpointing though upon inspection, the weights of checkpoints are not sparse (https://github.com/tensorflow/model-optimization/issues/206).
Arguments | |
---|---|
to_prune
|
A single keras layer, list of keras layers, or a
tf.keras.Model instance.
|
pruning_schedule
|
A PruningSchedule object that controls pruning rate
throughout training.
|
block_size
|
(optional) The dimensions (height, weight) for the block sparse pattern in rank-2 weight tensors. |
block_pooling_type
|
(optional) The function to use to pool weights in the block. Must be 'AVG' or 'MAX'. |
pruning_policy
|
(optional) The object that controls to which layers
PruneLowMagnitude wrapper will be applied. This API is experimental
and is subject to change.
|
sparsity_m_by_n
|
default None, otherwise a tuple of 2 integers, indicates pruning with m_by_n sparsity, e.g., (2, 4): 2 zeros out of 4 consecutive elements. It check whether we can do pruning with m_by_n sparsity. If this type of sparsity is not applicable, then an error is thrown. |
**kwargs
|
Additional keyword arguments to be passed to the keras layer. Ignored when to_prune is not a keras layer. |
Returns | |
---|---|
Layer or model modified with pruning wrappers. Optimizer is removed. |
Raises | |
---|---|
ValueError
|
if the keras layer is unsupported, or the keras model contains an unsupported layer. |