Universal Transformers

Thanks to Stephan Gouws for his help on writing and improving this blog post.

Transformers have recently become a competitive alternative to RNNs for a range of sequence modeling tasks. They address a significant shortcoming of RNNs, i.e. their inherently sequential computation which prevents parallelization across elements of the input sequence, whilst still addressing the vanishing gradients problem through its self-attention mechanism.

In fact, Transformers rely entirely on a self-attention mechanism to compute a series of context-informed vector-space representations of the symbols in its input (see this blog post to know more about the details of the Transformer).  This leads to two main properties for Transformers:

  • Straightforward to parallelize: There is no connections in time as with RNNs, allowing one to fully parallelize per-symbol computations.
  • Global receptive field: Each symbol’s representation is directly informed by all other symbols’ representations (in contrast to e.g. convolutional architectures which typically have a limited receptive field).

Although Transformers continue to achieve great improvements in many tasks, they have some shortcomings:

  • The Transformer is not Turing Complete: While the Transformer executes a total number of operations that scales with the input size, the number of sequential operations is constant and independent of the input size, determined solely by the number of layers. Assuming finite precision, this means that the Transformer cannot be computationally universal. An intuitive example are functions whose execution requires the sequential processing of each input element. In this case, for any given choice of depth T, one can construct an input sequence of length N > T that cannot be processed correctly by a Transformer:
  • Lack of Conditional Computation: The Transformer applies the same amount of computation to all inputs (as well as all parts of a single input). However, not all inputs need the same amount of computation and this can be conditioned on the complexity of the input.  

Universal Transformers (UTs) address these shortcomings.  In the next parts, we'll talk more about UT and its properties.

Universal Transformer: A Concurrent-Recurrent Sequence Model

The Universal Transformer is an extension to the Transformer models which combines the parallelizability and global receptive field of the Transformer model with the recurrent inductive bias of RNNs, which seems to be better suited to a range of algorithmic and natural language understanding sequence-to-sequence problems. Besides, as the name implies, in contrast to the standard Transformer, under certain assumptions the Universal Transformer can be shown to be computationally universal.

How Does the UT Work?

In the standard Transformer, we have a "fixed" stack of Transformer blocks, where each block is applied to all the input symbols in parallel.  In the Universal Transformer, however, instead of having a fixed number of layers, we iteratively apply a Universal Transformer block (a self-attention mechanism followed by a recurrent transformation) to refine the representations of all positions in the sequence in parallel, during an arbitrary number of steps (which is possible due to the recurrence).

The Universal Transformer encoder. It repeatedly refines a series of vector representations for each position of the sequence in parallel, by combining information from different positions using self-attention and applying a recurrent transition function. Arrows denote dependencies between operations.

In fact,  Universal Transformer is a recurrent function (not in time, but in depth) that evolves per-symbol hidden states in parallel, based at each step on the sequence of previous hidden states. In that sense, UT is similar to architectures such as the Neural GPU and the Neural Turing Machine. This gives UTs the attractive computational efficiency of the original feed-forward Transformer model, but with the added recurrent inductive bias of RNNs.

Note that when running for a fixed number of steps, the Universal Transformer is equivalent to a multi-layer Transformer with tied parameters across its layers.

Universal Transformer with Dynamic Halting

In sequence processing systems, certain symbols (e.g. some words or phonemes) are usually more ambiguous than others. It is, therefore, reasonable to allocate more processing resources to these more ambiguous symbols.

As stated before, the standard Transformer applies the same amount of computations (fixed number of layers) to all symbols in all inputs.  To address this, Universal Transformer with dynamic halting modulates the number of computational steps needed to process each input symbol dynamically based on a scalar pondering value that is predicted by the model at each step. The pondering values are in a sense the model’s estimation of how much further computation is required for the input symbols at each processing step.

Universal Transformer with dynamic halting uses an Adaptive Computation Time (ACT) mechanism, which was originally proposed for RNNS, to enable conditional computation.

More precisely, Universal Transformer with dynamic halting adds a dynamic ACT halting mechanism to each position in the input sequence. Once the per-symbol recurrent block halts (indicating a sufficient number of revisions for that symbol), its state is simply copied to the next step until all blocks halt or we reach a maximum number of steps.  The final output of the encoder is then the final layer of representations produced in this way:

Adaptive Universal Transformer encoder over 6 recurrent steps with early halting (different number of recurrent revisions) per position. The blue dotted arrows show states attended to during encoding of position 2 in different time steps.

Universality and Relation to Other Models

Unlike the standard Transformer --which cannot be computationally universal as the number of sequential operations is constant-- we can choose the number of steps as a function of the input length in the Universal Transformer. This holds independently of whether or not adaptive computation time is employed but does assume a non-constant, even if possibly deterministic, number of steps. Note that varying the number of steps dynamically after training is possible in Universal Transformers since the model shares weights across its sequential computation steps.

Given sufficient memory the Universal Transformer is computationally universal – i.e. it belongs to the class of models that can be used to simulate any Turing machine (You can check this blog post on "What exactly is Turing Completeness?". )

To show this, we can reduce a Neural GPU (which is Turing Complete) to a Universal Transformer: Let's ignore the decoder and parameterize the self-attention module (i.e., self-attention with the residual connection) to be the identity function. Now let’s assume the transition function is a convolution. Then, if we set the total number of recurrent steps T to be equal to the input length, we obtain exactly a Neural-GPU.

