Vision Transformers: From Idea to Applications (Part Two)
Understanding the Vision Transformer (ViT) architecture...
This newsletter is sponsored by Rebuy, where AI is being used to build the future of e-commerce personalization. If you like this newsletter, please subscribe, share it, or follow me on twitter. Thank you for your support!
this post. This is part two of a six-part series (written by Sairam Sundaresan and myself) that explores the (vision) transformer deep learning architecture and its many, impactful applications. I will write parts two, four, and six as part of Deep (Learning) Focus, while Sairam will release the other parts on his Gradient Ascent newsletter. Read part one here.
The transformer architecture was a revolutionary proposal that drastically improved state-of-the-art on numerous natural language benchmarks. For example, the transformer architecture was the backbone of impactful proposals such as:
Given that the transformer is so successful in the natural language domain (and beyond!), we might start to wonder: can we use the same architecture for computer vision? This question was answered positively in [1] with the proposal of the Vision Transformer (ViT) architecture, an alternative to commonly-used convolutional neural network (CNN) architectures (e.g., ResNets [4]), for image classification.
The Vision Transformer
The ViT architecture is not different from the transformer architecture that we explored in part one of this series. But, we don’t use the full encoder-decoder architecture, and we have to modify our input images a bit to convert them into transformer-compatible format. Let’s explore how this works in more detail.
Encoder-Only Architecture
The ViT architecture is just the encoder portion of the transformer architecture (i.e., an encoder-only transformer); see above. Notably, this is the same architecture that is used for BERT [2]. The architecture is comprised of several blocks, each of which contains two layers:
Each of these layers within a block are followed by layer normalization and a residual connection. Then, several of these blocks are stacked on top of each other to form the full, encoder-only architecture.
what size of model should we use? We increase the number of parameters in this architecture by increasing the model’s width (i.e., dimension of vectors used for self-attention), depth, or number of attention heads. Several different sizes of ViT architectures may be used. Because the model shares the same architecture as BERT, however, we usually inherit the same sizes of models that are popular for BERT (e.g., BERT-base and BERT-large); see below.
Constructing the Input Sequence
We cannot directly pass an image into a transformer encoder. Instead, we need a sequence of vectors! In language applications, this sequence is just the set of (ordered) embeddings that correspond to the tokens within our input. But, how can we form such a token sequence if our input is an image?
The approach followed in [1] is actually quite simple, we just (i) decompose the image into several non-overlapping patches and (ii) linearly project each of these patches to form a vector. This process is illustrated within the figure above. Then, we have a sequence of vectors that is transformer-compatible, where each vector is associated with a patch from the original image.
other details. From here, we must do a few more things to construct our final input sequence for the ViT, such as:
Adding position embeddings to each patch embedding
Adding an extra
[class]
token at the beginning of the sequence
These additions (depicted below) are necessary because self-attention has no way of understanding a patch’s position within a sequence unless we directly inject positional information. Self-attention is just a bunch of matrix multiplications and dot products that do not consider each token’s position!
The [class]
token that we add to the beginning of the sequence is often used as an aggregate representation of the full sequence. We can take the ViT’s final output for this embedding and use it for classification and more.
The Full Architecture
Now that we’ve overviewed a few basic components, the full ViT architecture is incredibly simple to understand. We just construct the input sequence by embedding image patches as described above, pass this sequence into our encoder-only transformer architecture, extract the final [class] embedding, and pass this embedding through a final classification layer; see above.
How do ViTs perform?
After developing an understanding of the ViT architecture, it’s useful to look into the experimental analysis of ViT to see how the model performs on popular benchmarks relative to other, common architectures for image classification. We will now review these results, as well as review a few, notable extensions that have been proposed to fix common problems with the ViT.
Image Classification with ViT
As mentioned previously, authors in [1] test three different sizes of ViTs (Base, Large, and Huge) on several image classification benchmarks; see the table above for these results.
training pipeline. In these experiments, ViT follows a two-step training procedure. The model is first pre-trained over an initial dataset, then fine-tuned on a downstream dataset (i.e., the target dataset on which evaluation occurs). Several different datasets are used for pre-training, including JFT-300M (huge internal dataset from Google), ImageNet-21K (larger version of ImageNet), and ImageNet. The entire pre-training process is supervised. By performing supervised pre-training over a very large dataset, however, we improve the quality of the model after fine-tuning on the downstream dataset.
bigger = better for pre-training. One major finding we can immediately see from the results shown above is that ViTs only perform well (i.e., better than state-of-the-art CNNs) given a sufficiently large pre-training dataset. Without extensive pre-training, ViTs are outperformed by popular CNN architectures. After sufficient pre-training, ViTs perform quite well; see below.
This trend is probably caused by the fact that, while CNNs are naturally invariant to patterns like translation and locality in images, the ViT architecture (and the transformer architecture in general) does not have this inductive bias and must directly learn these invariances from the data. Thus, pre-training over a large dataset may enable the learning of such valuable patterns that are naturally present in CNNs.
Notable Extensions
The ViTs explored in [1] performed well, but they needed something extra (i.e., a lot of pre-training) to work well, whereas CNNs work pretty well without any pre-training. With this in mind, we might begin to wonder whether ViTs are even worth using. Luckily, this pre-training requirement was mostly eliminated in later work, making ViTs a more practical architecture for computer vision.
mitigating pre-training requirements. In [5], authors propose a ViT variant that adds an additional, special token to the model’s input; see above. We can use this token to apply a distillation component to the model’s loss. In particular, a hard distillation (i.e., as opposed to soft distillation) loss is adopted.
What does this mean? Basically, this just means that we have access to a teacher model (typically a CNN) throughout training that already performs well. Then, we use the argmax output (i.e., the predicted classification) of this teacher network as an additional target for training the ViT. Interestingly, this approach achieves impressive levels of performance and does not require the ViT to be extensively pre-trained; see below.
relative positional embeddings. As mentioned before, one downside of the transformer architecture is the lack of translational invariance. If we take the same object and put it in two different spatial locations within an image, then we get different outputs. Intuitively, this is not desirable, as an object’s location does not fundamentally change its semantic meaning.
Translation invariance is hard-coded into a CNN through its use of convolutions. For ViTs, however, we must learn this property from data. To mitigate this problem, we can just use relative positional embeddings, as proposed in [6]. In their original form, transformers use global positional embeddings, which are not translationally invariant. Put simply, this means that a unique, additive positional embedding is associated with each position in the model’s input sequence.
As an alternative to global positional embeddings, we can use relative positional embeddings that are based upon the distance between two tokens in the input sequence, as opposed to each token’s global position. Such an approach is translationally invariant because it only depends on the relative distance between tokens; see above. The resulting model tends to perform slightly better; e.g., see evaluations in [7] and the results below from the original proposal of relative positional embeddings in [6].
Conclusion
Within this overview, we have learned about the ViT architecture, overviewed its performance, and covered some notable extensions that can help to improve it. The main takeaways are outlined below.
ViTs are encoder-only transformers. The ViT architecture is an encoder-only transformer that is quite similar to BERT. To make this model compatible with image inputs, we can just separate the image into patches, then embed these patches (using a linear transformation) to create a sequence of input vectors that represent each patch. Then, we can train the network normally on image classification tasks by passing the transformer’s output into a separate classification module.
how much pre-training do we need? Originally, the ViT needed to be pre-trained extensively to surpass CNN performance. This is likely due to the fact that the transformer architecture lacks useful inductive biases that CNNs have naturally. However, later research mitigated this requirement by finding novel training procedures (based on knowledge distillation) that enable ViTs to perform well without being too data hungry!
how should we encode position? Usually, transformers encode the position of each token using a global, additive embedding. Since this method is not translationally invariant, we can adopt a relative positional encoding scheme for ViTs to make the model more translationally invariant and improve its performance.
ViTs are here to stay. Although it took some time for transformer-based models to match or surpass CNN performance in computer vision tasks, they are now one of the most commonly-used architectures. Modern ViTs can match the efficiency of CNNs, while maintaining or exceeding their performance.
New to the newsletter?
Hello! I am Cameron R. Wolfe, a research scientist at Alegion and PhD student at Rice University. I study the empirical and theoretical foundations of deep learning.
This is the Deep (Learning) Focus newsletter, where I help readers to build a deeper understanding of topics in deep learning research via understandable overviews of popular papers on that topic. If you like this newsletter, please subscribe, share it with your friends, or follow me on twitter!
This post is part two of a six-part letter series on Vision Transformers that I wrote with Sairam Sundaresan. Please subscribe to his Gradient Ascent newsletter as well!
Bibliography
[1] Dosovitskiy, Alexey, et al. "An image is worth 16x16 words: Transformers for image recognition at scale."Â arXiv preprint arXiv:2010.11929Â (2020).
[2] Devlin, Jacob, et al. "Bert: Pre-training of deep bidirectional transformers for language understanding."Â arXiv preprint arXiv:1810.04805Â (2018).
[3] Vaswani, Ashish, et al. "Attention is all you need." Advances in neural information processing systems 30 (2017).
[4] He, Kaiming, et al. "Deep residual learning for image recognition."Â Proceedings of the IEEE conference on computer vision and pattern recognition. 2016.
[5] Touvron, Hugo, et al. "Training data-efficient image transformers & distillation through attention."Â International Conference on Machine Learning. PMLR, 2021.
[6] Shaw, Peter, Jakob Uszkoreit, and Ashish Vaswani. "Self-attention with relative position representations."Â arXiv preprint arXiv:1803.02155Â (2018).
[7] Li, Yanghao, et al. "MViTv2: Improved Multiscale Vision Transformers for Classification and Detection."Â Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2022.