Introduction
It is possible to integrate MinDiff directly into your model's implementation. While doing so does not have the convenience of using MinDiffModel
, this option offers the highest level of control which can be particularly useful when your model is a subclass of tf.keras.Model
.
This guide demonstrates how you can integrate MinDiff directly into a custom model's implementation by adding to the train_step
method.
Setup
pip install --upgrade tensorflow-model-remediation
import tensorflow as tf
tf.get_logger().setLevel('ERROR') # Avoid TF warnings.
from tensorflow_model_remediation import min_diff
from tensorflow_model_remediation.tools.tutorials_utils import uci as tutorials_utils
First, download the data. For succinctness, the input preparation logic has been factored out into helper functions as described in the input preparation guide. You can read the full guide for details on this process.
# Original Dataset for training, sampled at 0.3 for reduced runtimes.
train_df = tutorials_utils.get_uci_data(split='train', sample=0.3)
train_ds = tutorials_utils.df_to_dataset(train_df, batch_size=128)
# Dataset needed to train with MinDiff.
train_with_min_diff_ds = (
tutorials_utils.get_uci_with_min_diff_dataset(split='train', sample=0.3))
Original Custom Model Customizations
tf.keras.Model
is designed to be easily customized via subclassing. This usually involves changing what happens in the call to fit
as described here.
This guide uses a custom implementation where the train_step
closely resembles the default tf.keras.Model.train_step
. Normally, there would be no benefit to doing so, but here, it will help demonstrate how to integrate MinDiff.
class CustomModel(tf.keras.Model):
def train_step(self, data):
# Unpack the data.
x, y, sample_weight = tf.keras.utils.unpack_x_y_sample_weight(data)
with tf.GradientTape() as tape:
y_pred = self(x, training=True) # Forward pass.
loss = self.compiled_loss(
y, y_pred, sample_weight, regularization_losses=self.losses)
# Compute the loss value.
loss = self.compiled_loss(
y, y_pred, sample_weight, regularization_losses=self.losses)
# Compute gradients and update weights.
self.optimizer.minimize(loss, self.trainable_variables, tape=tape)
# Update and return metrics.
self.compiled_metrics.update_state(y, y_pred, sample_weight)
return {m.name: m.result() for m in self.metrics}
Train the model as you would a typical Model
using the Functional API.
model = tutorials_utils.get_uci_model(model_class=CustomModel) # Use CustomModel.
model.compile(optimizer='adam', loss='binary_crossentropy')
_ = model.fit(train_ds, epochs=1)
Integrating MinDiff directly into your model
Adding MinDiff to the train_step
To integrate MinDiff, you will need to add some lines to the CustomModel
which is renamed here as CustomModelWithMinDiff
.
For clarity, this guide uses a boolean flag called apply_min_diff
. All of the code relevant to MinDiff will only be run if it is set to True
. If set to False
then the model would behave exactly the same as CustomModel
.
min_diff_loss_fn = min_diff.losses.MMDLoss() # Hard coded for convenience.
min_diff_weight = 2 # Arbitrary number for example, hard coded for convenience.
apply_min_diff = True # Flag to help show where the additional lines are.
class CustomModelWithMinDiff(tf.keras.Model):
def train_step(self, data):
# Unpack the data.
x, y, sample_weight = tf.keras.utils.unpack_x_y_sample_weight(data)
# Unpack the MinDiff data.
if apply_min_diff:
min_diff_data = min_diff.keras.utils.unpack_min_diff_data(x)
min_diff_x, membership, min_diff_sample_weight = (
tf.keras.utils.unpack_x_y_sample_weight(min_diff_data))
x = min_diff.keras.utils.unpack_original_inputs(x)
with tf.GradientTape() as tape:
y_pred = self(x, training=True) # Forward pass.
loss = self.compiled_loss(
y, y_pred, sample_weight, regularization_losses=self.losses)
# Compute the loss value.
loss = self.compiled_loss(
y, y_pred, sample_weight, regularization_losses=self.losses)
# Calculate and add the min_diff_loss. This must be done within the scope
# of tf.GradientTape().
if apply_min_diff:
min_diff_predictions = self(min_diff_x, training=True)
min_diff_loss = min_diff_weight * min_diff_loss_fn(
min_diff_predictions, membership, min_diff_sample_weight)
loss += min_diff_loss
# Compute gradients and update weights.
self.optimizer.minimize(loss, self.trainable_variables, tape=tape)
# Update and return metrics.
self.compiled_metrics.update_state(y, y_pred, sample_weight)
return {m.name: m.result() for m in self.metrics}
Training with this model looks exactly the same as with the previous with the exception of the dataset used.
model = tutorials_utils.get_uci_model(model_class=CustomModelWithMinDiff)
model.compile(optimizer='adam', loss='binary_crossentropy')
_ = model.fit(train_with_min_diff_ds, epochs=1)
Reshaping your input (optional)
Given that this approach provides full control, you can take this opportunity to reshape the input into a slightly cleaner form. When using MinDiffModel
, the min_diff_data
needs to be packed into the first component of every batch. This is the case with the train_with_min_diff_ds
dataset.
for x, y in train_with_min_diff_ds.take(1):
print('Type of x:', type(x)) # MinDiffPackedInputs
print('Type of y:', type(y)) # Tensor (original labels)
With this requirement lifted, you can reorganize the data in a slightly more intuitive structure with the original and MinDiff data cleanly separated.
def _reformat_input(inputs, original_labels):
min_diff_data = min_diff.keras.utils.unpack_min_diff_data(inputs)
original_inputs = min_diff.keras.utils.unpack_original_inputs(inputs)
original_data = (original_inputs, original_labels)
return {
'min_diff_data': min_diff_data,
'original_data': original_data}
customized_train_with_min_diff_ds = train_with_min_diff_ds.map(_reformat_input)
This step is completely optional but can be useful to better organize the data. If you do so, the only difference in how you implement CustomModelWithMinDiff
will be how you unpack data
at the beginning.
class CustomModelWithMinDiff(tf.keras.Model):
def train_step(self, data):
# Unpack the MinDiff data from the custom structure.
if apply_min_diff:
min_diff_data = data['min_diff_data']
min_diff_x, membership, min_diff_sample_weight = (
tf.keras.utils.unpack_x_y_sample_weight(min_diff_data))
data = data['original_data']
... # possible preprocessing or validation on data before unpacking.
x, y, sample_weight = tf.keras.utils.unpack_x_y_sample_weight(data)
...
With this last step, you can fully control both the input format and how it is used within the model to apply MinDiff.
Additional Resources
- For an in depth discussion on fairness evaluation see the Fairness Indicators guidance
- For general information on Remediation and MinDiff, see the remediation overview.
- For details on requirements surrounding MinDiff see this guide.
- To see an end-to-end tutorial on using MinDiff in Keras, see this tutorial.