tfmot.sparsity.keras.prune_low_magnitude
Stay organized with collections
Save and categorize content based on your preferences.
Modify a tf.keras layer or model to be pruned during training.
tfmot.sparsity.keras.prune_low_magnitude(
to_prune,
pruning_schedule=pruning_sched.ConstantSparsity(0.5, 0)
,
block_size=(1, 1),
block_pooling_type='AVG',
pruning_policy=None,
sparsity_m_by_n=None,
**kwargs
)
Used in the notebooks
This function wraps a tf.keras model or layer with pruning functionality which
sparsifies the layer's weights during training. For example, using this with
50% sparsity will ensure that 50% of the layer's weights are zero.
The function accepts either a single keras layer
(subclass of tf.keras.layers.Layer
), list of keras layers or a Sequential
or Functional tf.keras model and handles them appropriately.
If it encounters a layer it does not know how to handle, it will throw an
error. While pruning an entire model, even a single unknown layer would lead
to an error.
Prune a model:
pruning_params = {
'pruning_schedule': ConstantSparsity(0.5, 0),
'block_size': (1, 1),
'block_pooling_type': 'AVG'
}
model = prune_low_magnitude(
keras.Sequential([
layers.Dense(10, activation='relu', input_shape=(100,)),
layers.Dense(2, activation='sigmoid')
]), **pruning_params)
Prune a layer:
pruning_params = {
'pruning_schedule': PolynomialDecay(initial_sparsity=0.2,
final_sparsity=0.8, begin_step=1000, end_step=2000),
'block_size': (2, 3),
'block_pooling_type': 'MAX'
}
model = keras.Sequential([
layers.Dense(10, activation='relu', input_shape=(100,)),
prune_low_magnitude(layers.Dense(2, activation='tanh'), **pruning_params)
])
Pretrained models: you must first load the weights and then apply the
prune API:
model.load_weights(...)
model = prune_low_magnitude(model)
Optimizer: this function removes the optimizer. The user is expected to
compile the model
again. It's easiest to rely on the default (step starts at 0) and then
use that to determine the desired begin_step for the pruning_schedules.
Checkpointing: checkpointing should include the optimizer, not just the
weights. Pruning supports
checkpointing though
upon inspection, the weights of checkpoints are not sparse
(https://github.com/tensorflow/model-optimization/issues/206).
Arguments |
to_prune
|
A single keras layer, list of keras layers, or a
tf.keras.Model instance.
|
pruning_schedule
|
A PruningSchedule object that controls pruning rate
throughout training.
|
block_size
|
(optional) The dimensions (height, weight) for the block
sparse pattern in rank-2 weight tensors.
|
block_pooling_type
|
(optional) The function to use to pool weights in the
block. Must be 'AVG' or 'MAX'.
|
pruning_policy
|
(optional) The object that controls to which layers
PruneLowMagnitude wrapper will be applied. This API is experimental
and is subject to change.
|
sparsity_m_by_n
|
default None, otherwise a tuple of 2 integers, indicates
pruning with m_by_n sparsity, e.g., (2, 4): 2 zeros out of 4 consecutive
elements. It check whether we can do pruning with m_by_n sparsity.
If this type of sparsity is not applicable, then an error is thrown.
|
**kwargs
|
Additional keyword arguments to be passed to the keras layer.
Ignored when to_prune is not a keras layer.
|
Returns |
Layer or model modified with pruning wrappers. Optimizer is removed.
|
Raises |
ValueError
|
if the keras layer is unsupported, or the keras model contains
an unsupported layer.
|
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.
Last updated 2023-05-26 UTC.
[[["Easy to understand","easyToUnderstand","thumb-up"],["Solved my problem","solvedMyProblem","thumb-up"],["Other","otherUp","thumb-up"]],[["Missing the information I need","missingTheInformationINeed","thumb-down"],["Too complicated / too many steps","tooComplicatedTooManySteps","thumb-down"],["Out of date","outOfDate","thumb-down"],["Samples / code issue","samplesCodeIssue","thumb-down"],["Other","otherDown","thumb-down"]],["Last updated 2023-05-26 UTC."],[],[],null,["# tfmot.sparsity.keras.prune_low_magnitude\n\n\u003cbr /\u003e\n\n|------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n| [View source on GitHub](https://github.com/tensorflow/model-optimization/blob/v0.7.5/tensorflow_model_optimization/python/core/sparsity/keras/prune.py#L56-L219) |\n\nModify a tf.keras layer or model to be pruned during training. \n\n tfmot.sparsity.keras.prune_low_magnitude(\n to_prune,\n pruning_schedule=../../../tfmot/sparsity/keras/ConstantSparsity,\n block_size=(1, 1),\n block_pooling_type='AVG',\n pruning_policy=None,\n sparsity_m_by_n=None,\n **kwargs\n )\n\n### Used in the notebooks\n\n| Used in the guide |\n|------------------------------------------------------------------------------------------------------------------|\n| - [Pruning comprehensive guide](https://www.tensorflow.org/model_optimization/guide/pruning/comprehensive_guide) |\n\nThis function wraps a tf.keras model or layer with pruning functionality which\nsparsifies the layer's weights during training. For example, using this with\n50% sparsity will ensure that 50% of the layer's weights are zero.\n\nThe function accepts either a single keras layer\n(subclass of [`tf.keras.layers.Layer`](https://www.tensorflow.org/api_docs/python/tf/keras/layers/Layer)), list of keras layers or a Sequential\nor Functional tf.keras model and handles them appropriately.\n\nIf it encounters a layer it does not know how to handle, it will throw an\nerror. While pruning an entire model, even a single unknown layer would lead\nto an error.\n\n#### Prune a model:\n\n pruning_params = {\n 'pruning_schedule': ConstantSparsity(0.5, 0),\n 'block_size': (1, 1),\n 'block_pooling_type': 'AVG'\n }\n\n model = prune_low_magnitude(\n keras.Sequential([\n layers.Dense(10, activation='relu', input_shape=(100,)),\n layers.Dense(2, activation='sigmoid')\n ]), **pruning_params)\n\n#### Prune a layer:\n\n pruning_params = {\n 'pruning_schedule': PolynomialDecay(initial_sparsity=0.2,\n final_sparsity=0.8, begin_step=1000, end_step=2000),\n 'block_size': (2, 3),\n 'block_pooling_type': 'MAX'\n }\n\n model = keras.Sequential([\n layers.Dense(10, activation='relu', input_shape=(100,)),\n prune_low_magnitude(layers.Dense(2, activation='tanh'), **pruning_params)\n ])\n\nPretrained models: you must first load the weights and then apply the\nprune API: \n\n model.load_weights(...)\n model = prune_low_magnitude(model)\n\nOptimizer: this function removes the optimizer. The user is expected to\ncompile the model\nagain. It's easiest to rely on the default (step starts at 0) and then\nuse that to determine the desired begin_step for the pruning_schedules.\n\nCheckpointing: checkpointing should include the optimizer, not just the\nweights. Pruning supports\ncheckpointing though\nupon inspection, the weights of checkpoints are not sparse\n(\u003chttps://github.com/tensorflow/model-optimization/issues/206\u003e).\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Arguments --------- ||\n|----------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|\n| `to_prune` | A single keras layer, list of keras layers, or a [`tf.keras.Model`](https://www.tensorflow.org/api_docs/python/tf/keras/Model) instance. |\n| `pruning_schedule` | A `PruningSchedule` object that controls pruning rate throughout training. |\n| `block_size` | (optional) The dimensions (height, weight) for the block sparse pattern in rank-2 weight tensors. |\n| `block_pooling_type` | (optional) The function to use to pool weights in the block. Must be 'AVG' or 'MAX'. |\n| `pruning_policy` | (optional) The object that controls to which layers `PruneLowMagnitude` wrapper will be applied. This API is experimental and is subject to change. |\n| `sparsity_m_by_n` | default None, otherwise a tuple of 2 integers, indicates pruning with m_by_n sparsity, e.g., (2, 4): 2 zeros out of 4 consecutive elements. It check whether we can do pruning with m_by_n sparsity. If this type of sparsity is not applicable, then an error is thrown. |\n| `**kwargs` | Additional keyword arguments to be passed to the keras layer. Ignored when to_prune is not a keras layer. |\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Returns ------- ||\n|---|---|\n| Layer or model modified with pruning wrappers. Optimizer is removed. ||\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n\u003cbr /\u003e\n\n| Raises ------ ||\n|--------------|--------------------------------------------------------------------------------------|\n| `ValueError` | if the keras layer is unsupported, or the keras model contains an unsupported layer. |\n\n\u003cbr /\u003e"]]