Efficient Transformers

Transformers has garnered immense interest lately due to their effectiveness across a range of domains like language, vision and reinforcement learning.  The self-attention mechanism is a key defining characteristic of Transformer models. The mechanism can be viewed as a graph-like inductive bias that connects all tokens in a sequence with a relevance-based pooling operation. A well-known concern with self-attention is the quadratic time and memory complexity, which can hinder model scalability in many settings.

Recently, a dizzying number of “X-former” models have been proposed, many of which make improvements around computational and memory efficiency. We hereinafter name this class of models “efficient Transformers”. We wrote a survey that sets out to provide a comprehensive overview of the recent advances made in this class of models. In our survey, we propose a taxonomy of efficient Transformer models, characterizing them by the technical innovation and primary use case. We also provide a detailed walk-through of many of these models, including: Memory Compressed, Image Transformer, Set Transformer, Transformer-XL, Sparse Transformer, Reformer, Routing Transformer, Axial Transformer, Compressive Transformer, Sinkhorn Transformer, Longformer, ETC, Synthesizer, Performer, Linformer, Linear Transformers, and Big Bird.

Taxonomy of Efficient Transformer Architectures

We can outline a general taxonomy of efficient Transformer models, charactered by their core techniques and primary use case. This grouping is not only helpful to draw a connection between the existing models, but also to understand the how future research in this direction will embed in the current set of solutions. We can of course come up with completely new ideas that are orthogonal to the existing ones.... which reminds me of Oriol's tweet:

Here, I shortly bring the bucketing we came up with for grouping efficient Xformer models. The primary goal of most of these models, with the exception of those based on segment-based recurrence, is to approximate the quadratic cost attention matrix. Each method applies some notion of sparsity to the otherwise dense attention mechanism.

  • Fixed Patterns : The earliest modifications to self-attention simply sparsifies the attention matrix by limiting the field of view to fixed, predefined patterns such as local windows and block patterns of fixed strides.
    • Blockwise Patterns: The simplest example of this technique in practice is the blockwise (or chunking) paradigm which considers blocks of local receptive fields by chunking input sequences into fixed blocks. Examples of models that do this include Blockwise and/or Local Attention. Chunking input sequences into blocks reduces the complexity from N^2 to B^2 (block size) with B <<< N, significantly reducing the cost. These blockwise or chunking methods serve as a basis for many more complex models.
    • Strided Patterns: Another approach is to consider strided attention patterns, i.e., only attending at fixed intervals. Models such as Sparse Transformer and/or Longformer employ strided or “dilated” windows.
    • Compressed Patterns: Another line of attack here is to use some pooling operator to down-sample the sequence length to be a form of fixed pattern. For instance, Compressed Attention uses strided convolution to effectively reduce the sequence length.
  • Combination of Patterns: The key idea of combined approaches is to improve coverage by combining two or more distinct access patterns. For example, the Sparse Transformer combines strided and local attention by assigning half of its heads to pattern. Similarly, Axial Transformer applies a sequence of self-attention computations given a high dimensional tensor as input, each along a single axis of the input tensor. In essence, the combination of patterns reduces memory complexity in the same way that fixed patterns does. The difference, however, is that the aggregation and combinaton of multiple patterns improves the overall coverage of the self-attention mechanism.
  • Learnable Patterns: An extension to fixed, pre-determined pattern are learnable ones. Unsurprisingly, models using learnable patterns aim to learn the access pattern in a data-driven fashion. A key characteristic of learning patterns is to determine a notion of token relevance and then assign tokens to buckets or clusters. Notably, Reformer introduces a hash-based similarity measure to efficiently cluster tokens into chunks. In a similar vein, the Routing Transformer employs online k-means clustering on the tokens. Meanwhile, the Sinkhorn Sorting Network exposes the sparsity in attention weights by learning to to sort blocks of the input sequence. In all these models, the similarity function is trained end-to-end jointly with the rest of the network. The key idea of learnable patterns is still to exploit fixed patterns (chunked patterns). However, this class of methods learn to sort/cluster the input tokens - enabling a more optimal global view of the sequence while maintaining the efficiency benefits of fixed patterns approaches.
  • Memory: Another prominent method is to leverage a side memory module that can access multiple tokens at once. A common form is global memory which is able to access the entire sequence. The global tokens act as a form of memory that learns to gather from input sequence tokens. This was first introduced in Set Transformers as the inducing points method. These parameters are often interpreted as “memory” and are used as a form of temporary context for future processing. This can be thought of as a form of parameter attention. Global memory is also used in ETC and Longformer. With a limited number of memory (or inducing points), we are able to perform a preliminary pooling like operation of the input sequence to compress the input sequence - a neat trick to have at one’s disposal when designing efficient self-attention modules.
  • Low-Rank Methods: Another emerging technique is to improve efficiency by leveraging low-rank approximations of the self-attention matrix. The key idea is to assume low-rank structure in the N \times N matrix. The Linformer  is a classic example of this technique, as it projects the length dimension of keys and values to a lower-dimensional representation (N to  k). It is easy to see that the low-rank method ameliorates the memory complexity problem of the self-attention because the N \times N matrix is now decomposed to N \times k.
  • Kernels: Another recently popular method to improve the efficiency of Transformers is to view the attention mechanism through kernelization. The usage of kernels enable clever mathematical re-writing of the self-attention mechanism to avoid explicitly computing the N \times N matrix. Since kernels are a form of approximation of the attention matrix, they can be also viewed as a form of low-rank method.
  • Recurrence: A natural extension to the blockwise method is to connect these blocks via recurrence. Transformer-XL proposed a segment-level recurrence mechanism that connects multiple segments and blocks. These models can, in some sense, be viewed as fixed pattern models. However, we decided to create its own category due to its deviation from other block / local approaches.

In our paper, we walk though details of each of these models, explain the computational and memory costs of each in detail, and talk about advantage and disadvantage of each of them them:

Please check the paper and let us know if you have any comments or suggestion.