Vision Transformer: Farewell Convolutions?

Vision Transformer (ViT) is a pure self-attention-based architecture (Transformer) without CNNs. ViT stays as close as possible to the Transformer architecture that was originally designed for text-based tasks. 

One of the most key characteristics of ViT is its extremely simple way of encoding inputs and also using vanilla transformer architecture with no fancy trick. In ViT, we first extract patches from the input image. Then we flatten each patch into a single vector by concatenating the channels and use a linear projection layer to embed patches [1]You can see this "extracting patches plus the linear projection" as applying a convolutional layer with window size and strides of the patch size!. Then we add learnable one-dimensional position embeddings to each patch, then feed this input as a sequence of image patch embedding to a Transformer encoder, similar to the sequence of word embeddings used when applying Transformers to text. 

Vision Transformer architecture, gif from Google AI blog.

ViT, in particular, shows excellent performance when trained on sufficient data and it outperforms comparable state-of-the-art CNN base models with four times fewer computational resources. In our paper, we present the results of ViT when pre trained on ImageNet-21k (14M images, 21k classes) or JFT (300M images, 18k classes) and evaluated in fine-tuning and linear fewshot setup on different downstream datasets, like imagenet. 

In the paper, we studied the impact of the amount of computation involved in pre-training, by training several different ViT and CNN-base models on JFT (with different model sizes and different training durations). As a result, they require varying amounts of compute for training. We observe that, for a given amount of compute, ViT yields better performance than the equivalent CNNs.

I also want to point out two other cool observations we had in ViT:

First, we looked at the learned positional embeddings and we found that ViT is able to recover the 2D structure of the input, while there is no explicit signal about the 2D grid structure of patches and everything is presented to ViT as a sequence. When visualizing positional embedding, we see that each position embedding is most similar to others in the same row and column:

Second, we looked into the average spatial distance between one element attending to another for each transformer block. We observed that ViT makes good use of its global receptive field. While some of the attention heads combine local information, i.e. small attention distances, some others attend to most of the tokens already in the lowest layers, showing that integrating information globally is indeed used by the model:

Does this mean that we will no longer need convolutions? 

Maybe... maybe not!
Every time this question comes up, I think about two main things. First of all, Vanilla Transformers don't have the inductive bias convolutions, which can be extremely helpful in a low data regime, but there are maybe solutions to this, like the idea of "distilling the effect of inductive bias" (please read this super nice blog post by Samira Abnar about this) of CNNs into ViT, as DeiT does it. 
Second of all, I look back to how Transformers took away the whole filed off NLP and there is only a small percentage of new NLP papers that are based on other classes of models, like LSTM. This is of course related to many factors other than the fact that Transformers are powerful, like the hypothesis that is explained in the hardware lottery, or the echo chamber effect that is caused by the success of Transformers in many tasks and benchmarks.

All in all, no one can predict the future, but the trend seems to be in favour of more and more employing Transformers in computer vision tasks,  and in a  near future, we will see Transformer based models in many setups and vision related tasks. 

To know more about Vision Transformer, please check our paper:

Or read this Google AI blog post about Vision Transformer. Btw, we open-sourced both the code and model to foster additional research in this area.


1 You can see this "extracting patches plus the linear projection" as applying a convolutional layer with window size and strides of the patch size!