![]() |
![]() |
![]() |
![]() |
Given an image like the example below, your goal is to generate a caption such as "a surfer riding on a wave".
![]() |
A man surfing, from wikimedia |
---|
The model architecture used here is inspired by Show, Attend and Tell: Neural Image Caption Generation with Visual Attention, but has been updated to use a 2-layer Transformer-decoder. To get the most out of this tutorial you should have some experience with text generation, seq2seq models & attention, or transformers.
The model architecture built in this tutorial is shown below. Features are extracted from the image, and passed to the cross-attention layers of the Transformer-decoder.
The model architecture |
---|
![]() |
The transformer decoder is mainly built from attention layers. It uses self-attention to process the sequence being generated, and it uses cross-attention to attend to the image.
By inspecting the attention weights of the cross attention layers you will see what parts of the image the model is looking at as it generates words.
This notebook is an end-to-end example. When you run the notebook, it downloads a dataset, extracts and caches the image features, and trains a decoder model. It then uses the model to generate captions on new images.
Setup
apt install --allow-change-held-packages libcudnn8=8.6.0.163-1+cuda11.8
pip uninstall -y tensorflow estimator keras
pip install -U tensorflow_text tensorflow tensorflow_datasets
pip install einops
This tutorial uses lots of imports, mostly for loading the dataset(s).
[Optional] Data handling
This section downloads a captions dataset and prepares it for training. It tokenizes the input text, and caches the results of running all the images through a pretrained feature-extractor model. It's not critical to understand everything in this section.
Data ready for training
After those preprocessing steps, here are the datasets:
train_ds = load_dataset('train_cache')
test_ds = load_dataset('test_cache')
train_ds.element_spec
The dataset now returns (input, label)
pairs suitable for training with keras. The inputs
are (images, input_tokens)
pairs. The images
have been processed with the feature-extractor model. For each location in the input_tokens
the model looks at the text so far and tries to predict the next which is lined up at the same location in the labels
.
for (inputs, ex_labels) in train_ds.take(1):
(ex_img, ex_in_tok) = inputs
print(ex_img.shape)
print(ex_in_tok.shape)
print(ex_labels.shape)
The input tokens and the labels are the same, just shifted by 1 step:
print(ex_in_tok[0].numpy())
print(ex_labels[0].numpy())
A Transformer decoder model
This model assumes that the pretrained image encoder is sufficient, and just focuses on building the text decoder. This tutorial uses a 2-layer Transformer-decoder.
The implementations are almost identical to those in the Transformers tutorial. Refer back to it for more details.
The Transformer encoder and decoder. |
---|
![]() |
The model will be implemented in three main parts:
- Input - The token embedding and positional encoding (
SeqEmbedding
). - Decoder - A stack of transformer decoder layers (
DecoderLayer
) where each contains:- A causal self attention later (
CausalSelfAttention
), where each output location can attend to the output so far. - A cross attention layer (
CrossAttention
) where each output location can attend to the input image. - A feed forward network (
FeedForward
) layer which further processes each output location independently.
- A causal self attention later (
- Output - A multiclass-classification over the output vocabulary.
Input
The input text has already been split up into tokens and converted to sequences of IDs.
Remember that unlike a CNN or RNN the Transformer's attention layers are invariant to the order of the sequence. Without some positional input, it just sees an unordered set not a sequence. So in addition to a simple vector embedding for each token ID, the embedding layer will also include an embedding for each position in the sequence.
The SeqEmbedding
layer defined below:
- It looks up the embedding vector for each token.
- It looks up an embedding vector for each sequence location.
- It adds the two together.
- It uses
mask_zero=True
to initialize the keras-masks for the model.
class SeqEmbedding(tf.keras.layers.Layer):
def __init__(self, vocab_size, max_length, depth):
super().__init__()
self.pos_embedding = tf.keras.layers.Embedding(input_dim=max_length, output_dim=depth)
self.token_embedding = tf.keras.layers.Embedding(
input_dim=vocab_size,
output_dim=depth,
mask_zero=True)
self.add = tf.keras.layers.Add()
def call(self, seq):
seq = self.token_embedding(seq) # (batch, seq, depth)
x = tf.range(tf.shape(seq)[1]) # (seq)
x = x[tf.newaxis, :] # (1, seq)
x = self.pos_embedding(x) # (1, seq, depth)
return self.add([seq,x])
Decoder
The decoder is a standard Transformer-decoder, it contains a stack of DecoderLayers
where each contains three sublayers: a CausalSelfAttention
, a CrossAttention
, and aFeedForward
. The implementations are almost identical to the Transformer tutorial, refer to it for more details.
The CausalSelfAttention
layer is below:
class CausalSelfAttention(tf.keras.layers.Layer):
def __init__(self, **kwargs):
super().__init__()
self.mha = tf.keras.layers.MultiHeadAttention(**kwargs)
# Use Add instead of + so the keras mask propagates through.
self.add = tf.keras.layers.Add()
self.layernorm = tf.keras.layers.LayerNormalization()
def call(self, x):
attn = self.mha(query=x, value=x,
use_causal_mask=True)
x = self.add([x, attn])
return self.layernorm(x)
The CrossAttention
layer is below. Note the use of return_attention_scores
.
class CrossAttention(tf.keras.layers.Layer):
def __init__(self,**kwargs):
super().__init__()
self.mha = tf.keras.layers.MultiHeadAttention(**kwargs)
self.add = tf.keras.layers.Add()
self.layernorm = tf.keras.layers.LayerNormalization()
def call(self, x, y, **kwargs):
attn, attention_scores = self.mha(
query=x, value=y,
return_attention_scores=True)
self.last_attention_scores = attention_scores
x = self.add([x, attn])
return self.layernorm(x)
The FeedForward
layer is below. Remember that a layers.Dense
layer is applied to the last axis of the input. The input will have a shape of (batch, sequence, channels)
, so it automatically applies pointwise across the batch
and sequence
axes.
class FeedForward(tf.keras.layers.Layer):
def __init__(self, units, dropout_rate=0.1):
super().__init__()
self.seq = tf.keras.Sequential([
tf.keras.layers.Dense(units=2*units, activation='relu'),
tf.keras.layers.Dense(units=units),
tf.keras.layers.Dropout(rate=dropout_rate),
])
self.layernorm = tf.keras.layers.LayerNormalization()
def call(self, x):
x = x + self.seq(x)
return self.layernorm(x)
Next arrange these three layers into a larger DecoderLayer
. Each decoder layer applies the three smaller layers in sequence. After each sublayer the shape of out_seq
is (batch, sequence, channels)
. The decoder layer also returns the attention_scores
for later visualizations.
class DecoderLayer(tf.keras.layers.Layer):
def __init__(self, units, num_heads=1, dropout_rate=0.1):
super().__init__()
self.self_attention = CausalSelfAttention(num_heads=num_heads,
key_dim=units,
dropout=dropout_rate)
self.cross_attention = CrossAttention(num_heads=num_heads,
key_dim=units,
dropout=dropout_rate)
self.ff = FeedForward(units=units, dropout_rate=dropout_rate)
def call(self, inputs, training=False):
in_seq, out_seq = inputs
# Text input
out_seq = self.self_attention(out_seq)
out_seq = self.cross_attention(out_seq, in_seq)
self.last_attention_scores = self.cross_attention.last_attention_scores
out_seq = self.ff(out_seq)
return out_seq
Output
At minimum the output layer needs a layers.Dense
layer to generate logit-predictions for each token at each location.
But there are a few other features you can add to make this work a little better:
Handle bad tokens: The model will be generating text. It should never generate a pad, unknown, or start token (
''
,'[UNK]'
,'[START]'
). So set the bias for these to a large negative value.Smart initialization: The default initialization of a dense layer will give a model that initially predicts each token with almost uniform likelihood. The actual token distribution is far from uniform. The optimal value for the initial bias of the output layer is the log of the probability of each token. So include an
adapt
method to count the tokens and set the optimal initial bias. This reduces the initial loss from the entropy of the uniform distribution (log(vocabulary_size)
) to the marginal entropy of the distribution (-p*log(p)
).
The smart initialization will significantly reduce the initial loss:
output_layer = TokenOutput(tokenizer, banned_tokens=('', '[UNK]', '[START]'))
# This might run a little faster if the dataset didn't also have to load the image data.
output_layer.adapt(train_ds.map(lambda inputs, labels: labels))
Build the model
To build the model, you need to combine several parts:
- The image
feature_extractor
and the texttokenizer
and. - The
seq_embedding
layer, to convert batches of token-IDs to vectors(batch, sequence, channels)
. - The stack of
DecoderLayers
layers that will process the text and image data. - The
output_layer
which returns a pointwise prediction of what the next word should be.
class Captioner(tf.keras.Model):
@classmethod
def add_method(cls, fun):
setattr(cls, fun.__name__, fun)
return fun
def __init__(self, tokenizer, feature_extractor, output_layer, num_layers=1,
units=256, max_length=50, num_heads=1, dropout_rate=0.1):
super().__init__()
self.feature_extractor = feature_extractor
self.tokenizer = tokenizer
self.word_to_index = tf.keras.layers.StringLookup(
mask_token="",
vocabulary=tokenizer.get_vocabulary())
self.index_to_word = tf.keras.layers.StringLookup(
mask_token="",
vocabulary=tokenizer.get_vocabulary(),
invert=True)
self.seq_embedding = SeqEmbedding(
vocab_size=tokenizer.vocabulary_size(),
depth=units,
max_length=max_length)
self.decoder_layers = [
DecoderLayer(units, num_heads=num_heads, dropout_rate=dropout_rate)
for n in range(num_layers)]
self.output_layer = output_layer
When you call the model, for training, it receives an image, txt
pair. To make this function more usable, be flexible about the input:
- If the image has 3 channels run it through the feature_extractor. Otherwise assume that it has been already. Similarly
- If the text has dtype
tf.string
run it through the tokenizer.
After that running the model is only a few steps:
- Flatten the extracted image features, so they can be input to the decoder layers.
- Look up the token embeddings.
- Run the stack of
DecoderLayer
s, on the image features and text embeddings. - Run the output layer to predict the next token at each position.
@Captioner.add_method
def call(self, inputs):
image, txt = inputs
if image.shape[-1] == 3:
# Apply the feature-extractor, if you get an RGB image.
image = self.feature_extractor(image)
# Flatten the feature map
image = einops.rearrange(image, 'b h w c -> b (h w) c')
if txt.dtype == tf.string:
# Apply the tokenizer if you get string inputs.
txt = tokenizer(txt)
txt = self.seq_embedding(txt)
# Look at the image
for dec_layer in self.decoder_layers:
txt = dec_layer(inputs=(image, txt))
txt = self.output_layer(txt)
return txt
model = Captioner(tokenizer, feature_extractor=mobilenet, output_layer=output_layer,
units=256, dropout_rate=0.5, num_layers=2, num_heads=2)
Generate captions
Before getting into training, write a bit of code to generate captions. You'll use this to see how training is progressing.
Start by downloading a test image:
image_url = 'https://tensorflow.org/images/surf.jpg'
image_path = tf.keras.utils.get_file('surf.jpg', origin=image_url)
image = load_image(image_path)
To caption an image with this model:
- Extract the
img_features
- Initialize the list of output tokens with a
[START]
token. - Pass
img_features
andtokens
into the model.- It returns a list of logits.
- Choose the next token based on those logits.
- Add it to the list of tokens, and continue the loop.
- If it generates an
'[END]'
token, break out of the loop.
So add a "simple" method to do just that:
@Captioner.add_method
def simple_gen(self, image, temperature=1):
initial = self.word_to_index([['[START]']]) # (batch, sequence)
img_features = self.feature_extractor(image[tf.newaxis, ...])
tokens = initial # (batch, sequence)
for n in range(50):
preds = self((img_features, tokens)).numpy() # (batch, sequence, vocab)
preds = preds[:,-1, :] #(batch, vocab)
if temperature==0:
next = tf.argmax(preds, axis=-1)[:, tf.newaxis] # (batch, 1)
else:
next = tf.random.categorical(preds/temperature, num_samples=1) # (batch, 1)
tokens = tf.concat([tokens, next], axis=1) # (batch, sequence)
if next[0] == self.word_to_index('[END]'):
break
words = index_to_word(tokens[0, 1:-1])
result = tf.strings.reduce_join(words, axis=-1, separator=' ')
return result.numpy().decode()
Here are some generated captions for that image, the model's untrained, so they don't make much sense yet:
for t in (0.0, 0.5, 1.0):
result = model.simple_gen(image, temperature=t)
print(result)
The temperature parameter allows you to interpolate between 3 modes:
- Greedy decoding (
temperature=0.0
) - Chooses the most likely next token at each step. - Random sampling according to the logits (
temperature=1.0
). - Uniform random sampling (
temperature >> 1.0
).
Since the model is untrained, and it used the frequency-based initialization, the "greedy" output (first) usually only contains the most common tokens: ['a', '.', '[END]']
.
Train
To train the model you'll need several additional components:
- The Loss and metrics
- The Optimizer
- Optional Callbacks
Losses and metrics
Here's an implementation of a masked loss and accuracy:
When calculating the mask for the loss, note the loss < 1e8
. This term discards the artificial, impossibly high losses for the banned_tokens
.
def masked_loss(labels, preds):
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels, preds)
mask = (labels != 0) & (loss < 1e8)
mask = tf.cast(mask, loss.dtype)
loss = loss*mask
loss = tf.reduce_sum(loss)/tf.reduce_sum(mask)
return loss
def masked_acc(labels, preds):
mask = tf.cast(labels!=0, tf.float32)
preds = tf.argmax(preds, axis=-1)
labels = tf.cast(labels, tf.int64)
match = tf.cast(preds == labels, mask.dtype)
acc = tf.reduce_sum(match*mask)/tf.reduce_sum(mask)
return acc
Callbacks
For feedback during training setup a keras.callbacks.Callback
to generate some captions for the surfer image at the end of each epoch.
class GenerateText(tf.keras.callbacks.Callback):
def __init__(self):
image_url = 'https://tensorflow.org/images/surf.jpg'
image_path = tf.keras.utils.get_file('surf.jpg', origin=image_url)
self.image = load_image(image_path)
def on_epoch_end(self, epochs=None, logs=None):
print()
print()
for t in (0.0, 0.5, 1.0):
result = self.model.simple_gen(self.image, temperature=t)
print(result)
print()
It generates three output strings, like the earlier example, like before the first is "greedy", choosing the argmax of the logits at each step.
g = GenerateText()
g.model = model
g.on_epoch_end(0)
Also use callbacks.EarlyStopping
to terminate training when the model starts to overfit.
callbacks = [
GenerateText(),
tf.keras.callbacks.EarlyStopping(
patience=5, restore_best_weights=True)]
Train
Configure and execute the training.
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
loss=masked_loss,
metrics=[masked_acc])
For more frequent reporting, use the Dataset.repeat()
method, and set the steps_per_epoch
and validation_steps
arguments to Model.fit
.
With this setup on Flickr8k
a full pass over the dataset is 900+ batches, but below the reporting-epochs are 100 steps.
history = model.fit(
train_ds.repeat(),
steps_per_epoch=100,
validation_data=test_ds.repeat(),
validation_steps=20,
epochs=100,
callbacks=callbacks)
Plot the loss and accuracy over the training run:
plt.plot(history.history['loss'], label='loss')
plt.plot(history.history['val_loss'], label='val_loss')
plt.ylim([0, max(plt.ylim())])
plt.xlabel('Epoch #')
plt.ylabel('CE/token')
plt.legend()
plt.plot(history.history['masked_acc'], label='accuracy')
plt.plot(history.history['val_masked_acc'], label='val_accuracy')
plt.ylim([0, max(plt.ylim())])
plt.xlabel('Epoch #')
plt.ylabel('CE/token')
plt.legend()
Attention plots
Now, using the trained model, run that simple_gen
method on the image:
result = model.simple_gen(image, temperature=0.0)
result
Split the output back into tokens:
str_tokens = result.split()
str_tokens.append('[END]')
The DecoderLayers
each cache the attention scores for their CrossAttention
layer. The shape of each attention map is (batch=1, heads, sequence, image)
:
attn_maps = [layer.last_attention_scores for layer in model.decoder_layers]
[map.shape for map in attn_maps]
So stack the maps along the batch
axis, then average over the (batch, heads)
axes, while splitting the image
axis back into height, width
:
attention_maps = tf.concat(attn_maps, axis=0)
attention_maps = einops.reduce(
attention_maps,
'batch heads sequence (height width) -> sequence height width',
height=7, width=7,
reduction='mean')
Now you have a single attention map, for each sequence prediction. The values in each map should sum to 1.
einops.reduce(attention_maps, 'sequence height width -> sequence', reduction='sum')
So here is where the model was focusing attention while generating each token of the output:
def plot_attention_maps(image, str_tokens, attention_map):
fig = plt.figure(figsize=(16, 9))
len_result = len(str_tokens)
titles = []
for i in range(len_result):
map = attention_map[i]
grid_size = max(int(np.ceil(len_result/2)), 2)
ax = fig.add_subplot(3, grid_size, i+1)
titles.append(ax.set_title(str_tokens[i]))
img = ax.imshow(image)
ax.imshow(map, cmap='gray', alpha=0.6, extent=img.get_extent(),
clim=[0.0, np.max(map)])
plt.tight_layout()
plot_attention_maps(image/255, str_tokens, attention_maps)
Now put that together into a more usable function:
@Captioner.add_method
def run_and_show_attention(self, image, temperature=0.0):
result_txt = self.simple_gen(image, temperature)
str_tokens = result_txt.split()
str_tokens.append('[END]')
attention_maps = [layer.last_attention_scores for layer in self.decoder_layers]
attention_maps = tf.concat(attention_maps, axis=0)
attention_maps = einops.reduce(
attention_maps,
'batch heads sequence (height width) -> sequence height width',
height=7, width=7,
reduction='mean')
plot_attention_maps(image/255, str_tokens, attention_maps)
t = plt.suptitle(result_txt)
t.set_y(1.05)
run_and_show_attention(model, image)
Try it on your own images
For fun, below you're provided a method you can use to caption your own images with the model you've just trained. Keep in mind, it was trained on a relatively small amount of data, and your images may be different from the training data (so be prepared for strange results!)
image_url = 'https://tensorflow.org/images/bedroom_hrnet_tutorial.jpg'
image_path = tf.keras.utils.get_file(origin=image_url)
image = load_image(image_path)
run_and_show_attention(model, image)