View source on GitHub |
Model that adds a Counterfactual loss component to another model during training.
model_remediation.counterfactual.keras.CounterfactualModel(
original_model: tf.keras.Model,
loss=losses.PairwiseMSELoss()
,
loss_weight=1.0,
**kwargs
)
Inherits from: tf.keras.Model
Arguments | |
---|---|
original_model
|
Instance of tf.keras.Model that will be trained with the
additional counterfactual_loss .
|
loss
|
Instance of counterfactual.losses.CounterfactualLoss or string of
loss name that will be used to calculate the counterfactual_loss .
Defaults to PairwiseMSELoss .
|
loss_weight
|
Scalar applied to the counterfactual_loss before being
included in training. Defaults to 1.0.
|
**kwargs
|
Named parameters that will be passed directly to the base
class' __init__ function.
|
CounterfactualModel
wraps the model passed in, original_model
, and adds a
component to the loss during training and optionally during evaluation.
Construction
There are two ways to construct a CounterfactualModel
instance:
1 - Directly wrap your model with CounterfactualModel
. This is the simplest
usage and is most likely what you will want to use (unless your original model
has some custom implementations that need to be taken into account).
import tensorflow as tf
model = tf.keras.Sequential([...])
model = CounterfactualModel(model, ...)
In this case, all methods other than the ones listed below will use the
default implementations of tf.keras.Model
.
If you are in this use case, the next section is not relevant to you and you skip to the section on usage.
2 - Subclassing CounterfactualModel
to integrate custom implementations.
This will likely be needed if the original_model is itself a customized
subclass of tf.keras.Model
. If that is the case and you want to preserve the
custom implementations, you can create a new custom class that inherits first
from CounterfactualModel
and second from your custom class.
import tensorflow as tf
class CustomSequential(tf.keras.Sequential):
def train_step(self, data):
print("In a custom train_step!")
super().train_step(data)
class CustomCounterfactualModel(CounterfactualModel, CustomSequential):
pass # No additional implementation is required.
model = CustomSequential([...])
model = CustomCounterfactualModel(model, ...) # This will use the custom
# train_step.
If you need to customize methods defined by CounterfactualModel
, then you
can create a direct subclass and override whatever is needed.
import tensorflow as tf
class CustomCounterfactualModel(CounterfactualModel):
def update_metrics(self, inputs, ...):
print("In a custom CounterfactualModel method!")
super().update_metrics(inputs, ...)
model = tf.keras.Sequential([...])
model = CounterfactualModel(model, ...) # This will use the custom
# update_metrics method.
Usage:
Once you have created an instance of CounterfactualModel
, it can be used
almost exactly the same way as the model it wraps. The main two exceptions to
this are:
- During training, the inputs must include
counterfactual_data
, seeCounterfactualModel.compute_counterfactual_loss
for details. - Saving and loading a model can have slightly different behavior if you are
subclassing
CounterfactualModel
. SeeCounterfactualModel.save
andCounterfactualModel.save_original_model
for details.
Optionally, inputs containing counterfactual_data
can be passed in to
evaluate
and predict
. For the former, this will result in the
counterfactual_loss
appearing in the metrics. For predict
this
should have no visible effect.
Attributes | |
---|---|
original_model
|
tf.keras.Model to be trained with the additional counterfactual_loss .
Inference and evaluation will also come from the results this model provides. |
Methods
call
call(
inputs, *args, **kwargs
)
Calls the model on new inputs and returns the outputs as tensors.
In this case call()
just reapplies
all ops in the graph to the new inputs
(e.g. build a new computational graph from the provided inputs).
Args | |
---|---|
inputs
|
Input tensor, or dict/list/tuple of input tensors. |
training
|
Boolean or boolean scalar tensor, indicating whether to
run the Network in training mode or inference mode.
|
mask
|
A mask or list of masks. A mask can be either a boolean tensor or None (no mask). For more details, check the guide here. |
Returns | |
---|---|
A tensor if there is a single output, or a list of tensors if there are more than one outputs. |
compile
compile(
*args, **kwargs
)
Compile both self
and original_model
using the same parameters.
See tf.keras.Model.compile
for details.
compute_counterfactual_loss
compute_counterfactual_loss(
original_predictions,
counterfactual_predictions,
counterfactual_sample_weight
)
Computes counterfactual_loss
(es) corresponding to counterfactual_data
.
Arguments | |
---|---|
original_predictions
|
Predictions on original data. |
counterfactual_predictions
|
Predictions of a model on counterfactual data. |
counterfactual_sample_weight
|
Per sample weight to scale counterfactual loss. |
Returns | |
---|---|
Scalar (if only one) or list of counterfactual_loss values calculated
from counterfactual_data .
|
compute_total_loss
compute_total_loss(
y,
y_pred,
y_pred_original,
y_pred_counterfactual,
sample_weight,
cf_sample_weight
)
save
save(
*args, **kwargs
)
Exports the model as described in tf.keras.Model.save
.
For subclasses of CounterfactualModel
that have not been registered as
Keras objects, this method will likely be what you want to call to continue
training your model with Counterfactual after having loaded it. If you want
to use the loaded model purely for inference, you will likely want to use
CounterfactualModel.save_original_model
instead.
The exception noted above for unregistered CounterfactualModel
subclasses
is the only difference with tf.keras.Model.save
. To avoid these subtle
differences, we strongly recommend registering CounterfactualModel
subclasses as Keras objects. See the documentation of
tf.keras.utils.register_keras_serializable
for details.
save_original_model
save_original_model(
*args, **kwargs
)
Exports the original_model
.
This model will be the type of original_model
and will no longer be able
to train or evaluate with Counterfactual data.
update_metrics
update_metrics(
y, y_pred, sample_weight, total_loss, compiled_loss, counterfactual_loss
)
Updates mean metrics being tracked for Counterfactual losses.