Almost all natural language processing tasks, ranging from language modeling and masked word prediction to translation and question answering, were revolutionized when the Transformer architecture made its debut in 2017. It took no more than 2 or 3 years for the Transformers to also excel in computer vision tasks. In this story, we explore two fundamental architectures that allowed transformers to break into the world of computer vision.
table of Contents
· The vision transformer
∘ Key idea
∘ Operation
∘ Hybrid architecture
∘ Loss of structure
∘ Results
∘ Self-supervised learning using masking
· Vision transformer with masked autoencoder
∘ Key idea
∘ Architecture
∘ Final observation and example
Key idea
The vision transformer simply aims to generalize the standard transformer Architecture for processing and learning from image input. There is one key insight about the architecture that the authors were transparent enough to highlight:
“Inspired by the successes of Transformer scaling in NLP, we experimented with applying a standard Transformer directly to images, with as few modifications as possible.”
Operation
It is valid to take ““least possible modifications” Literally, because they make virtually no modifications. What they modify is the input structure:
- In NLP, the transformer encoder takes a one-hot vector sequence (or equivalent token indices) that represent the input sentence/paragraph and returns a sequence of contextual embedding vectors that could be used for other tasks (e.g. classification)
- To generalize the CV, the vision transformer takes a patch vector sequence that represent the input image and returns a sequence of contextual embedding vectors that could be used for other tasks (e.g. classification)
In particular, suppose the input images have dimensions (n,n,3) to pass this as input to the transformer, what the vision transformer does is:
- It divides it into k² patches for some k (e.g. k=3) as in the figure above.
- Now each patch will be (n/k,n/k,3) the next step is to flatten each patch into a vector
The patch vector will have a dimensionality of 3*(n/k)*(n/k). For example, if the image is (900,900,3) and we use k=3, then a patch vector will have a dimensionality of 300*300*3 which represents the pixel values in the flattened patch. In the article, the authors use k=16. Hence the name of the article “A picture is worth 16×16 words: transformers for image recognition at scale”; Instead of introducing a one-hot vector that represents the word, they represent vector pixels that represent a patch of the image.
The rest of the operations remain as in the original transformer encoder:
- These patch vectors pass through a trainable embedding layer
- Positional embeddings are added to each vector to maintain a sense of spatial information in the image.
- The output is num_patches Encoder representations (one for each patch) that could be used for patch- or image-level classification
- Most often (and as in the paper), a CLS token is prepended to the corresponding representation which is used to make a prediction on the entire image (similar to BERT).
How about the transformer decoder?
Well, remember that it is like the encoder of the transformer; the difference is that it uses masked self-attention instead of self-attention (but the same input signature remains). In any case, a decoder-only transformer architecture should rarely be used because simply predicting the next patch may not be a task of great interest.
Hybrid architecture
The authors also mention that it is possible to start with a CNN feature map instead of the image itself to form a hybrid architecture (CNN sending output to the vision transformer). In this case, we think of the input as a generic (n,n,p) feature map and a patch vector will have dimensions (n/k)*(n/k)*p.
Loss of structure
It may cross your mind that this architecture shouldn't be that good because it treated the image as a linear structure when it isn't. The author tries to show that this is intentional by mentioning
“The two-dimensional neighborhood structure is used very sparingly… position embeddings at initialization time do not contain information about the 2D positions of the patches and all spatial relationships between patches must be learned from scratch”
We will see that the transformer is able to learn this as evidenced by its good performance in our experiments and more importantly, the architecture in the next article.
Results
The main verdict from the results is that vision transformers tend not to outperform CNN-based models for small data sets, but come close or outperform CNN-based models for larger data sets and , in any case, require significantly less computation:
Here we see that for the JFT-300M dataset (which has 300 million images), the ViT models pre-trained on the dataset outperform the ResNet-based baselines and require substantially less computational resources for pre-training. As you can see, the largest vision transformer they used (ViT-Huge with 632M parameters and k=16) used about 25% of the computation used for the ResNet-based model and still outperformed it. Performance doesn't even drop that much with ViT-Large using only <6.8% of the compute.
Meanwhile, others also report results where ResNet performed significantly better when trained on ImageNet-1K, which has only 1.3 million images.
Self-supervised learning using masking
The authors performed a preliminary exploration on predicting masked patches for self-monitoring, mimicking the masked language modeling task used in BERT (i.e., masking patches and attempting to predict them).
“We use the masked patch prediction objective for preliminary self-supervision experiments. To do so, we corrupt 50% of the patch embeddings, either by replacing them with a learnable embedding (mask) (80%), a random embedding from another patch (10%), or simply leaving them as is (10%).”
Using self-supervised pre-training, their smallest ViT-Base/16 model achieves 79.9% accuracy on ImageNet, a significant 2% improvement over training from scratch, but still 4% behind supervised pre-training. .
Key idea
As we saw in the article on Vision Transformer, the gains from pretraining by masking patches in the input images were not as significant as in ordinary NLP, where masked pretraining can lead to state-of-the-art results on some fine-tuning tasks.
This paper proposes a vision transformer architecture involving an encoder and a decoder that, when pre-trained with masking, results in significant improvements over the base vision transformer model (up to 6% improvement compared to training a base-sized vision transformer in a supervised manner).
This is an example (input, output, true labels). It is an autoencoder in the sense that it attempted to reconstruct the input while completing the missing patches.
Architecture
His encoder It is simply the ordinary vision transformer encoder that we explained above. In training and inference, only the “observed” patches are needed.
Meanwhile, their decoder It is also just the ordinary vision transformer encoder but requires:
- Masked symbolic vectors for missing patches
- Encoder output vectors for known patches
So for an image ((A, B, x), (C, x, x), (x, D, E)) where x denotes a missing patch, the decoder will take the sequence of patch vectors (Enc(A), Enc(B), Vec(x), Vec(x), Vec(x), Enc(D), Enc(E)). Enc returns the output vector of the encoder given the patch vector and x is a vector to represent the missing token.
He last layer In the decoder there is a linear layer that maps the contextual embeddings (produced by the vision transformer encoder in the decoder) to a vector of length equal to the patch size. The loss function is the mean square error that squares the difference between the original patch vector and the one predicted by this layer. In the loss function, we only look at the decoder predictions due to the masked tokens and ignore the ones corresponding to the current ones (i.e. Dec(A), Dec(B), Dec(C), etc.).
Final observation and example
It may come as a surprise that the authors suggest masking about 75% of patches in images, while BERT would mask only about 15% of words. They justify this as follows:
Images are natural signals with high spatial redundancy; for example, a missing fragment can be recovered from neighboring fragments with little high-level understanding of parts, objects, and scenes. To overcome this difference and encourage learning useful features, we mask out a large portion of random fragments.
Want to try it yourself? Check this out demo notebook by Niels Rogge.
That’s all for this story. We’ve been on a journey to understand how fundamental transformer models generalize to the world of computer vision. I hope you found it clear, insightful, and worth your time.
References:
(1) Dosovitskiy, A. et al. (2021) A picture is worth 16×16 words: Transformers for scale image recognition, arXiv.org. Available in: https://arxiv.org/abs/2010.11929 (Consultation: June 28, 2024).
(2) He, K. et al. (2021) Masked autoencoders are scalable vision learners, arXiv.org. Available in: https://arxiv.org/abs/2111.06377 (Consultation: June 28, 2024).