View source on GitHub |
Specifies what layers to prune in the model.
PruningPolicy controls application of PruneLowMagnitude
wrapper on per-layer
basis and checks that the model contains only supported layers.
PruningPolicy works together with prune_low_magnitude
through which it
provides fine-grained control over pruning in the model.
pruning_params = {
'pruning_schedule': ConstantSparsity(0.5, 0),
'block_size': (1, 1),
'block_pooling_type': 'AVG'
}
model = prune_low_magnitude(
keras.Sequential([
layers.Dense(10, activation='relu', input_shape=(100,)),
layers.Dense(2, activation='sigmoid')
]),
pruning_policy=PruneForLatencyOnXNNPack(),
**pruning_params)
You can inherit this class to write your own custom pruning policy.
The API is experimental and is subject to change.
Methods
allow_pruning
@abc.abstractmethod
allow_pruning( layer )
Checks if pruning wrapper should be applied for the current layer.
Args | |
---|---|
layer
|
Current layer in the model. |
Returns | |
---|---|
True/False, whether the pruning wrapper should be applied for the layer. |
ensure_model_supports_pruning
@abc.abstractmethod
ensure_model_supports_pruning( model )
Checks that the model contains only supported layers.
Args | |
---|---|
model
|
A tf.keras.Model instance which is going to be pruned.
|
Raises | |
---|---|
ValueError
|
if the keras model doesn't support pruning policy, i.e. keras model contains an unsupported layer. |