In a real-world application of federated learning, the raw training data is typically distributed across many devices or data silos -- requiring special preprocessing and loading before it's usable.
This tutorial describes how to load examples stored in those remote locations with TFF's DataBackend
and DataExecutor
interfaces, and use them to train a model using federated learning. We'll demonstrate the use of data loading APIs by using a training dataset stored locally and simulate the sampling of examples as if the dataset were partitioned over distinct remote clients. As you adapt this tutorial to your use case, you will simply swap that dataset with your own distributed data.
If you're new to federated learning or TFF, consider reading Federated Learning for Image Classification for a primer.
View on TensorFlow.org | Run in Google Colab | View source on GitHub | Download notebook |
Before we start
Before we start, please run the following to make sure that your environment is correctly setup. Refer to the Installation guide for more information.
Set up open-source environment
!pip install --quiet --upgrade tensorflow-federated
!pip install --quiet --upgrade nest-asyncio
import nest_asyncio
nest_asyncio.apply()
Import packages
import collections
import random
from typing import Any, List, Optional, Sequence
import numpy as np
import tensorflow as tf
import tensorflow_federated as tff
np.random.seed(0)
Preparing the input data
Let's begin by loading TFF's federated version of the EMNIST dataset from the built-in repository:
emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()
And construct a preprocessing function to transform the raw examples in the EMNIST dataset.
NUM_EPOCHS = 5
SHUFFLE_BUFFER = 100
def preprocess(dataset):
def map_fn(element):
# Rename the features from `pixels` and `label`, to `x` and `y` for use with
# Keras.
return collections.OrderedDict(
# Transform each `28x28` image into a `784`-element array.
x=tf.reshape(element['pixels'], [-1, 784]),
y=tf.reshape(element['label'], [-1, 1]))
# Shuffle the individual examples and `repeat` over several epochs.
return dataset.repeat(NUM_EPOCHS).shuffle(SHUFFLE_BUFFER, seed=1).map(map_fn)
Let's verify this works:
# The local dataset corresponding to a single client as tf.data.Dataset.
example_dataset = emnist_train.create_tf_dataset_for_client(
emnist_train.client_ids[0])
preprocessed_example_dataset = preprocess(example_dataset)
print(preprocessed_example_dataset)
<MapDataset element_spec=OrderedDict([('x', TensorSpec(shape=(1, 784), dtype=tf.float32, name=None)), ('y', TensorSpec(shape=(1, 1), dtype=tf.int32, name=None))])>
Next, we'll construct an implementation of DataBackend
that will load and preprocess the local examples from clients in the EMNIST dataset, which is crucial for fetching trainable examples during federated learning.
Defining how to fetch client data
We need an instance of DataBackend
to instruct the TFF workers how to load and tranform the local data.
TFF workers are the processes that run on edge machines and perform the work for a single or multiple logical clients. In this example, the EMNIST dataset we'll use for training is already partitioned by logical clients and all the workers are going to be running in the same local environment. So our DataBackend
can reference the data corresponding to any client. But in a non-experimental setting, the TFF workers will be distributed over individual remote machines, each mapping to a distinct set of clients, and you need to ensure that the DataBackend
can correctly resolve data references according to its local context.
# A `DataBackend` is a programmatic construct that resolves symbolic references,
# represented as application-specific URIs, to materialized examples that
# TFF operations can process.
class MyDataBackend(tff.framework.DataBackend):
async def materialize(self, data, type_spec):
# In this example, the URI contains the Id of a client.
client_id = int(data.uri[-1])
# The client Id is used to retrieve the corresponding local data.
client_dataset = emnist_train.create_tf_dataset_for_client(
emnist_train.client_ids[client_id])
# We process the client dataset before returning so its compatible with our
# model definitions.
return preprocess(client_dataset)
Setting up the runtime environment
TFF computations are invoked by an ExecutionContext
and in order for data URIs defined in TFF computations to be understood at runtime, a custom context must be defined for TFF workers that includes a pointer to the DataBackend
we just created, so URIs can be properly resolved.
def ex_fn(device: tf.config.LogicalDevice) -> tff.framework.DataExecutor:
# A `DataBackend` object is wrapped by a `DataExecutor`, which queries the
# backend when a TFF worker encounters an operation requires fetching local
# data.
return tff.framework.DataExecutor(
tff.framework.EagerTFExecutor(device), data_backend=MyDataBackend())
# In a distributed setting, this needs to run in the TFF worker as a service
# connecting to some port. The top-level controller feeding TFF computations
# would then connect to this port.
factory = tff.framework.local_executor_factory(leaf_executor_fn=ex_fn)
ctx = tff.framework.SyncExecutionContext(executor_fn=factory)
tff.framework.set_default_context(ctx)
Training the model
Now we are ready to train a model using federated learning. Lets define a Keras model:
def create_keras_model():
return tf.keras.models.Sequential([
tf.keras.layers.InputLayer(input_shape=(784,)),
tf.keras.layers.Dense(10, kernel_initializer='zeros'),
tf.keras.layers.Softmax(),
])
def model_fn():
keras_model = create_keras_model()
return tff.learning.from_keras_model(
keras_model,
input_spec=preprocessed_example_dataset.element_spec,
loss=tf.keras.losses.SparseCategoricalCrossentropy(),
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
We can pass this TFF-wrapped definition of our model
to a Federated Averaging algorithm by invoking the helper
function tff.learning.algorithms.build_weighted_fed_avg
, as follows:
iterative_process = tff.learning.algorithms.build_weighted_fed_avg(
model_fn,
client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.02),
server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0))
state = iterative_process.initialize()
The initialize
computation returns the initial state of the
Federated Averaging process.
To run a round of training, we need to construct a sample of data by collecting a sample of URI references as follows:
NUM_CLIENTS = 10
element_type = tff.types.StructWithPythonType(
preprocessed_example_dataset.element_spec,
container_type=collections.OrderedDict)
dataset_type = tff.types.SequenceType(element_type)
round_data_uris = [f'uri://{i}' for i in range(NUM_CLIENTS)]
round_train_data = tff.framework.CreateDataDescriptor(
arg_uris=round_data_uris, arg_type=dataset_type)
Now we can round a round of training:
result = iterative_process.next(state, round_train_data)
state = result.state
metrics = result.metrics
print('round 1, metrics={}'.format(metrics))
round 1, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.11234568), ('loss', 11.965633), ('num_examples', 4860), ('num_batches', 4860)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())])
Training over multiple rounds
We can define a FederatedDataSource
container for selecting clients and assembling the inputs for retrieving local data. This makes it convenient to loop over multiple rounds of training, and can be reused across multiple training jobs.
class MyFederatedDataSourceIterator(tff.program.FederatedDataSourceIterator):
def __init__(self, client_ids: Sequence[str],
federated_type: tff.FederatedType):
self._client_ids = client_ids
self._federated_type = federated_type
@property
def federated_type(self) -> tff.FederatedType:
return self._federated_type
def select(self, num_clients: Optional[int] = None) -> Any:
client_ids_sample = random.sample(self._client_ids, num_clients)
data_uris = [f'uri://{i}' for i in client_ids_sample]
return tff.framework.CreateDataDescriptor(
arg_uris=data_uris, arg_type=self._federated_type)
class MyFederatedDataSource(tff.program.FederatedDataSource):
def __init__(self, client_ids: Sequence[str],
federated_type: tff.FederatedType):
self._client_ids = client_ids
self._federated_type = federated_type
self._capabilities = [tff.program.Capability.RANDOM_UNIFORM]
@property
def federated_type(self) -> tff.FederatedType:
return self._federated_type
@property
def capabilities(self) -> List[tff.program.Capability]:
return self._capabilities
def iterator(self) -> tff.program.FederatedDataSourceIterator:
return MyFederatedDataSourceIterator(self._client_ids, self._federated_type)
train_data_source = MyFederatedDataSource(
client_ids=emnist_train.client_ids, federated_type=dataset_type)
train_data_iterator = train_data_source.iterator()
Now we can run our federated learning training loop like so:
NUM_ROUNDS = 10
for round_num in range(2, NUM_ROUNDS + 1):
round_train_data = train_data_iterator.select(NUM_CLIENTS)
result = iterative_process.next(state, round_train_data)
state = result.state
metrics = result.metrics
print('round {:2d}, metrics={}'.format(round_num, metrics))
round 2, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.12357217), ('loss', 9.161968), ('num_examples', 4815), ('num_batches', 4815)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())]) round 3, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.20563674), ('loss', 7.0862083), ('num_examples', 4790), ('num_batches', 4790)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())]) round 4, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.30241227), ('loss', 5.6945825), ('num_examples', 4560), ('num_batches', 4560)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())]) round 5, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.3867347), ('loss', 4.7210026), ('num_examples', 4900), ('num_batches', 4900)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())]) round 6, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.42311886), ('loss', 4.205554), ('num_examples', 4585), ('num_batches', 4585)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())]) round 7, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.4501548), ('loss', 4.1297464), ('num_examples', 4845), ('num_batches', 4845)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())]) round 8, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.56590474), ('loss', 2.8927681), ('num_examples', 5250), ('num_batches', 5250)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())]) round 9, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.59917355), ('loss', 2.7431731), ('num_examples', 4840), ('num_batches', 4840)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())]) round 10, metrics=OrderedDict([('distributor', ()), ('client_work', OrderedDict([('train', OrderedDict([('sparse_categorical_accuracy', 0.5717234), ('loss', 2.9738288), ('num_examples', 4845), ('num_batches', 4845)]))])), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', ())])
Conclusion
This concludes the tutorial. We encourage you to explore the other tutorials we've developed to learn about the many other features of the TFF framework.