Complete code examples for Machine Translation with Attention, Image Captioning, Text Generation, and DCGAN implemented with tf.keras and eager execution

By Yash Katariya

I’ve always found generative and sequence models fascinating: they ask a different flavor of question than we usually encounter when we first begin studying machine learning. When I first started studying ML, I learned (as many of us do) about classification and regression. These help us ask and answer questions like:

  • Is this a picture of a cat or a dog? (Classification)
  • What’s the probability that it will rain tomorrow? (Regression)

Classification and regression are incredibly useful skills to master, and there’s nearly no limit to the applications of these areas to useful, real-world problems. But, there are other types of questions we might ask, that feel very different.

  • Can we generate a poem? (Text generation)
  • Can we generate a photo of a cat? (GANs)
  • Can we translate a sentence from one language to another? (NMT)
  • Can we generate a caption for an image? (Image captioning)

During my summer internship, I developed examples for these using two of TensorFlow’s latest APIs: tf.keras, and eager execution, and I’ve shared them all below. I hope you find them useful, and fun!

  • Eager execution is an imperative, define-by-run interface where operations are executed immediately as they are called from Python. This makes it easier to get started with TensorFlow, and can make research and development more intuitive.
  • tf.keras is a high-level API for defining models with lego-like building blocks. I implemented these examples using Model subclassing, which allows one to make fully-customizable models by subclassing tf.keras.Model and defining your own forward pass. Model subclassing is particularly useful when eager execution is enabled since the forward pass can be written imperatively.

If you’re new to these APIs, you can learn more about them by exploring the sequence of notebooks on, which contains recently updated examples.

Each of the examples below is end-to-end, and follows a similar pattern:

  1. Automatically download the training data.
  2. Preprocess the training data, and create a dataset for use in our input pipeline.
  3. Define the model using the tf.keras model subclassing API.
  4. Train the model using eager execution.
  5. Demonstrate how to use the trained model.

Example #1: Text Generation

Our first example is for text generation, where we use an RNN to generate text in a similar style to Shakespeare. You can run it on Colaboratory with the link above (or you can also download it as a Jupyter notebook from GitHub). The code is explained in detail in the notebook.

Given a large collection of Shakespeare’s writings, this example learns to generate text that sounds and appears similar stylistically:

Example text generated by the notebook after training for 30 epochs on a collection of Shakespeare’s writing.

While most of the sentences will not make sense (of course, this simple model has not learned the meaning of language), what’s impressive is that most of the words *are* valid, and that the structure of the plays it emits look similar to those from the original text. (This is a character based model, in the short amount of time we’ve trained it — it has successfully learned both of those things from scratch). If you like, you can change the dataset by changing a single line of code.

The best place to learn more about RNNs is Andrej Karpathy’s excellent article, The Unreasonable Effectiveness of Recurrent Neural Networks. If you’d like to learn more about implementing RNNs with Keras or tf.keras, we recommend these notebooks by Francois Chollet.

Example #2: DCGAN

In this example, we generate handwritten digits using DCGAN. A Generative Adversarial Network (GAN) consists of a generator and a discriminator. The job of the generator is to create convincing images so as to fool the discriminator. The job of the discriminator is to classify between real images and fake images (created by the generator). The output you see below is generated after training the generator and discriminator for 150 epochs using the architecture and hyperparameters described in this paper.

GIF of images generated every 10 epochs out of 150 epochs. You can find code to create a GIF like this in the notebook.

Example #3: Neural Machine Translation with Attention

This example trains a model to translate Spanish sentences to English sentences. After training the model, you will be able to input a Spanish sentence, such as “¿todavia estan en casa?”, and return the English translation: “are you still at home?”

The image you see below is the attention plot. It shows which parts of the input sentence has the model’s attention while translating. For example, when the model translated the word “cold”, it was looking at “mucho”, “frio”, “aqui”. We implemented Bahdanau Attention from scratch using tf.keras and eager execution, explained in detail in the notebook. You can also use this implementation as a base for implementing you own custom models.

Attention plot for the above translation.

Example #4: Image Captioning with Attention

In this example, we train our model to predict a caption for an image. We also generate an attention plot, which shows the parts of the image the model focuses on as it generates the caption. For example, the model focuses near the surfboard in the image when it predicts the word “surfboard”. This model is trained using a subset of the MS-COCO dataset, which will be downloaded automatically by the notebook.

Predicted Caption for the image.
Attention plot of each word for the above image.

Next steps

To learn more about tf.keras and eager, keep your eyes on for updated content, and periodically check this blog, and TensorFlow’s twitter feed. Thanks for reading!


Thanks very much to Josh Gordon, Mark Daoust, Alexandre Passos, Asim Shankar, Billy Lamberta, Daniel ‘Wolff’ Dobson, and Francois Chollet for their contributions and help!

Leave a Reply

Your email address will not be published. Required fields are marked *