Note that the last step is where the Universal Transformer crucially differs from the vanilla Transformer whose depth cannot scale dynamically with the size of the input. A similar relationship exists between the Universal Transformer and the Neural Turing Machine, whose single read/write operations per step can be expressed by the global, parallel representation revisions of the Universal Transformer.

The cool thing about the Universal Transformer is that not only is it theoretically appealing (Turing complete),  but in contrast to other computationally universal models like Neural-GPU which only perform well on algorithmic tasks, the Universal Transformer also achieves competitive results on realistic natural language tasks such as LAMBADA and machine translation. This closes the gap between practical sequence models competitive on large-scale tasks such as machine translation, and computationally universal models like Neural GPUs.

Universal Transformers for Language Understanding and Reasoning

We applied Universal Transformer to a variety of algorithmic tasks and a diverse set of large-scale language understanding tasks.  These tasks were chosen because they are challenging in different aspects. For instance, bAbI question answering and reasoning tasks with 1k training samples require data efficient models that are capable of doing multi-hop reasoning.  Likewise, a set of algorithmic tasks like copy, reverse, addition, etc. are designed to assess the length generalization capabilities of a model (by training on short examples and evaluating on much longer examples). Subject-verb agreement task needs modeling hierarchical structure which requires a recurrent inductive bias.  LAMBADA is a challenging language modeling task that requires capturing a broad context. And finally, MT is a very important large-scale task that is one of the standard benchmarks for evaluating language processing models. Results on all these tasks are reported in the paper.

Here, we just bring some analysis on the bAbI Question-Answering task as an example. In bAbI tasks, the goal is to answer a question given a series of facts forming a story. The goal is to measure various forms of language understanding by requiring a certain type of reasoning over the linguistic facts presented in each story.

A standard Transformer does not achieve good generalization on this task, no matter how much one tunes the hyper-parameters and the model. However, we can design a model based on the Universal Transformer that achieves state-of-the-art (SOTA) results on bAbI. To encode the input,  we first encode each fact in the story by applying a learned multiplicative positional mask to each word’s embedding, and then summing all embeddings. We then embed the questions in the same way, and feed the UT with these embeddings of the facts and questions. Both the UT with/without dynamic halting achieve SOTA results in terms of average error and number of failed tasks, in both the 10K and 1K training regime.

Here is a visualization of the attention distribution over multiple processing steps of UT in one of the examples from the test set in task 2:

An example from tasks 2: (requiring two supportive facts to solve)

Story:

John went to the hallway. 
John went back to the bathroom.
John grabbed the milk there.
Sandra went back to the office.
Sandra journeyed to the kitchen.
Sandra got the apple there. Sandra dropped the apple there. John dropped the milk.

Question:

Where is the milk?

Model's Output:

bathroom

Visualization of the attention distributions, when encoding the question: “Where is the milk?”.

  • Step#1
  • Step#2
  • Step#3
  • Step#4

In this example, and in fact in most of the cases,  the attention distributions start out very uniform, but get progressively sharper (peakier)  in later steps around the correct supporting facts that are required to answer each question, which is indeed very similar to how humans would solve the task (i.e. from coarse to fine).

Here is a visualization of the per-symbol pondering times for a sample input processed by UT with adaptive halting:

Ponder time of Adaptive Universal Transformer for encoding facts in a story and question in a bAbI task requiring three supporting facts.

As can be seen, the network learns to ponder over relevant facts more, compared to the facts in the story that provides no support for the answer to the question.

Recurrent Inductive Bias, Data Efficiency and the Notion of State in Depth

Following intuitions behind weight sharing found in CNNs and RNNs, UTs extend the Transformer with a simple form of weight sharing of the model that strikes an effective balance between inductive bias and model expressivity.

Sharing weights in depth of the network introduces a recurrence into the model. This recurrent inductive bias appears to be crucial for learning generalizable solutions in some tasks, like those that need modeling hierarchical structure of the input,  or capturing dependencies in a broader context. Besides this fact, weight sharing in depth leads to better performance of UTs (compared to the standard Transformer) on very small datasets and allows the UT to be a very data efficient model, making it attractive for domains and tasks with limited available data.

There has been a long track of research on RNNs and many works followed the idea of recurrence in time to improve sequence processing. UT is a recurrent model where the recurrence is in depth, not in time. So there is a notion of state in depth of the model and one of the interesting directions is to take ideas that are worked for RNNs, “flip them vertically” and see if they can help improve the flow of information in depth of the model. For instance, we can introduce memory/state with forget gates in depth by simply using an LSTM as the recurrent transition function:

Many of these ideas are already implemented in and ready to be explored (for instance check the UT with LSTM as the transition function here).

What to know more?

The code used to train and evaluate Universal Transformers can be found here:

https://github.com/tensorflow/tensor2tensor

The code for training as well as attention and ponder time visualization of bAbI tasks can be found here:

https://github.com/MostafaDehghani/bAbI-T2T

For more details about the model as well as results and analysis on all tasks, please take a look at the paper:

  • M. Dehghani, S. Gouws, O. Vinyals, J. Uszkoreit, and L. Kaiser. "Universal Transformers". International Conference on Learning Representations (ICLR'19).