Chapter 7 Learning Representations for Real-World Data

The best theory is inspired by practice, and the best practice is inspired by theory.”

  — Donald Knuth

The previous chapters have presented a systematic introduction to mathematical problems, computational frameworks, and practical algorithms associated with learning low-dimensional distributions from high-dimensional data. Although most theoretical justifications of these methods have been established for idealistic models of data such as (mixtures of) subspaces and/or Gaussians, the principles and ideas behind these computational methods are nevertheless powerful and general, and they are in fact meant to be applicable for real-world datasets and tasks.

To help readers understand the material in this book better and learn how to apply what you have learned so far to real-world data, toward the end of the book, we provide some demonstrations and vignettes of several representative applications. Each application proposes a solution to a real-world task with a real-world dataset (such as visual data and text data), using the methods we have introduced in this book. The results presented in this chapter are meant to serve the following two purposes:

  • firstly, to provide additional experimental details and empirical evidence which validate the methods presented earlier in the book, and demonstrate their significant potential in real-world contexts;

  • secondly, to introduce the reader to certain modern empirical methods and tasks in deep learning which are not well-documented outside of research or production codebases.

However, in our honest opinion, the solutions and results given here are designed simply to verify that the methodology works. As such, there is great room for future improvement, both in engineering and theoretical understanding, to potentially improve the state-of-the-art. We will discuss some future directions in Chapter 8.

7.1 Technical Setup and Outline of the Chapter

In previous chapters, we alluded to different setups in which we used representation-learning techniques to process real data at scale. In this chapter, we will describe such setups in great detail. The objective of this section is to get you, the reader, to be able to reproduce any experiment discussed in this section (or indeed the book) using just the description we will give in the book, the principles introduced in previous chapters and expanded on in this chapter, and hyperparameters taken from a smattering of papers whose results are discussed in this chapter. To this end, we will precisely describe all procedures in a detailed language, pseudocode, or mathematical notation that can be directly implemented in code. Wherever possible, we will discuss how the concrete implementations connect to the principles presented earlier in the book.

Figure 7.1 : A diagram of the encoder pipeline. Data 𝑿 ∈ 𝒟 \bm{X}\in\mathcal{D} bold_italic_X ∈ caligraphic_D is fed through the embedding f θ emb f_{\theta}^{\mathrm{emb}} italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_emb end_POSTSUPERSCRIPT to get a sequence in ( ℝ d ) ∗ (\mathbb{R}^{d})^{*} ( blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT . The embedding is fed through a backbone f θ bb f_{\theta}^{\mathrm{bb}} italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_bb end_POSTSUPERSCRIPT to get features 𝒁 θ ​ ( 𝑿 ) \bm{Z}_{\theta}(\bm{X}) bold_italic_Z start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_X ) for each token. We can extract an aggregate feature 𝒛 θ ​ ( 𝑿 ) \bm{z}_{\theta}(\bm{X}) bold_italic_z start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_X ) using the extraction map f θ ext f_{\theta}^{\mathrm{ext}} italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ext end_POSTSUPERSCRIPT . Finally, to use the aggregate feature in downstream tasks, we can use the task-specific head h θ h_{\theta} italic_h start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT .
Figure 7.1: A diagram of the encoder pipeline. Data 𝑿𝒟\bm{X}\in\mathcal{D}bold_italic_X ∈ caligraphic_D is fed through the embedding fθembf_{\theta}^{\mathrm{emb}}italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_emb end_POSTSUPERSCRIPT to get a sequence in (d)(\mathbb{R}^{d})^{*}( blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT. The embedding is fed through a backbone fθbbf_{\theta}^{\mathrm{bb}}italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_bb end_POSTSUPERSCRIPT to get features 𝒁θ(𝑿)\bm{Z}_{\theta}(\bm{X})bold_italic_Z start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_X ) for each token. We can extract an aggregate feature 𝒛θ(𝑿)\bm{z}_{\theta}(\bm{X})bold_italic_z start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_X ) using the extraction map fθextf_{\theta}^{\mathrm{ext}}italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ext end_POSTSUPERSCRIPT. Finally, to use the aggregate feature in downstream tasks, we can use the task-specific head hθh_{\theta}italic_h start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT.
Figure 7.2 : A diagram of the autoencoder pipeline. Data 𝑿 ∈ 𝒟 \bm{X}\in\mathcal{D} bold_italic_X ∈ caligraphic_D is fed through the embedding f θ emb f_{\theta}^{\mathrm{emb}} italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_emb end_POSTSUPERSCRIPT to get a sequence in ( ℝ d ) ∗ (\mathbb{R}^{d})^{*} ( blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT . The embedding is fed through an encoder backbone f θ bb f_{\theta}^{\mathrm{bb}} italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_bb end_POSTSUPERSCRIPT to get features 𝒁 θ ​ ( 𝑿 ) \bm{Z}_{\theta}(\bm{X}) bold_italic_Z start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_X ) for each token. To decode 𝒁 θ ​ ( 𝑿 ) \bm{Z}_{\theta}(\bm{X}) bold_italic_Z start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_X ) , we pass it through a decoder backbone g η bb g_{\eta}^{\mathrm{bb}} italic_g start_POSTSUBSCRIPT italic_η end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_bb end_POSTSUPERSCRIPT . To map the decoder backbone output back to data space 𝒟 \mathcal{D} caligraphic_D , we use an unembedding layer g η unemb g_{\eta}^{\mathrm{unemb}} italic_g start_POSTSUBSCRIPT italic_η end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_unemb end_POSTSUPERSCRIPT , overall obtaining a reconstruction 𝑿 ^ θ , η ​ ( 𝑿 ) \hat{\bm{X}}_{\theta,\eta}(\bm{X}) over^ start_ARG bold_italic_X end_ARG start_POSTSUBSCRIPT italic_θ , italic_η end_POSTSUBSCRIPT ( bold_italic_X ) (here stylized to be a pixelated reconstruction of the input).
Figure 7.2: A diagram of the autoencoder pipeline. Data 𝑿𝒟\bm{X}\in\mathcal{D}bold_italic_X ∈ caligraphic_D is fed through the embedding fθembf_{\theta}^{\mathrm{emb}}italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_emb end_POSTSUPERSCRIPT to get a sequence in (d)(\mathbb{R}^{d})^{*}( blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT. The embedding is fed through an encoder backbone fθbbf_{\theta}^{\mathrm{bb}}italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_bb end_POSTSUPERSCRIPT to get features 𝒁θ(𝑿)\bm{Z}_{\theta}(\bm{X})bold_italic_Z start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_X ) for each token. To decode 𝒁θ(𝑿)\bm{Z}_{\theta}(\bm{X})bold_italic_Z start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_X ), we pass it through a decoder backbone gηbbg_{\eta}^{\mathrm{bb}}italic_g start_POSTSUBSCRIPT italic_η end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_bb end_POSTSUPERSCRIPT. To map the decoder backbone output back to data space 𝒟\mathcal{D}caligraphic_D, we use an unembedding layer gηunembg_{\eta}^{\mathrm{unemb}}italic_g start_POSTSUBSCRIPT italic_η end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_unemb end_POSTSUPERSCRIPT, overall obtaining a reconstruction 𝑿^θ,η(𝑿)\hat{\bm{X}}_{\theta,\eta}(\bm{X})over^ start_ARG bold_italic_X end_ARG start_POSTSUBSCRIPT italic_θ , italic_η end_POSTSUBSCRIPT ( bold_italic_X ) (here stylized to be a pixelated reconstruction of the input).

Let us define the set of possible data as 𝒟\mathcal{D}caligraphic_D (eventually this will be the set of images \mathcal{I}caligraphic_I, for example, or the set of text 𝒯\mathcal{T}caligraphic_T), and the set of finite sequences of tokens in d\mathbb{R}^{d}blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT (i.e., the set of matrices with dditalic_d rows) as (d)T=1d×T(\mathbb{R}^{d})^{*}\doteq\bigcup_{T=1}^{\infty}\mathbb{R}^{d\times T}( blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ≐ ⋃ start_POSTSUBSCRIPT italic_T = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∞ end_POSTSUPERSCRIPT blackboard_R start_POSTSUPERSCRIPT italic_d × italic_T end_POSTSUPERSCRIPT. In order to discuss the wealth of applications we introduce in this chapter, we first recall that in the rest of the book, we discuss two different types of model architectures.

  • An encoder architecture, parameterized by θ\thetaitalic_θ, which is composed of several components:

    • An embedding fθemb:𝒟(d)f_{\theta}^{\mathrm{emb}}\colon\mathcal{D}\to(\mathbb{R}^{d})^{*}italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_emb end_POSTSUPERSCRIPT : caligraphic_D → ( blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT, which converts the input data 𝒟\mathcal{D}caligraphic_D into a series of tokens which are mapped into, or embedded in, dditalic_d-dimensional space. In the rest of the chapter, we will often identify tokens and embeddings with each other.

    • An encoder backbone fθbb:(d)(d)f_{\theta}^{\mathrm{bb}}\colon(\mathbb{R}^{d})^{*}\to(\mathbb{R}^{d})^{*}italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_bb end_POSTSUPERSCRIPT : ( blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT → ( blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT, which processes the series of embeddings using a sequence-to-sequence operation. This backbone is implemented by the network architectures discussed in the previous chapters, but we will give a more formal description as we go along.

    • An aggregate feature extractor fθext:(d)df_{\theta}^{\mathrm{ext}}\colon(\mathbb{R}^{d})^{*}\to\mathbb{R}^{d}italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ext end_POSTSUPERSCRIPT : ( blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT → blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT, which extracts an aggregate representation of the whole sequence. This is used to define a single feature for the entire data sample.

    • A task-specific head hθ:dmh_{\theta}\colon\mathbb{R}^{d}\to\mathbb{R}^{m}italic_h start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT : blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT → blackboard_R start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT, which extracts an mmitalic_m-dimensional output for prediction.

    We also define fθfθbbfθemb:𝒟(d)f_{\theta}\doteq f_{\theta}^{\mathrm{bb}}\circ f_{\theta}^{\mathrm{emb}}\colon\mathcal{D}\to(\mathbb{R}^{d})^{*}italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ≐ italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_bb end_POSTSUPERSCRIPT ∘ italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_emb end_POSTSUPERSCRIPT : caligraphic_D → ( blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT. Given an input 𝑿\bm{X}bold_italic_X, we write 𝒁θ(𝑿)fθ(𝑿)\bm{Z}_{\theta}(\bm{X})\doteq f_{\theta}(\bm{X})bold_italic_Z start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_X ) ≐ italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_X ) and 𝒛θ(𝑿)fθext(𝒁θ(𝑿))\bm{z}_{\theta}(\bm{X})\doteq f_{\theta}^{\mathrm{ext}}(\bm{Z}_{\theta}(\bm{X}))bold_italic_z start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_X ) ≐ italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ext end_POSTSUPERSCRIPT ( bold_italic_Z start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_X ) ). The overall pipeline is depicted in Figure 7.1.

  • An autoencoder architecture, which is composed of several components:

    • An embedding fθemb:𝒟(d)f_{\theta}^{\mathrm{emb}}\colon\mathcal{D}\to(\mathbb{R}^{d})^{*}italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_emb end_POSTSUPERSCRIPT : caligraphic_D → ( blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT, which converts the input data 𝒟\mathcal{D}caligraphic_D into a series of tokens which are embedded in dditalic_d-dimensional space.

    • An encoder backbone fθbb:(d)(d)f_{\theta}^{\mathrm{bb}}\colon(\mathbb{R}^{d})^{*}\to(\mathbb{R}^{d})^{*}italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_bb end_POSTSUPERSCRIPT : ( blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT → ( blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT, which processes the series of embeddings using a sequence-to-sequence operation.

    • A decoder backbone gηbb:(d)(d)g_{\eta}^{\mathrm{bb}}\colon(\mathbb{R}^{d})^{*}\to(\mathbb{R}^{d})^{*}italic_g start_POSTSUBSCRIPT italic_η end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_bb end_POSTSUPERSCRIPT : ( blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT → ( blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT, which conceptually undoes the operation of the encoder backbone.

    • An unembedding gηunemb:(d)𝒟g_{\eta}^{\mathrm{unemb}}\colon(\mathbb{R}^{d})^{*}\to\mathcal{D}italic_g start_POSTSUBSCRIPT italic_η end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_unemb end_POSTSUPERSCRIPT : ( blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT → caligraphic_D, which acts as an inverse of the embedding.

    We also define fθfθbbfθemb:𝒟(d)f_{\theta}\doteq f_{\theta}^{\mathrm{bb}}\circ f_{\theta}^{\mathrm{emb}}\colon\mathcal{D}\to(\mathbb{R}^{d})^{*}italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ≐ italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_bb end_POSTSUPERSCRIPT ∘ italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_emb end_POSTSUPERSCRIPT : caligraphic_D → ( blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT and gηgηunembgηbb:(d)𝒟g_{\eta}\doteq g_{\eta}^{\mathrm{unemb}}\circ g_{\eta}^{\mathrm{bb}}\colon(\mathbb{R}^{d})^{*}\to\mathcal{D}italic_g start_POSTSUBSCRIPT italic_η end_POSTSUBSCRIPT ≐ italic_g start_POSTSUBSCRIPT italic_η end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_unemb end_POSTSUPERSCRIPT ∘ italic_g start_POSTSUBSCRIPT italic_η end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_bb end_POSTSUPERSCRIPT : ( blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT → caligraphic_D. Given an input 𝑿\bm{X}bold_italic_X, we write 𝒁θ(𝑿)fθ(𝑿)\bm{Z}_{\theta}(\bm{X})\doteq f_{\theta}(\bm{X})bold_italic_Z start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_X ) ≐ italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_X ) and 𝑿^θ,η(𝑿)gη(𝒁θ(𝑿))\hat{\bm{X}}_{\theta,\eta}(\bm{X})\doteq g_{\eta}(\bm{Z}_{\theta}(\bm{X}))over^ start_ARG bold_italic_X end_ARG start_POSTSUBSCRIPT italic_θ , italic_η end_POSTSUBSCRIPT ( bold_italic_X ) ≐ italic_g start_POSTSUBSCRIPT italic_η end_POSTSUBSCRIPT ( bold_italic_Z start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_X ) ). The overall pipeline is depicted in Figure 7.2.

We will repeatedly use this notation many times in this chapter, so please feel free to refer back to it if something doesn’t make sense. This decomposition of our networks also closely mirrors most code implementations, and you can start your coding projects by defining these networks.

In this chapter, we will discuss applications of the book’s principles to contrastive learning in Section 7.2. This will serve as both an introduction to image data, data augmentation techniques, and the common architecture known as the transformer, as well as a first demonstration of the drastic kinds of simplifications we can make using the demonstrated principles. We will continue with modifications to the network architecture in Sections 7.3 and 7.4, which demonstrate the capabilities of simplified architectures for encoding within the image and text domains. We then demonstrate simplified architectures for autoencoding in Section 7.6.

7.2 Simplified Contrastive Learning

Learning high-quality and faithful representations of data is a fundamental problem in deep learning, known as self-supervised learning. There have been many approaches proposed for this task, many of which do not evidently use the techniques and principles outlined in this manuscript. One such approach is called contrastive learning, so named because the learning objective is (roughly speaking) about ensuring that features of “similar” data are similar, and features of “dissimilar” data are far apart. Contrastive learning solutions are often highly-engineered, empirically designed approaches. In this section, we will describe one such approach named DINO [CTM+21], and use the principles described in the previous chapters to drastically simplify their design decisions while improving the learned representations.

7.2.1 Data

The data that we will use to explore and simplify the DINO methodology are all 2-dimensional image data. For training, we will use the ImageNet-1K and ImageNet-21K datasets. Each sample in the dataset is an RGB image, of varying resolution, and a label indicating the object or scene that the image contains (i.e., the class of the image). The ImageNet-1K dataset contains 1.28M training images and 50K validation images partitioned into 1K classes. The ImageNet-21K dataset contains 14.2M training images and 21.8K classes, but the classes are not disjoint (i.e., some classes are subsets of others). Since we are doing self-supervised learning, the labels will not be used during training, only during evaluation. For evaluation, we will use a litany of datasets. Of these, the most common is CIFAR-10. CIFAR-10 is a dataset of 60K RGB 32 ×\times× 32 natural images partitioned into 10 classes, with a pre-established training set of 50K samples and a validation set of 10K samples. The purpose of using CIFAR-10 is to ensure that models which train on one distribution of images (ImageNet) can generalize to another distribution of images (CIFAR-10). We also refer to other similar datasets, such as CIFAR-100 (disjoint from CIFAR-10), Oxford Flowers, and Oxford Pets. Exemplars of ImageNet-1K and CIFAR-10 data are shown in Figure 7.3. In order to increase the robustness of our model, we often apply small data augmentations to each image during processing, such as flips, added small random noise, or random large crops; we do not include this in our notation, as each augmentation of a natural image is itself (very close to) a natural image in our dataset.

(a) ImageNet-1K samples.
(a) ImageNet-1K samples.
(a) ImageNet-1K samples.
(b) CIFAR10 samples.

Figure 7.3: Images from ImageNet-1K (left) and CIFAR-10 (right). Notice that the CIFAR-10 images are much lower resolution, generally speaking, reducing the complexity of learning that distribution.

On a slightly more formal level, our data 𝑿\bm{X}bold_italic_X will be images; we let \mathcal{I}caligraphic_I be the set of all images. Since an image is a rectangular array of pixels, and each pixel has a color given by RGB, CMYK, or another color format, we say that an image is an element of c×h×w\mathbb{R}^{c\times h\times w}blackboard_R start_POSTSUPERSCRIPT italic_c × italic_h × italic_w end_POSTSUPERSCRIPT — here ccitalic_c is the number of channels (i.e., 333 for RGB and 444 for CMYK), hhitalic_h is the image height, and wwitalic_w is the image width. Consequently, the set of all images c,h,w=1c×h×w\mathcal{I}\doteq\bigcup_{c,h,w=1}^{\infty}\mathbb{R}^{c\times h\times w}caligraphic_I ≐ ⋃ start_POSTSUBSCRIPT italic_c , italic_h , italic_w = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∞ end_POSTSUPERSCRIPT blackboard_R start_POSTSUPERSCRIPT italic_c × italic_h × italic_w end_POSTSUPERSCRIPT is the set of all possible such data. Again, we will use this notation repeatedly.

7.2.2 Task and Objective Function

Our task is to learn a good representation of the data. Contrastive learning, by and large, does this by defining what properties of the input image we wish the features to reflect, constructing images which share these properties but vary others, and setting up a loss which promotes that the features of images with shared properties are close and images with different properties are different. The naturally optimal solution to this learning problem is that the learned features preserve the desired properties of the input. However, there are many practical and empirical complications that arise in the course of training contrastive models.

In the case of DINO, the authors propose to use a methodology which produces a single feature vector for the whole image and desires the feature vector to contain “global” (i.e., image-level) information. Accordingly, the loss will promote that images with similar global information have similar features and images with different global information have different features.

This seems intuitive, but as previously mentioned, there are several empirical considerations, even while setting up the loss. First and foremost, how should we promote similarities and differences? The answer from DINO [CTM+21] is111In the author’s view, inexplicably… to convert the output features into “logits” corresponding to some probability distribution and take their cross-entropy. More specifically, let Δm{𝒙m:xi0i[m],i=1mxi=1}\Delta_{m}\doteq\{\bm{x}\in\mathbb{R}^{m}\colon x_{i}\geq 0\ \forall i\in[m],\sum_{i=1}^{m}x_{i}=1\}roman_Δ start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ≐ { bold_italic_x ∈ blackboard_R start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT : italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ≥ 0 ∀ italic_i ∈ [ italic_m ] , ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = 1 } be the space of probability vectors in m\mathbb{R}^{m}blackboard_R start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT and define the function dCE:m×md_{\operatorname{CE}}\colon\mathbb{R}^{m}\times\mathbb{R}^{m}\to\mathbb{R}italic_d start_POSTSUBSCRIPT roman_CE end_POSTSUBSCRIPT : blackboard_R start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT × blackboard_R start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT → blackboard_R by

dCE(𝒑,𝒒)CE(𝒑,𝒒),𝒑,𝒒Δmd_{\operatorname{CE}}(\bm{p},\bm{q})\doteq\operatorname{CE}(\bm{p},\bm{q}),\quad\forall\bm{p},\bm{q}\in\Delta_{m}italic_d start_POSTSUBSCRIPT roman_CE end_POSTSUBSCRIPT ( bold_italic_p , bold_italic_q ) ≐ roman_CE ( bold_italic_p , bold_italic_q ) , ∀ bold_italic_p , bold_italic_q ∈ roman_Δ start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT (7.2.1)

where CE:Δm×Δm\operatorname{CE}\colon\Delta_{m}\times\Delta_{m}\to\mathbb{R}roman_CE : roman_Δ start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT × roman_Δ start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT → blackboard_R is the cross-entropy, defined as

CE(𝒑,𝒒)i=1mpilogqi,𝒑=(p1,,pm),𝒒=(q1,,qm)Δm.\operatorname{CE}(\bm{p},\bm{q})\doteq-\sum_{i=1}^{m}p_{i}\log q_{i},\quad\forall\bm{p}=(p_{1},\dots,p_{m}),\bm{q}=(q_{1},\dots,q_{m})\in\Delta_{m}.roman_CE ( bold_italic_p , bold_italic_q ) ≐ - ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT roman_log italic_q start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , ∀ bold_italic_p = ( italic_p start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_p start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ) , bold_italic_q = ( italic_q start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_q start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ) ∈ roman_Δ start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT . (7.2.2)

Before we continue our discussion, let us build some intuition about this distance function. We have, in particular,

CE(𝒑,𝒒)\displaystyle\operatorname{CE}(\bm{p},\bm{q})roman_CE ( bold_italic_p , bold_italic_q ) =i=1mpilogqi=i=1mpilog(pi/qi)i=1mpilogpi=𝖪𝖫(𝒑𝒒)+H(𝒑)\displaystyle=-\sum_{i=1}^{m}p_{i}\log q_{i}=\sum_{i=1}^{m}p_{i}\log(p_{i}/q_{i})-\sum_{i=1}^{m}p_{i}\log p_{i}=\operatorname{\mathsf{KL}}(\bm{p}\;\|\;\bm{q})+H(\bm{p})= - ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT roman_log italic_q start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT roman_log ( italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT / italic_q start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) - ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT roman_log italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = sansserif_KL ( bold_italic_p ∥ bold_italic_q ) + italic_H ( bold_italic_p ) (7.2.3)

where 𝖪𝖫:Δm×Δm\operatorname{\mathsf{KL}}\colon\Delta_{m}\times\Delta_{m}\to\mathbb{R}sansserif_KL : roman_Δ start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT × roman_Δ start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT → blackboard_R is the KL divergence, defined as

𝖪𝖫(𝒑𝒒)i=1mpilog(pi/qi),\operatorname{\mathsf{KL}}(\bm{p}\;\|\;\bm{q})\doteq\sum_{i=1}^{m}p_{i}\log(p_{i}/q_{i}),sansserif_KL ( bold_italic_p ∥ bold_italic_q ) ≐ ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT roman_log ( italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT / italic_q start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) , (7.2.4)

and H:ΔmH\colon\Delta_{m}\to\mathbb{R}italic_H : roman_Δ start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT → blackboard_R is the entropy of a random variable. Note that 𝖪𝖫(𝒑𝒒)\operatorname{\mathsf{KL}}(\bm{p}\;\|\;\bm{q})sansserif_KL ( bold_italic_p ∥ bold_italic_q ) is minimized if and only if 𝒑=𝒒\bm{p}=\bm{q}bold_italic_p = bold_italic_q. So minimizing dCEd_{\operatorname{CE}}italic_d start_POSTSUBSCRIPT roman_CE end_POSTSUBSCRIPT does two things: it makes 𝒑=𝒒\bm{p}=\bm{q}bold_italic_p = bold_italic_q, and it makes 𝒑\bm{p}bold_italic_p and 𝒒\bm{q}bold_italic_q have minimal entropy (i.e., vectors with 111 in one component and 0 elsewhere — these are called one-hot vectors). Overall, the goal of this objective is not just to match 𝒑\bm{p}bold_italic_p and 𝒒\bm{q}bold_italic_q but also to shape them in a certain way to make them low-entropy. Keep this in mind when we discuss the formulation.

The next question is, how should we obtain samples with similar global information? The answer from DINO (as well as nearly all contrastive learning) is data augmentation — from each sample, make several correlated samples which share the desired properties. In the DINO case, we use different crops or views of the input image. Recall that we model an image as an element of the set \mathcal{I}caligraphic_I. In this notation, a view is a function v:v\colon\mathcal{I}\to\mathcal{I}italic_v : caligraphic_I → caligraphic_I. In the DINO case, the view is a random resized crop: it takes a randomly chosen rectangular crop of the image (which has a fixed percentage pv[0,1]p_{v}\in[0,1]italic_p start_POSTSUBSCRIPT italic_v end_POSTSUBSCRIPT ∈ [ 0 , 1 ] of the total area of the image), resizes it proportionally so that the shorter edge is SrszS_{\mathrm{rsz}}italic_S start_POSTSUBSCRIPT roman_rsz end_POSTSUBSCRIPT pixels long, then resizes it to a fixed shape (C,Sv,Sv)(C,S_{v},S_{v})( italic_C , italic_S start_POSTSUBSCRIPT italic_v end_POSTSUBSCRIPT , italic_S start_POSTSUBSCRIPT italic_v end_POSTSUBSCRIPT ) where Sv1S_{v}\geq 1italic_S start_POSTSUBSCRIPT italic_v end_POSTSUBSCRIPT ≥ 1 is the size of the view and CCitalic_C is the number of channels in the original image.

Figure 7.4 : Local and global views in DINO. Local views and global views take a rectangular crop of the input image and resize it to a square shape, which is then input into the network for processing.
Figure 7.4: Local and global views in DINO. Local views and global views take a rectangular crop of the input image and resize it to a square shape, which is then input into the network for processing.

There are two types of views we want to use, depicted in Figure 7.4:

  • global views, which are random resized crops with area percentage parameter pglo[0,1]p_{\mathrm{glo}}\in[0,1]italic_p start_POSTSUBSCRIPT roman_glo end_POSTSUBSCRIPT ∈ [ 0 , 1 ] and output shape (C,Sglo,Sglo)(C,S_{\mathrm{glo}},S_{\mathrm{glo}})( italic_C , italic_S start_POSTSUBSCRIPT roman_glo end_POSTSUBSCRIPT , italic_S start_POSTSUBSCRIPT roman_glo end_POSTSUBSCRIPT );

  • local views, which are random resized crops with area percentage parameter ploc[0,1]p_{\mathrm{loc}}\in[0,1]italic_p start_POSTSUBSCRIPT roman_loc end_POSTSUBSCRIPT ∈ [ 0 , 1 ] and output shape (C,Sloc,Sloc)(C,S_{\mathrm{loc}},S_{\mathrm{loc}})( italic_C , italic_S start_POSTSUBSCRIPT roman_loc end_POSTSUBSCRIPT , italic_S start_POSTSUBSCRIPT roman_loc end_POSTSUBSCRIPT ). Here ploc<pglop_{\mathrm{loc}}<p_{\mathrm{glo}}italic_p start_POSTSUBSCRIPT roman_loc end_POSTSUBSCRIPT < italic_p start_POSTSUBSCRIPT roman_glo end_POSTSUBSCRIPT and Sloc<SgloS_{\mathrm{loc}}<S_{\mathrm{glo}}italic_S start_POSTSUBSCRIPT roman_loc end_POSTSUBSCRIPT < italic_S start_POSTSUBSCRIPT roman_glo end_POSTSUBSCRIPT.

DINO desires that the aggregate features 𝒛θ(𝑿v)(fθextfθ)(𝑿v)\bm{z}_{\theta}(\bm{X}_{v})\doteq(f_{\theta}^{\mathrm{ext}}\circ f_{\theta})(\bm{X}_{v})bold_italic_z start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_X start_POSTSUBSCRIPT italic_v end_POSTSUBSCRIPT ) ≐ ( italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ext end_POSTSUPERSCRIPT ∘ italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ) ( bold_italic_X start_POSTSUBSCRIPT italic_v end_POSTSUBSCRIPT ) of all views 𝑿vv(𝑿)\bm{X}_{v}\doteq v(\bm{X})bold_italic_X start_POSTSUBSCRIPT italic_v end_POSTSUBSCRIPT ≐ italic_v ( bold_italic_X ) of an input image 𝑿\bm{X}bold_italic_X be consistent with each other. DINO does this by using a “DINO head”222Note that h𝐖,𝛍h_{\bm{W},\bm{\mu}}italic_h start_POSTSUBSCRIPT bold_italic_W , bold_italic_μ end_POSTSUBSCRIPT is the task-specific head, which in Section 7.1 is parameterized only by θ\thetaitalic_θ as opposed to any specific parameters, but since we use two invocations of hhitalic_h with different values of the second parameter, we keep the specified notation. h𝑾,𝝁h_{\bm{W},\bm{\mu}}italic_h start_POSTSUBSCRIPT bold_italic_W , bold_italic_μ end_POSTSUBSCRIPT, parameterized by a matrix 𝑾s×d\bm{W}\in\mathbb{R}^{s\times d}bold_italic_W ∈ blackboard_R start_POSTSUPERSCRIPT italic_s × italic_d end_POSTSUPERSCRIPT and a vector 𝝁s\bm{\mu}\in\mathbb{R}^{s}bold_italic_μ ∈ blackboard_R start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT, to extract a probability vector 𝒑θ,𝑾,𝝁(𝑿v)h𝑾,𝝁(𝒛θ(𝑿v))\bm{p}_{\theta,\bm{W},\bm{\mu}}(\bm{X}_{v})\doteq h_{\bm{W},\bm{\mu}}(\bm{z}_{\theta}(\bm{X}_{v}))bold_italic_p start_POSTSUBSCRIPT italic_θ , bold_italic_W , bold_italic_μ end_POSTSUBSCRIPT ( bold_italic_X start_POSTSUBSCRIPT italic_v end_POSTSUBSCRIPT ) ≐ italic_h start_POSTSUBSCRIPT bold_italic_W , bold_italic_μ end_POSTSUBSCRIPT ( bold_italic_z start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_X start_POSTSUBSCRIPT italic_v end_POSTSUBSCRIPT ) ) from the aggregate feature 𝒛θ(𝑿v)\bm{z}_{\theta}(\bm{X}_{v})bold_italic_z start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_X start_POSTSUBSCRIPT italic_v end_POSTSUBSCRIPT ), using the following simple recipe:

h𝑾,𝝁(𝒛)softmax([𝑾𝒛𝝁]/τ),𝒛d,h_{\bm{W},\bm{\mu}}(\bm{z})\doteq\operatorname{\mathrm{softmax}}([\bm{W}\bm{z}-\bm{\mu}]/\tau),\qquad\forall\bm{z}\in\mathbb{R}^{d},italic_h start_POSTSUBSCRIPT bold_italic_W , bold_italic_μ end_POSTSUBSCRIPT ( bold_italic_z ) ≐ roman_softmax ( [ bold_italic_W bold_italic_z - bold_italic_μ ] / italic_τ ) , ∀ bold_italic_z ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT , (7.2.5)

where the softmax:sΔs\operatorname{\mathrm{softmax}}\colon\mathbb{R}^{s}\to\Delta_{s}roman_softmax : blackboard_R start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT → roman_Δ start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT function is defined by

softmax([x1xs])1i=1sexi[ex1exs]\operatorname{\mathrm{softmax}}\left(\begin{bmatrix}x_{1}\\ \vdots\\ x_{s}\end{bmatrix}\right)\doteq\frac{1}{\sum_{i=1}^{s}e^{x_{i}}}\begin{bmatrix}e^{x_{1}}\\ \vdots\\ e^{x_{s}}\end{bmatrix}roman_softmax ( [ start_ARG start_ROW start_CELL italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_CELL end_ROW start_ROW start_CELL ⋮ end_CELL end_ROW start_ROW start_CELL italic_x start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT end_CELL end_ROW end_ARG ] ) ≐ divide start_ARG 1 end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT italic_e start_POSTSUPERSCRIPT italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUPERSCRIPT end_ARG [ start_ARG start_ROW start_CELL italic_e start_POSTSUPERSCRIPT italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUPERSCRIPT end_CELL end_ROW start_ROW start_CELL ⋮ end_CELL end_ROW start_ROW start_CELL italic_e start_POSTSUPERSCRIPT italic_x start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT end_POSTSUPERSCRIPT end_CELL end_ROW end_ARG ] (7.2.6)

and τ>0\tau>0italic_τ > 0 is a “temperature” parameter which controls the entropy of the softmax’s output.

In particular, DINO minimizes the difference between the probability vector 𝒑θ,𝑾,𝝁(𝑿g)h𝑾,𝝁(𝒛θ(𝑿g))\bm{p}_{\theta,\bm{W},\bm{\mu}}(\bm{X}_{g})\doteq h_{\bm{W},\bm{\mu}}(\bm{z}_{\theta}(\bm{X}_{g}))bold_italic_p start_POSTSUBSCRIPT italic_θ , bold_italic_W , bold_italic_μ end_POSTSUBSCRIPT ( bold_italic_X start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT ) ≐ italic_h start_POSTSUBSCRIPT bold_italic_W , bold_italic_μ end_POSTSUBSCRIPT ( bold_italic_z start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_X start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT ) ) for each global view 𝑿gvg(𝑿)\bm{X}_{g}\doteq v_{g}(\bm{X})bold_italic_X start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT ≐ italic_v start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT ( bold_italic_X ) and the probability vector 𝒑θ,𝑾(𝑿c)h𝑾,𝟎m(𝒛θ(𝑿c))\bm{p}_{\theta,\bm{W}}(\bm{X}_{c})\doteq h_{\bm{W},\bm{0}_{m}}(\bm{z}_{\theta}(\bm{X}_{c}))bold_italic_p start_POSTSUBSCRIPT italic_θ , bold_italic_W end_POSTSUBSCRIPT ( bold_italic_X start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT ) ≐ italic_h start_POSTSUBSCRIPT bold_italic_W , bold_0 start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( bold_italic_z start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_X start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT ) ) for each view 𝑿cvc(𝑿)\bm{X}_{c}\doteq v_{c}(\bm{X})bold_italic_X start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT ≐ italic_v start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT ( bold_italic_X ). Here, vcv_{c}italic_v start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT can either be a local view or a global view. We will discuss the implementation of fθf_{\theta}italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT and fθextf_{\theta}^{\mathrm{ext}}italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ext end_POSTSUPERSCRIPT shortly in Section 7.2.3. Overall, DINO solves the problem

minθ,𝑾,𝝁DINO(θ,𝑾,𝝁)whereDINO(θ,𝑾,𝝁)𝔼[dCE(𝒑θ,𝑾,𝝁(𝑿g),𝒑θ,𝑾(𝑿c))],\min_{\theta,\bm{W},\bm{\mu}}\mathcal{L}_{\mathrm{DINO}}(\theta,\bm{W},\bm{\mu})\qquad\text{where}\qquad\mathcal{L}_{\mathrm{DINO}}(\theta,\bm{W},\bm{\mu})\doteq\operatorname{\mathbb{E}}[d_{\operatorname{CE}}(\bm{p}_{\theta,\bm{W},\bm{\mu}}(\bm{X}_{g}),\bm{p}_{\theta,\bm{W}}(\bm{X}_{c}))],roman_min start_POSTSUBSCRIPT italic_θ , bold_italic_W , bold_italic_μ end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT roman_DINO end_POSTSUBSCRIPT ( italic_θ , bold_italic_W , bold_italic_μ ) where caligraphic_L start_POSTSUBSCRIPT roman_DINO end_POSTSUBSCRIPT ( italic_θ , bold_italic_W , bold_italic_μ ) ≐ blackboard_E [ italic_d start_POSTSUBSCRIPT roman_CE end_POSTSUBSCRIPT ( bold_italic_p start_POSTSUBSCRIPT italic_θ , bold_italic_W , bold_italic_μ end_POSTSUBSCRIPT ( bold_italic_X start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT ) , bold_italic_p start_POSTSUBSCRIPT italic_θ , bold_italic_W end_POSTSUBSCRIPT ( bold_italic_X start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT ) ) ] , (7.2.7)

where the expectation is over data 𝑿\bm{X}bold_italic_X, global views vgv_{g}italic_v start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT, and other views vcv_{c}italic_v start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT.

In this specific case, however, if you try to implement (7.2.7) and optimize it on a real network, it is very likely that you will run into a problem: after running a few iterations of the learning algorithm, the feature mapping fθextfθf_{\theta}^{\mathrm{ext}}\circ f_{\theta}italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ext end_POSTSUPERSCRIPT ∘ italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT will become the constant function! This certainly optimizes the above loss since it minimizes the distance between features of different views of the same image. But we obviously do not want to learn this solution.

Actually avoiding collapse is a very common consideration in contrastive learning. So how do we do it in this case? The solution from DINO, again, is empirically designed, and carefully tunes the optimization of the parameter 𝝁\bm{\mu}bold_italic_μ (which is updated using all samples in the batch) and a “temperature” hyperparameter τ\tauitalic_τ which is part of the implementation of h𝑾,𝝁h_{\bm{W},\bm{\mu}}italic_h start_POSTSUBSCRIPT bold_italic_W , bold_italic_μ end_POSTSUBSCRIPT and discussed in Section 7.2.3. Given a certain special set of hyperparameters that work well, this is indeed enough to ensure non-collapse of the representation. However, outside of this special configuration, training models to converge is difficult, and the training is highly unstable.

To amend this state of affairs, let us discuss simplifications to the formulation. First, instead of computing a probability vector using a learned transformation h𝑾,𝝁h_{\bm{W},\bm{\mu}}italic_h start_POSTSUBSCRIPT bold_italic_W , bold_italic_μ end_POSTSUBSCRIPT of the aggregate features 𝒛θ\bm{z}_{\theta}bold_italic_z start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT, we can directly use the aggregate representation, ignoring the task-specific head (or equivalently, setting it to the identity mapping). But now we need a way to compare the vectors directly. Using our hypothesis from Chapter 4 that good representations should have Euclidean (subspace) geometry, a much more natural measure of difference is the squared 2\ell^{2}roman_ℓ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT distance d2:d×dd_{\ell^{2}}\colon\mathbb{R}^{d}\times\mathbb{R}^{d}\to\mathbb{R}italic_d start_POSTSUBSCRIPT roman_ℓ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_POSTSUBSCRIPT : blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT × blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT → blackboard_R, defined as

d2(𝒙,𝒚)12𝒙𝒚22,𝒙,𝒚d.d_{\ell^{2}}(\bm{x},\bm{y})\doteq\frac{1}{2}\|\bm{x}-\bm{y}\|_{2}^{2},\qquad\forall\bm{x},\bm{y}\in\mathbb{R}^{d}.italic_d start_POSTSUBSCRIPT roman_ℓ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( bold_italic_x , bold_italic_y ) ≐ divide start_ARG 1 end_ARG start_ARG 2 end_ARG ∥ bold_italic_x - bold_italic_y ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT , ∀ bold_italic_x , bold_italic_y ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT . (7.2.8)

This distance-based score is even more efficient to compute than the cross-entropy score. Thus, d2d_{\ell^{2}}italic_d start_POSTSUBSCRIPT roman_ℓ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_POSTSUBSCRIPT takes the place of dCEd_{\operatorname{CE}}italic_d start_POSTSUBSCRIPT roman_CE end_POSTSUBSCRIPT in our simplification.

Before, collapse was avoided by using tricks to update 𝝁\bm{\mu}bold_italic_μ and τ\tauitalic_τ. In our simplification, if we compare the features within the representation space instead of converting them to probabilities, we do not have either of these parameters and so must consider a different way to avoid collapse. To do this, we return to the fundamentals. The basic idea of avoiding collapse is that in order to make sure that all samples do not return the same exact same features, we need different samples to have different features. In other words, we would like the covariance of the features to be large in some sense. But from Chapters 3 and 4, we already have a quantity which measures the size of the covariance matrix. Namely, we use the straightforward (population-level) Gaussian coding rate RRitalic_R to ensure that the features of global views of different images, which have different global information, are well-separated and not collapsed (hence expanded). The overall modified loss SimDINO\mathcal{L}_{\mathrm{SimDINO}}caligraphic_L start_POSTSUBSCRIPT roman_SimDINO end_POSTSUBSCRIPT becomes:

SimDINO(θ)𝔼[d2(𝒛θ(𝑿g),𝒛θ(𝑿c))]γ2logdet(𝑰+dε2Cov(𝒛θ(𝑿g))),\mathcal{L}_{\mathrm{SimDINO}}(\theta)\doteq\operatorname{\mathbb{E}}[d_{\ell^{2}}(\bm{z}_{\theta}(\bm{X}_{g}),\bm{z}_{\theta}(\bm{X}_{c}))]-\frac{\gamma}{2}\log\det\left(\bm{I}+\frac{d}{\varepsilon^{2}}\operatorname{Cov}(\bm{z}_{\theta}(\bm{X}_{g}))\right),caligraphic_L start_POSTSUBSCRIPT roman_SimDINO end_POSTSUBSCRIPT ( italic_θ ) ≐ blackboard_E [ italic_d start_POSTSUBSCRIPT roman_ℓ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( bold_italic_z start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_X start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT ) , bold_italic_z start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_X start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT ) ) ] - divide start_ARG italic_γ end_ARG start_ARG 2 end_ARG roman_log roman_det ( bold_italic_I + divide start_ARG italic_d end_ARG start_ARG italic_ε start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG roman_Cov ( bold_italic_z start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_X start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT ) ) ) , (7.2.9)

where ε>0\varepsilon>0italic_ε > 0 is fixed and the appropriate expectations are, as before, taken over data 𝑿\bm{X}bold_italic_X, global view vgv_{g}italic_v start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT, and other (local or global) view vcv_{c}italic_v start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT. The loss in (7.2.9) is the loss used for the simplified DINO (“SimDINO”). As we will see, when properly implemented, it works at least as well as the original DINO.

7.2.3 Architecture: Vision Transformer

For the architecture, we use a standard vision transformer. Here is how such an architecture works formally in the context of image data. Recall from Section 7.1 that there are four components to an encoder architecture, namely an embedding, a backbone, a feature extractor, and a task-specific head. We discuss these four parts presently.

Figure 7.5 : An example of an image turned into 5 × 5 5\times 5 5 × 5 square patches, which are placed in raster order. Each patch is of the same size, and the grid of patches is of shape ( N H , N W ) = ( 5 , 5 ) (N_{H},N_{W})=(5,5) ( italic_N start_POSTSUBSCRIPT italic_H end_POSTSUBSCRIPT , italic_N start_POSTSUBSCRIPT italic_W end_POSTSUBSCRIPT ) = ( 5 , 5 ) . The grid of patches is then unrolled into a sequence of length 5 × 5 = 25 5\times 5=25 5 × 5 = 25 in raster order.
Figure 7.5: An example of an image turned into 5×55\times 55 × 5 square patches, which are placed in raster order. Each patch is of the same size, and the grid of patches is of shape (NH,NW)=(5,5)(N_{H},N_{W})=(5,5)( italic_N start_POSTSUBSCRIPT italic_H end_POSTSUBSCRIPT , italic_N start_POSTSUBSCRIPT italic_W end_POSTSUBSCRIPT ) = ( 5 , 5 ). The grid of patches is then unrolled into a sequence of length 5×5=255\times 5=255 × 5 = 25 in raster order.
Figure 7.6 : The transformer embedding pipeline. Given a sequence of unrolled patches in raster order 𝑿 patch \bm{X}^{\mathrm{patch}} bold_italic_X start_POSTSUPERSCRIPT roman_patch end_POSTSUPERSCRIPT , each unrolled patch is linearly projected into the feature space, and equipped with an (additive) positional encoding and an additional token known as the class token. The output is the first-layer-input feature 𝒁 θ 1 ​ ( 𝑿 ) = f θ emb ​ ( 𝑿 ) \bm{Z}_{\theta}^{1}(\bm{X})=f_{\theta}^{\mathrm{emb}}(\bm{X}) bold_italic_Z start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT ( bold_italic_X ) = italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_emb end_POSTSUPERSCRIPT ( bold_italic_X ) .
Figure 7.6: The transformer embedding pipeline. Given a sequence of unrolled patches in raster order 𝑿patch\bm{X}^{\mathrm{patch}}bold_italic_X start_POSTSUPERSCRIPT roman_patch end_POSTSUPERSCRIPT, each unrolled patch is linearly projected into the feature space, and equipped with an (additive) positional encoding and an additional token known as the class token. The output is the first-layer-input feature 𝒁θ1(𝑿)=fθemb(𝑿)\bm{Z}_{\theta}^{1}(\bm{X})=f_{\theta}^{\mathrm{emb}}(\bm{X})bold_italic_Z start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT ( bold_italic_X ) = italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_emb end_POSTSUPERSCRIPT ( bold_italic_X ).
Embedding.

Given image data 𝑿\bm{X}\in\mathcal{I}bold_italic_X ∈ caligraphic_I, we embed it as a sequence of tokens in d\mathbb{R}^{d}blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT using the map fθembf_{\theta}^{\mathrm{emb}}italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_emb end_POSTSUPERSCRIPT, as follows. The first two steps are depicted in Figure 7.5, and the latter two are depicted in Figure 7.6.

  1. 1.

    First, we turn the image data 𝑿\bm{X}bold_italic_X into a sequence of patches of shape (C,PH,PW)(C,P_{H},P_{W})( italic_C , italic_P start_POSTSUBSCRIPT italic_H end_POSTSUBSCRIPT , italic_P start_POSTSUBSCRIPT italic_W end_POSTSUBSCRIPT ) where PHP_{H}italic_P start_POSTSUBSCRIPT italic_H end_POSTSUBSCRIPT and PWP_{W}italic_P start_POSTSUBSCRIPT italic_W end_POSTSUBSCRIPT are the patch dimensions. We assume that PHP_{H}italic_P start_POSTSUBSCRIPT italic_H end_POSTSUBSCRIPT and PWP_{W}italic_P start_POSTSUBSCRIPT italic_W end_POSTSUBSCRIPT evenly divide the height and width of 𝑿\bm{X}bold_italic_X, respectively (in the notation of Section 7.2.2 we assume that PHP_{H}italic_P start_POSTSUBSCRIPT italic_H end_POSTSUBSCRIPT and PWP_{W}italic_P start_POSTSUBSCRIPT italic_W end_POSTSUBSCRIPT evenly divide SlocS_{\mathrm{loc}}italic_S start_POSTSUBSCRIPT roman_loc end_POSTSUBSCRIPT and SgloS_{\mathrm{glo}}italic_S start_POSTSUBSCRIPT roman_glo end_POSTSUBSCRIPT). Let the resulting grid of patches have NHN_{H}italic_N start_POSTSUBSCRIPT italic_H end_POSTSUBSCRIPT rows and NWN_{W}italic_N start_POSTSUBSCRIPT italic_W end_POSTSUBSCRIPT columns.

  2. 2.

    We unroll each patch into a vector of length DCPHPWD\doteq CP_{H}P_{W}italic_D ≐ italic_C italic_P start_POSTSUBSCRIPT italic_H end_POSTSUBSCRIPT italic_P start_POSTSUBSCRIPT italic_W end_POSTSUBSCRIPT. There are NNHNWN\doteq N_{H}N_{W}italic_N ≐ italic_N start_POSTSUBSCRIPT italic_H end_POSTSUBSCRIPT italic_N start_POSTSUBSCRIPT italic_W end_POSTSUBSCRIPT patch vectors, which we place in “raster order” (top left \to top right \to bottom left \to bottom right) into a matrix 𝑿patchD×N\bm{X}^{\mathrm{patch}}\in\mathbb{R}^{D\times N}bold_italic_X start_POSTSUPERSCRIPT roman_patch end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_D × italic_N end_POSTSUPERSCRIPT, where 𝑿patchfpatch(𝑿)\bm{X}^{\mathrm{patch}}\doteq f^{\mathrm{patch}}(\bm{X})bold_italic_X start_POSTSUPERSCRIPT roman_patch end_POSTSUPERSCRIPT ≐ italic_f start_POSTSUPERSCRIPT roman_patch end_POSTSUPERSCRIPT ( bold_italic_X ). Notice that DDitalic_D depends only on the patch size and number of channels. Since the latter quantity is normally constant among samples in the same dataset, DDitalic_D is the same for all images in the dataset, while NNitalic_N is different for larger and smaller images.

  3. 3.

    We then perform the following operation on 𝑿patchD×N\bm{X}^{\mathrm{patch}}\in\mathbb{R}^{D\times N}bold_italic_X start_POSTSUPERSCRIPT roman_patch end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_D × italic_N end_POSTSUPERSCRIPT to project it to d×n\mathbb{R}^{d\times n}blackboard_R start_POSTSUPERSCRIPT italic_d × italic_n end_POSTSUPERSCRIPT where nN+1n\doteq N+1italic_n ≐ italic_N + 1:

    𝑿patch[𝒛cls1,𝑾emb𝑿]+𝑬pos.\bm{X}^{\mathrm{patch}}\mapsto[\bm{z}_{\mathrm{cls}}^{1},\bm{W}^{\mathrm{emb}}\bm{X}]+\bm{E}^{\mathrm{pos}}.bold_italic_X start_POSTSUPERSCRIPT roman_patch end_POSTSUPERSCRIPT ↦ [ bold_italic_z start_POSTSUBSCRIPT roman_cls end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT , bold_italic_W start_POSTSUPERSCRIPT roman_emb end_POSTSUPERSCRIPT bold_italic_X ] + bold_italic_E start_POSTSUPERSCRIPT roman_pos end_POSTSUPERSCRIPT . (7.2.10)

    Here we have three trainable parameters 𝑾emb\bm{W}^{\mathrm{emb}}bold_italic_W start_POSTSUPERSCRIPT roman_emb end_POSTSUPERSCRIPT, 𝒛cls1\bm{z}_{\mathrm{cls}}^{1}bold_italic_z start_POSTSUBSCRIPT roman_cls end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT, and 𝑬pos\bm{E}^{\mathrm{pos}}bold_italic_E start_POSTSUPERSCRIPT roman_pos end_POSTSUPERSCRIPT whose purpose is as follows:

    • 𝑾embd×D\bm{W}^{\mathrm{emb}}\in\mathbb{R}^{d\times D}bold_italic_W start_POSTSUPERSCRIPT roman_emb end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_D end_POSTSUPERSCRIPT is a matrix which projects each patch vector to a token feature.

    • 𝒛cls1d\bm{z}_{\mathrm{cls}}^{1}\in\mathbb{R}^{d}bold_italic_z start_POSTSUBSCRIPT roman_cls end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT is a so-called class token or register token. The class token heuristically holds global information of the whole data and is used for downstream tasks. In the framework of compressive deep networks from Chapter 4, we expect that the class token is projected onto the same subspaces as the salient or semantically relevant tokens during the progression of the forward pass.

    • 𝑬posd×N\bm{E}^{\mathrm{pos}}\in\mathbb{R}^{d\times N}bold_italic_E start_POSTSUPERSCRIPT roman_pos end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_N end_POSTSUPERSCRIPT is a so-called positional encoding which distinguishes tokens of different patches from each other. That is, token features should have positional information, so that the overall map fpref^{\mathrm{pre}}italic_f start_POSTSUPERSCRIPT roman_pre end_POSTSUPERSCRIPT is not invariant to permutations of the patches, and 𝑬pos\bm{E}^{\mathrm{pos}}bold_italic_E start_POSTSUPERSCRIPT roman_pos end_POSTSUPERSCRIPT inserts this positional information.

      • In this DINO case, where the transformer receives differently-sized images, we learn a positional encoding for the largest size received during training, and interpolate to get the positional encodings for smaller-sized inputs.

Thus, in the end we have

fθemb(𝑿)[𝒛cls1,𝑾embfpatch(𝑿)+𝑬pos].f_{\theta}^{\mathrm{emb}}(\bm{X})\doteq\begin{bmatrix}\bm{z}_{\mathrm{cls}}^{1},\bm{W}^{\mathrm{emb}}f^{\mathrm{patch}}(\bm{X})+\bm{E}^{\mathrm{pos}}\end{bmatrix}.italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_emb end_POSTSUPERSCRIPT ( bold_italic_X ) ≐ [ start_ARG start_ROW start_CELL bold_italic_z start_POSTSUBSCRIPT roman_cls end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT , bold_italic_W start_POSTSUPERSCRIPT roman_emb end_POSTSUPERSCRIPT italic_f start_POSTSUPERSCRIPT roman_patch end_POSTSUPERSCRIPT ( bold_italic_X ) + bold_italic_E start_POSTSUPERSCRIPT roman_pos end_POSTSUPERSCRIPT end_CELL end_ROW end_ARG ] . (7.2.11)

All parameters 𝒛cls1,𝑾emb,𝑬pos\bm{z}_{\mathrm{cls}}^{1},\bm{W}^{\mathrm{emb}},\bm{E}^{\mathrm{pos}}bold_italic_z start_POSTSUBSCRIPT roman_cls end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT , bold_italic_W start_POSTSUPERSCRIPT roman_emb end_POSTSUPERSCRIPT , bold_italic_E start_POSTSUPERSCRIPT roman_pos end_POSTSUPERSCRIPT are contained in the parameter set θ\thetaitalic_θ.

Figure 7.7 : One layer f θ ℓ f_{\theta}^{\ell} italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT of the transformer backbone. The input features go through layer-normalization, multi-head self-attention, and multi-layer perceptron blocks in sequence to form the output features of the layer.
Figure 7.7: One layer fθf_{\theta}^{\ell}italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT of the transformer backbone. The input features go through layer-normalization, multi-head self-attention, and multi-layer perceptron blocks in sequence to form the output features of the layer.
Backbone.

Given a sequence of embeddings 𝒁θ1(𝑿)fθemb(𝑿)(d)\bm{Z}_{\theta}^{1}(\bm{X})\doteq f_{\theta}^{\mathrm{emb}}(\bm{X})\in(\mathbb{R}^{d})^{*}bold_italic_Z start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT ( bold_italic_X ) ≐ italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_emb end_POSTSUPERSCRIPT ( bold_italic_X ) ∈ ( blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT, we process it using the backbone map fθbbf_{\theta}^{\mathrm{bb}}italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_bb end_POSTSUPERSCRIPT as follows and as depicted in Figure 7.7. The function fθbbf_{\theta}^{\mathrm{bb}}italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_bb end_POSTSUPERSCRIPT is composed of LLitalic_L layers fθf_{\theta}^{\ell}italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT, i.e.,

fθbb=fθLfθ1.f_{\theta}^{\mathrm{bb}}=f_{\theta}^{L}\circ\cdots\circ f_{\theta}^{1}.italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_bb end_POSTSUPERSCRIPT = italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT ∘ ⋯ ∘ italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT . (7.2.12)

The layer fθf_{\theta}^{\ell}italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT has the following implementation:

𝒁θ+1/2(𝑿)\displaystyle\bm{Z}_{\theta}^{\ell+1/2}(\bm{X})bold_italic_Z start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ + 1 / 2 end_POSTSUPERSCRIPT ( bold_italic_X ) =𝒁θ(𝑿)+MHSAθ(LNθ1,(𝒁θ(𝑿)))\displaystyle=\bm{Z}_{\theta}^{\ell}(\bm{X})+\operatorname{MHSA}_{\theta}^{\ell}(\operatorname{LN}_{\theta}^{1,\ell}(\bm{Z}_{\theta}^{\ell}(\bm{X})))= bold_italic_Z start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ( bold_italic_X ) + roman_MHSA start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ( roman_LN start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 , roman_ℓ end_POSTSUPERSCRIPT ( bold_italic_Z start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ( bold_italic_X ) ) ) (7.2.13)
𝒁θ+1(𝑿)\displaystyle\bm{Z}_{\theta}^{\ell+1}(\bm{X})bold_italic_Z start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ + 1 end_POSTSUPERSCRIPT ( bold_italic_X ) =𝒁θ+1/2(𝑿)+MLPθ(LNθ2,(𝒁θ+1/2(𝑿)))\displaystyle=\bm{Z}_{\theta}^{\ell+1/2}(\bm{X})+\operatorname{MLP}_{\theta}^{\ell}(\operatorname{LN}_{\theta}^{2,\ell}(\bm{Z}_{\theta}^{\ell+1/2}(\bm{X})))= bold_italic_Z start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ + 1 / 2 end_POSTSUPERSCRIPT ( bold_italic_X ) + roman_MLP start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ( roman_LN start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 , roman_ℓ end_POSTSUPERSCRIPT ( bold_italic_Z start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ + 1 / 2 end_POSTSUPERSCRIPT ( bold_italic_X ) ) ) (7.2.14)

and fθf_{\theta}^{\ell}italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT is defined such that fθ(𝒁θ(𝑿))𝒁θ+1(𝑿)f_{\theta}^{\ell}(\bm{Z}_{\theta}^{\ell}(\bm{X}))\doteq\bm{Z}_{\theta}^{\ell+1}(\bm{X})italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ( bold_italic_Z start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ( bold_italic_X ) ) ≐ bold_italic_Z start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ + 1 end_POSTSUPERSCRIPT ( bold_italic_X ). Here we have used some operators, such as MHSAθ,MLPθ\operatorname{MHSA}_{\theta}^{\ell},\operatorname{MLP}_{\theta}^{\ell}roman_MHSA start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT , roman_MLP start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT and LNθi,\operatorname{LN}_{\theta}^{i,\ell}roman_LN start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i , roman_ℓ end_POSTSUPERSCRIPT that are defined as follows:

  • The MHSAθ\operatorname{MHSA}_{\theta}^{\ell}roman_MHSA start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT operator is multi-head-self-attention, the predecessor of the multi-head subspace self-attention (cf Chapter 4). The formulation is as follows:

    MHSAθ(𝒁)\displaystyle\operatorname{MHSA}_{\theta}^{\ell}(\bm{Z})roman_MHSA start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ( bold_italic_Z ) 𝑼out[SA([𝑼qry1,]𝒁,[𝑼key1,]𝒁,[𝑼val1,]𝒁)SA([𝑼qryK,]𝒁,[𝑼keyK,]𝒁,[𝑼valK,]𝒁)]+𝒃out𝟏n,\displaystyle\doteq\bm{U}_{\mathrm{out}}^{\ell}\begin{bmatrix}\operatorname{SA}([\bm{U}_{\mathrm{qry}}^{1,\ell}]^{\top}\bm{Z},[\bm{U}_{\mathrm{key}}^{1,\ell}]^{\top}\bm{Z},[\bm{U}_{\mathrm{val}}^{1,\ell}]^{\top}\bm{Z})\\ \vdots\\ \operatorname{SA}([\bm{U}_{\mathrm{qry}}^{K,\ell}]^{\top}\bm{Z},[\bm{U}_{\mathrm{key}}^{K,\ell}]^{\top}\bm{Z},[\bm{U}_{\mathrm{val}}^{K,\ell}]^{\top}\bm{Z})\end{bmatrix}+\bm{b}_{\mathrm{out}}^{\ell}\bm{1}_{n}^{\top},≐ bold_italic_U start_POSTSUBSCRIPT roman_out end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT [ start_ARG start_ROW start_CELL roman_SA ( [ bold_italic_U start_POSTSUBSCRIPT roman_qry end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 , roman_ℓ end_POSTSUPERSCRIPT ] start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_Z , [ bold_italic_U start_POSTSUBSCRIPT roman_key end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 , roman_ℓ end_POSTSUPERSCRIPT ] start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_Z , [ bold_italic_U start_POSTSUBSCRIPT roman_val end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 , roman_ℓ end_POSTSUPERSCRIPT ] start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_Z ) end_CELL end_ROW start_ROW start_CELL ⋮ end_CELL end_ROW start_ROW start_CELL roman_SA ( [ bold_italic_U start_POSTSUBSCRIPT roman_qry end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K , roman_ℓ end_POSTSUPERSCRIPT ] start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_Z , [ bold_italic_U start_POSTSUBSCRIPT roman_key end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K , roman_ℓ end_POSTSUPERSCRIPT ] start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_Z , [ bold_italic_U start_POSTSUBSCRIPT roman_val end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K , roman_ℓ end_POSTSUPERSCRIPT ] start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_Z ) end_CELL end_ROW end_ARG ] + bold_italic_b start_POSTSUBSCRIPT roman_out end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT bold_1 start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT , (7.2.15)
    whereSA(𝑸,𝑲,𝑽)\displaystyle\text{where}\qquad\operatorname{SA}(\bm{Q},\bm{K},\bm{V})where roman_SA ( bold_italic_Q , bold_italic_K , bold_italic_V ) 𝑽softmax(𝑲𝑸p)𝑨(𝑸,𝑲)\displaystyle\doteq\bm{V}\underbrace{\operatorname{\mathrm{softmax}}\left(\frac{\bm{K}^{\top}\bm{Q}}{\sqrt{p}}\right)}_{\doteq\bm{A}(\bm{Q},\bm{K})}≐ bold_italic_V under⏟ start_ARG roman_softmax ( divide start_ARG bold_italic_K start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_Q end_ARG start_ARG square-root start_ARG italic_p end_ARG end_ARG ) end_ARG start_POSTSUBSCRIPT ≐ bold_italic_A ( bold_italic_Q , bold_italic_K ) end_POSTSUBSCRIPT (7.2.16)

    where ppitalic_p is a positive integer, 𝑼qryk,,𝑼keyk,,𝑼valk,d×p\bm{U}_{\mathrm{qry}}^{k,\ell},\bm{U}_{\mathrm{key}}^{k,\ell},\bm{U}_{\mathrm{val}}^{k,\ell}\in\mathbb{R}^{d\times p}bold_italic_U start_POSTSUBSCRIPT roman_qry end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k , roman_ℓ end_POSTSUPERSCRIPT , bold_italic_U start_POSTSUBSCRIPT roman_key end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k , roman_ℓ end_POSTSUPERSCRIPT , bold_italic_U start_POSTSUBSCRIPT roman_val end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k , roman_ℓ end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_p end_POSTSUPERSCRIPT, 𝑼outd×Kp\bm{U}_{\mathrm{out}}^{\ell}\in\mathbb{R}^{d\times Kp}bold_italic_U start_POSTSUBSCRIPT roman_out end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_K italic_p end_POSTSUPERSCRIPT, and 𝒃outd\bm{b}_{\mathrm{out}}^{\ell}\in\mathbb{R}^{d}bold_italic_b start_POSTSUBSCRIPT roman_out end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT are trainable parameters contained in the parameter set θ\thetaitalic_θ, and the softmax\operatorname{\mathrm{softmax}}roman_softmax is defined column-wise as

    softmax(𝑴)\displaystyle\operatorname{\mathrm{softmax}}(\bm{M})roman_softmax ( bold_italic_M ) [softmax(𝒎1)softmax(𝒎p)],\displaystyle\doteq\begin{bmatrix}\operatorname{\mathrm{softmax}}(\bm{m}_{1})&\cdots&\operatorname{\mathrm{softmax}}(\bm{m}_{p})\end{bmatrix},≐ [ start_ARG start_ROW start_CELL roman_softmax ( bold_italic_m start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) end_CELL start_CELL ⋯ end_CELL start_CELL roman_softmax ( bold_italic_m start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT ) end_CELL end_ROW end_ARG ] , (7.2.17)
    𝑴\displaystyle\forall\bm{M}∀ bold_italic_M =[𝒎1,,𝒎p]n×p.\displaystyle=\begin{bmatrix}\bm{m}_{1},\dots,\bm{m}_{p}\end{bmatrix}\in\mathbb{R}^{n\times p}.= [ start_ARG start_ROW start_CELL bold_italic_m start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , bold_italic_m start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT end_CELL end_ROW end_ARG ] ∈ blackboard_R start_POSTSUPERSCRIPT italic_n × italic_p end_POSTSUPERSCRIPT . (7.2.18)

    In practice, the dimensions are usually picked such that Kp=dKp=ditalic_K italic_p = italic_d. The terms

    𝑨θk,(𝒁)𝑨([𝑼qryk,]𝒁,[𝑼keyk,]𝒁),SAθk,(𝒁)SA([𝑼qryk,]𝒁,[𝑼keyk,]𝒁,[𝑼valk,]𝒁)\bm{A}_{\theta}^{k,\ell}(\bm{Z})\doteq\bm{A}([\bm{U}_{\mathrm{qry}}^{k,\ell}]^{\top}\bm{Z},[\bm{U}_{\mathrm{key}}^{k,\ell}]^{\top}\bm{Z}),\qquad\operatorname{SA}_{\theta}^{k,\ell}(\bm{Z})\doteq\operatorname{SA}([\bm{U}_{\mathrm{qry}}^{k,\ell}]^{\top}\bm{Z},[\bm{U}_{\mathrm{key}}^{k,\ell}]^{\top}\bm{Z},[\bm{U}_{\mathrm{val}}^{k,\ell}]^{\top}\bm{Z})bold_italic_A start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k , roman_ℓ end_POSTSUPERSCRIPT ( bold_italic_Z ) ≐ bold_italic_A ( [ bold_italic_U start_POSTSUBSCRIPT roman_qry end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k , roman_ℓ end_POSTSUPERSCRIPT ] start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_Z , [ bold_italic_U start_POSTSUBSCRIPT roman_key end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k , roman_ℓ end_POSTSUPERSCRIPT ] start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_Z ) , roman_SA start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k , roman_ℓ end_POSTSUPERSCRIPT ( bold_italic_Z ) ≐ roman_SA ( [ bold_italic_U start_POSTSUBSCRIPT roman_qry end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k , roman_ℓ end_POSTSUPERSCRIPT ] start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_Z , [ bold_italic_U start_POSTSUBSCRIPT roman_key end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k , roman_ℓ end_POSTSUPERSCRIPT ] start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_Z , [ bold_italic_U start_POSTSUBSCRIPT roman_val end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k , roman_ℓ end_POSTSUPERSCRIPT ] start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_Z ) (7.2.19)

    are also known as the kkitalic_kth attention map and kkitalic_kth attention head output at layer \ellroman_ℓ, respectively. Furthermore, the operation SA(𝑸,𝑲,𝑽)\operatorname{SA}(\bm{Q},\bm{K},\bm{V})roman_SA ( bold_italic_Q , bold_italic_K , bold_italic_V ) can be computed extremely efficiently using specialized software such as FlashAttention [SBZ+25].

  • The MLPθ\operatorname{MLP}_{\theta}^{\ell}roman_MLP start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT is a two-layer perceptron, a regular nonlinearity used in deep networks, and has the form

    MLPθ(𝒁)𝑾downReLU(𝑾up𝒁+𝒃up𝟏n)+𝒃down𝟏n\operatorname{MLP}_{\theta}^{\ell}(\bm{Z})\doteq\bm{W}_{\mathrm{down}}^{\ell}\operatorname{ReLU}(\bm{W}_{\mathrm{up}}^{\ell}\bm{Z}+\bm{b}_{\mathrm{up}}^{\ell}\bm{1}_{n}^{\top})+\bm{b}_{\mathrm{down}}^{\ell}\bm{1}_{n}^{\top}roman_MLP start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ( bold_italic_Z ) ≐ bold_italic_W start_POSTSUBSCRIPT roman_down end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT roman_ReLU ( bold_italic_W start_POSTSUBSCRIPT roman_up end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT bold_italic_Z + bold_italic_b start_POSTSUBSCRIPT roman_up end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT bold_1 start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) + bold_italic_b start_POSTSUBSCRIPT roman_down end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT bold_1 start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT (7.2.20)

    where 𝑾upq×d,𝑾downd×q,𝒃upq,𝒃downd\bm{W}_{\mathrm{up}}^{\ell}\in\mathbb{R}^{q\times d},\bm{W}_{\mathrm{down}}^{\ell}\in\mathbb{R}^{d\times q},\bm{b}_{\mathrm{up}}^{\ell}\in\mathbb{R}^{q},\bm{b}_{\mathrm{down}}^{\ell}\in\mathbb{R}^{d}bold_italic_W start_POSTSUBSCRIPT roman_up end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_q × italic_d end_POSTSUPERSCRIPT , bold_italic_W start_POSTSUBSCRIPT roman_down end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_q end_POSTSUPERSCRIPT , bold_italic_b start_POSTSUBSCRIPT roman_up end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT , bold_italic_b start_POSTSUBSCRIPT roman_down end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT are trainable parameters also contained in the parameter set θ\thetaitalic_θ, and ReLU\operatorname{ReLU}roman_ReLU is the element-wise ReLU nonlinearity, i.e., ReLU(𝑴)ij=max{Mij,0}\operatorname{ReLU}(\bm{M})_{ij}=\max\{M_{ij},0\}roman_ReLU ( bold_italic_M ) start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT = roman_max { italic_M start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT , 0 }.

  • Each layer-norm LNθi,\operatorname{LN}_{\theta}^{i,\ell}roman_LN start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i , roman_ℓ end_POSTSUPERSCRIPT for i{1,2}i\in\{1,2\}italic_i ∈ { 1 , 2 } is a standard normalization, which applies column-wise to each token feature independently:

    LNθi,(𝒁)=LNθi,([𝒛1,,𝒛n])=[LNθi,(𝒛1),,LNθi,(𝒛n)]\operatorname{LN}_{\theta}^{i,\ell}(\bm{Z})=\operatorname{LN}_{\theta}^{i,\ell}(\begin{bmatrix}\bm{z}_{1},\dots,\bm{z}_{n}\end{bmatrix})=\begin{bmatrix}\operatorname{LN}_{\theta}^{i,\ell}(\bm{z}_{1}),\dots,\operatorname{LN}_{\theta}^{i,\ell}(\bm{z}_{n})\end{bmatrix}roman_LN start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i , roman_ℓ end_POSTSUPERSCRIPT ( bold_italic_Z ) = roman_LN start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i , roman_ℓ end_POSTSUPERSCRIPT ( [ start_ARG start_ROW start_CELL bold_italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , bold_italic_z start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT end_CELL end_ROW end_ARG ] ) = [ start_ARG start_ROW start_CELL roman_LN start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i , roman_ℓ end_POSTSUPERSCRIPT ( bold_italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) , … , roman_LN start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i , roman_ℓ end_POSTSUPERSCRIPT ( bold_italic_z start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ) end_CELL end_ROW end_ARG ] (7.2.21)

    and has the form

    LNθi,(𝒛)=𝒛mean(𝒛)𝟏d𝒛mean(𝒛)𝟏d2𝜶i,+𝜷i,wheremean(𝒛)=1d𝟏d𝒛\operatorname{LN}_{\theta}^{i,\ell}(\bm{z})=\frac{\bm{z}-\operatorname{mean}(\bm{z})\bm{1}_{d}}{\|\bm{z}-\operatorname{mean}(\bm{z})\bm{1}_{d}\|_{2}}\mathbin{\mathchoice{\raisebox{1.3pt}{$\displaystyle\mathchoice{\scalebox{0.8}{$\displaystyle\odot$}}{\scalebox{0.8}{$\textstyle\odot$}}{\scalebox{0.8}{$\scriptstyle\odot$}}{\scalebox{0.8}{$\scriptscriptstyle\odot$}}$}}{\raisebox{1.3pt}{$\mathchoice{\scalebox{0.8}{$\displaystyle\odot$}}{\scalebox{0.8}{$\textstyle\odot$}}{\scalebox{0.8}{$\scriptstyle\odot$}}{\scalebox{0.8}{$\scriptscriptstyle\odot$}}$}}{\raisebox{0.75pt}{$\scriptstyle\mathchoice{\scalebox{0.8}{$\displaystyle\odot$}}{\scalebox{0.8}{$\textstyle\odot$}}{\scalebox{0.8}{$\scriptstyle\odot$}}{\scalebox{0.8}{$\scriptscriptstyle\odot$}}$}}{\raisebox{0.6pt}{$\scriptscriptstyle\mathchoice{\scalebox{0.8}{$\displaystyle\odot$}}{\scalebox{0.8}{$\textstyle\odot$}}{\scalebox{0.8}{$\scriptstyle\odot$}}{\scalebox{0.8}{$\scriptscriptstyle\odot$}}$}}}\bm{\alpha}^{i,\ell}+\bm{\beta}^{i,\ell}\qquad\text{where}\qquad\operatorname{mean}(\bm{z})=\frac{1}{d}\bm{1}_{d}^{\top}\bm{z}roman_LN start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_i , roman_ℓ end_POSTSUPERSCRIPT ( bold_italic_z ) = divide start_ARG bold_italic_z - roman_mean ( bold_italic_z ) bold_1 start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT end_ARG start_ARG ∥ bold_italic_z - roman_mean ( bold_italic_z ) bold_1 start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_ARG ⊙ bold_italic_α start_POSTSUPERSCRIPT italic_i , roman_ℓ end_POSTSUPERSCRIPT + bold_italic_β start_POSTSUPERSCRIPT italic_i , roman_ℓ end_POSTSUPERSCRIPT where roman_mean ( bold_italic_z ) = divide start_ARG 1 end_ARG start_ARG italic_d end_ARG bold_1 start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_z (7.2.22)

    where \mathbin{\mathchoice{\raisebox{1.3pt}{$\displaystyle\mathchoice{\scalebox{0.8}{$\displaystyle\odot$}}{\scalebox{0.8}{$\textstyle\odot$}}{\scalebox{0.8}{$\scriptstyle\odot$}}{\scalebox{0.8}{$\scriptscriptstyle\odot$}}$}}{\raisebox{1.3pt}{$\mathchoice{\scalebox{0.8}{$\displaystyle\odot$}}{\scalebox{0.8}{$\textstyle\odot$}}{\scalebox{0.8}{$\scriptstyle\odot$}}{\scalebox{0.8}{$\scriptscriptstyle\odot$}}$}}{\raisebox{0.75pt}{$\scriptstyle\mathchoice{\scalebox{0.8}{$\displaystyle\odot$}}{\scalebox{0.8}{$\textstyle\odot$}}{\scalebox{0.8}{$\scriptstyle\odot$}}{\scalebox{0.8}{$\scriptscriptstyle\odot$}}$}}{\raisebox{0.6pt}{$\scriptscriptstyle\mathchoice{\scalebox{0.8}{$\displaystyle\odot$}}{\scalebox{0.8}{$\textstyle\odot$}}{\scalebox{0.8}{$\scriptstyle\odot$}}{\scalebox{0.8}{$\scriptscriptstyle\odot$}}$}}} denotes element-wise multiplication, and 𝜶i,,𝜷i,d\bm{\alpha}^{i,\ell},\bm{\beta}^{i,\ell}\in\mathbb{R}^{d}bold_italic_α start_POSTSUPERSCRIPT italic_i , roman_ℓ end_POSTSUPERSCRIPT , bold_italic_β start_POSTSUPERSCRIPT italic_i , roman_ℓ end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT are trainable parameters contained in the parameter set θ\thetaitalic_θ. The layer-norm operator serves as a sort of normalization on each token, where the scale of each token afterwards is learnable and shared amongst all tokens.

The transformer is one of the most popular neural network architectures in history, powering applications in almost all fields of deep learning.

Feature extractor.

We use a post-processing step fθextf_{\theta}^{\mathrm{ext}}italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ext end_POSTSUPERSCRIPT which extracts the class token feature, which (recall) is the feature meant to contain aggregate information about the input image, and applies an MLP and normalization to it. Namely, we have

𝒛θ(𝑿)fθext(𝒁θ(𝑿))=fθext([𝒛θ1(𝑿),,𝒛θn(𝑿)])MLPθext(𝒛θ1(𝑿))MLPθext(𝒛θ1(𝑿))2.\bm{z}_{\theta}(\bm{X})\doteq f_{\theta}^{\mathrm{ext}}(\bm{Z}_{\theta}(\bm{X}))=f_{\theta}^{\mathrm{ext}}([\bm{z}_{\theta}^{1}(\bm{X}),\dots,\bm{z}_{\theta}^{n}(\bm{X})])\doteq\frac{\operatorname{MLP}_{\theta}^{\mathrm{ext}}(\bm{z}_{\theta}^{1}(\bm{X}))}{\|\operatorname{MLP}_{\theta}^{\mathrm{ext}}(\bm{z}_{\theta}^{1}(\bm{X}))\|_{2}}.bold_italic_z start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_X ) ≐ italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ext end_POSTSUPERSCRIPT ( bold_italic_Z start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_X ) ) = italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ext end_POSTSUPERSCRIPT ( [ bold_italic_z start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT ( bold_italic_X ) , … , bold_italic_z start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ( bold_italic_X ) ] ) ≐ divide start_ARG roman_MLP start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ext end_POSTSUPERSCRIPT ( bold_italic_z start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT ( bold_italic_X ) ) end_ARG start_ARG ∥ roman_MLP start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ext end_POSTSUPERSCRIPT ( bold_italic_z start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT ( bold_italic_X ) ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_ARG . (7.2.23)
Task-specific (“DINO”) head.

For DINO, we use the task-specific DINO head h𝑾,𝝁h_{\bm{W},\bm{\mu}}italic_h start_POSTSUBSCRIPT bold_italic_W , bold_italic_μ end_POSTSUBSCRIPT. For SimDINO, we use no task-specific head at all, as previously described.

7.2.4 Optimization Strategy

Figure 7.8 : The DINO pipeline. Student features and teacher features are computed for each input. The objective attempts to align the student features with the teacher features by projecting both sets of features into a high-dimensional probability simplex and computing a cross-entropy loss. Notably, because of the “stop-grad”, the gradient is only computed w.r.t. the student parameters’ outputs .
Figure 7.8: The DINO pipeline. Student features and teacher features are computed for each input. The objective attempts to align the student features with the teacher features by projecting both sets of features into a high-dimensional probability simplex and computing a cross-entropy loss. Notably, because of the “stop-grad”, the gradient is only computed w.r.t. the student parameters’ outputs.
Optimizing DINO.

We have a loss function and an architecture, so we now discuss the optimization strategy. The optimization strategy for DINO uses two sets of weights for the same architecture: student weights θs\theta_{\mathrm{s}}italic_θ start_POSTSUBSCRIPT roman_s end_POSTSUBSCRIPT and teacher weights θt\theta_{\mathrm{t}}italic_θ start_POSTSUBSCRIPT roman_t end_POSTSUBSCRIPT. These correspond to two different neural networks, called the teacher network and student network, with the same architecture. The teacher network encodes all global views, while the student network encodes all “other” views. The goal of the loss is to distill teacher outputs into the student model. Namely, we train on the loss DINOst\mathcal{L}_{\mathrm{DINO}{}-\mathrm{s}\mathrm{t}}caligraphic_L start_POSTSUBSCRIPT roman_DINO - roman_st end_POSTSUBSCRIPT:

DINOst(θs,θt,𝑾s,𝑾t,𝝁)𝔼[dCE(𝒑θt,𝑾t,𝝁(𝑿g),𝒑θs,𝑾s(𝑿c))].\mathcal{L}_{\mathrm{DINO}{}-\mathrm{s}\mathrm{t}}(\theta_{\mathrm{s}},\theta_{\mathrm{t}},\bm{W}_{\mathrm{s}},\bm{W}_{\mathrm{t}},\bm{\mu})\doteq\operatorname{\mathbb{E}}[d_{\operatorname{CE}}(\bm{p}_{\theta_{\mathrm{t}},\bm{W}_{\mathrm{t}},\bm{\mu}}(\bm{X}_{g}),\bm{p}_{\theta_{\mathrm{s}},\bm{W}_{\mathrm{s}}}(\bm{X}_{c}))].caligraphic_L start_POSTSUBSCRIPT roman_DINO - roman_st end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT roman_s end_POSTSUBSCRIPT , italic_θ start_POSTSUBSCRIPT roman_t end_POSTSUBSCRIPT , bold_italic_W start_POSTSUBSCRIPT roman_s end_POSTSUBSCRIPT , bold_italic_W start_POSTSUBSCRIPT roman_t end_POSTSUBSCRIPT , bold_italic_μ ) ≐ blackboard_E [ italic_d start_POSTSUBSCRIPT roman_CE end_POSTSUBSCRIPT ( bold_italic_p start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT roman_t end_POSTSUBSCRIPT , bold_italic_W start_POSTSUBSCRIPT roman_t end_POSTSUBSCRIPT , bold_italic_μ end_POSTSUBSCRIPT ( bold_italic_X start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT ) , bold_italic_p start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT roman_s end_POSTSUBSCRIPT , bold_italic_W start_POSTSUBSCRIPT roman_s end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( bold_italic_X start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT ) ) ] . (7.2.24)

Now, we can fully describe the overall pipeline of DINO, depicted in Figure 7.8.

While it is easy to reason about (7.2.24), it is impossible in practice to implement optimization algorithms such as gradient descent with a loss given by DINOst\mathcal{L}_{\mathrm{DINO}{}-\mathrm{s}\mathrm{t}}caligraphic_L start_POSTSUBSCRIPT roman_DINO - roman_st end_POSTSUBSCRIPT. This is because the expectations in the loss are impossible to evaluate, much less to take the gradient of. In this extremely frequent case, we approximate the expectation via finite samples. That is, at each timestep kkitalic_k we:

  • Subsample BBitalic_B data points from our dataset {𝑿1(k),,𝑿B(k)}\{\bm{X}_{1}^{(k)},\dots,\bm{X}_{B}^{(k)}\}\subset\mathcal{I}{ bold_italic_X start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT , … , bold_italic_X start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT } ⊂ caligraphic_I.

  • For each data point 𝑿b(k)\bm{X}_{b}^{(k)}bold_italic_X start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT, sample MgloM_{\mathrm{glo}}italic_M start_POSTSUBSCRIPT roman_glo end_POSTSUBSCRIPT global views vb,g(k),iv_{b,g}^{(k),i}italic_v start_POSTSUBSCRIPT italic_b , italic_g end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) , italic_i end_POSTSUPERSCRIPT and MlocM_{\mathrm{loc}}italic_M start_POSTSUBSCRIPT roman_loc end_POSTSUBSCRIPT local views vb,(k),iv_{b,\ell}^{(k),i}italic_v start_POSTSUBSCRIPT italic_b , roman_ℓ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) , italic_i end_POSTSUPERSCRIPT. Apply the views to 𝑿b(k)\bm{X}_{b}^{(k)}bold_italic_X start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT to obtain 𝑿b,g(k),ivb,g(k),i(𝑿b(k))\bm{X}_{b,g}^{(k),i}\doteq v_{b,g}^{(k),i}(\bm{X}_{b}^{(k)})bold_italic_X start_POSTSUBSCRIPT italic_b , italic_g end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) , italic_i end_POSTSUPERSCRIPT ≐ italic_v start_POSTSUBSCRIPT italic_b , italic_g end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) , italic_i end_POSTSUPERSCRIPT ( bold_italic_X start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT ) and 𝑿b,(k),ivb,(k),i(𝑿b(k))\bm{X}_{b,\ell}^{(k),i}\doteq v_{b,\ell}^{(k),i}(\bm{X}_{b}^{(k)})bold_italic_X start_POSTSUBSCRIPT italic_b , roman_ℓ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) , italic_i end_POSTSUPERSCRIPT ≐ italic_v start_POSTSUBSCRIPT italic_b , roman_ℓ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) , italic_i end_POSTSUPERSCRIPT ( bold_italic_X start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT ).

  • For each local view 𝑿b,(k),i\bm{X}_{b,\ell}^{(k),i}bold_italic_X start_POSTSUBSCRIPT italic_b , roman_ℓ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) , italic_i end_POSTSUPERSCRIPT, compute the following quantities:

    𝒛θs(𝑿b,(k),i)(fθsextfθs)(𝑿b,(k),i),𝒑θs,𝑾s(𝑿b,(k),i)h𝑾s,𝟎m(𝒛θs(𝑿b,(k),i(θ)))\bm{z}_{\theta_{\mathrm{s}}}(\bm{X}_{b,\ell}^{(k),i})\doteq(f_{\theta_{\mathrm{s}}}^{\mathrm{ext}}\circ f_{\theta_{\mathrm{s}}})(\bm{X}_{b,\ell}^{(k),i}),\qquad\bm{p}_{\theta_{\mathrm{s}},\bm{W}_{\mathrm{s}}}(\bm{X}_{b,\ell}^{(k),i})\doteq h_{\bm{W}_{\mathrm{s}},\bm{0}_{m}}(\bm{z}_{\theta_{\mathrm{s}}}(\bm{X}_{b,\ell}^{(k),i}(\theta)))bold_italic_z start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT roman_s end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( bold_italic_X start_POSTSUBSCRIPT italic_b , roman_ℓ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) , italic_i end_POSTSUPERSCRIPT ) ≐ ( italic_f start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT roman_s end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ext end_POSTSUPERSCRIPT ∘ italic_f start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT roman_s end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) ( bold_italic_X start_POSTSUBSCRIPT italic_b , roman_ℓ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) , italic_i end_POSTSUPERSCRIPT ) , bold_italic_p start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT roman_s end_POSTSUBSCRIPT , bold_italic_W start_POSTSUBSCRIPT roman_s end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( bold_italic_X start_POSTSUBSCRIPT italic_b , roman_ℓ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) , italic_i end_POSTSUPERSCRIPT ) ≐ italic_h start_POSTSUBSCRIPT bold_italic_W start_POSTSUBSCRIPT roman_s end_POSTSUBSCRIPT , bold_0 start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( bold_italic_z start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT roman_s end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( bold_italic_X start_POSTSUBSCRIPT italic_b , roman_ℓ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) , italic_i end_POSTSUPERSCRIPT ( italic_θ ) ) ) (7.2.25)

    and for each global view 𝑿b,g(k),i\bm{X}_{b,g}^{(k),i}bold_italic_X start_POSTSUBSCRIPT italic_b , italic_g end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) , italic_i end_POSTSUPERSCRIPT, compute the following quantities (by an abuse of notation):

    𝒛θs(𝑿b,g(k),i)(fθsextfθs)(𝑿b,g(k),i),𝒑θs,𝑾s(𝑿b,g(k),i)h𝑾s,𝟎m(𝒛θs(𝑿b,g(k),i)),\displaystyle\bm{z}_{\theta_{\mathrm{s}}}(\bm{X}_{b,g}^{(k),i})\doteq(f_{\theta_{\mathrm{s}}}^{\mathrm{ext}}\circ f_{\theta_{\mathrm{s}}})(\bm{X}_{b,g}^{(k),i}),\qquad\bm{p}_{\theta_{\mathrm{s}},\bm{W}_{\mathrm{s}}}(\bm{X}_{b,g}^{(k),i})\doteq h_{\bm{W}_{\mathrm{s}},\bm{0}_{m}}(\bm{z}_{\theta_{\mathrm{s}}}(\bm{X}_{b,g}^{(k),i})),bold_italic_z start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT roman_s end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( bold_italic_X start_POSTSUBSCRIPT italic_b , italic_g end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) , italic_i end_POSTSUPERSCRIPT ) ≐ ( italic_f start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT roman_s end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ext end_POSTSUPERSCRIPT ∘ italic_f start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT roman_s end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) ( bold_italic_X start_POSTSUBSCRIPT italic_b , italic_g end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) , italic_i end_POSTSUPERSCRIPT ) , bold_italic_p start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT roman_s end_POSTSUBSCRIPT , bold_italic_W start_POSTSUBSCRIPT roman_s end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( bold_italic_X start_POSTSUBSCRIPT italic_b , italic_g end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) , italic_i end_POSTSUPERSCRIPT ) ≐ italic_h start_POSTSUBSCRIPT bold_italic_W start_POSTSUBSCRIPT roman_s end_POSTSUBSCRIPT , bold_0 start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( bold_italic_z start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT roman_s end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( bold_italic_X start_POSTSUBSCRIPT italic_b , italic_g end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) , italic_i end_POSTSUPERSCRIPT ) ) , (7.2.26)
    𝒛θt(𝑿b,g(k),i)(fθtextfθt)(𝑿b,g(k),i),𝒑θt,𝑾t,𝝁(𝑿b,g(k),i)h𝑾t,𝝁(𝒁θt(𝑿b,g(k),i)).\displaystyle\bm{z}_{\theta_{\mathrm{t}}}(\bm{X}_{b,g}^{(k),i})\doteq(f_{\theta_{\mathrm{t}}}^{\mathrm{ext}}\circ f_{\theta_{\mathrm{t}}})(\bm{X}_{b,g}^{(k),i}),\qquad\bm{p}_{\theta_{\mathrm{t}},\bm{W}_{\mathrm{t}},\bm{\mu}}(\bm{X}_{b,g}^{(k),i})\doteq h_{\bm{W}_{\mathrm{t}},\bm{\mu}}(\bm{Z}_{\theta_{\mathrm{t}}}(\bm{X}_{b,g}^{(k),i})).bold_italic_z start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT roman_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( bold_italic_X start_POSTSUBSCRIPT italic_b , italic_g end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) , italic_i end_POSTSUPERSCRIPT ) ≐ ( italic_f start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT roman_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ext end_POSTSUPERSCRIPT ∘ italic_f start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT roman_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) ( bold_italic_X start_POSTSUBSCRIPT italic_b , italic_g end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) , italic_i end_POSTSUPERSCRIPT ) , bold_italic_p start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT roman_t end_POSTSUBSCRIPT , bold_italic_W start_POSTSUBSCRIPT roman_t end_POSTSUBSCRIPT , bold_italic_μ end_POSTSUBSCRIPT ( bold_italic_X start_POSTSUBSCRIPT italic_b , italic_g end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) , italic_i end_POSTSUPERSCRIPT ) ≐ italic_h start_POSTSUBSCRIPT bold_italic_W start_POSTSUBSCRIPT roman_t end_POSTSUBSCRIPT , bold_italic_μ end_POSTSUBSCRIPT ( bold_italic_Z start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT roman_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( bold_italic_X start_POSTSUBSCRIPT italic_b , italic_g end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) , italic_i end_POSTSUPERSCRIPT ) ) . (7.2.27)
  • Compute the surrogate, approximate loss ^DINOst(k)\hat{\mathcal{L}}_{\mathrm{DINO}-\mathrm{s}\mathrm{t}}^{(k)}over^ start_ARG caligraphic_L end_ARG start_POSTSUBSCRIPT roman_DINO - roman_st end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT, defined as follows:

    ^DINOst(k)(θs,θt,𝑾s,𝑾t,𝝁)1BMglo(Mglo+Mloc1)b=1Bi=1Mglo\displaystyle\hat{\mathcal{L}}_{\mathrm{DINO}{}-\mathrm{s}\mathrm{t}}^{(k)}(\theta_{\mathrm{s}},\theta_{\mathrm{t}},\bm{W}_{\mathrm{s}},\bm{W}_{\mathrm{t}},\bm{\mu})\doteq\frac{1}{BM_{\mathrm{glo}}(M_{\mathrm{glo}}+M_{\mathrm{loc}}-1)}\sum_{b=1}^{B}\sum_{i=1}^{M_{\mathrm{glo}}}over^ start_ARG caligraphic_L end_ARG start_POSTSUBSCRIPT roman_DINO - roman_st end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT ( italic_θ start_POSTSUBSCRIPT roman_s end_POSTSUBSCRIPT , italic_θ start_POSTSUBSCRIPT roman_t end_POSTSUBSCRIPT , bold_italic_W start_POSTSUBSCRIPT roman_s end_POSTSUBSCRIPT , bold_italic_W start_POSTSUBSCRIPT roman_t end_POSTSUBSCRIPT , bold_italic_μ ) ≐ divide start_ARG 1 end_ARG start_ARG italic_B italic_M start_POSTSUBSCRIPT roman_glo end_POSTSUBSCRIPT ( italic_M start_POSTSUBSCRIPT roman_glo end_POSTSUBSCRIPT + italic_M start_POSTSUBSCRIPT roman_loc end_POSTSUBSCRIPT - 1 ) end_ARG ∑ start_POSTSUBSCRIPT italic_b = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_B end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M start_POSTSUBSCRIPT roman_glo end_POSTSUBSCRIPT end_POSTSUPERSCRIPT (7.2.28)
    [j=1MlocdCE(𝒑θt,𝑾t,𝝁(𝑿b,g(k),i),𝒑θs,𝑾s(𝑿b,(k),j))+j=1jiMglodCE(𝒑θt,𝑾t,𝝁(𝑿b,g(k),i),𝒑θs,𝑾s(𝑿b,g(k),j))]\displaystyle\Bigg{[}\sum_{j=1}^{M_{\mathrm{loc}}}d_{\operatorname{CE}}(\bm{p}_{\theta_{\mathrm{t}},\bm{W}_{\mathrm{t}},\bm{\mu}}(\bm{X}_{b,g}^{(k),i}),\bm{p}_{\theta_{\mathrm{s}},\bm{W}_{\mathrm{s}}}(\bm{X}_{b,\ell}^{(k),j}))+\sum_{\begin{subarray}{c}j=1\\ j\neq i\end{subarray}}^{M_{\mathrm{glo}}}d_{\operatorname{CE}}(\bm{p}_{\theta_{\mathrm{t}},\bm{W}_{\mathrm{t}},\bm{\mu}}(\bm{X}_{b,g}^{(k),i}),\bm{p}_{\theta_{\mathrm{s}},\bm{W}_{\mathrm{s}}}(\bm{X}_{b,g}^{(k),j}))\Bigg{]}[ ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M start_POSTSUBSCRIPT roman_loc end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT roman_CE end_POSTSUBSCRIPT ( bold_italic_p start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT roman_t end_POSTSUBSCRIPT , bold_italic_W start_POSTSUBSCRIPT roman_t end_POSTSUBSCRIPT , bold_italic_μ end_POSTSUBSCRIPT ( bold_italic_X start_POSTSUBSCRIPT italic_b , italic_g end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) , italic_i end_POSTSUPERSCRIPT ) , bold_italic_p start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT roman_s end_POSTSUBSCRIPT , bold_italic_W start_POSTSUBSCRIPT roman_s end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( bold_italic_X start_POSTSUBSCRIPT italic_b , roman_ℓ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) , italic_j end_POSTSUPERSCRIPT ) ) + ∑ start_POSTSUBSCRIPT start_ARG start_ROW start_CELL italic_j = 1 end_CELL end_ROW start_ROW start_CELL italic_j ≠ italic_i end_CELL end_ROW end_ARG end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M start_POSTSUBSCRIPT roman_glo end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT roman_CE end_POSTSUBSCRIPT ( bold_italic_p start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT roman_t end_POSTSUBSCRIPT , bold_italic_W start_POSTSUBSCRIPT roman_t end_POSTSUBSCRIPT , bold_italic_μ end_POSTSUBSCRIPT ( bold_italic_X start_POSTSUBSCRIPT italic_b , italic_g end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) , italic_i end_POSTSUPERSCRIPT ) , bold_italic_p start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT roman_s end_POSTSUBSCRIPT , bold_italic_W start_POSTSUBSCRIPT roman_s end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( bold_italic_X start_POSTSUBSCRIPT italic_b , italic_g end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) , italic_j end_POSTSUPERSCRIPT ) ) ]

    as well as its gradients with respect to θs\theta_{\mathrm{s}}italic_θ start_POSTSUBSCRIPT roman_s end_POSTSUBSCRIPT and 𝑾s\bm{W}_{\mathrm{s}}bold_italic_W start_POSTSUBSCRIPT roman_s end_POSTSUBSCRIPT, which should be computed under the assumption that θt\theta_{\mathrm{t}}italic_θ start_POSTSUBSCRIPT roman_t end_POSTSUBSCRIPT, 𝑾t\bm{W}_{\mathrm{t}}bold_italic_W start_POSTSUBSCRIPT roman_t end_POSTSUBSCRIPT, and 𝝁\bm{\mu}bold_italic_μ are constants — namely that they are detached from the computational graph and not dependent on θs\theta_{\mathrm{s}}italic_θ start_POSTSUBSCRIPT roman_s end_POSTSUBSCRIPT and 𝑾s\bm{W}_{\mathrm{s}}bold_italic_W start_POSTSUBSCRIPT roman_s end_POSTSUBSCRIPT.

  • Update the student parameters θs\theta_{\mathrm{s}}italic_θ start_POSTSUBSCRIPT roman_s end_POSTSUBSCRIPT and 𝑾s\bm{W}_{\mathrm{s}}bold_italic_W start_POSTSUBSCRIPT roman_s end_POSTSUBSCRIPT via an iterative gradient-based optimization algorithm, and update θt\theta_{\mathrm{t}}italic_θ start_POSTSUBSCRIPT roman_t end_POSTSUBSCRIPT, 𝑾t\bm{W}_{\mathrm{t}}bold_italic_W start_POSTSUBSCRIPT roman_t end_POSTSUBSCRIPT, and 𝝁\bm{\mu}bold_italic_μ via exponential moving averages with decay parameters ν(k)\nu^{(k)}italic_ν start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT, ν(k)\nu^{(k)}italic_ν start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT, and ρ(k)\rho^{(k)}italic_ρ start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT respectively, i.e.,

    (θs(k+1),𝑾s(k+1))\displaystyle(\theta_{\mathrm{s}}^{(k+1)},\bm{W}_{\mathrm{s}}^{(k+1)})( italic_θ start_POSTSUBSCRIPT roman_s end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k + 1 ) end_POSTSUPERSCRIPT , bold_italic_W start_POSTSUBSCRIPT roman_s end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k + 1 ) end_POSTSUPERSCRIPT ) =OptUpdate(k)(θs(k),𝑾s(k);(θs,𝑾s)^DINOst(k))\displaystyle=\textsc{OptUpdate}^{(k)}(\theta_{\mathrm{s}}^{(k)},\bm{W}_{\mathrm{s}}^{(k)};\nabla_{(\theta_{\mathrm{s}},\bm{W}_{\mathrm{s}})}\hat{\mathcal{L}}_{\mathrm{DINO}-\mathrm{s}\mathrm{t}}^{(k)})= OptUpdate start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT ( italic_θ start_POSTSUBSCRIPT roman_s end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT , bold_italic_W start_POSTSUBSCRIPT roman_s end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT ; ∇ start_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT roman_s end_POSTSUBSCRIPT , bold_italic_W start_POSTSUBSCRIPT roman_s end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT over^ start_ARG caligraphic_L end_ARG start_POSTSUBSCRIPT roman_DINO - roman_st end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT ) (7.2.29)
    θt(k+1)\displaystyle\theta_{\mathrm{t}}^{(k+1)}italic_θ start_POSTSUBSCRIPT roman_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k + 1 ) end_POSTSUPERSCRIPT =ν(k)θt(k)+(1ν(k))θs(k+1)\displaystyle=\nu^{(k)}\theta_{\mathrm{t}}^{(k)}+(1-\nu^{(k)})\theta_{\mathrm{s}}^{(k+1)}= italic_ν start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT roman_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT + ( 1 - italic_ν start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT ) italic_θ start_POSTSUBSCRIPT roman_s end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k + 1 ) end_POSTSUPERSCRIPT (7.2.30)
    𝑾t(k+1)\displaystyle\bm{W}_{\mathrm{t}}^{(k+1)}bold_italic_W start_POSTSUBSCRIPT roman_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k + 1 ) end_POSTSUPERSCRIPT =ν(k)𝑾t(k)+(1ν(k))𝑾s(k+1)\displaystyle=\nu^{(k)}\bm{W}_{\mathrm{t}}^{(k)}+(1-\nu^{(k)})\bm{W}_{\mathrm{s}}^{(k+1)}= italic_ν start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT bold_italic_W start_POSTSUBSCRIPT roman_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT + ( 1 - italic_ν start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT ) bold_italic_W start_POSTSUBSCRIPT roman_s end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k + 1 ) end_POSTSUPERSCRIPT (7.2.31)
    𝝁(k+1)\displaystyle\bm{\mu}^{(k+1)}bold_italic_μ start_POSTSUPERSCRIPT ( italic_k + 1 ) end_POSTSUPERSCRIPT =ρ(k)𝝁(k)+(1ρ(k))1BMglob=1Bi=1Mglo𝑾(k)𝒛θt(𝑿b,g(k),i),\displaystyle=\rho^{(k)}\bm{\mu}^{(k)}+(1-\rho^{(k)})\cdot\frac{1}{BM_{\mathrm{glo}}}\sum_{b=1}^{B}\sum_{i=1}^{M_{\mathrm{glo}}}\bm{W}^{(k)}\bm{z}_{\theta_{\mathrm{t}}}(\bm{X}_{b,g}^{(k),i}),= italic_ρ start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT bold_italic_μ start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT + ( 1 - italic_ρ start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT ) ⋅ divide start_ARG 1 end_ARG start_ARG italic_B italic_M start_POSTSUBSCRIPT roman_glo end_POSTSUBSCRIPT end_ARG ∑ start_POSTSUBSCRIPT italic_b = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_B end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M start_POSTSUBSCRIPT roman_glo end_POSTSUBSCRIPT end_POSTSUPERSCRIPT bold_italic_W start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT bold_italic_z start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT roman_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( bold_italic_X start_POSTSUBSCRIPT italic_b , italic_g end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) , italic_i end_POSTSUPERSCRIPT ) , (7.2.32)

    For example, if the chosen optimization algorithm were stochastic gradient descent, we would have the update θs(k+1)θs(k)δ(k)θs^DINOst(k)\theta_{\mathrm{s}}^{(k+1)}\doteq\theta_{\mathrm{s}}^{(k)}-\delta^{(k)}\nabla_{\theta_{\mathrm{s}}}\hat{\mathcal{L}}_{\mathrm{DINO}{}-\mathrm{s}\mathrm{t}}^{(k)}italic_θ start_POSTSUBSCRIPT roman_s end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k + 1 ) end_POSTSUPERSCRIPT ≐ italic_θ start_POSTSUBSCRIPT roman_s end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT - italic_δ start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT ∇ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT roman_s end_POSTSUBSCRIPT end_POSTSUBSCRIPT over^ start_ARG caligraphic_L end_ARG start_POSTSUBSCRIPT roman_DINO - roman_st end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT, and so on.

Notice that the optimization procedure is rather irregular: although all four parameters change at each iteration, only two of them are directly updated from a gradient-based method. The other two are updated from exponential moving averages, and indeed treated as constants when computing any gradients. After training, we discard the student weights and use the teacher weights for our trained network ffitalic_f, as this exponentially moving average has been empirically shown to stabilize the resulting model (this idea is known as Polyak averaging or iterate averaging).

The way that ν\nuitalic_ν and ρ\rhoitalic_ρ change over the optimization trajectory (i.e., the functions kν(k)k\mapsto\nu^{(k)}italic_k ↦ italic_ν start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT and kρ(k)k\mapsto\rho^{(k)}italic_k ↦ italic_ρ start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT) are hyperparameters or design decisions, with ν(1)<1\nu^{(1)}<1italic_ν start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT < 1 and limkν(k)=1\lim_{k\to\infty}\nu^{(k)}=1roman_lim start_POSTSUBSCRIPT italic_k → ∞ end_POSTSUBSCRIPT italic_ν start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT = 1 usually, and similar for ρ\rhoitalic_ρ. The temperature hyperparameter τ\tauitalic_τ, used in the DINO head h𝑾,𝝁h_{\bm{W},\bm{\mu}}italic_h start_POSTSUBSCRIPT bold_italic_W , bold_italic_μ end_POSTSUBSCRIPT, also changes over the optimization trajectory (though this dependence is not explicitly notated).

Using the surrogate (“empirical”) loss transforms our intractable optimization problem, as in optimizing the loss in (7.2.24), into a tractable stochastic optimization problem which is run to train essentially every deep learning model in the world. This conversion is extremely natural once you have seen some examples, and we will hopefully give these examples throughout the chapter.

Figure 7.9 : The SimDINO pipeline. Here, in contrast to the DINO pipeline in Figure 7.8 , the loss is computed directly on the features without need of further manipulation. This shaves off several large matrices’ worth of parameters and simplifies the pipeline, simultaneously making it more stable to train.
Figure 7.9: The SimDINO pipeline. Here, in contrast to the DINO pipeline in Figure 7.8, the loss is computed directly on the features without need of further manipulation. This shaves off several large matrices’ worth of parameters and simplifies the pipeline, simultaneously making it more stable to train.
Optimizing SimDINO.

The simplified DINO population-level objective is very similar in spirit but much simpler in execution, i.e.,

SimDINOst(θs,θt)𝔼[d2(𝒛θt(𝑿g),𝒛θs(𝑿c))]γ2logdet(𝑰+dε2Cov(𝒛θs(𝑿g)))).\mathcal{L}_{\mathrm{SimDINO}-\mathrm{s}\mathrm{t}}(\theta_{\mathrm{s}},\theta_{\mathrm{t}})\doteq\operatorname{\mathbb{E}}\left[d_{\ell^{2}}(\bm{z}_{\theta_{\mathrm{t}}}(\bm{X}_{g}),\bm{z}_{\theta_{\mathrm{s}}}(\bm{X}_{c}))\right]-\frac{\gamma}{2}\log\det\left(\bm{I}+\frac{d}{\varepsilon^{2}}\operatorname{Cov}(\bm{z}_{\theta_{\mathrm{s}}}(\bm{X}_{g})))\right).caligraphic_L start_POSTSUBSCRIPT roman_SimDINO - roman_st end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT roman_s end_POSTSUBSCRIPT , italic_θ start_POSTSUBSCRIPT roman_t end_POSTSUBSCRIPT ) ≐ blackboard_E [ italic_d start_POSTSUBSCRIPT roman_ℓ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( bold_italic_z start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT roman_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( bold_italic_X start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT ) , bold_italic_z start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT roman_s end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( bold_italic_X start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT ) ) ] - divide start_ARG italic_γ end_ARG start_ARG 2 end_ARG roman_log roman_det ( bold_italic_I + divide start_ARG italic_d end_ARG start_ARG italic_ε start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG roman_Cov ( bold_italic_z start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT roman_s end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( bold_italic_X start_POSTSUBSCRIPT italic_g end_POSTSUBSCRIPT ) ) ) ) . (7.2.33)

Thus, as elaborated in Figure 7.9, the SimDINO pipeline is strictly simpler than the DINO pipeline. We can use a simpler version of the DINO training pipeline to optimize SimDINO. At each timestep kkitalic_k, we:

  • Subsample BBitalic_B data points from our dataset {𝑿1(k),,𝑿B(k)}\{\bm{X}_{1}^{(k)},\dots,\bm{X}_{B}^{(k)}\}\subset\mathcal{I}{ bold_italic_X start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT , … , bold_italic_X start_POSTSUBSCRIPT italic_B end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT } ⊂ caligraphic_I.

  • For each data point 𝑿b(k)\bm{X}_{b}^{(k)}bold_italic_X start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT, sample MgloM_{\mathrm{glo}}italic_M start_POSTSUBSCRIPT roman_glo end_POSTSUBSCRIPT global views vb,g(k),iv_{b,g}^{(k),i}italic_v start_POSTSUBSCRIPT italic_b , italic_g end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) , italic_i end_POSTSUPERSCRIPT and MlocM_{\mathrm{loc}}italic_M start_POSTSUBSCRIPT roman_loc end_POSTSUBSCRIPT local views vb,(k),iv_{b,\ell}^{(k),i}italic_v start_POSTSUBSCRIPT italic_b , roman_ℓ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) , italic_i end_POSTSUPERSCRIPT. Apply the views to 𝑿b(k)\bm{X}_{b}^{(k)}bold_italic_X start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT to obtain 𝑿b,g(k),ivb,g(k),i(𝑿b(k))\bm{X}_{b,g}^{(k),i}\doteq v_{b,g}^{(k),i}(\bm{X}_{b}^{(k)})bold_italic_X start_POSTSUBSCRIPT italic_b , italic_g end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) , italic_i end_POSTSUPERSCRIPT ≐ italic_v start_POSTSUBSCRIPT italic_b , italic_g end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) , italic_i end_POSTSUPERSCRIPT ( bold_italic_X start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT ) and 𝑿b,(k),ivb,(k),i(𝑿b(k))\bm{X}_{b,\ell}^{(k),i}\doteq v_{b,\ell}^{(k),i}(\bm{X}_{b}^{(k)})bold_italic_X start_POSTSUBSCRIPT italic_b , roman_ℓ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) , italic_i end_POSTSUPERSCRIPT ≐ italic_v start_POSTSUBSCRIPT italic_b , roman_ℓ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) , italic_i end_POSTSUPERSCRIPT ( bold_italic_X start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT ).

  • For each local view 𝑿b,(k),i\bm{X}_{b,\ell}^{(k),i}bold_italic_X start_POSTSUBSCRIPT italic_b , roman_ℓ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) , italic_i end_POSTSUPERSCRIPT compute 𝒛θs(𝑿b,(k),i)(fθsextfθs)(𝑿b,(k),i)\bm{z}_{\theta_{\mathrm{s}}}(\bm{X}_{b,\ell}^{(k),i})\doteq(f_{\theta_{\mathrm{s}}}^{\mathrm{ext}}\circ f_{\theta_{\mathrm{s}}})(\bm{X}_{b,\ell}^{(k),i})bold_italic_z start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT roman_s end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( bold_italic_X start_POSTSUBSCRIPT italic_b , roman_ℓ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) , italic_i end_POSTSUPERSCRIPT ) ≐ ( italic_f start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT roman_s end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ext end_POSTSUPERSCRIPT ∘ italic_f start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT roman_s end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) ( bold_italic_X start_POSTSUBSCRIPT italic_b , roman_ℓ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) , italic_i end_POSTSUPERSCRIPT ). For each global view 𝑿b,g(k),i\bm{X}_{b,g}^{(k),i}bold_italic_X start_POSTSUBSCRIPT italic_b , italic_g end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) , italic_i end_POSTSUPERSCRIPT compute 𝒛θs(𝑿b,g(k),i)(fθsextfθs)(𝑿b,g(k),i)\bm{z}_{\theta_{\mathrm{s}}}(\bm{X}_{b,g}^{(k),i})\doteq(f_{\theta_{\mathrm{s}}}^{\mathrm{ext}}\circ f_{\theta_{\mathrm{s}}})(\bm{X}_{b,g}^{(k),i})bold_italic_z start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT roman_s end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( bold_italic_X start_POSTSUBSCRIPT italic_b , italic_g end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) , italic_i end_POSTSUPERSCRIPT ) ≐ ( italic_f start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT roman_s end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ext end_POSTSUPERSCRIPT ∘ italic_f start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT roman_s end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) ( bold_italic_X start_POSTSUBSCRIPT italic_b , italic_g end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) , italic_i end_POSTSUPERSCRIPT ) and 𝒛θt(𝑿b,g(k),i)(fθtextfθt)(𝑿b,g(k),i)\bm{z}_{\theta_{\mathrm{t}}}(\bm{X}_{b,g}^{(k),i})\doteq(f_{\theta_{\mathrm{t}}}^{\mathrm{ext}}\circ f_{\theta_{\mathrm{t}}})(\bm{X}_{b,g}^{(k),i})bold_italic_z start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT roman_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( bold_italic_X start_POSTSUBSCRIPT italic_b , italic_g end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) , italic_i end_POSTSUPERSCRIPT ) ≐ ( italic_f start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT roman_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ext end_POSTSUPERSCRIPT ∘ italic_f start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT roman_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) ( bold_italic_X start_POSTSUBSCRIPT italic_b , italic_g end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) , italic_i end_POSTSUPERSCRIPT ).

  • Compute the surrogate, approximate loss ^SimDINOst(k)\hat{\mathcal{L}}_{\mathrm{SimDINO}-\mathrm{s}\mathrm{t}}^{(k)}over^ start_ARG caligraphic_L end_ARG start_POSTSUBSCRIPT roman_SimDINO - roman_st end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT, defined as follows:

    ^SimDINOst(k)(θs,θt)1BMglo(Mglo+Mloc1)b=1Bi=1Mglo[j=1Mlocd2(𝒛θt(𝑿b,g(k),i),𝒛θs(𝑿b,(k),j))\displaystyle\hat{\mathcal{L}}_{\mathrm{SimDINO}{}-\mathrm{s}\mathrm{t}}^{(k)}(\theta_{\mathrm{s}},\theta_{\mathrm{t}})\doteq\frac{1}{BM_{\mathrm{glo}}(M_{\mathrm{glo}}+M_{\mathrm{loc}}-1)}\sum_{b=1}^{B}\sum_{i=1}^{M_{\mathrm{glo}}}\Bigg{[}\sum_{j=1}^{M_{\mathrm{loc}}}d_{\ell^{2}}(\bm{z}_{\theta_{\mathrm{t}}}(\bm{X}_{b,g}^{(k),i}),\bm{z}_{\theta_{\mathrm{s}}}(\bm{X}_{b,\ell}^{(k),j}))over^ start_ARG caligraphic_L end_ARG start_POSTSUBSCRIPT roman_SimDINO - roman_st end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT ( italic_θ start_POSTSUBSCRIPT roman_s end_POSTSUBSCRIPT , italic_θ start_POSTSUBSCRIPT roman_t end_POSTSUBSCRIPT ) ≐ divide start_ARG 1 end_ARG start_ARG italic_B italic_M start_POSTSUBSCRIPT roman_glo end_POSTSUBSCRIPT ( italic_M start_POSTSUBSCRIPT roman_glo end_POSTSUBSCRIPT + italic_M start_POSTSUBSCRIPT roman_loc end_POSTSUBSCRIPT - 1 ) end_ARG ∑ start_POSTSUBSCRIPT italic_b = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_B end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M start_POSTSUBSCRIPT roman_glo end_POSTSUBSCRIPT end_POSTSUPERSCRIPT [ ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M start_POSTSUBSCRIPT roman_loc end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT roman_ℓ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( bold_italic_z start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT roman_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( bold_italic_X start_POSTSUBSCRIPT italic_b , italic_g end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) , italic_i end_POSTSUPERSCRIPT ) , bold_italic_z start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT roman_s end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( bold_italic_X start_POSTSUBSCRIPT italic_b , roman_ℓ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) , italic_j end_POSTSUPERSCRIPT ) ) (7.2.34)
    +j=1Mglod2(𝒛θt(𝑿b,g(k),i),𝒛θs(𝑿b,g(k),j))]γMgloi=1MgloRε([𝒛θs(𝑿1,g(k),i),,𝒛θs(𝑿B,g(k),i)])\displaystyle\qquad\qquad+\sum_{j=1}^{M_{\mathrm{glo}}}d_{\ell^{2}}(\bm{z}_{\theta_{\mathrm{t}}}(\bm{X}_{b,g}^{(k),i}),\bm{z}_{\theta_{\mathrm{s}}}(\bm{X}_{b,g}^{(k),j}))\Bigg{]}-\frac{\gamma}{M_{\mathrm{glo}}}\sum_{i=1}^{M_{\mathrm{glo}}}R_{\varepsilon}([\bm{z}_{\theta_{\mathrm{s}}}(\bm{X}_{1,g}^{(k),i}),\dots,\bm{z}_{\theta_{\mathrm{s}}}(\bm{X}_{B,g}^{(k),i})])+ ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M start_POSTSUBSCRIPT roman_glo end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT roman_ℓ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( bold_italic_z start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT roman_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( bold_italic_X start_POSTSUBSCRIPT italic_b , italic_g end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) , italic_i end_POSTSUPERSCRIPT ) , bold_italic_z start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT roman_s end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( bold_italic_X start_POSTSUBSCRIPT italic_b , italic_g end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) , italic_j end_POSTSUPERSCRIPT ) ) ] - divide start_ARG italic_γ end_ARG start_ARG italic_M start_POSTSUBSCRIPT roman_glo end_POSTSUBSCRIPT end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_M start_POSTSUBSCRIPT roman_glo end_POSTSUBSCRIPT end_POSTSUPERSCRIPT italic_R start_POSTSUBSCRIPT italic_ε end_POSTSUBSCRIPT ( [ bold_italic_z start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT roman_s end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( bold_italic_X start_POSTSUBSCRIPT 1 , italic_g end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) , italic_i end_POSTSUPERSCRIPT ) , … , bold_italic_z start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT roman_s end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( bold_italic_X start_POSTSUBSCRIPT italic_B , italic_g end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) , italic_i end_POSTSUPERSCRIPT ) ] )

    where RεR_{\varepsilon}italic_R start_POSTSUBSCRIPT italic_ε end_POSTSUBSCRIPT is the Gaussian coding rate estimated on finite samples, described in Chapter 4. The gradient of ^SimDINOst(k)\hat{\mathcal{L}}_{\mathrm{SimDINO}-\mathrm{s}\mathrm{t}}^{(k)}over^ start_ARG caligraphic_L end_ARG start_POSTSUBSCRIPT roman_SimDINO - roman_st end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT with respect to θs\theta_{\mathrm{s}}italic_θ start_POSTSUBSCRIPT roman_s end_POSTSUBSCRIPT should (again) be computed, under the assumption that θt\theta_{\mathrm{t}}italic_θ start_POSTSUBSCRIPT roman_t end_POSTSUBSCRIPT is constant.

  • Update the student parameters θs\theta_{\mathrm{s}}italic_θ start_POSTSUBSCRIPT roman_s end_POSTSUBSCRIPT via an iterative gradient-based optimization algorithm, and update θt\theta_{\mathrm{t}}italic_θ start_POSTSUBSCRIPT roman_t end_POSTSUBSCRIPT via an exponential moving average with decay parameter ν(k)\nu^{(k)}italic_ν start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT, i.e.,

    θs(k+1)\displaystyle\theta_{\mathrm{s}}^{(k+1)}italic_θ start_POSTSUBSCRIPT roman_s end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k + 1 ) end_POSTSUPERSCRIPT =OptUpdate(k)(θs(k);θs^SimDINOst(k))\displaystyle=\textsc{OptUpdate}^{(k)}(\theta_{\mathrm{s}}^{(k)};\nabla_{\theta_{\mathrm{s}}}\hat{\mathcal{L}}_{\mathrm{SimDINO}-\mathrm{s}\mathrm{t}}^{(k)})= OptUpdate start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT ( italic_θ start_POSTSUBSCRIPT roman_s end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT ; ∇ start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT roman_s end_POSTSUBSCRIPT end_POSTSUBSCRIPT over^ start_ARG caligraphic_L end_ARG start_POSTSUBSCRIPT roman_SimDINO - roman_st end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT ) (7.2.35)
    θt(k+1)\displaystyle\theta_{\mathrm{t}}^{(k+1)}italic_θ start_POSTSUBSCRIPT roman_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k + 1 ) end_POSTSUPERSCRIPT =ν(k)θt(k)+(1ν(k))θs(k+1).\displaystyle=\nu^{(k)}\theta_{\mathrm{t}}^{(k)}+(1-\nu^{(k)})\theta_{\mathrm{s}}^{(k+1)}.= italic_ν start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT roman_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT + ( 1 - italic_ν start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT ) italic_θ start_POSTSUBSCRIPT roman_s end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k + 1 ) end_POSTSUPERSCRIPT . (7.2.36)

Again, we re-iterate that the gradient is only taken w.r.t. θs\theta_{\mathrm{s}}italic_θ start_POSTSUBSCRIPT roman_s end_POSTSUBSCRIPT, treating θt\theta_{\mathrm{t}}italic_θ start_POSTSUBSCRIPT roman_t end_POSTSUBSCRIPT as a constant. Here, note that while the choice of ν\nuitalic_ν is still a design decision, the hyperparameters ρ\rhoitalic_ρ and τ\tauitalic_τ are removed.

7.2.5 Evaluation Methodology

There are several ways to evaluate a trained transformer model. We highlight two in this section. Let us define the center crop view vcc:v_{\mathrm{cc}}\colon\mathcal{I}\to\mathcal{I}italic_v start_POSTSUBSCRIPT roman_cc end_POSTSUBSCRIPT : caligraphic_I → caligraphic_I which is a deterministic resized crop:

  • it resizes the image so that the shortest edge is of size SrszS_{\mathrm{rsz}}italic_S start_POSTSUBSCRIPT roman_rsz end_POSTSUBSCRIPT (similar to random resized crops with area percentage parameter 111);

  • then it takes the center Scc×SccS_{\mathrm{cc}}\times S_{\mathrm{cc}}italic_S start_POSTSUBSCRIPT roman_cc end_POSTSUBSCRIPT × italic_S start_POSTSUBSCRIPT roman_cc end_POSTSUBSCRIPT crop;

so that the final shape is (C,Scc,Scc)(C,S_{\mathrm{cc}},S_{\mathrm{cc}})( italic_C , italic_S start_POSTSUBSCRIPT roman_cc end_POSTSUBSCRIPT , italic_S start_POSTSUBSCRIPT roman_cc end_POSTSUBSCRIPT ). Notice that the view vccv_{\mathrm{cc}}italic_v start_POSTSUBSCRIPT roman_cc end_POSTSUBSCRIPT is completely deterministic given an input. For an input 𝑿\bm{X}bold_italic_X, we write 𝑿ccvcc(𝑿)\bm{X}_{\mathrm{cc}}\doteq v_{\mathrm{cc}}(\bm{X})bold_italic_X start_POSTSUBSCRIPT roman_cc end_POSTSUBSCRIPT ≐ italic_v start_POSTSUBSCRIPT roman_cc end_POSTSUBSCRIPT ( bold_italic_X ). Here SccSrszS_{\mathrm{cc}}\leq S_{\mathrm{rsz}}italic_S start_POSTSUBSCRIPT roman_cc end_POSTSUBSCRIPT ≤ italic_S start_POSTSUBSCRIPT roman_rsz end_POSTSUBSCRIPT.

Linear probing.

The first, and most architecture-agnostic, way to evaluate an encoder model 𝑿𝒛θ(𝑿)\bm{X}\mapsto\bm{z}_{\theta}(\bm{X})bold_italic_X ↦ bold_italic_z start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_X ) is to employ linear probing. Linear probing is, in a sentence, running logistic regression on the aggregate features computed by the encoder. This tells us how much semantic information exists in the representations, as well as how easily this information can be extracted. (That is: to what extent do the features of images with different semantics live on different subspaces of the feature space?)

More formally, let us suppose that we want to evaluate the quality and faithfulness of the features of the encoder on image-label data (𝑿,𝒚)(\bm{X},\bm{y})( bold_italic_X , bold_italic_y ), where there are NclsN_{\mathrm{cls}}italic_N start_POSTSUBSCRIPT roman_cls end_POSTSUBSCRIPT classes and 𝒚{0,1}Ncls\bm{y}\in\{0,1\}^{N_{\mathrm{cls}}}bold_italic_y ∈ { 0 , 1 } start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT roman_cls end_POSTSUBSCRIPT end_POSTSUPERSCRIPT is a “one-hot encoding” (namely, zeros in all positions except a 111 in the iiitalic_ith position if 𝑿\bm{X}bold_italic_X is in the iiitalic_ith class). One way to do this is to solve the logistic regression problem

min𝑾Ncls×d𝔼[CE(𝒚,𝑾𝒛θ(𝑿cc))].\min_{\bm{W}\in\mathbb{R}^{N_{\mathrm{cls}}\times d}}\operatorname{\mathbb{E}}[\operatorname{CE}(\bm{y},\bm{W}\bm{z}_{\theta}(\bm{X}_{\mathrm{cc}}))].roman_min start_POSTSUBSCRIPT bold_italic_W ∈ blackboard_R start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT roman_cls end_POSTSUBSCRIPT × italic_d end_POSTSUPERSCRIPT end_POSTSUBSCRIPT blackboard_E [ roman_CE ( bold_italic_y , bold_italic_W bold_italic_z start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_X start_POSTSUBSCRIPT roman_cc end_POSTSUBSCRIPT ) ) ] . (7.2.37)

More practically, if we have labeled data {(𝑿b,𝒚b)}b=1B\{(\bm{X}_{b},\bm{y}_{b})\}_{b=1}^{B}{ ( bold_italic_X start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT , bold_italic_y start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT ) } start_POSTSUBSCRIPT italic_b = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_B end_POSTSUPERSCRIPT, we can solve the empirical logistic regression problem (akin to (7.2.24) vs. (7.2.28)) given by

min𝑾Ncls×d1Bb=1BCE(𝒚b,𝑾𝒛θ(𝑿b,cc)).\min_{\bm{W}\in\mathbb{R}^{N_{\mathrm{cls}}\times d}}\frac{1}{B}\sum_{b=1}^{B}\operatorname{CE}(\bm{y}_{b},\bm{W}\bm{z}_{\theta}(\bm{X}_{b,\mathrm{cc}})).roman_min start_POSTSUBSCRIPT bold_italic_W ∈ blackboard_R start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT roman_cls end_POSTSUBSCRIPT × italic_d end_POSTSUPERSCRIPT end_POSTSUBSCRIPT divide start_ARG 1 end_ARG start_ARG italic_B end_ARG ∑ start_POSTSUBSCRIPT italic_b = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_B end_POSTSUPERSCRIPT roman_CE ( bold_italic_y start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT , bold_italic_W bold_italic_z start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_X start_POSTSUBSCRIPT italic_b , roman_cc end_POSTSUBSCRIPT ) ) . (7.2.38)

This problem is a convex optimization problem in 𝑾\bm{W}bold_italic_W, and thus can be solved efficiently via (stochastic) gradient descent or a litany of other algorithms. This linear probe, together with the encoder, may be used as a classifier, and we can evaluate the classification accuracy. The usual practice is to train the model first on a large dataset (such as ImageNet-1K), then train the linear probe on a dataset (such as the training dataset of CIFAR-10), and evaluate it on a third (“holdout”) dataset which is drawn from the same distribution as the second one (such as the evaluation dataset of CIFAR-10).

kkitalic_k-nearest neighbors.

We can also evaluate the performance of the features on classification tasks without needing to explicitly train a classifier by using the kkitalic_k-nearest neighbor algorithm to get an average predicted label. Namely, given a dataset {𝒛b}b=1Bd\{\bm{z}_{b}\}_{b=1}^{B}\subseteq\mathbb{R}^{d}{ bold_italic_z start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_b = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_B end_POSTSUPERSCRIPT ⊆ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT, define the kkitalic_k-nearest neighbors of another point 𝒛d\bm{z}\in\mathbb{R}^{d}bold_italic_z ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT as NNk(𝒛,{𝒛b}b=1B)\operatorname{NN}_{k}(\bm{z},\{\bm{z}_{b}\}_{b=1}^{B})roman_NN start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( bold_italic_z , { bold_italic_z start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_b = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_B end_POSTSUPERSCRIPT ). Using this notation, we can compute the predicted label 𝒚^θ(𝑿{(𝑿b,𝒚b)}b=1B)\hat{\bm{y}}_{\theta}(\bm{X}\mid\{(\bm{X}_{b},\bm{y}_{b})\}_{b=1}^{B})over^ start_ARG bold_italic_y end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_X ∣ { ( bold_italic_X start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT , bold_italic_y start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT ) } start_POSTSUBSCRIPT italic_b = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_B end_POSTSUPERSCRIPT ) as

𝒚^θ(𝑿{(𝑿b,𝒚b)}b=1B)=𝟏(i)whereiargmaxi[Q]b=1B𝒚b𝟏[𝒛θ(𝑿cc,b)NNk(𝒛θ(𝑿cc))].\hat{\bm{y}}_{\theta}(\bm{X}\mid\{(\bm{X}_{b},\bm{y}_{b})\}_{b=1}^{B})=\bm{1}(i^{\star})\quad\text{where}\quad i^{\star}\doteq\operatorname*{arg\ max}_{i\in[Q]}\sum_{b=1}^{B}\bm{y}_{b}\mathbf{1}[\bm{z}_{\theta}(\bm{X}_{\mathrm{cc},b})\in\operatorname{NN}_{k}(\bm{z}_{\theta}(\bm{X}_{\mathrm{cc}}))].over^ start_ARG bold_italic_y end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_X ∣ { ( bold_italic_X start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT , bold_italic_y start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT ) } start_POSTSUBSCRIPT italic_b = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_B end_POSTSUPERSCRIPT ) = bold_1 ( italic_i start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) where italic_i start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ≐ start_OPERATOR roman_arg roman_max end_OPERATOR start_POSTSUBSCRIPT italic_i ∈ [ italic_Q ] end_POSTSUBSCRIPT ∑ start_POSTSUBSCRIPT italic_b = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_B end_POSTSUPERSCRIPT bold_italic_y start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT bold_1 [ bold_italic_z start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_X start_POSTSUBSCRIPT roman_cc , italic_b end_POSTSUBSCRIPT ) ∈ roman_NN start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( bold_italic_z start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_X start_POSTSUBSCRIPT roman_cc end_POSTSUBSCRIPT ) ) ] . (7.2.39)

Here, 𝟏(i)ΔNcls\bm{1}(i)\in\Delta_{N_{\mathrm{cls}}}bold_1 ( italic_i ) ∈ roman_Δ start_POSTSUBSCRIPT italic_N start_POSTSUBSCRIPT roman_cls end_POSTSUBSCRIPT end_POSTSUBSCRIPT is (by an abuse of notation, cf. indicator variables) the one-hot probability vector supported at iiitalic_i, i.e., 111 in the iiitalic_ith coordinate and 0 elsewhere. That is, this procedure takes the most common label among the kkitalic_k nearest points in feature space. The kkitalic_k-nearest neighbor classification accuracy is just the accuracy of this predicted label, namely,

𝔼𝑿,𝒚[𝟏(𝒚^θ(𝑿{(𝑿b,𝒚b)}b=1B)=𝒚)]\operatorname{\mathbb{E}}_{\bm{X},\bm{y}}[\mathbf{1}(\hat{\bm{y}}_{\theta}(\bm{X}\mid\{(\bm{X}_{b},\bm{y}_{b})\}_{b=1}^{B})=\bm{y})]blackboard_E start_POSTSUBSCRIPT bold_italic_X , bold_italic_y end_POSTSUBSCRIPT [ bold_1 ( over^ start_ARG bold_italic_y end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_X ∣ { ( bold_italic_X start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT , bold_italic_y start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT ) } start_POSTSUBSCRIPT italic_b = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_B end_POSTSUPERSCRIPT ) = bold_italic_y ) ] (7.2.40)

or more commonly its corresponding empirical version, where (𝑿,𝒚)(\bm{X},\bm{y})( bold_italic_X , bold_italic_y ) ranges over a finite dataset (not the existing samples (𝑿b,𝒚b)(\bm{X}_{b},\bm{y}_{b})( bold_italic_X start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT , bold_italic_y start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT ) which are used for the kkitalic_k neighbors).

Fidelity of the attention maps.

Another way to check the performance of the representations, for a transformer-based encoder, is to examine the fidelity of the attention maps 𝑨L,kn×n\bm{A}^{L,k}\in\mathbb{R}^{n\times n}bold_italic_A start_POSTSUPERSCRIPT italic_L , italic_k end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_n × italic_n end_POSTSUPERSCRIPT as defined in Equation 7.2.19, at the last layer LLitalic_L, and given by the following pipeline:

𝑿𝒁L1=[𝒛1L1class token,𝒛2L1,𝒛nL1patch tokens]𝑨k,L=[𝑨1,1k,L𝑨1,2:k,L𝑨2:,1k,L𝑨2:,2:k,L].\displaystyle\bm{X}\mapsto\cdots\mapsto\bm{Z}^{L-1}=[\underbrace{\bm{z}_{1}^{L-1}}_{\text{class token}},\underbrace{\bm{z}_{2}^{L-1}\dots,\bm{z}_{n}^{L-1}}_{\text{patch tokens}}]\mapsto\bm{A}^{k,L}=\begin{bmatrix}\bm{A}_{1,1}^{k,L}&\bm{A}_{1,2:}^{k,L}\\ \bm{A}_{2:,1}^{k,L}&\bm{A}_{2:,2:}^{k,L}\end{bmatrix}.bold_italic_X ↦ ⋯ ↦ bold_italic_Z start_POSTSUPERSCRIPT italic_L - 1 end_POSTSUPERSCRIPT = [ under⏟ start_ARG bold_italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L - 1 end_POSTSUPERSCRIPT end_ARG start_POSTSUBSCRIPT class token end_POSTSUBSCRIPT , under⏟ start_ARG bold_italic_z start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L - 1 end_POSTSUPERSCRIPT … , bold_italic_z start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L - 1 end_POSTSUPERSCRIPT end_ARG start_POSTSUBSCRIPT patch tokens end_POSTSUBSCRIPT ] ↦ bold_italic_A start_POSTSUPERSCRIPT italic_k , italic_L end_POSTSUPERSCRIPT = [ start_ARG start_ROW start_CELL bold_italic_A start_POSTSUBSCRIPT 1 , 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k , italic_L end_POSTSUPERSCRIPT end_CELL start_CELL bold_italic_A start_POSTSUBSCRIPT 1 , 2 : end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k , italic_L end_POSTSUPERSCRIPT end_CELL end_ROW start_ROW start_CELL bold_italic_A start_POSTSUBSCRIPT 2 : , 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k , italic_L end_POSTSUPERSCRIPT end_CELL start_CELL bold_italic_A start_POSTSUBSCRIPT 2 : , 2 : end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k , italic_L end_POSTSUPERSCRIPT end_CELL end_ROW end_ARG ] . (7.2.41)

In particular, we examine what the attention maps for a given input reveal about the salient objects in the input image, i.e., which parts of the image provide the most globally-relevant information to the class token. One particular way to do this is to examine the component of the attention map where the class token is extracted as the query and removed from the value matrix, i.e., 𝑨2:,1k,L1×(n1)=1×N\bm{A}_{2:,1}^{k,L}\in\mathbb{R}^{1\times(n-1)}=\mathbb{R}^{1\times N}bold_italic_A start_POSTSUBSCRIPT 2 : , 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k , italic_L end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT 1 × ( italic_n - 1 ) end_POSTSUPERSCRIPT = blackboard_R start_POSTSUPERSCRIPT 1 × italic_N end_POSTSUPERSCRIPT or its transpose 𝒂k,L=(𝑨2:,1k,L)N\bm{a}^{k,L}=(\bm{A}_{2:,1}^{k,L})^{\top}\in\mathbb{R}^{N}bold_italic_a start_POSTSUPERSCRIPT italic_k , italic_L end_POSTSUPERSCRIPT = ( bold_italic_A start_POSTSUBSCRIPT 2 : , 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k , italic_L end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT. Notice that this vector 𝒂k,L\bm{a}^{k,L}bold_italic_a start_POSTSUPERSCRIPT italic_k , italic_L end_POSTSUPERSCRIPT, which we label as the “saliency vector at the kkitalic_kth attention head at layer LLitalic_L,” has a value for every patch, 1,,N1,\dots,N1 , … , italic_N, and we use this value to describe how relevant each patch is toward the global information. In particular for visualization’s sake we create a new image where each patch is replaced by its corresponding value in the saliency vector, showcasing the contribution of each patch; we call this image the “saliency map at the kkitalic_kth attention head at layer LLitalic_L”. To visualize the total relevance of each patch toward the global information across all heads, we can average the saliency vector, i.e., 𝒂~L1Kk=1K𝒂k,L\tilde{\bm{a}}^{L}\doteq\frac{1}{K}\sum_{k=1}^{K}\bm{a}^{k,L}over~ start_ARG bold_italic_a end_ARG start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT ≐ divide start_ARG 1 end_ARG start_ARG italic_K end_ARG ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT bold_italic_a start_POSTSUPERSCRIPT italic_k , italic_L end_POSTSUPERSCRIPT and expand into the average saliency map. The average saliency maps should highlight the relevant parts of the input image.

Object detection and segmentation.

We can evaluate how the representations capture the fine-grained (i.e., smaller or more detailed) properties of the input by using them for semantic segmentation. Roughly, this means that we use the features to construct bounding boxes for all objects in the input. There are several ways to do this, and several ways to score the resulting bounding boxes compared to ground truth. Each combination of methods corresponds to a particular segmentation metric. We do not formally describe them here as they are not particularly insightful, but the DINO paper [CTM+21] and DINOv2 paper [ODM+23] contain references to all metrics that are used in practice.

7.2.6 Experimental Setup and Results

Since SimDINO is directly built upon DINO, we compare the optimal settings for DINO as given by their original paper [CTM+21] with the same settings applied to SimDINO for a fair comparison.

Objective function.

We use 101010 local views (i.e., Mloc=10M_{\mathrm{loc}}=10italic_M start_POSTSUBSCRIPT roman_loc end_POSTSUBSCRIPT = 10) of resolution 96×9696\times 9696 × 96 (i.e., Sloc=96S_{\mathrm{loc}}=96italic_S start_POSTSUBSCRIPT roman_loc end_POSTSUBSCRIPT = 96) and 222 global views (i.e., Mglo=2M_{\mathrm{glo}}=2italic_M start_POSTSUBSCRIPT roman_glo end_POSTSUBSCRIPT = 2) of resolution 224×224224\times 224224 × 224 (i.e., Sglo=224S_{\mathrm{glo}}=224italic_S start_POSTSUBSCRIPT roman_glo end_POSTSUBSCRIPT = 224) for all experiments. The corresponding portions of the original images cropped for local and global views are ploc[120,310]p_{\mathrm{loc}}\in[\frac{1}{20},\frac{3}{10}]italic_p start_POSTSUBSCRIPT roman_loc end_POSTSUBSCRIPT ∈ [ divide start_ARG 1 end_ARG start_ARG 20 end_ARG , divide start_ARG 3 end_ARG start_ARG 10 end_ARG ] and pglo[310,1]p_{\mathrm{glo}}\in[\frac{3}{10},1]italic_p start_POSTSUBSCRIPT roman_glo end_POSTSUBSCRIPT ∈ [ divide start_ARG 3 end_ARG start_ARG 10 end_ARG , 1 ] (chosen randomly per-view). The smaller edge size within the resized crops is Srsz=256S_{\mathrm{rsz}}=256italic_S start_POSTSUBSCRIPT roman_rsz end_POSTSUBSCRIPT = 256, and the center crop (evaluation) view edge size is Scc=224S_{\mathrm{cc}}=224italic_S start_POSTSUBSCRIPT roman_cc end_POSTSUBSCRIPT = 224. All of these settings apply to both DINO and SimDINO.

Model architecture.

For all inputs, we set the patch size to be 16×1616\times 1616 × 16 (namely, PH=PW=16P_{H}=P_{W}=16italic_P start_POSTSUBSCRIPT italic_H end_POSTSUBSCRIPT = italic_P start_POSTSUBSCRIPT italic_W end_POSTSUBSCRIPT = 16). We use the small, base, and large models of the ViT [DBK+21] architecture as the embedding and backbone. The feature extractor is a three-layer MLP with a hidden size of 204820482048 and an output dimension of 256256256, followed by an 2\ell^{2}roman_ℓ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT-normalization, as specified in Section 7.2.3. For DINO architectures (i.e., not SimDINO architectures), the DINO head 𝑾\bm{W}bold_italic_W is a matrix in 65536×256\mathbb{R}^{65536\times 256}blackboard_R start_POSTSUPERSCRIPT 65536 × 256 end_POSTSUPERSCRIPT, and the parameter 𝝁\bm{\mu}bold_italic_μ is a vector in 65536\mathbb{R}^{65536}blackboard_R start_POSTSUPERSCRIPT 65536 end_POSTSUPERSCRIPT.

Datasets and optimization.

For pre-training, both our DINO reproduction and SimDINO use the ImageNet-1K dataset across all methods. We use AdamW [LH17] as the optimizer, which is a very standard choice. We follow the following hyperparameter recommendations:

  • The batch size is B=1024B=1024italic_B = 1024.

  • The learning rate (for AdamW and the student model) has “base” value 2×1032\times 10^{-3}2 × 10 start_POSTSUPERSCRIPT - 3 end_POSTSUPERSCRIPT. In the first 101010 epochs the learning rate linearly increases from 0 to the base value (i.e., at the iiitalic_ith epoch the learning rate is (i/10)2×103(i/10)\cdot 2\times 10^{-3}( italic_i / 10 ) ⋅ 2 × 10 start_POSTSUPERSCRIPT - 3 end_POSTSUPERSCRIPT, for 1i101\leq i\leq 101 ≤ italic_i ≤ 10). Then over the next 909090 epochs the learning rate decays via a so-called cosine schedule back down to 0. The definition of a cosine schedule is given in many places, including PyTorch documentation, and it is commonly used when training deep vision models.

  • The weight decay (the W in AdamW) follows a cosine schedule from 0.04 to 0.40.40.4 over training.

  • The EMA rate ν\nuitalic_ν follows a cosine schedule from 0.9960.9960.996 to 1.01.01.0 over training. Specifically for DINO, the centering EMA rate ρ\rhoitalic_ρ is fixed at 0.90.90.9.

  • Specifically for DINO, the teacher temperature τt\tau_{\mathrm{t}}italic_τ start_POSTSUBSCRIPT roman_t end_POSTSUBSCRIPT is fixed at 0.10.10.1, while the student temperature τs\tau_{\mathrm{s}}italic_τ start_POSTSUBSCRIPT roman_s end_POSTSUBSCRIPT linearly increases from 0.040.040.04 to 0.070.070.07 during the first 303030 epochs and is fixed at 0.070.070.07 thereafter.

We use some (essentially information-preserving) data augmentations, such as flips, color jittering, Gaussian blur, and solarization, for each seen image during training, before taking the local and global views. The exact hyperparameters governing these are not listed here, but are referenced in the DINO paper [CTM+21].

For linear probing, the linear probe is usually trained using the AdamW optimizer with learning rate 2×1042\times 10^{-4}2 × 10 start_POSTSUPERSCRIPT - 4 end_POSTSUPERSCRIPT, weight decay 0.010.010.01, and batch size 512512512, but these are often modified on a case-by-case basis to minimize the loss.

Method Model Epochs 20-NN Linear Probing
DINO ViT-B 100 72.9 76.3
SimDINO ViT-B 100 74.9 77.3
DINO ViT-L 100
SimDINO ViT-L 100 75.6 77.4
SwAV ViT-S 800 66.3 73.5
MoCov3 ViT-B 300 76.7
Table 7.1: Classification performance on hold-out test data for DINO and SimDINO, using both kkitalic_k-nearest neighbor accuracy (k=20k=20italic_k = 20) and linear probing. At the same number of iterations (100100100), SimDINO is clearly better in terms of performance, and is more stable (the DINO training running on ViT-L backbone with the provided settings has very unstable optimization and obtains NaN loss in short order). We also compare to other standout methods, namely SwAV and MoCov3, which DINO was built on.
Figure 7.10 : A qualitative comparison of saliency maps generated by DINO (middle row) and by SimDINO (bottom row) . For each image, we compute and display the average saliency map in the last layer L L italic_L . The saliency maps are similar across models, meaning that all models converge to a similar notion of what objects are important. Note that although X eval X_{\mathrm{eval}} italic_X start_POSTSUBSCRIPT roman_eval end_POSTSUBSCRIPT is a square image, it is interpolated back into rectangular shape to make this visualization.
Figure 7.10: A qualitative comparison of saliency maps generated by DINO (middle row) and by SimDINO (bottom row). For each image, we compute and display the average saliency map in the last layer LLitalic_L. The saliency maps are similar across models, meaning that all models converge to a similar notion of what objects are important. Note that although XevalX_{\mathrm{eval}}italic_X start_POSTSUBSCRIPT roman_eval end_POSTSUBSCRIPT is a square image, it is interpolated back into rectangular shape to make this visualization.
Detection \uparrow Segmentation \uparrow
Method Model AP50 AP75 AP AP50 AP75 AP
SimDINO ViT-L/16 5.4 1.9 2.4 4.5 1.4 1.9
SimDINO ViT-B/16 5.2 2.0 2.5 4.7 1.5 2.0
DINO ViT-B/16 3.9 1.5 1.8 3.1 1.0 1.4
DINO ViT-B/8 5.1 2.3 2.5 4.1 1.3 1.8
Table 7.2: Segmentation performance of pre-trained DINO and SimDINO models on COCO val2017 [LMB+14], a segmentation dataset which contains object location metadata. We do not train on COCO, merely using the pre-trained embedding and backbone, and the bounding boxes are extracted from the features via a method called MaskCut [WGY+23]. Nevertheless, SimDINO surpasses DINO at object detection and segmentation under fair comparison, and even surpasses DINO with smaller patch size (side length 888 instead of 161616). Smaller patch sizes are known to help performance, especially with detection and segmentation tasks, so this result is quite surprising and encouraging.
Evaluation results.

In terms of downstream classification performance, we obtain the performance in Table 7.1. We observe that the performance of SimDINO is much higher than that of DINO under fair comparison. Also, it is much more stable: the prescribed settings of DINO cannot train a ViT-L(arge) model. On the other hand, Figure 7.10 shows visualizations of the average saliency maps in DINO and our simplified DINO, observing that the saliency maps look quite similar across models, indicating that the models learn features which are at least as good at capturing fine-grained details. The segmentation and object detection performances in Table 7.2 confirm this claim quantitatively, where SimDINO features show substantive improvement over those of DINO.

7.3 Image Classification

In the previous section, we simplified an overly complex learning objective using our intuition about representation learning through the lens of compression. However, many of the most popular learning procedures are incredibly simple. In these cases, it is difficult to further simplify the objective. Thus, in this and future sections, we will focus on principled ways to modify the deep network architectures for a variety of tasks.

Let us first start with arguably the most classical task in machine learning: image classification, which is often used as a standard task to evaluate pattern recognition algorithms or deep network architectures. From our discussion of white-box architectures in Chapter 4, we only need a semantically meaningful task to learn good representations with white-box architectures. We will validate this idea in this section.

First, the dataset stays largely the same as Section 7.2.1. Both the training and test data consist of labeled images, i.e., image-label pairs (𝑿,𝒚)C×H×W×{0,1}Ncls(\bm{X},\bm{y})\in\mathbb{R}^{C\times H\times W}\times\{0,1\}^{N_{\mathrm{cls}}}( bold_italic_X , bold_italic_y ) ∈ blackboard_R start_POSTSUPERSCRIPT italic_C × italic_H × italic_W end_POSTSUPERSCRIPT × { 0 , 1 } start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT roman_cls end_POSTSUBSCRIPT end_POSTSUPERSCRIPT. We still apply various data augmentations (e.g., flips, Gaussian blurring, solarization, etc.) to each sample in each new batch.

7.3.1 Task and Objective

Unlike before, our task is not just to learn a good representation of the data, but also to simultaneously build a classifier. Formally, we have labeled data pairs (𝑿,𝒚)(\bm{X},\bm{y})( bold_italic_X , bold_italic_y ), where 𝒚{0,1}Ncls\bm{y}\in\{0,1\}^{N_{\mathrm{cls}}}bold_italic_y ∈ { 0 , 1 } start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT roman_cls end_POSTSUBSCRIPT end_POSTSUPERSCRIPT is a one-hot vector denoting the class membership of 𝑿\bm{X}bold_italic_X. We consider a deterministic center crop view vccv_{\mathrm{cc}}italic_v start_POSTSUBSCRIPT roman_cc end_POSTSUBSCRIPT of the input data 𝑿\bm{X}bold_italic_X (cf Section 7.2.2). We want to jointly train a feature mapping (fθ,fθext)(f_{\theta},f_{\theta}^{\mathrm{ext}})( italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT , italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ext end_POSTSUPERSCRIPT ) and a classification head hθh_{\theta}italic_h start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT, defined as follows:

hθ(𝒛)softmax(𝑾head𝒛+𝒃head),𝒛dh_{\theta}(\bm{z})\doteq\operatorname{\mathrm{softmax}}(\bm{W}^{\mathrm{head}}\bm{z}+\bm{b}^{\mathrm{head}}),\qquad\forall\bm{z}\in\mathbb{R}^{d}italic_h start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_z ) ≐ roman_softmax ( bold_italic_W start_POSTSUPERSCRIPT roman_head end_POSTSUPERSCRIPT bold_italic_z + bold_italic_b start_POSTSUPERSCRIPT roman_head end_POSTSUPERSCRIPT ) , ∀ bold_italic_z ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT (7.3.1)

where (𝑾head,𝒃head)Ncls×d×Ncls(\bm{W}^{\mathrm{head}},\bm{b}^{\mathrm{head}})\in\mathbb{R}^{N_{\mathrm{cls}}\times d}\times\mathbb{R}^{N_{\mathrm{cls}}}( bold_italic_W start_POSTSUPERSCRIPT roman_head end_POSTSUPERSCRIPT , bold_italic_b start_POSTSUPERSCRIPT roman_head end_POSTSUPERSCRIPT ) ∈ blackboard_R start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT roman_cls end_POSTSUBSCRIPT × italic_d end_POSTSUPERSCRIPT × blackboard_R start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT roman_cls end_POSTSUBSCRIPT end_POSTSUPERSCRIPT are trainable parameters in the parameter set θ\thetaitalic_θ, such that the map 𝑿cc𝒑θ(𝑿cc)hθ(𝒛θ(𝑿cc))\bm{X}_{\mathrm{cc}}\mapsto\bm{p}_{\theta}(\bm{X}_{\mathrm{cc}})\doteq h_{\theta}(\bm{z}_{\theta}(\bm{X}_{\mathrm{cc}}))bold_italic_X start_POSTSUBSCRIPT roman_cc end_POSTSUBSCRIPT ↦ bold_italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_X start_POSTSUBSCRIPT roman_cc end_POSTSUBSCRIPT ) ≐ italic_h start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_z start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_X start_POSTSUBSCRIPT roman_cc end_POSTSUBSCRIPT ) ) predicts a smoothed label for the view 𝑿cc=vcc(𝑿)\bm{X}_{\mathrm{cc}}=v_{\mathrm{cc}}(\bm{X})bold_italic_X start_POSTSUBSCRIPT roman_cc end_POSTSUBSCRIPT = italic_v start_POSTSUBSCRIPT roman_cc end_POSTSUBSCRIPT ( bold_italic_X ) of the input 𝑿\bm{X}bold_italic_X. The learning problem attempts to minimize the distance between 𝒑θ\bm{p}_{\theta}bold_italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT and 𝒚\bm{y}bold_italic_y measured through cross-entropy:

minθ{CE(θ)𝔼[CE(𝒚,𝒑θ(𝑿cc))]}.\min_{\theta}\left\{\mathcal{L}_{\operatorname{CE}}(\theta)\doteq\operatorname{\mathbb{E}}[\operatorname{CE}(\bm{y},\bm{p}_{\theta}(\bm{X}_{\mathrm{cc}}))]\right\}.roman_min start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT { caligraphic_L start_POSTSUBSCRIPT roman_CE end_POSTSUBSCRIPT ( italic_θ ) ≐ blackboard_E [ roman_CE ( bold_italic_y , bold_italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_X start_POSTSUBSCRIPT roman_cc end_POSTSUBSCRIPT ) ) ] } . (7.3.2)

7.3.2 The CRATE Architecture

The architecture that we use is the CRATE architecture, described in some detail in Chapter 4. The overall setup is similar to that of the regular transformer in Section 7.2.3, with a few changes. While the embedding step is the same as both DINO and SimDINO in Section 7.2.3, the feature extraction step is the same as SimDINO in Section 7.2.3 as it just extracts the feature corresponding to the class token, and the classification head is described in Section 7.3.1, the backbone architecture is different. Each layer takes the form

𝒁θ+1/2(𝑿)\displaystyle\bm{Z}_{\theta}^{\ell+1/2}(\bm{X})bold_italic_Z start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ + 1 / 2 end_POSTSUPERSCRIPT ( bold_italic_X ) =𝒁θ(𝑿)+MSSAθ(LNθ1,(𝒁θ(𝑿))),\displaystyle=\bm{Z}_{\theta}^{\ell}(\bm{X})+\operatorname{MSSA}_{\theta}^{\ell}(\operatorname{LN}_{\theta}^{1,\ell}(\bm{Z}_{\theta}^{\ell}(\bm{X}))),= bold_italic_Z start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ( bold_italic_X ) + roman_MSSA start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ( roman_LN start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 , roman_ℓ end_POSTSUPERSCRIPT ( bold_italic_Z start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ( bold_italic_X ) ) ) , (7.3.3)
𝒁θ+1(𝑿)\displaystyle\bm{Z}_{\theta}^{\ell+1}(\bm{X})bold_italic_Z start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ + 1 end_POSTSUPERSCRIPT ( bold_italic_X ) =ISTAθ(LNθ2,(𝒁θ+1/2(𝑿))),\displaystyle=\operatorname{ISTA}_{\theta}^{\ell}(\operatorname{LN}_{\theta}^{2,\ell}(\bm{Z}_{\theta}^{\ell+1/2}(\bm{X}))),= roman_ISTA start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ( roman_LN start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 , roman_ℓ end_POSTSUPERSCRIPT ( bold_italic_Z start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ + 1 / 2 end_POSTSUPERSCRIPT ( bold_italic_X ) ) ) , (7.3.4)

where the MSSAθ\operatorname{MSSA}_{\theta}^{\ell}roman_MSSA start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT and ISTAθ\operatorname{ISTA}_{\theta}^{\ell}roman_ISTA start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT blocks are as described in Chapter 4, namely:

  • The MSSA\operatorname{MSSA}roman_MSSA operator is multi-head-subspace-self-attention, defined as follows:

    MSSAθ(𝒁)𝑼out[SA([𝑼1,]𝒁,[𝑼1,]𝒁,[𝑼1,]𝒁)SA([𝑼K,]𝒁,[𝑼K,]𝒁,[𝑼1,]𝒁)]+𝒃out𝟏n\operatorname{MSSA}_{\theta}^{\ell}(\bm{Z})\doteq\bm{U}_{\mathrm{out}}^{\ell}\begin{bmatrix}\operatorname{SA}([\bm{U}^{1,\ell}]^{\top}\bm{Z},[\bm{U}^{1,\ell}]^{\top}\bm{Z},[\bm{U}^{1,\ell}]^{\top}\bm{Z})\\ \vdots\\ \operatorname{SA}([\bm{U}^{K,\ell}]^{\top}\bm{Z},[\bm{U}^{K,\ell}]^{\top}\bm{Z},[\bm{U}^{1,\ell}]^{\top}\bm{Z})\end{bmatrix}+\bm{b}_{\mathrm{out}}^{\ell}\bm{1}_{n}^{\top}roman_MSSA start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ( bold_italic_Z ) ≐ bold_italic_U start_POSTSUBSCRIPT roman_out end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT [ start_ARG start_ROW start_CELL roman_SA ( [ bold_italic_U start_POSTSUPERSCRIPT 1 , roman_ℓ end_POSTSUPERSCRIPT ] start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_Z , [ bold_italic_U start_POSTSUPERSCRIPT 1 , roman_ℓ end_POSTSUPERSCRIPT ] start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_Z , [ bold_italic_U start_POSTSUPERSCRIPT 1 , roman_ℓ end_POSTSUPERSCRIPT ] start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_Z ) end_CELL end_ROW start_ROW start_CELL ⋮ end_CELL end_ROW start_ROW start_CELL roman_SA ( [ bold_italic_U start_POSTSUPERSCRIPT italic_K , roman_ℓ end_POSTSUPERSCRIPT ] start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_Z , [ bold_italic_U start_POSTSUPERSCRIPT italic_K , roman_ℓ end_POSTSUPERSCRIPT ] start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_Z , [ bold_italic_U start_POSTSUPERSCRIPT 1 , roman_ℓ end_POSTSUPERSCRIPT ] start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_Z ) end_CELL end_ROW end_ARG ] + bold_italic_b start_POSTSUBSCRIPT roman_out end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT bold_1 start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT (7.3.5)

    where 𝑼k,d×p\bm{U}^{k,\ell}\in\mathbb{R}^{d\times p}bold_italic_U start_POSTSUPERSCRIPT italic_k , roman_ℓ end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_p end_POSTSUPERSCRIPT, 𝑼outd×Kp\bm{U}_{\mathrm{out}}^{\ell}\in\mathbb{R}^{d\times Kp}bold_italic_U start_POSTSUBSCRIPT roman_out end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_K italic_p end_POSTSUPERSCRIPT, and 𝒃outd\bm{b}_{\mathrm{out}}^{\ell}\in\mathbb{R}^{d}bold_italic_b start_POSTSUBSCRIPT roman_out end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT are trainable parameters belonging to the parameter set θ\thetaitalic_θ, and (recall) the self-attention operator SA\operatorname{SA}roman_SA is defined in (7.2.16).

  • The ISTA\operatorname{ISTA}roman_ISTA operator is the iterative-shrinkage-thresholding-algorithm operator, defined as follows:

    ISTAθ(𝒁)ReLU(𝒁β(𝑫)(𝑫𝒁𝒁)+βλ𝟏d𝟏n),\operatorname{ISTA}_{\theta}^{\ell}(\bm{Z})\doteq\operatorname{ReLU}(\bm{Z}-\beta(\bm{D}^{\ell})^{\top}(\bm{D}^{\ell}\bm{Z}-\bm{Z})+\beta\lambda\bm{1}_{d}\bm{1}_{n}^{\top}),roman_ISTA start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ( bold_italic_Z ) ≐ roman_ReLU ( bold_italic_Z - italic_β ( bold_italic_D start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ( bold_italic_D start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT bold_italic_Z - bold_italic_Z ) + italic_β italic_λ bold_1 start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT bold_1 start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) , (7.3.6)

    so named because the map 𝑿ReLU(𝑿β𝑫(𝑫𝑿𝒁)+βλ𝟏d𝟏n)\bm{X}\mapsto\operatorname{ReLU}(\bm{X}-\beta\bm{D}^{\top}(\bm{D}\bm{X}-\bm{Z})+\beta\lambda\bm{1}_{d}\bm{1}_{n}^{\top})bold_italic_X ↦ roman_ReLU ( bold_italic_X - italic_β bold_italic_D start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ( bold_italic_D bold_italic_X - bold_italic_Z ) + italic_β italic_λ bold_1 start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT bold_1 start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) is one step of the well-established ISTA algorithm to find an element-wise non-negative sparse representation for 𝒁\bm{Z}bold_italic_Z with respect to the complete dictionary 𝑫\bm{D}bold_italic_D (cf Section 2.3).

We call this architecture CRATE, and a layer of the backbone is depicted in Figure 4.13. CRATE models, on top of being interpretable, are generally also highly performant and parameter-efficient.

7.3.3 Optimization

We train our classifier using a simple end-to-end stochastic optimization procedure, where we subsample data and views, compute the average loss and its gradient over these samples, and use an optimization algorithm to change the parameters. At each timestep kkitalic_k, we:

  • Subsample BBitalic_B different labeled samples {(𝑿b(k),𝒚b(k))}b=1B×{0,1}Ncls\{(\bm{X}_{b}^{(k)},\bm{y}_{b}^{(k)})\}_{b=1}^{B}\subseteq\mathcal{I}\times\{0,1\}^{N_{\mathrm{cls}}}{ ( bold_italic_X start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT , bold_italic_y start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT ) } start_POSTSUBSCRIPT italic_b = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_B end_POSTSUPERSCRIPT ⊆ caligraphic_I × { 0 , 1 } start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT roman_cls end_POSTSUBSCRIPT end_POSTSUPERSCRIPT.

  • For each labeled sample (𝑿b(k),𝒚b(k))(\bm{X}_{b}^{(k)},\bm{y}_{b}^{(k)})( bold_italic_X start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT , bold_italic_y start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT ), compute the central crop view vb,cc(k)v_{b,\mathrm{cc}}^{(k)}italic_v start_POSTSUBSCRIPT italic_b , roman_cc end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT and apply it to 𝑿b(k)\bm{X}_{b}^{(k)}bold_italic_X start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT to get 𝑿b,cc(k)vb,cc(k)(𝑿b(k))\bm{X}_{b,\mathrm{cc}}^{(k)}\doteq v_{b,\mathrm{cc}}^{(k)}(\bm{X}_{b}^{(k)})bold_italic_X start_POSTSUBSCRIPT italic_b , roman_cc end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT ≐ italic_v start_POSTSUBSCRIPT italic_b , roman_cc end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT ( bold_italic_X start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT ).

  • Compute the predictions 𝒑θ(𝑿b,cc(k))(hθfθextfθ)(𝑿b,cc(k))\bm{p}_{\theta}(\bm{X}_{b,\mathrm{cc}}^{(k)})\doteq(h_{\theta}\circ f_{\theta}^{\mathrm{ext}}\circ f_{\theta})(\bm{X}_{b,\mathrm{cc}}^{(k)})bold_italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_X start_POSTSUBSCRIPT italic_b , roman_cc end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT ) ≐ ( italic_h start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ∘ italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ext end_POSTSUPERSCRIPT ∘ italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ) ( bold_italic_X start_POSTSUBSCRIPT italic_b , roman_cc end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT ).

  • Form the surrogate stochastic loss

    ^CE(k)(θ)1Bb=1BCE(𝒚b(k),𝒑θ(𝑿b,cc(k))).\hat{\mathcal{L}}_{\operatorname{CE}}^{(k)}(\theta)\doteq\frac{1}{B}\sum_{b=1}^{B}\operatorname{CE}(\bm{y}_{b}^{(k)},\bm{p}_{\theta}(\bm{X}_{b,\mathrm{cc}}^{(k)})).over^ start_ARG caligraphic_L end_ARG start_POSTSUBSCRIPT roman_CE end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT ( italic_θ ) ≐ divide start_ARG 1 end_ARG start_ARG italic_B end_ARG ∑ start_POSTSUBSCRIPT italic_b = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_B end_POSTSUPERSCRIPT roman_CE ( bold_italic_y start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT , bold_italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_X start_POSTSUBSCRIPT italic_b , roman_cc end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT ) ) . (7.3.7)
  • Compute one step of an optimization algorithm on θ\thetaitalic_θ, giving the following iteration:

    θ(k+1)OptUpdate(k)(θ(k);θ^CE(k)).\theta^{(k+1)}\doteq\textsc{OptUpdate}^{(k)}(\theta^{(k)};\nabla_{\theta}\hat{\mathcal{L}}_{\operatorname{CE}}^{(k)}).italic_θ start_POSTSUPERSCRIPT ( italic_k + 1 ) end_POSTSUPERSCRIPT ≐ OptUpdate start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT ( italic_θ start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT ; ∇ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT over^ start_ARG caligraphic_L end_ARG start_POSTSUBSCRIPT roman_CE end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT ) . (7.3.8)

7.3.4 Evaluation Methodology

We use the same evaluation procedure as Section 7.2.5. To summarize, for all evaluations (as well as training) we use a center crop view vccv_{\mathrm{cc}}italic_v start_POSTSUBSCRIPT roman_cc end_POSTSUBSCRIPT which reshapes the input image and takes a large central crop of size (C,Scc,Scc)(C,S_{\mathrm{cc}},S_{\mathrm{cc}})( italic_C , italic_S start_POSTSUBSCRIPT roman_cc end_POSTSUBSCRIPT , italic_S start_POSTSUBSCRIPT roman_cc end_POSTSUBSCRIPT ) where CCitalic_C is the number of channels in the input image. We can then do linear probing, attention map visualization, and detection/segmentation benchmarks, given the output of this view.

7.3.5 Experimental Setup and Results

Since CRATE is directly based on the transformer, we compare the optimal settings for ViT as given by [DBK+21, TCD+20] with the same settings applied to CRATE for a fair comparison.

Model architecture.

The center crop resizes the whole image so that the shorter edge is of size 256256256 (i.e., Srsz=256S_{\mathrm{rsz}}=256italic_S start_POSTSUBSCRIPT roman_rsz end_POSTSUBSCRIPT = 256) before taking a center crop of size 224×224224\times 224224 × 224 (i.e., Scc=224S_{\mathrm{cc}}=224italic_S start_POSTSUBSCRIPT roman_cc end_POSTSUBSCRIPT = 224), both in evaluation and training. We take patch size 161616 (i.e., PH=PW=16P_{H}=P_{W}=16italic_P start_POSTSUBSCRIPT italic_H end_POSTSUBSCRIPT = italic_P start_POSTSUBSCRIPT italic_W end_POSTSUBSCRIPT = 16). We use the tiny, small, base, and large models of the ViT [DBK+21] architecture as the embedding and backbone, swapping out the MHSA and MLP components for MSSA and ISTA, respectively, using the same number of heads and head dimension in the case of MSSA, and therefore reducing the number of training parameters drastically. For CRATE, we set (β,λ)=(1,0.1)(\beta,\lambda)=(1,0.1)( italic_β , italic_λ ) = ( 1 , 0.1 ).

Datasets and optimization.

For pre-training, we use the ImageNet-1K dataset. We use the LION optimizer [CLH+24] to pre-train both our ViT replication as well as CRATE. We set the base learning rate as 2.4×1042.4\times 10^{-4}2.4 × 10 start_POSTSUPERSCRIPT - 4 end_POSTSUPERSCRIPT, the weight decay as 0.50.50.5, and batch size as B=2048B=2048italic_B = 2048. Our learning rate schedule increases the learning rate linearly to the base learning rate over the first 555 epochs, and decreases to 0 using a cosine schedule over the next 145145145 epochs (training all models for 150150150 epochs each). For pre-training, we apply a usual regime of data augmentations (flips, Gaussian blurs, solarization, etc.) to the image data, and also add small noise to the labels (this is called label smoothing [MKH19]).

For linear probing, we use several evaluation datasets such as CIFAR10, Oxford-Flowers, and Oxford-IIT-Pets. We use the AdamW optimizer to train the linear probe, using learning rate 5×1055\times 10^{-5}5 × 10 start_POSTSUPERSCRIPT - 5 end_POSTSUPERSCRIPT, weight decay 0.010.010.01, and batch size B=256B=256italic_B = 256. We also apply the aforementioned data augmentations to the image data.

Model CRATE-T CRATE-S CRATE-B CRATE-L ViT-T ViT-S
# parameters 6.09M 13.12M 22.80M 77.64M 5.72M 22.05M
ImageNet-1K 66.7 69.2 70.8 71.3 71.5 72.4
ImageNet-1K ReaL 74.0 76.0 76.5 77.4 78.3 78.4
CIFAR10 95.5 96.0 96.8 97.2 96.6 97.2
CIFAR100 78.9 81.0 82.7 83.6 81.8 83.2
Oxford Flowers-102 84.6 87.1 88.7 88.3 85.1 88.5
Oxford-IIIT-Pets 81.4 84.9 85.3 87.4 88.5 88.6
Table 7.3: Linear probing classification accuracy of CRATE and ViT on various datasets with different model sizes when the backbone is pre-trained for classification on ImageNet-1K. We observe that given the same model configuration, CRATE has comparable classification performance with a simpler, more principled, and more parameter-efficient design.
Figure 7.11 : Interpretable saliency maps in CRATE with patch size 8 8 8 . When given images with similar properties (perhaps but not necessarily from the same class), the saliency maps corresponding to different attention heads in the last layer each highlight a specific property. One can observe that the average saliency map (not included) then highlights all relevant objects in the image, showing that it uses all fine-grained details of the input image for classification. This is the first machine learning system to do this, to the authors’ knowledge, much less automatically without training on any segmentation data.
Figure 7.11: Interpretable saliency maps in CRATE with patch size 888. When given images with similar properties (perhaps but not necessarily from the same class), the saliency maps corresponding to different attention heads in the last layer each highlight a specific property. One can observe that the average saliency map (not included) then highlights all relevant objects in the image, showing that it uses all fine-grained details of the input image for classification. This is the first machine learning system to do this, to the authors’ knowledge, much less automatically without training on any segmentation data.
Detection (\uparrow) Segmentation (\uparrow)
Model AP50 AP75 AP AP50 AP75 AP
CRATE-S/8 2.9 1.0 1.1 1.8 0.7 0.8
CRATE-B/8 2.9 1.0 1.3 2.2 0.7 1.0
ViT-S/8 0.1 0.1 0.0 0.0 0.0 0.0
ViT-B/8 0.8 0.2 0.4 0.7 0.5 0.4
Table 7.4: Object detection and fine-grained segmentation via MaskCut on COCO val2017 [LMB+14]. Here all models are trained with patch size 888 instead of 161616. CRATE conclusively performs better than the ViT at detection and segmentation metrics when both are trained using supervised classification.
Experiment results.

Table 7.3 demonstrates that CRATE models achieve parity or improvement compared to the popular Vision Transformer (ViT) architecture at similar parameter counts, at least in terms of the linear separability of their features with respect to different classes. In terms of attention map fidelity, Figure 7.11 demonstrates a truly extraordinary result: without needing to train on any segmentation or object detection data, not only do the saliency maps effectively capture all relevant parts of the input image, the saliency maps self-organize to each correspond to a discrete set of concepts, even across samples and classes! This is the first system to do this, to the authors’ knowledge, and it can do this without using any extra data except for the image classification data. Table 7.4 confirms these qualitative insights quantitatively, showing significant improvement over ViTs trained in the same supervised classification setup.

7.4 Causal Language Modeling

We now study causal language modeling, a method for training large language models (LLMs). This is the same setup used to train, among many others, GPT-2 and many other language models.

7.4.1 Data

The data we will use to investigate the performance of CRATE for language tasks will be OpenWebText (OWT) [GC19], an open-source reproduction of the unreleased WebText dataset used by OpenAI to train GPT2. Each sample in OWT is a web document, typically sourced from high-quality web pages, blogs, articles, or online discussions, that is written in well-formed natural language. The OpenWebText dataset contains around 8.01M documents of varying lengths, totaling around 41.70GB of text. For evaluation, we will use several datasets, such as WikiText [MXB+16]333For WikiText2 and WikiText103 [MXB+16], the test splits are the same, so we merge them as a single dataset referred to as WikiText., LAMBADA [PKL+16]444To obtain the accuracy on the LAMBADA dataset, we use greedy decoding., and PTB [MSM93]. PTB and OWT are generally easier compared to other datasets. PTB focuses on simpler journalistic text, ideal for traditional language modeling, while OWT is diverse and informal, covering various topics but with less complexity in language structure or long-range dependencies. WikiText, with its formal structure and domain-specific content, requires a more complex understanding than OWT but remains manageable. LAMBADA is the most challenging, as it involves long-range dependencies, requiring the model to grasp broader contextual information to complete sentences accurately.

On a more formal level, our data 𝑿\bm{X}bold_italic_X will be text, or strings of characters; we let 𝒯\mathcal{T}caligraphic_T be the set of all strings.

7.4.2 Task and Objective

For causal language modeling pre-training, the idea is that we want to train the model to output human-like text. The most popular way to do this by far is to use a two-stage training process:555Modern language model training has several additional training steps which demand different data distributions and algorithm approaches. However, training a model to merely mimic human writing only requires these few presented steps.

  • First, we wish to learn a way to optimally encode documents as a sequence of basic (“building block”) strings, called tokens. This is called tokenization, and we build a tokenizer.

  • Second, we wish to learn a way to predict the distribution of a token given all previous tokens. This is called next-token prediction, and we build a language model.

This procedure actually dates back to Markov, who first noticed that natural language could be modeled by the eponymous Markov chain structure [Mar06] given an appropriate tokenization, and then to Shannon, who proposed doing this exact language modeling setup with a character-level tokenizer (i.e., each character is a token) and so-called “nnitalic_n-gram” (i.e., an explicit look-up table, calculated from training data, for the distribution of a token given the nnitalic_n previous tokens) in place of the language model [Sha48].666A recent study [LMZ+24] scaling up nnitalic_n-gram models has shown that they are able to model text reasonably well for large nnitalic_n, but of course the memory required to store such a lookup table is of order VnV^{n}italic_V start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT and hence completely intractable.

Training a Tokenizer

To build a tokenizer amounts to building a vocabulary 𝒱\mathcal{V}caligraphic_V, which is a set of tokens and has some pre-specified size VVitalic_V. There are several methods to do this. One popular algorithm is known as Byte Pair Encoding (BPE), which can be described as:

  • Start with a list of all unique characters in your training data, and their frequencies. Ensure that there are fewer than VVitalic_V such characters, and add each character as a separate string (“token”) to the vocabulary along with its frequency.

  • Until there are VVitalic_V tokens in the vocabulary:

    • Construct a token by taking the two most frequent existing tokens and merging them.

    • Compute this token’s frequency in the dataset.

    • Add it to the vocabulary (along with its frequency).

  • At this point, the frequency information is no longer needed and can be discarded.

The overall process of BPE is in Figure 7.12. Note that this procedure is a modification of a classical information-theoretic compression procedure for learning a lossless encoding of bytestream data (such as text), and as such, one can interpret it as finding an optimal lossless compression of the data. Notice that this is possible because (unlike images), the data here are fundamentally discrete and noise-free.

Figure 7.12 : The process of tokenizing text data using BPE. (Image credit to https://huggingface.co/learn/nlp-course/chapter6/5 ). (Left) We begin by analyzing the given text corpus and constructing an initial vocabulary that consists of individual characters (or bytes in the case of byte-level BPE). Then, we compute the frequencies of adjacent character pairs in the corpus. This involves scanning the entire text and counting how often each two-character sequence (bigram) appears. (Right) After computing the frequencies of adjacent character pairs, we identify the most frequent pair in the corpus. This pair is then merged into a new subword unit, which is added to the vocabulary as a single token. This process is repeated iteratively until the predefined vocabulary size is reached.
Figure 7.12 : The process of tokenizing text data using BPE. (Image credit to https://huggingface.co/learn/nlp-course/chapter6/5 ). (Left) We begin by analyzing the given text corpus and constructing an initial vocabulary that consists of individual characters (or bytes in the case of byte-level BPE). Then, we compute the frequencies of adjacent character pairs in the corpus. This involves scanning the entire text and counting how often each two-character sequence (bigram) appears. (Right) After computing the frequencies of adjacent character pairs, we identify the most frequent pair in the corpus. This pair is then merged into a new subword unit, which is added to the vocabulary as a single token. This process is repeated iteratively until the predefined vocabulary size is reached.
Figure 7.12: The process of tokenizing text data using BPE. (Image credit to https://huggingface.co/learn/nlp-course/chapter6/5). (Left) We begin by analyzing the given text corpus and constructing an initial vocabulary that consists of individual characters (or bytes in the case of byte-level BPE). Then, we compute the frequencies of adjacent character pairs in the corpus. This involves scanning the entire text and counting how often each two-character sequence (bigram) appears. (Right) After computing the frequencies of adjacent character pairs, we identify the most frequent pair in the corpus. This pair is then merged into a new subword unit, which is added to the vocabulary as a single token. This process is repeated iteratively until the predefined vocabulary size is reached.

After such a vocabulary is built, a tokenizer can break down a document into tokens (i.e., “tokenize” it). BPE uses a similar procedure to tokenize data as in training:

  • Separate the document into a long list of one-character-long tokens. That is, if the document is “Hello” then the initial list is ‘H’, ‘e’, ‘l’, ‘l’, ‘o’.

  • While any two adjacent tokens can be concatenated and their concatenation is another token, we do it, i.e., we replace this pair of tokens with the merged token. Namely, if ‘He’ is a token in the vocabulary, ‘H’, ‘e’, ‘l’, ‘l’, ‘o’ would become ‘He’, ‘l’, ‘l’, ‘o’.

  • Repeat the above process until no more merges can be done. At this point, the document is partitioned into the final list (sequence) of tokens.

There are many practical and efficiency-based considerations to take into account during tokenization. The above algorithm, as presented, is very far from optimal if naively implemented, for instance. We do not cover this topic in great detail; there are many resources online to learn more, such as HuggingFace tutorials.

For instance, each token has a corresponding index which is just its index in the vocabulary (which after all is just a list of length VVitalic_V). Thus, the output of most tokenizers is a list of indices, say an element of [V][V]^{*}[ italic_V ] start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT. Keep in mind that they correspond to substrings of the original document, as shown above.

Once a tokenizer is learned, it can be used as a black box by any language model. For instance, many models have the same (OpenAI-based) tokenizer based on the tiktoken library. In the remainder of this section, we will use such a fixed and pre-built tokenizer for everything, and thus identify each text document 𝑿𝒯\bm{X}\in\mathcal{T}bold_italic_X ∈ caligraphic_T with its tokenized version in [V][V]^{*}[ italic_V ] start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT. Therefore, we may as well consider the text space 𝒯\mathcal{T}caligraphic_T as equal to the space of token sequences [V][V]^{*}[ italic_V ] start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT (and lose nothing essential).

Training a Language Model

Once we have each document as a sequence of tokens 𝑿[V]N[V]=𝒯\bm{X}\in[V]^{N}\subseteq[V]^{*}=\mathcal{T}bold_italic_X ∈ [ italic_V ] start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT ⊆ [ italic_V ] start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT = caligraphic_T, we wish to perform next-token prediction. That is, given a context 𝑿:n[V]n\bm{X}_{:n}\in[V]^{n}bold_italic_X start_POSTSUBSCRIPT : italic_n end_POSTSUBSCRIPT ∈ [ italic_V ] start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT (i.e., the first nnitalic_n tokens 𝒙1,,𝒙n[V]\bm{x}_{1},\dots,\bm{x}_{n}\in[V]bold_italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , bold_italic_x start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ∈ [ italic_V ] in the document)777Note the incongruity with Python notation: here the notation includes index nnitalic_n., we wish to predict the token 𝒙n+1[V]\bm{x}_{n+1}\in[V]bold_italic_x start_POSTSUBSCRIPT italic_n + 1 end_POSTSUBSCRIPT ∈ [ italic_V ] at position n+1n+1italic_n + 1. To do this, we compute the aggregate feature of 𝑿:n\bm{X}_{:n}bold_italic_X start_POSTSUBSCRIPT : italic_n end_POSTSUBSCRIPT via 𝒛θ(𝑿:n)(fθextfθ)(𝑿:n)d\bm{z}_{\theta}(\bm{X}_{:n})\doteq(f_{\theta}^{\mathrm{ext}}\circ f_{\theta})(\bm{X}_{:n})\in\mathbb{R}^{d}bold_italic_z start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_X start_POSTSUBSCRIPT : italic_n end_POSTSUBSCRIPT ) ≐ ( italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ext end_POSTSUPERSCRIPT ∘ italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ) ( bold_italic_X start_POSTSUBSCRIPT : italic_n end_POSTSUBSCRIPT ) ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT, and use a classification head hθ:dΔVh_{\theta}\colon\mathbb{R}^{d}\to\Delta_{V}italic_h start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT : blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT → roman_Δ start_POSTSUBSCRIPT italic_V end_POSTSUBSCRIPT (implemented as either a linear layer, MLP, or something slightly more complicated) to project this feature into the VVitalic_V-dimensional probability simplex ΔV\Delta_{V}roman_Δ start_POSTSUBSCRIPT italic_V end_POSTSUBSCRIPT. This projection 𝒑θ(𝑿:n)hθ(𝒛θ(𝑿:n))\bm{p}_{\theta}(\bm{X}_{:n})\doteq h_{\theta}(\bm{z}_{\theta}(\bm{X}_{:n}))bold_italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_X start_POSTSUBSCRIPT : italic_n end_POSTSUBSCRIPT ) ≐ italic_h start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_z start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_X start_POSTSUBSCRIPT : italic_n end_POSTSUBSCRIPT ) ) serves as an estimated probability distribution of the next token. Then, using the notation 𝟏(𝒙n+1)ΔV\bm{1}(\bm{x}_{n+1})\in\Delta_{V}bold_1 ( bold_italic_x start_POSTSUBSCRIPT italic_n + 1 end_POSTSUBSCRIPT ) ∈ roman_Δ start_POSTSUBSCRIPT italic_V end_POSTSUBSCRIPT to be 111 in the 𝒙n+1\bm{x}_{n+1}bold_italic_x start_POSTSUBSCRIPT italic_n + 1 end_POSTSUBSCRIPTth component and 0 elsewhere, the causal language modeling loss is

minθ{CLM(θ)𝔼𝑿[1N1n=1N1CE(𝟏(𝒙n+1),𝒑θ(𝑿:n))]}\min_{\theta}\left\{\mathcal{L}_{\mathrm{CLM}}(\theta)\doteq\operatorname{\mathbb{E}}_{\bm{X}}\left[\frac{1}{N-1}\sum_{n=1}^{N-1}\operatorname{CE}(\bm{1}(\bm{x}_{n+1}),\bm{p}_{\theta}(\bm{X}_{:n}))\right]\right\}roman_min start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT { caligraphic_L start_POSTSUBSCRIPT roman_CLM end_POSTSUBSCRIPT ( italic_θ ) ≐ blackboard_E start_POSTSUBSCRIPT bold_italic_X end_POSTSUBSCRIPT [ divide start_ARG 1 end_ARG start_ARG italic_N - 1 end_ARG ∑ start_POSTSUBSCRIPT italic_n = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N - 1 end_POSTSUPERSCRIPT roman_CE ( bold_1 ( bold_italic_x start_POSTSUBSCRIPT italic_n + 1 end_POSTSUBSCRIPT ) , bold_italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_X start_POSTSUBSCRIPT : italic_n end_POSTSUBSCRIPT ) ) ] } (7.4.1)

Note how similar this is to a classification loss (say for images); one uses the cross-entropy and tries to align a predicted probability vector with the ground truth. The major difference between these two losses is that in this one, we compute the loss on a whole sequence, where each prediction is correlated with the others (unlike in the i.i.d. classification case).

Optimizing this loss is usually called “pre-training” in the language model community (contrast with “post-training” and, more recently, “mid-training”, which are methodologies to modify a next-token-predictor for useful tasks).

Side note: Why does the first term of (7.4.1) predict 𝟏(𝒙2)\bm{1}(\bm{x}_{2})bold_1 ( bold_italic_x start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ), and there is no term which measures the loss to predict the first token? It’s because if we wanted to predict the first token, we would have the empty sequence as context, and therefore make this first token prediction using a qualitatively different mechanism than that which applies to the other tokens. So actually this model is not trained to predict the very first token of any document. The reason this is OK is due to an implementation detail of the tokenizer: often, after building the tokenizer, we insert a special token into its vocabulary, called the beginning-of-string (or document) token and labeled <|bos|>.888There are usually several special tokens for different purposes. Existing text containing the special tokens are specially processed. Then, while processing each document, we add the <|bos|> token at the beginning of the document’s token sequence, increasing the length of the tokenized sequence by 111. Thus the above causal language modeling objective has a term which involves trying to predict the first token of the document given only the <|bos|> token as context, and so it is a conceptually correct loss.

7.4.3 Architecture: Causal CRATE

For the architecture, we use a standard GPT-2-style transformer, substituting CRATE layers for the transformer layers.999In direct contravention of the conventions in this book and those of many other communities, the NLP community calls such GPT-2-style transformers (encompassing nearly all current LLMs) “decoder-only” transformers. “Encoder-only” transformers have a different architecture, and “encoder-decoder” transformers concatenate an “encoder-only” transformer with a “decoder-only” transformer. This despite the fact that “decoder-only” transformers also compute an encoding of the data! For completeness, we specify the architecture here.

Embedding.

We first embed the token sequence 𝑿[V]N\bm{X}\in[V]^{N}bold_italic_X ∈ [ italic_V ] start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT to Euclidean space. This is often done by associating each index in [V][V][ italic_V ] with a vector in d\mathbb{R}^{d}blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT using a massive101010By “massive” we mean that such a structure is often a large fraction of the language model’s total size. array 𝑬V×d\bm{E}\in\mathbb{R}^{V\times d}bold_italic_E ∈ blackboard_R start_POSTSUPERSCRIPT italic_V × italic_d end_POSTSUPERSCRIPT, and directly forming the sequence [𝑬𝒙1,,𝑬𝒙N]d×N[\bm{E}_{\bm{x}_{1}},\dots,\bm{E}_{\bm{x}_{N}}]\in\mathbb{R}^{d\times N}[ bold_italic_E start_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT , … , bold_italic_E start_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT end_POSTSUBSCRIPT ] ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_N end_POSTSUPERSCRIPT. The full embedding map fθembf_{\theta}^{\mathrm{emb}}italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_emb end_POSTSUPERSCRIPT also applies a positional encoding 𝑬posd×Nmax\bm{E}^{\mathrm{pos}}\in\mathbb{R}^{d\times N_{\max}}bold_italic_E start_POSTSUPERSCRIPT roman_pos end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_N start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT end_POSTSUPERSCRIPT where NmaxN_{\max}italic_N start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT is the maximum number of tokens which are possible to process,111111Modern positional encoding methods have since taken care of this issue and allowed for (in theory) infinite extrapolation, but such methods are more complex to develop, and for simplicity we only introduce the absolute additive positional encoding here. which yields the embedding map

fθemb(𝑿)[𝑬𝒙1,,𝑬𝒙N]+𝑬:Nposf_{\theta}^{\mathrm{emb}}(\bm{X})\doteq[\bm{E}_{\bm{x}_{1}},\dots,\bm{E}_{\bm{x}_{N}}]+\bm{E}_{:N}^{\mathrm{pos}}italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_emb end_POSTSUPERSCRIPT ( bold_italic_X ) ≐ [ bold_italic_E start_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT , … , bold_italic_E start_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT end_POSTSUBSCRIPT ] + bold_italic_E start_POSTSUBSCRIPT : italic_N end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_pos end_POSTSUPERSCRIPT (7.4.2)

The parameters 𝑬\bm{E}bold_italic_E and 𝑬pos\bm{E}^{\mathrm{pos}}bold_italic_E start_POSTSUPERSCRIPT roman_pos end_POSTSUPERSCRIPT are directly trainable. Since 𝑬\bm{E}bold_italic_E is so large (and the gradient update is very sparse w.r.t. it since only a small fraction of the vocabulary is used in each sample), specialized software is used to make sure the memory updates are not too onerous. Notice also that we do not use a class token like in the other sections; more on this later.

Backbone.

We process the embeddings using a CRATE-like backbone which uses causal masking. To motivate causal masking, consider the causal language modeling loss CLM\mathcal{L}_{\mathrm{CLM}}caligraphic_L start_POSTSUBSCRIPT roman_CLM end_POSTSUBSCRIPT defined in (7.4.1). The most naive implementation would require us to compute hte forward pass NNitalic_N times in order to backpropagate once. Obviously this is extremely inefficient, since NNitalic_N can often be in the thousands. In order to scale training with this loss efficiently, we impose a causal constraint, i.e.,

𝒁θ(𝑿:n)=𝒁θ(𝑿):n\bm{Z}_{\theta}(\bm{X}_{:n})=\bm{Z}_{\theta}(\bm{X})_{:n}bold_italic_Z start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_X start_POSTSUBSCRIPT : italic_n end_POSTSUBSCRIPT ) = bold_italic_Z start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_X ) start_POSTSUBSCRIPT : italic_n end_POSTSUBSCRIPT (7.4.3)

i.e., the nnitalic_n columns of the token features 𝒁θ(𝑿:n)d×n\bm{Z}_{\theta}(\bm{X}_{:n})\in\mathbb{R}^{d\times n}bold_italic_Z start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_X start_POSTSUBSCRIPT : italic_n end_POSTSUBSCRIPT ) ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_n end_POSTSUPERSCRIPT should be the same as the first nnitalic_n columns of the token features 𝒁θ(𝑿)d×N\bm{Z}_{\theta}(\bm{X})\in\mathbb{R}^{d\times N}bold_italic_Z start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_X ) ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_N end_POSTSUPERSCRIPT regardless of the positive values of nnitalic_n and NNitalic_N such that NnN\geq nitalic_N ≥ italic_n. In effect, this means we can apply the backbone once to the whole sequence and compute 𝒁θ(𝑿)\bm{Z}_{\theta}(\bm{X})bold_italic_Z start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_X ), then apply fθextf_{\theta}^{\mathrm{ext}}italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ext end_POSTSUPERSCRIPT to each increasing subset 𝒁θ(𝑿:n)=𝒁θ(𝑿):n\bm{Z}_{\theta}(\bm{X}_{:n})=\bm{Z}_{\theta}(\bm{X})_{:n}bold_italic_Z start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_X start_POSTSUBSCRIPT : italic_n end_POSTSUBSCRIPT ) = bold_italic_Z start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_X ) start_POSTSUBSCRIPT : italic_n end_POSTSUBSCRIPT as nnitalic_n grows to the sequence length NNitalic_N. Then we can use all of those to compute the loss.

So now that we want a causal architecture for the backbone, how can we get it? Since the MLP and layer normalizations inside each transformer layer affect each token individually, the only thing that matters for causality is the attention block (or MSSA\operatorname{MSSA}roman_MSSA in the case of CRATE). In order to make MSSA\operatorname{MSSA}roman_MSSA causal, we define the CausalMSSA\mathrm{CausalMSSA}roman_CausalMSSA block as

CausalMSSAθ(𝒁)𝑼out[CausalSA([𝑼1,]𝒁,[𝑼1,]𝒁,[𝑼1,]𝒁)CausalSA([𝑼K,]𝒁,[𝑼K,]𝒁,[𝑼1,]𝒁)]+𝒃out𝟏N\displaystyle\operatorname{CausalMSSA}_{\theta}^{\ell}(\bm{Z})\doteq\bm{U}_{\mathrm{out}}^{\ell}\begin{bmatrix}\operatorname{CausalSA}([\bm{U}^{1,\ell}]^{\top}\bm{Z},[\bm{U}^{1,\ell}]^{\top}\bm{Z},[\bm{U}^{1,\ell}]^{\top}\bm{Z})\\ \vdots\\ \operatorname{CausalSA}([\bm{U}^{K,\ell}]^{\top}\bm{Z},[\bm{U}^{K,\ell}]^{\top}\bm{Z},[\bm{U}^{1,\ell}]^{\top}\bm{Z})\end{bmatrix}+\bm{b}_{\mathrm{out}}^{\ell}\bm{1}_{N}^{\top}roman_CausalMSSA start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ( bold_italic_Z ) ≐ bold_italic_U start_POSTSUBSCRIPT roman_out end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT [ start_ARG start_ROW start_CELL roman_CausalSA ( [ bold_italic_U start_POSTSUPERSCRIPT 1 , roman_ℓ end_POSTSUPERSCRIPT ] start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_Z , [ bold_italic_U start_POSTSUPERSCRIPT 1 , roman_ℓ end_POSTSUPERSCRIPT ] start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_Z , [ bold_italic_U start_POSTSUPERSCRIPT 1 , roman_ℓ end_POSTSUPERSCRIPT ] start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_Z ) end_CELL end_ROW start_ROW start_CELL ⋮ end_CELL end_ROW start_ROW start_CELL roman_CausalSA ( [ bold_italic_U start_POSTSUPERSCRIPT italic_K , roman_ℓ end_POSTSUPERSCRIPT ] start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_Z , [ bold_italic_U start_POSTSUPERSCRIPT italic_K , roman_ℓ end_POSTSUPERSCRIPT ] start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_Z , [ bold_italic_U start_POSTSUPERSCRIPT 1 , roman_ℓ end_POSTSUPERSCRIPT ] start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_Z ) end_CELL end_ROW end_ARG ] + bold_italic_b start_POSTSUBSCRIPT roman_out end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT bold_1 start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT (7.4.4)
where CausalSA(𝑸,𝑲,𝑽)𝑽softmax(CausalMask(𝑲𝑸)p)\displaystyle\operatorname{CausalSA}(\bm{Q},\bm{K},\bm{V})\doteq\bm{V}\operatorname{\mathrm{softmax}}\left(\frac{\operatorname{CausalMask}(\bm{K}^{\top}\bm{Q})}{\sqrt{p}}\right)roman_CausalSA ( bold_italic_Q , bold_italic_K , bold_italic_V ) ≐ bold_italic_V roman_softmax ( divide start_ARG roman_CausalMask ( bold_italic_K start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_Q ) end_ARG start_ARG square-root start_ARG italic_p end_ARG end_ARG ) (7.4.5)
where CausalMask(𝑴)ij={Mij,ifij,,ifi<j.\displaystyle\operatorname{CausalMask}(\bm{M})_{ij}=\begin{cases}M_{ij},&\text{if}\ i\geq j,\\ -\infty,&\text{if}\ i<j\end{cases}.roman_CausalMask ( bold_italic_M ) start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT = { start_ROW start_CELL italic_M start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT , end_CELL start_CELL if italic_i ≥ italic_j , end_CELL end_ROW start_ROW start_CELL - ∞ , end_CELL start_CELL if italic_i < italic_j end_CELL end_ROW . (7.4.6)

Here, practitioners say that the causal mask allows future tokens iiitalic_i to attend to past tokens jjitalic_j but not vice versa. To see why, let us write out the expression for the ttitalic_tth column of CausalSA(𝑸,𝑲,𝑽)\operatorname{CausalSA}(\bm{Q},\bm{K},\bm{V})roman_CausalSA ( bold_italic_Q , bold_italic_K , bold_italic_V ):

CausalSA(𝑸,𝑲,𝑽)t=i=1t𝑽isoftmax([𝑲:t]𝑸t)i\operatorname{CausalSA}(\bm{Q},\bm{K},\bm{V})_{t}=\sum_{i=1}^{t}\bm{V}_{i}\operatorname{\mathrm{softmax}}\left([\bm{K}_{:t}]^{\top}\bm{Q}_{t}\right)_{i}roman_CausalSA ( bold_italic_Q , bold_italic_K , bold_italic_V ) start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT bold_italic_V start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT roman_softmax ( [ bold_italic_K start_POSTSUBSCRIPT : italic_t end_POSTSUBSCRIPT ] start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_Q start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT (7.4.7)

(where here the non-colon subscript denotes the column). This expression for the ttitalic_tth token uses no information about any token beyond index ttitalic_t. Therefore CausalSA\operatorname{CausalSA}roman_CausalSA, hence CausalMSSA\operatorname{CausalMSSA}roman_CausalMSSA, hence the whole causal CRATE backbone is causal in terms of the definition in (7.4.3), and we unlock the considerable efficiency gains that we were promised.

Feature extractor.

We use a post-processing step fθextf_{\theta}^{\mathrm{ext}}italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ext end_POSTSUPERSCRIPT which extracts the feature vector of the last known token so as to predict the next token. In theory, this means that each token 𝒁θ(𝑿)n\bm{Z}_{\theta}(\bm{X})_{n}bold_italic_Z start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_X ) start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT should contain rich information about all tokens that come before or at index nnitalic_n, i.e., 𝒙1,,𝒙n\bm{x}_{1},\dots,\bm{x}_{n}bold_italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , bold_italic_x start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT, as all of this information should be available for predicting the next token at index n+1n+1italic_n + 1. In practice, only a few of these tokens are really needed for each prediction task. Anyways, the equation for fθextf_{\theta}^{\mathrm{ext}}italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ext end_POSTSUPERSCRIPT is

fθext(𝒁θ(𝑿:n))(𝒁θ(𝑿))nf_{\theta}^{\mathrm{ext}}(\bm{Z}_{\theta}(\bm{X}_{:n}))\doteq(\bm{Z}_{\theta}(\bm{X}))_{n}italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ext end_POSTSUPERSCRIPT ( bold_italic_Z start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_X start_POSTSUBSCRIPT : italic_n end_POSTSUBSCRIPT ) ) ≐ ( bold_italic_Z start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_X ) ) start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT (7.4.8)

where (again) the non-colon subscript is the column. In this case, as promised, we just directly extract the feature vector of the last token in the sequence.

Task-specific head.

For our classification head hθh_{\theta}italic_h start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT, the GPT-2 architecture uses a simple linear layer and a softmax to get the desired probability vectors:

hθ(𝒛)softmax(𝑾out𝒛+𝒃out),h_{\theta}(\bm{z})\doteq\operatorname{\mathrm{softmax}}(\bm{W}^{\mathrm{out}}\bm{z}+\bm{b}^{\mathrm{out}}),italic_h start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_z ) ≐ roman_softmax ( bold_italic_W start_POSTSUPERSCRIPT roman_out end_POSTSUPERSCRIPT bold_italic_z + bold_italic_b start_POSTSUPERSCRIPT roman_out end_POSTSUPERSCRIPT ) , (7.4.9)

where 𝑾outV×d,𝒃outV\bm{W}^{\mathrm{out}}\in\mathbb{R}^{V\times d},\bm{b}^{\mathrm{out}}\in\mathbb{R}^{V}bold_italic_W start_POSTSUPERSCRIPT roman_out end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_V × italic_d end_POSTSUPERSCRIPT , bold_italic_b start_POSTSUPERSCRIPT roman_out end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_V end_POSTSUPERSCRIPT. Some other more modern architectures use small MLPs and layer normalizations, but the idea is very much the same. Note that this linear layers also have large memory usage (because VVitalic_V is very large), and form a bottleneck in training; there has been significant effort attempting to circumvent it.

All these architectural choices mean that causal training is extremely efficient relative to non-causal training:

  • We only need one forward pass through the backbone to compute the loss for the whole sequence.

  • The feature extraction is basically free.

  • All tokens can be pushed through the task-specific head in parallel.

7.4.4 Optimization Strategy

We train our language model using end-to-end stochastic optimization. One remaining issue is that, in practice, different documents in a batch have different lengths (in terms of the number of tokens required for each sequence), but as of the time of writing this book, the main deep learning frameworks by and large allow only “rectangular” tensors, which do not accommodate this behavior. To try to resolve this issue, we just insert a special padding token <|pad|> for all shorter samples in the batch, so that we can batch-process everything using rectangular tensors. At each timestep kkitalic_k, we:

  • Subsample BBitalic_B different tokenized documents {𝑿b(k)}b=1B𝒯=[V]\{\bm{X}_{b}^{(k)}\}_{b=1}^{B}\subseteq\mathcal{T}=[V]^{*}{ bold_italic_X start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT } start_POSTSUBSCRIPT italic_b = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_B end_POSTSUPERSCRIPT ⊆ caligraphic_T = [ italic_V ] start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT, each with length Nb(k)N_{b}^{(k)}italic_N start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT.

  • Compute Nmax(k)maxb[B]Nb(k)N_{\max}^{(k)}\doteq\max_{b\in[B]}N_{b}^{(k)}italic_N start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT ≐ roman_max start_POSTSUBSCRIPT italic_b ∈ [ italic_B ] end_POSTSUBSCRIPT italic_N start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT and pad each 𝑿b(k)\bm{X}_{b}^{(k)}bold_italic_X start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT to length Nmax(k)N_{\max}^{(k)}italic_N start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT using a special padding token.

  • Compute the features 𝒁θ(𝑿b(k))\bm{Z}_{\theta}(\bm{X}_{b}^{(k)})bold_italic_Z start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_X start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT ).

  • Compute the predicted distributions 𝒑θ(𝑿b,:n(k))(hθfθext)(𝒁θ(𝑿b(k)):n)\bm{p}_{\theta}(\bm{X}_{b,:n}^{(k)})\doteq(h_{\theta}\circ f_{\theta}^{\mathrm{ext}})(\bm{Z}_{\theta}(\bm{X}_{b}^{(k)})_{:n})bold_italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_X start_POSTSUBSCRIPT italic_b , : italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT ) ≐ ( italic_h start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ∘ italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ext end_POSTSUPERSCRIPT ) ( bold_italic_Z start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_X start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT ) start_POSTSUBSCRIPT : italic_n end_POSTSUBSCRIPT ).

  • Form the surrogate stochastic loss

    ^CLM(k)(θ)1B(Nmax(k)1)b=1Bn=1Nmax(k)1CE(𝟏(𝒙b,n+1(k)),𝒑θ(𝑿b,:n(k)))).\hat{\mathcal{L}}_{\mathrm{CLM}}^{(k)}(\theta)\doteq\frac{1}{B(N_{\max}^{(k)}-1)}\sum_{b=1}^{B}\sum_{n=1}^{N_{\max}^{(k)}-1}\operatorname{CE}(\bm{1}(\bm{x}_{b,n+1}^{(k)}),\bm{p}_{\theta}(\bm{X}_{b,:n}^{(k)}))).over^ start_ARG caligraphic_L end_ARG start_POSTSUBSCRIPT roman_CLM end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT ( italic_θ ) ≐ divide start_ARG 1 end_ARG start_ARG italic_B ( italic_N start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT - 1 ) end_ARG ∑ start_POSTSUBSCRIPT italic_b = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_B end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_n = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT roman_CE ( bold_1 ( bold_italic_x start_POSTSUBSCRIPT italic_b , italic_n + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT ) , bold_italic_p start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_X start_POSTSUBSCRIPT italic_b , : italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT ) ) ) . (7.4.10)
  • Compute one step of an optimization algorithm on θ\thetaitalic_θ, giving the following iteration:

    θ(k+1)OptUpdate(k)(θ(k);θ^CLM(k)).\theta^{(k+1)}\doteq\textsc{OptUpdate}^{(k)}(\theta^{(k)};\nabla_{\theta}\hat{\mathcal{L}}_{\mathrm{CLM}}^{(k)}).italic_θ start_POSTSUPERSCRIPT ( italic_k + 1 ) end_POSTSUPERSCRIPT ≐ OptUpdate start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT ( italic_θ start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT ; ∇ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT over^ start_ARG caligraphic_L end_ARG start_POSTSUBSCRIPT roman_CLM end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT ) . (7.4.11)

7.4.5 Evaluation Methodology

There are several ways to evaluate a trained transformer language model.

  • On a holdout dataset of arbitrary text, we can evaluate CLM\mathcal{L}_{\mathrm{CLM}}caligraphic_L start_POSTSUBSCRIPT roman_CLM end_POSTSUBSCRIPT on it; lower losses are better since they mean the model’s sampling yields better performance.

  • On a multiple choice question dataset, for each question we can put it as the context and check the estimated probability of the correct answer being generated.

  • We can also test the text generation capabilities. Namely, we can repeatedly sample from the model’s probability distribution over the next token given the context. Each time we sample we generate a new token, which we print and add to the context. This allows us to sample from the LLM, and judge the generated samples however we please.121212Having to re-run the model on each token every time can become prohibitively expensive. Clever storages of different internal features of the language model (such as the so-called KKitalic_K-VVitalic_V cache), along with the causality of the architecture, can dramatically reduce the cost of sampling.

In this section, we perform the first kind of evaluation exclusively.

7.4.6 Experimental Setup and Results

Since our causal CRATE architecture is directly built upon GPT-2, we compare the optimal settings for GPT-2 as given by the NanoGPT repository [Kar22] with the same settings applied to CRATE for a fair comparison.

Model architecture.

We use the GPT-2 tokenizer, which has vocabulary size V=50257V=50257italic_V = 50257, including a special token for <|pad|>.131313The <|bos|> token is not included in this setup, although it is very common in modern language models. The context length is Nmax=1024N_{\max}=1024italic_N start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT = 1024. The backbone model follows the GPT2-Base architecture [RWC+19] with the appropriate alterations to have causal CRATE layers, and we compare against GPT2-Small and GPT2-Base.

Datasets and optimization.

For training causal CRATE, we follow the implementations in the NanoGPT repository [Kar22]. Specifically, we use a batch size of 384 and train for 600,000 steps with the Adam optimizer [KB14]. For the Adam optimizer, we use (β1,β2)=(0.9,0.95)(\beta_{1},\beta_{2})=(0.9,0.95)( italic_β start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_β start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) = ( 0.9 , 0.95 ) and weight decay of 0.10.10.1. For the learning rate schedule, we apply a linear warm-up and cosine decay, with a peak value of η=6×104\eta=6\times 10^{-4}italic_η = 6 × 10 start_POSTSUPERSCRIPT - 4 end_POSTSUPERSCRIPT at the 2,0002,0002 , 000th iteration, and minimum value 6×1056\times 10^{-5}6 × 10 start_POSTSUPERSCRIPT - 5 end_POSTSUPERSCRIPT. The training and validation losses over iterations are shown in Figure 7.13. The training/validation loss converges around 3.373.373.37 after training with a batch size of 384384384 and 600,000600,000600 , 000 iterations. In comparison, the open GPT-2 implementation is pre-trained on OpenWebText with a batch size of 512512512 and 600,000600,000600 , 000 steps and converges to a validation loss of 2.852.852.85 [Kar22].

Figure 7.13 : The loss curve of CRATE-GPT-Base trained on the OpenWebText dataset.
Figure 7.13: The loss curve of CRATE-GPT-Base trained on the OpenWebText dataset.
Experiment results.

Table 7.5 demonstrates that CRATE models achieve reasonable performance on the causal language modeling loss across a variety of datasets compared to GPT-2 models with similar parameter counts and similar architectures.

Table 7.5: Zero-shot cross-entropy loss of the CRATE-GPT2-Base model and GPT2-Small, GPT2-Base model evaluated on the test split of the datasets (\downarrow lower is better).
#parameters OWT LAMBADA WikiText PTB Avg
GPT2-Base 124M 2.85\downarrow 4.12\downarrow 3.89\downarrow 4.63\downarrow 3.87\downarrow
GPT2-Small 64M 3.04 4.49 4.31 5.15 4.25
Causal-CRATE-Base 60M 3.37 4.91 4.61 5.53 4.61

7.5 Scaling White-Box Transformers

In this section, we will discuss three ways in which various parts of CRATE-type models can be scaled up or made more efficient while still remaining white-box. These developments mix both conceptual and empirical insights, and can be viewed as case studies about how to use white-box understanding to improve deep learning models in practice. The tasks that we use to evaluate the methods will be image classification and next-token-prediction, the data will be ImageNet and OpenWebText respectively, the optimization procedure will be the same backpropagation, and the only thing that changes is the architecture.

7.5.1 Increasing Network Width: CRATE-α\alphaitalic_α

Figure 7.14 : One layer of the CRATE- α \alpha italic_α backbone. The difference from CRATE is that the ISTA θ ℓ \operatorname{ISTA}_{\theta}^{\ell} roman_ISTA start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT block is replaced by the ODL θ ℓ \operatorname{ODL}_{\theta}^{\ell} roman_ODL start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT block, which performs several ISTA \operatorname{ISTA} roman_ISTA steps with an overcomplete dictionary.
Figure 7.14: One layer of the CRATE-α\alphaitalic_α backbone. The difference from CRATE is that the ISTAθ\operatorname{ISTA}_{\theta}^{\ell}roman_ISTA start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT block is replaced by the ODLθ\operatorname{ODL}_{\theta}^{\ell}roman_ODL start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT block, which performs several ISTA\operatorname{ISTA}roman_ISTA steps with an overcomplete dictionary.

One design decision enforced by the CRATE framework is the width of the nonlinearity in the network. In a regular transformer, the width is usually set to 444, 888, or 113\frac{11}{3}divide start_ARG 11 end_ARG start_ARG 3 end_ARG times the feature dimension. However, CRATE enforces that the width is exactly equal to the feature dimension, i.e., the dictionaries 𝑫\bm{D}^{\ell}bold_italic_D start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT are square, which could lead to reduced performance. The fundamental reason that the CRATE framework constrains us to this choice is as follows:

  • The ISTA block takes a single step of optimization for dictionary learning.

  • Usually one step of any iterative optimization algorithm cannot effectively optimize the objective. So then why does this work?

  • Optimization algorithms usually converge very quickly if and only if they have good initializations, or warm starts. The ISTA block has a warm start — it treats the input features as an initialization to the resulting sparse codes.

  • This enforces that the input features and sparse codes have the same dimension. Namely, ISTA learns a complete sparsifying dictionary (cf Chapter 2).

Thus if we want to use a wide dictionary, we need ISTA to perform overcomplete dictionary learning. This means we cannot have the same warm start (as our sparse codes have a larger dimension than our features), and need more iterations to converge to a sparse code. Hence the step from features 𝒁θ+1/2\bm{Z}_{\theta}^{\ell+1/2}bold_italic_Z start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ + 1 / 2 end_POSTSUPERSCRIPT to sparse codes 𝒁θ+1\bm{Z}_{\theta}^{\ell+1}bold_italic_Z start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ + 1 end_POSTSUPERSCRIPT would no longer be

𝒁θ+1=ISTAθ(𝒁θ+1/2𝒁θ+1/2)\bm{Z}_{\theta}^{\ell+1}=\operatorname{ISTA}_{\theta}^{\ell}(\bm{Z}_{\theta}^{\ell+1/2}\mid\bm{Z}_{\theta}^{\ell+1/2})bold_italic_Z start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ + 1 end_POSTSUPERSCRIPT = roman_ISTA start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ( bold_italic_Z start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ + 1 / 2 end_POSTSUPERSCRIPT ∣ bold_italic_Z start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ + 1 / 2 end_POSTSUPERSCRIPT ) (7.5.1)

where the ISTAθ\operatorname{ISTA}_{\theta}^{\ell}roman_ISTA start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT function is defined as (by an abuse of notation from earlier sections)

ISTAθ(𝒁𝒀)ReLU(𝒁β(𝑫)(𝑫𝒁𝒀)+βλ𝟏s𝟏n)\operatorname{ISTA}_{\theta}^{\ell}(\bm{Z}\mid\bm{Y})\doteq\operatorname{ReLU}(\bm{Z}-\beta(\bm{D}^{\ell})^{\top}(\bm{D}^{\ell}\bm{Z}-\bm{Y})+\beta\lambda\bm{1}_{s}\bm{1}_{n}^{\top})roman_ISTA start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ( bold_italic_Z ∣ bold_italic_Y ) ≐ roman_ReLU ( bold_italic_Z - italic_β ( bold_italic_D start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ( bold_italic_D start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT bold_italic_Z - bold_italic_Y ) + italic_β italic_λ bold_1 start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT bold_1 start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) (7.5.2)

but rather the following iteration:

𝒁θ+1=𝑨θ,T;𝑨θ,t+1=ISTAθ(𝑨θ,t𝒁θ+1/2)0t<T;𝑨θ,0=𝟎s×n,\bm{Z}_{\theta}^{\ell+1}=\bm{A}_{\theta}^{\ell,T};\qquad\bm{A}_{\theta}^{\ell,t+1}=\operatorname{ISTA}_{\theta}^{\ell}(\bm{A}_{\theta}^{\ell,t}\mid\bm{Z}_{\theta}^{\ell+1/2})\quad\forall 0\leq t<T;\qquad\bm{A}_{\theta}^{\ell,0}=\bm{0}_{s\times n},bold_italic_Z start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ + 1 end_POSTSUPERSCRIPT = bold_italic_A start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ , italic_T end_POSTSUPERSCRIPT ; bold_italic_A start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ , italic_t + 1 end_POSTSUPERSCRIPT = roman_ISTA start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ( bold_italic_A start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ , italic_t end_POSTSUPERSCRIPT ∣ bold_italic_Z start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ + 1 / 2 end_POSTSUPERSCRIPT ) ∀ 0 ≤ italic_t < italic_T ; bold_italic_A start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ , 0 end_POSTSUPERSCRIPT = bold_0 start_POSTSUBSCRIPT italic_s × italic_n end_POSTSUBSCRIPT , (7.5.3)

i.e., running proximal gradient on the LASSO objective for T1T\geq 1italic_T ≥ 1 steps in the forward pass at each layer, initialized at 𝟎s×n\bm{0}_{s\times n}bold_0 start_POSTSUBSCRIPT italic_s × italic_n end_POSTSUBSCRIPT. In this circumstance, the dictionary can be as wide as needed, i.e., 𝑫s×d\bm{D}^{\ell}\in\mathbb{R}^{s\times d}bold_italic_D start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_s × italic_d end_POSTSUPERSCRIPT where sds\geq ditalic_s ≥ italic_d (usually one takes s=4ds=4ditalic_s = 4 italic_d in practice).

Detection Segmentation
Model AP50{}_{50}\uparrowstart_FLOATSUBSCRIPT 50 end_FLOATSUBSCRIPT ↑ AP75{}_{75}\uparrowstart_FLOATSUBSCRIPT 75 end_FLOATSUBSCRIPT ↑ AP \uparrow AP50{}_{50}\uparrowstart_FLOATSUBSCRIPT 50 end_FLOATSUBSCRIPT ↑ AP75{}_{75}\uparrowstart_FLOATSUBSCRIPT 75 end_FLOATSUBSCRIPT ↑ AP \uparrow
CRATE-α\alphaitalic_α-B/8 3.5 1.1 1.5 2.2 1.0 1.1
CRATE-α\alphaitalic_α-L/8 4.0 1.7 2.0 2.7 1.1 1.4
CRATE-B/8 2.9 1.0 1.3 2.2 0.7 1.0
ViT-B/8 0.8 0.2 0.4 0.7 0.5 0.4
Table 7.6: Object detection and fine-grained segmentation via MaskCut on COCO val2017 [LMB+14]. Here all models are trained with patch size 888 instead of 161616. Compared with existing models such as CRATE and ViT, the CRATE-α\alphaitalic_α model family noticeably has improved performance as well as scalability.

However, this presents an empirical problem. Using the above configuration, if 𝒁+1/2d×n\bm{Z}^{\ell+1/2}\in\mathbb{R}^{d\times n}bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ + 1 / 2 end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_n end_POSTSUPERSCRIPT, then 𝒁+1s×n\bm{Z}^{\ell+1}\in\mathbb{R}^{s\times n}bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ + 1 end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_s × italic_n end_POSTSUPERSCRIPT, which can have an arbitrarily large feature dimension. In practice, we want the feature dimension at each layer to be the same. So this sets up a practical trichotomy for designing the network, namely, we cannot have all of the following desiderata:

  1. 1.

    The feature dimension at each layer is the same.

  2. 2.

    The dictionary is wide, i.e., overcomplete.

  3. 3.

    The output of the nonlinearity is the sparse codes of the input with respect to the dictionary.

In practice, giving up (1) is less tractable for efficiency reasons. Giving up (2) leads to the usual CRATE framework. Giving up (3) leads to a wide version of CRATE, i.e., CRATE-α\alphaitalic_α, which has the following nonlinearity to get from 𝒁+1/2\bm{Z}^{\ell+1/2}bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ + 1 / 2 end_POSTSUPERSCRIPT to 𝒁+1\bm{Z}^{\ell+1}bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ + 1 end_POSTSUPERSCRIPT:

𝒁θ+1=𝑫𝑨θ,T;𝑨θ,t+1=ISTAθ(𝑨θ,t𝒁θ+1/2);𝑨θ,0=𝟎,\bm{Z}_{\theta}^{\ell+1}=\bm{D}^{\ell}\bm{A}_{\theta}^{\ell,T};\qquad\bm{A}_{\theta}^{\ell,t+1}=\operatorname{ISTA}_{\theta}^{\ell}(\bm{A}_{\theta}^{\ell,t}\mid\bm{Z}_{\theta}^{\ell+1/2});\qquad\bm{A}_{\theta}^{\ell,0}=\bm{0},bold_italic_Z start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ + 1 end_POSTSUPERSCRIPT = bold_italic_D start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT bold_italic_A start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ , italic_T end_POSTSUPERSCRIPT ; bold_italic_A start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ , italic_t + 1 end_POSTSUPERSCRIPT = roman_ISTA start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ( bold_italic_A start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ , italic_t end_POSTSUPERSCRIPT ∣ bold_italic_Z start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ + 1 / 2 end_POSTSUPERSCRIPT ) ; bold_italic_A start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ , 0 end_POSTSUPERSCRIPT = bold_0 , (7.5.4)

i.e., it takes the sparse codes obtained via proximal gradient descent and multiplies by the dictionary to get the denoised version of the input. Thus CRATE-α\alphaitalic_α’s nonlinearity computes a denoised version of the input which is amenable to sparse coding, not the actual sparse codes themselves. The map from 𝒁θ+1/2\bm{Z}_{\theta}^{\ell+1/2}bold_italic_Z start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ + 1 / 2 end_POSTSUPERSCRIPT to 𝒁θ+1\bm{Z}_{\theta}^{\ell+1}bold_italic_Z start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ + 1 end_POSTSUPERSCRIPT here is called the Overcomplete Dictionary Learning (ODL) block and denoted ODLθ\operatorname{ODL}_{\theta}^{\ell}roman_ODL start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT, i.e.,

𝒁θ+1(𝑿)ODLθ(𝒁θ+1/2(𝑿)).\bm{Z}_{\theta}^{\ell+1}(\bm{X})\doteq\operatorname{ODL}_{\theta}^{\ell}(\bm{Z}_{\theta}^{\ell+1/2}(\bm{X})).bold_italic_Z start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ + 1 end_POSTSUPERSCRIPT ( bold_italic_X ) ≐ roman_ODL start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ( bold_italic_Z start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ + 1 / 2 end_POSTSUPERSCRIPT ( bold_italic_X ) ) . (7.5.5)
Figure 7.15 : Saliency maps from CRATE- α \alpha italic_α with patch size 8 8 8 . Each row is a different image and each column corresponds to a different attention head in the last layer. We observe that the saliency maps strongly correspond to the objects in the input image.
Figure 7.15: Saliency maps from CRATE-α\alphaitalic_α with patch size 888. Each row is a different image and each column corresponds to a different attention head in the last layer. We observe that the saliency maps strongly correspond to the objects in the input image.
Model GPT-2-B(ase) CRATE-B CRATE-α\alphaitalic_α-S(mall) CRATE-α\alphaitalic_α-B
# parameters 124M 60M 57M 120M
OWT val. loss 2.85 3.37 3.28 3.14
Table 7.7: Validation loss in language modeling. Here all models are pre-trained on most of OpenWebText, and the validation cross-entropy loss is measured on a hold-out subset of OpenWebText. CRATE-α\alphaitalic_α shows significant improvement over the CRATE design, though there still exists a gap with traditional transformers like GPT-2.

The CRATE-α\alphaitalic_α layer is shown in Figure 7.14. In practice this modification of CRATE performs very well at larger scales. For example, when we pre-train CRATE-α\alphaitalic_α models on ImageNet-21K, unsupervised tasks like segmentation (see Figure 7.15 and Table 7.6) generally have significantly improved performance compared to CRATE. Similar trends are present in language model training using causal self-attention (see Table 7.7). Overall, it is a promising avenue to scaling up the performance to match black-box models such as transformers.141414Note that the experimental results in this section use a slightly different model architecture, which add very slight empirical gains. The changes are: (1) an additional residual connection on the ODL block, (2) modifying ISTA\operatorname{ISTA}roman_ISTA to use two separate dictionaries instead of 𝑫\bm{D}^{\ell}bold_italic_D start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT and (𝑫)(\bm{D}^{\ell})^{\top}( bold_italic_D start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT.

7.5.2 Linear Time Complexity Transformers

In practice, deep learning models suffer from bottlenecks in space and time complexity, representing problem sizes beyond which they cannot scale given fixed resources. One such bottleneck, particularly meaningful when dealing with data where each sample is itself high-dimensional and rich (such as long streams of text or videos), is the time complexity of processing long sequences of data. In order to alleviate the time complexity of processing data using transformers, in Section 4.3.2 we proposed a token statistics self-attention operator TSSAθ\operatorname{TSSA}_{\theta}^{\ell}roman_TSSA start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT. We now build a token statistics transformer, called ToST, around it, which we can use for long-context tasks. In particular, we can use the following layer (depicted in Figure 7.16) as a drop-in replacement for a backbone layer in CRATE:

𝒁θ+1/2(𝑿)\displaystyle\bm{Z}_{\theta}^{\ell+1/2}(\bm{X})bold_italic_Z start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ + 1 / 2 end_POSTSUPERSCRIPT ( bold_italic_X ) =𝒁θ(𝑿)+TSSAθ(LNθ1,(𝒁θ(𝑿)))\displaystyle=\bm{Z}_{\theta}^{\ell}(\bm{X})+\operatorname{TSSA}_{\theta}^{\ell}(\operatorname{LN}_{\theta}^{1,\ell}(\bm{Z}_{\theta}^{\ell}(\bm{X})))= bold_italic_Z start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ( bold_italic_X ) + roman_TSSA start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ( roman_LN start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 , roman_ℓ end_POSTSUPERSCRIPT ( bold_italic_Z start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ( bold_italic_X ) ) ) (7.5.6)
𝒁θ+1(𝑿)\displaystyle\bm{Z}_{\theta}^{\ell+1}(\bm{X})bold_italic_Z start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ + 1 end_POSTSUPERSCRIPT ( bold_italic_X ) =𝒁θ+1/2(𝑿)+MLPθ(LNθ2,(𝒁θ+1/2(𝑿)))\displaystyle=\bm{Z}_{\theta}^{\ell+1/2}(\bm{X})+\operatorname{MLP}_{\theta}^{\ell}(\operatorname{LN}_{\theta}^{2,\ell}(\bm{Z}_{\theta}^{\ell+1/2}(\bm{X})))= bold_italic_Z start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ + 1 / 2 end_POSTSUPERSCRIPT ( bold_italic_X ) + roman_MLP start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ( roman_LN start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 , roman_ℓ end_POSTSUPERSCRIPT ( bold_italic_Z start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ + 1 / 2 end_POSTSUPERSCRIPT ( bold_italic_X ) ) ) (7.5.7)

where the TSSA\operatorname{TSSA}roman_TSSA block is defined as in Section 4.3.2. Notice that this is exactly the same as the vision transformer architecture discussed in Section 7.2.3, except that TSSA\operatorname{TSSA}roman_TSSA replaces the conventional multi-head self-attention block MHSA\operatorname{MHSA}roman_MHSA. Regardless, the computational complexity of the forward pass of this layer is linear in all problem variables — sequence length, feature dimension, number of heads, and head dimension.

Figure 7.16 : One layer of the ToST backbone . Token representations go through layer-norms, the token statistics self-attention (TSSA) operator, and an MLP, in order to form the layer’s output.
Figure 7.16: One layer of the ToST backbone. Token representations go through layer-norms, the token statistics self-attention (TSSA) operator, and an MLP, in order to form the layer’s output.
Datasets ToST-T(iny) ToST-S(mall) ToST-M(edium) XCiT-S XCiT-M ViT-S ViT-B(ase)
# parameters 5.8M 22.6M 68.1M 24.9M 80.2M 22.1M 86.6 M
ImageNet 67.3 77.9 80.3 80.5 81.5 79.8 81.8
ImageNet ReaL 72.2 84.1 85.6 85.6 85.9 85.6 86.7
CIFAR10 95.5 96.5 97.5 98.1 98.3 98.6 98.8
CIFAR100 78.3 82.7 84.5 86.1 87.6 88.8 89.3
Oxford Flowers-102 88.6 92.8 94.2 93.9 94.0 94.0 95.7
Oxford-IIIT-Pets 85.6 91.1 92.8 92.9 94.0 92.8 94.1
Table 7.8: Linear probing classification accuracy of ToST on various datasets with different model sizes when the backbone is pre-trained for ImageNet-1K classification. We observe that, compared to the XCiT (a empirically-designed transformer-like architecture specialized for efficient processing of long sequences) and the ViT, ToST maintains relatively similar performance, even while enjoying benefits like faster runtime and white-box design.
Model # params OWT Lambada Wikitext PTB Avg \downarrow
GPT-2-Base 124M 2.84 4.32 4.13 5.75 4.26
ToST-Base 110M 3.20 4.98 4.77 6.39 4.84
ToST-Medium 304M 2.88 4.45 4.30 5.64 4.32
ToST-Large 655M 2.72 4.32 3.99 5.03 4.02
Table 7.9: Language modeling validation loss computed on (holdout sets of) a variety of natural language datasets, after pre-training the model on that dataset. We observe that ToST scales well, so that ToST-Large surpasses the baseline GPT-2-Base in causal language modeling, while enjoying superior efficiency in long contexts.
Model ListOps Text Retrieval Image Pathfinder Avg
Reformer 37.27 56.10 53.40 38.07 68.50 50.56
BigBird 36.05 64.02 59.29 40.83 74.87 54.17
LinFormer 16.13 65.90 53.09 42.34 75.30 50.46
Performer 18.01 65.40 53.82 42.77 77.05 51.18
Transformer 37.11 65.21 79.14 42.94 71.83 59.24
ToST 37.25 66.75 79.46 46.62 69.41 59.90
Table 7.10: Long-Range Arena (LRA) performance comparison of ToST(-B) versus the top transformer variants optimized for long-context. Long-Range Arena is a family of benchmarks that test the long sequence modeling capability of algorithms and architectures, by fixing the dataset and evaluation mechanism. ToST scores at the top of the leaderboard compared to all known transformer variants, including XCiT and the regular (ViT) transformer (cf Table 7.8). Moreover, ToST has the lowest time- and space-complexity inference. (In this table, the best score for a particular benchmark is bolded, and the second-best score is underlined.)

Moreover, the proposed architecture, named ToST (for “Token Statistics Transformer”) performs well at vision tasks (i.e., Table 7.8) and language tasks (i.e., Table 7.9). This is especially true for long-sequence-length tasks (cf Table 7.10), where it is both more performant and much more efficient than conventional transformers and all other transformer-like architectures.

7.5.3 Attention-Only Transformers

Another bottleneck to remove from deep learning models, specifically transformer-like architectures, is the memory bottleneck that comes from massive matrix multiplications in MLPs, where the internal dimension is far greater than the feature dimension dditalic_d. It thus is an interesting and important question to ask: do we really need the MLP inside a transformer, and how good can the performance get without it? To explore this question, we use the attention-only-transformer (AoT) architecture (see Section 4.3.1), depicted in Figure 7.17. Namely, each layer is simply of the form

𝒁θ+1(𝑿)=𝒁θ(𝑿)+MSSAθ(LNθ(𝒁θ(𝑿))).\bm{Z}_{\theta}^{\ell+1}(\bm{X})=\bm{Z}_{\theta}^{\ell}(\bm{X})+\operatorname{MSSA}_{\theta}^{\ell}(\operatorname{LN}_{\theta}^{\ell}(\bm{Z}_{\theta}^{\ell}(\bm{X}))).bold_italic_Z start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ + 1 end_POSTSUPERSCRIPT ( bold_italic_X ) = bold_italic_Z start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ( bold_italic_X ) + roman_MSSA start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ( roman_LN start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ( bold_italic_Z start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ( bold_italic_X ) ) ) . (7.5.8)

In our implementation, we also experimented with using multi-head self-attention (MHSA) in place of MSSA. It turns out that this architecture is also viable, though the depth of the network needs to be much greater in order to achieve equivalent performance to the usual CRATE or transformer architecture.

Figure 7.17 : One layer of the AoT backbone. Token representations merely go through a layer-norm and the multi-head (subspace) self-attention operator to form the layer’s output. Notice that there is no token-wise nonlinearity such as MLP or ISTA or ODL.
Figure 7.17: One layer of the AoT backbone. Token representations merely go through a layer-norm and the multi-head (subspace) self-attention operator to form the layer’s output. Notice that there is no token-wise nonlinearity such as MLP or ISTA or ODL.

We conduct experiments using the proposed AoT architecture and demonstrate its potential. We pre-train the AoT-MSSA and AoT-MHSA models of different sizes, along with GPT-2, on OpenWebText [GC19]. We plot the training loss and validation loss against the number of training iterations in Figure 7.18(a) and (b), respectively. It is observed that medium- and large-sized AoT-based models achieve training and validation losses comparable to those of the GPT-2 base model. In addition, compared to the GPT-2 base model, the AoT-MHSA model is identical to the GPT-2 base model, except for the absence of MLP layers in the architecture. As shown in Figure 7.18, incorporating MLP layers can accelerate the training process. Using the above pre-trained models, we compute the cross-entropy validation loss without training on different datasets in Table 7.11. It is observed that the AoT models with medium and large parameter sizes can achieve comparable performance to the GPT-2 base model. Moreover, we found that adding MLP layers to AoT does not improve the zero-shot performance. These results highlight the potential of attention-only models to achieve competitive results while maintaining interpretability.

Table 7.11: Zero-shot results on several language benchmark datasets and tasks: Evaluation of different sizes of AoT with the MSSA and MHSA operators and comparison to the GPT2 model.
Models LAMBADA PTB WikiText LAMBADA CBT CN CBT NE
# of parameters (val loss) \downarrow (val loss) \downarrow (val loss) \downarrow (acc) \uparrow (acc) \uparrow (acc) \uparrow
AoT-MSSA Base (102M) 4.70 6.03 4.65 0.25 0.80 0.74
AoT-MSSA Medium (182M) 4.47 5.08 4.22 0.29 0.84 0.77
AoT-MHSA Base (122M) 4.42 5.52 4.19 0.38 0.86 0.82
GPT-2 Base (124M) 4.32 5.75 4.13 0.40 0.87 0.84
Figure 7.18 : Evaluating models on language tasks. We plot the training loss (left) and validation loss (right) of the AoT and GPT-2 models pretrained on OpenWebText.
Figure 7.18 : Evaluating models on language tasks. We plot the training loss (left) and validation loss (right) of the AoT and GPT-2 models pretrained on OpenWebText.
Figure 7.18: Evaluating models on language tasks. We plot the training loss (left) and validation loss (right) of the AoT and GPT-2 models pretrained on OpenWebText.

7.6 Masked Autoencoding for Imagery Data

The second application we discuss is nonlinear image completion, also known as masked autoencoding (MAE), which is a direct generalization of the low-rank matrix completion problem discussed in Chapter 2. Masked autoencoding, since its introduction in the deep learning context by [HCX+22], has been a staple and simple self-supervised representation learning method, which aims to endow each patch feature within 𝒁θ\bm{Z}_{\theta}bold_italic_Z start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT with aggregate information as well as information about its neighbors, such that both the patch feature and aggregate features are rich sources of information for the whole sample.

The dataset is kept the same as the image datasets discussed in Section 7.2.1. As usual, we still apply data augmentations to each sample in each new batch.

7.6.1 Task and Objective

As the name suggests, masked autoencoding involves a view vmv_{m}italic_v start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT which, given an input, performs a random resized crop (cf Section 7.2.2) to turn the input image into a square image of size (C,Smask,Smask)(C,S_{\mathrm{mask}},S_{\mathrm{mask}})( italic_C , italic_S start_POSTSUBSCRIPT roman_mask end_POSTSUBSCRIPT , italic_S start_POSTSUBSCRIPT roman_mask end_POSTSUBSCRIPT ), then masks (i.e., sets to zero) a fixed percentage pmask[0,1]p_{\mathrm{mask}}\in[0,1]italic_p start_POSTSUBSCRIPT roman_mask end_POSTSUBSCRIPT ∈ [ 0 , 1 ] of pixels in the input. For efficiency reasons151515The original implementation of MAE by [HCX+22] embeds the whole image, removes the tokens that would be masked, feeds the resulting token set through the encoder, adds back learned placeholder tokens in the masked spots and adds back the appropriate positional encoding, and feeds the resulting token set through the decoder to get the autoencoding prediction. This is more efficient since the encoder has fewer tokens to go through, but conceptually is the same as the method discussed in the text, and the resulting models’ performance in the masked autoencoding task and downstream evaluations is very similar., the masking is done patch-wise, i.e., after embedding the whole image, pmaskp_{\mathrm{mask}}italic_p start_POSTSUBSCRIPT roman_mask end_POSTSUBSCRIPT percentage of patches are set to zero. The goal of MAE is to train an encoder fθ:(d)f_{\theta}\colon\mathcal{I}\to(\mathbb{R}^{d})^{*}italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT : caligraphic_I → ( blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT and a decoder gη:(d)g_{\eta}\colon(\mathbb{R}^{d})^{*}\to\mathcal{I}italic_g start_POSTSUBSCRIPT italic_η end_POSTSUBSCRIPT : ( blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT → caligraphic_I that can reconstruct an input from its masking, i.e., writing 𝑿^θ,ηgηfθ\hat{\bm{X}}_{\theta,\eta}\doteq g_{\eta}\circ f_{\theta}over^ start_ARG bold_italic_X end_ARG start_POSTSUBSCRIPT italic_θ , italic_η end_POSTSUBSCRIPT ≐ italic_g start_POSTSUBSCRIPT italic_η end_POSTSUBSCRIPT ∘ italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT, we have

minθ,η{MAE(θ,η)𝔼𝑿^θ,η(𝑿m)𝑿F2}\min_{\theta,\eta}\left\{\mathcal{L}_{\mathrm{MAE}}(\theta,\eta)\doteq\operatorname{\mathbb{E}}\|\hat{\bm{X}}_{\theta,\eta}(\bm{X}_{m})-\bm{X}\|_{F}^{2}\right\}roman_min start_POSTSUBSCRIPT italic_θ , italic_η end_POSTSUBSCRIPT { caligraphic_L start_POSTSUBSCRIPT roman_MAE end_POSTSUBSCRIPT ( italic_θ , italic_η ) ≐ blackboard_E ∥ over^ start_ARG bold_italic_X end_ARG start_POSTSUBSCRIPT italic_θ , italic_η end_POSTSUBSCRIPT ( bold_italic_X start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ) - bold_italic_X ∥ start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT } (7.6.1)

Essentially this means that the features 𝒁θ(𝑿m)\bm{Z}_{\theta}(\bm{X}_{m})bold_italic_Z start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_X start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ) of the view 𝑿mvm(𝑿)\bm{X}_{m}\doteq v_{m}(\bm{X})bold_italic_X start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ≐ italic_v start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ( bold_italic_X ) must contain information about the masked patches as well as the existing patches. From the perspective of the compression-based white-box models in Chapter 4, if a white-box autoencoder (fθ,gη)(f_{\theta},g_{\eta})( italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT , italic_g start_POSTSUBSCRIPT italic_η end_POSTSUBSCRIPT ) succeeds at this task, it means that the learned subspaces and dictionaries perform a redundant encoding of the data such that it can reconstruct missing parts of the data from encoded other parts of the data. This means that information about each patch is stored in other patches. Therefore, each patch feature should contain both information about the patch and information about the statistics of the whole image. Thus, again, we expect that the representations should contain both local and global semantically relevant information, and therefore representations of different patches with similar local and global information should be related (i.e., on the same subspace or encoded together by a dictionary).

7.6.2 Architecture

Figure 7.19 : One layer of the encoder and decoder in a CRATE autoencoder backbone. The encoder and decoder layers both feed their inputs through multi-head subspace self-attention and a dictionary learning or dictionary encoding step. Note that the encoder and decoder layers are symmetrically designed; the conceptual goal of each decoder layer is to invert an encoder layer, so this symmetry is very much by design (see e.g., Chapter 5 ).
Figure 7.19: One layer of the encoder and decoder in a CRATE autoencoder backbone. The encoder and decoder layers both feed their inputs through multi-head subspace self-attention and a dictionary learning or dictionary encoding step. Note that the encoder and decoder layers are symmetrically designed; the conceptual goal of each decoder layer is to invert an encoder layer, so this symmetry is very much by design (see e.g., Chapter 5).

We use a CRATE encoder and decoder, depicted in Figure 7.7, though of course it is possible to use a regular transformer encoder and decoder. Details follow.

The encoder.

The encoder is the same as the CRATE encoder in Section 7.3.2, with the caveat that there is no feature extractor fθextf_{\theta}^{\mathrm{ext}}italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ext end_POSTSUPERSCRIPT. However, both the embedding fθembf_{\theta}^{\mathrm{emb}}italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_emb end_POSTSUPERSCRIPT and the backbone fθbbf_{\theta}^{\mathrm{bb}}italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_bb end_POSTSUPERSCRIPT are the same.

The decoder backbone.

The decoder backbone is the CRATE decoder described in Chapter 5. For completeness, we describe it now. Given a feature sequence 𝒁θ(𝑿)fθ(𝑿)(d)\bm{Z}_{\theta}(\bm{X})\doteq f_{\theta}(\bm{X})\in(\mathbb{R}^{d})^{*}bold_italic_Z start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_X ) ≐ italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_X ) ∈ ( blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT, we can process it using the decoder backbone gηbbg_{\eta}^{\mathrm{bb}}italic_g start_POSTSUBSCRIPT italic_η end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_bb end_POSTSUPERSCRIPT as follows. The function gηbbg_{\eta}^{\mathrm{bb}}italic_g start_POSTSUBSCRIPT italic_η end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_bb end_POSTSUPERSCRIPT is composed of LLitalic_L layers gηg_{\eta}^{\ell}italic_g start_POSTSUBSCRIPT italic_η end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT, i.e.,

gηbb=gηLgη1.g_{\eta}^{\mathrm{bb}}=g_{\eta}^{L}\circ\cdots\circ g_{\eta}^{1}.italic_g start_POSTSUBSCRIPT italic_η end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_bb end_POSTSUPERSCRIPT = italic_g start_POSTSUBSCRIPT italic_η end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT ∘ ⋯ ∘ italic_g start_POSTSUBSCRIPT italic_η end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT . (7.6.2)

The layer gηg_{\eta}^{\ell}italic_g start_POSTSUBSCRIPT italic_η end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT has the following implementation. First, define 𝒁~θ,η1(𝑿)𝒁θ(𝑿)\tilde{\bm{Z}}_{\theta,\eta}^{1}(\bm{X})\doteq\bm{Z}_{\theta}(\bm{X})over~ start_ARG bold_italic_Z end_ARG start_POSTSUBSCRIPT italic_θ , italic_η end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT ( bold_italic_X ) ≐ bold_italic_Z start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_X ). Then, we obtain

𝒁~θ,η+1/2(𝑿)\displaystyle\tilde{\bm{Z}}_{\theta,\eta}^{\ell+1/2}(\bm{X})over~ start_ARG bold_italic_Z end_ARG start_POSTSUBSCRIPT italic_θ , italic_η end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ + 1 / 2 end_POSTSUPERSCRIPT ( bold_italic_X ) =[𝑫~]LNη1,(𝒁~θ,η(𝑿))\displaystyle=[\tilde{\bm{D}}^{\ell}]^{\top}\operatorname{LN}_{\eta}^{1,\ell}(\tilde{\bm{Z}}_{\theta,\eta}^{\ell}(\bm{X}))= [ over~ start_ARG bold_italic_D end_ARG start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ] start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT roman_LN start_POSTSUBSCRIPT italic_η end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 , roman_ℓ end_POSTSUPERSCRIPT ( over~ start_ARG bold_italic_Z end_ARG start_POSTSUBSCRIPT italic_θ , italic_η end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ( bold_italic_X ) ) (7.6.3)
𝒁~θ,η+1(𝑿)\displaystyle\tilde{\bm{Z}}_{\theta,\eta}^{\ell+1}(\bm{X})over~ start_ARG bold_italic_Z end_ARG start_POSTSUBSCRIPT italic_θ , italic_η end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ + 1 end_POSTSUPERSCRIPT ( bold_italic_X ) =𝒁~θ,η+1/2(𝑿)MSSAη(LNη2,(𝒁~θ,η+1/2))\displaystyle=\tilde{\bm{Z}}_{\theta,\eta}^{\ell+1/2}(\bm{X})-\operatorname{MSSA}_{\eta}^{\ell}(\operatorname{LN}_{\eta}^{2,\ell}(\tilde{\bm{Z}}_{\theta,\eta}^{\ell+1/2}))= over~ start_ARG bold_italic_Z end_ARG start_POSTSUBSCRIPT italic_θ , italic_η end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ + 1 / 2 end_POSTSUPERSCRIPT ( bold_italic_X ) - roman_MSSA start_POSTSUBSCRIPT italic_η end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ( roman_LN start_POSTSUBSCRIPT italic_η end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 , roman_ℓ end_POSTSUPERSCRIPT ( over~ start_ARG bold_italic_Z end_ARG start_POSTSUBSCRIPT italic_θ , italic_η end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ + 1 / 2 end_POSTSUPERSCRIPT ) ) (7.6.4)

and gηg_{\eta}^{\ell}italic_g start_POSTSUBSCRIPT italic_η end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT is defined such that gη(𝒁~θ,η)𝒁~θ,η+1(𝑿)g_{\eta}^{\ell}(\tilde{\bm{Z}}_{\theta,\eta}^{\ell})\doteq\tilde{\bm{Z}}_{\theta,\eta}^{\ell+1}(\bm{X})italic_g start_POSTSUBSCRIPT italic_η end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ( over~ start_ARG bold_italic_Z end_ARG start_POSTSUBSCRIPT italic_θ , italic_η end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ) ≐ over~ start_ARG bold_italic_Z end_ARG start_POSTSUBSCRIPT italic_θ , italic_η end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ + 1 end_POSTSUPERSCRIPT ( bold_italic_X ). Here, the relevant concept is that gηg_{\eta}^{\ell}italic_g start_POSTSUBSCRIPT italic_η end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT should learn an approximate inverse of fθL+1f_{\theta}^{L+1-\ell}italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L + 1 - roman_ℓ end_POSTSUPERSCRIPT, as discretizations of a forward- and reverse-time diffusion process, respectively. In particular, 𝑫~\tilde{\bm{D}}^{\ell}over~ start_ARG bold_italic_D end_ARG start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT should approximate 𝑫L+1\bm{D}^{L+1-\ell}bold_italic_D start_POSTSUPERSCRIPT italic_L + 1 - roman_ℓ end_POSTSUPERSCRIPT, and similarly, the MSSAη\operatorname{MSSA}_{\eta}^{\ell}roman_MSSA start_POSTSUBSCRIPT italic_η end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT parameters should be similar to the parameters of MSSAθL+1\operatorname{MSSA}_{\theta}^{L+1-\ell}roman_MSSA start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L + 1 - roman_ℓ end_POSTSUPERSCRIPT. The output is 𝒁~θ,η𝒁~θ,ηL+1\tilde{\bm{Z}}_{\theta,\eta}\doteq\tilde{\bm{Z}}_{\theta,\eta}^{L+1}over~ start_ARG bold_italic_Z end_ARG start_POSTSUBSCRIPT italic_θ , italic_η end_POSTSUBSCRIPT ≐ over~ start_ARG bold_italic_Z end_ARG start_POSTSUBSCRIPT italic_θ , italic_η end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L + 1 end_POSTSUPERSCRIPT.

The un-embedding module.

To transform 𝒁~θ,η(𝑿)\tilde{\bm{Z}}_{\theta,\eta}(\bm{X})over~ start_ARG bold_italic_Z end_ARG start_POSTSUBSCRIPT italic_θ , italic_η end_POSTSUBSCRIPT ( bold_italic_X ) back into an estimate for 𝑿\bm{X}bold_italic_X, we need to undo the effect of the embedding module fθembf_{\theta}^{\mathrm{emb}}italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_emb end_POSTSUPERSCRIPT using the unembedding module gηunembg_{\eta}^{\mathrm{unemb}}italic_g start_POSTSUBSCRIPT italic_η end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_unemb end_POSTSUPERSCRIPT. As such, harkening back to the functional form of the embedding module in (7.2.11), i.e.,

fθemb(𝑿)[𝒛cls1,𝑾embfpatch(𝑿)+𝑬pos]f_{\theta}^{\mathrm{emb}}(\bm{X})\doteq\begin{bmatrix}\bm{z}_{\mathrm{cls}}^{1},\bm{W}^{\mathrm{emb}}f^{\mathrm{patch}}(\bm{X})+\bm{E}^{\mathrm{pos}}\end{bmatrix}italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_emb end_POSTSUPERSCRIPT ( bold_italic_X ) ≐ [ start_ARG start_ROW start_CELL bold_italic_z start_POSTSUBSCRIPT roman_cls end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT , bold_italic_W start_POSTSUPERSCRIPT roman_emb end_POSTSUPERSCRIPT italic_f start_POSTSUPERSCRIPT roman_patch end_POSTSUPERSCRIPT ( bold_italic_X ) + bold_italic_E start_POSTSUPERSCRIPT roman_pos end_POSTSUPERSCRIPT end_CELL end_ROW end_ARG ] (7.6.5)

it implies that our inverse operation gηunembg_{\eta}^{\mathrm{unemb}}italic_g start_POSTSUBSCRIPT italic_η end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_unemb end_POSTSUPERSCRIPT looks like the following:

gηunemb(𝒁~)gηunemb([𝒛~1,,𝒛~n])=gunpatch(𝑾unemb([𝒛~2,,𝒛~n]𝑬~pos)),g_{\eta}^{\mathrm{unemb}}(\tilde{\bm{Z}})\doteq g_{\eta}^{\mathrm{unemb}}(\begin{bmatrix}\tilde{\bm{z}}^{1},\dots,\tilde{\bm{z}}^{n}\end{bmatrix})=g^{\mathrm{unpatch}}(\bm{W}^{\mathrm{unemb}}([\tilde{\bm{z}}^{2},\dots,\tilde{\bm{z}}^{n}]-\tilde{\bm{E}}^{\mathrm{pos}})),italic_g start_POSTSUBSCRIPT italic_η end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_unemb end_POSTSUPERSCRIPT ( over~ start_ARG bold_italic_Z end_ARG ) ≐ italic_g start_POSTSUBSCRIPT italic_η end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_unemb end_POSTSUPERSCRIPT ( [ start_ARG start_ROW start_CELL over~ start_ARG bold_italic_z end_ARG start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT , … , over~ start_ARG bold_italic_z end_ARG start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT end_CELL end_ROW end_ARG ] ) = italic_g start_POSTSUPERSCRIPT roman_unpatch end_POSTSUPERSCRIPT ( bold_italic_W start_POSTSUPERSCRIPT roman_unemb end_POSTSUPERSCRIPT ( [ over~ start_ARG bold_italic_z end_ARG start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT , … , over~ start_ARG bold_italic_z end_ARG start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ] - over~ start_ARG bold_italic_E end_ARG start_POSTSUPERSCRIPT roman_pos end_POSTSUPERSCRIPT ) ) , (7.6.6)

where gunpatchg^{\mathrm{unpatch}}italic_g start_POSTSUPERSCRIPT roman_unpatch end_POSTSUPERSCRIPT does the inverse operation of the unrolling and flattening operation that fpatchf^{\mathrm{patch}}italic_f start_POSTSUPERSCRIPT roman_patch end_POSTSUPERSCRIPT does.161616Again, the “inverse positional encoding” 𝑬~pos\tilde{\bm{E}}^{\mathrm{pos}}over~ start_ARG bold_italic_E end_ARG start_POSTSUPERSCRIPT roman_pos end_POSTSUPERSCRIPT is learned for a large input, and for smaller inputs may be interpolated. It is even possible to directly set 𝑬~pos\tilde{\bm{E}}^{\mathrm{pos}}over~ start_ARG bold_italic_E end_ARG start_POSTSUPERSCRIPT roman_pos end_POSTSUPERSCRIPT equal to the positional encoding 𝑬pos\bm{E}^{\mathrm{pos}}bold_italic_E start_POSTSUPERSCRIPT roman_pos end_POSTSUPERSCRIPT and use the same interpolated positional encodings for each input in both the encoder and decoder.

This architecture is a white-box autoencoder (fθ,gη)(f_{\theta},g_{\eta})( italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT , italic_g start_POSTSUBSCRIPT italic_η end_POSTSUBSCRIPT ) where (recall) fθ=fθbbfθembf_{\theta}=f_{\theta}^{\mathrm{bb}}\circ f_{\theta}^{\mathrm{emb}}italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT = italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_bb end_POSTSUPERSCRIPT ∘ italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_emb end_POSTSUPERSCRIPT and gη=gηunembgηbbg_{\eta}=g_{\eta}^{\mathrm{unemb}}\circ g_{\eta}^{\mathrm{bb}}italic_g start_POSTSUBSCRIPT italic_η end_POSTSUBSCRIPT = italic_g start_POSTSUBSCRIPT italic_η end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_unemb end_POSTSUPERSCRIPT ∘ italic_g start_POSTSUBSCRIPT italic_η end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_bb end_POSTSUPERSCRIPT. In particular, we can use it to compute an estimate for a masked view 𝑿^θ,η(𝑿m)=(gηfη)(𝑿m)\hat{\bm{X}}_{\theta,\eta}(\bm{X}_{m})=(g_{\eta}\circ f_{\eta})(\bm{X}_{m})over^ start_ARG bold_italic_X end_ARG start_POSTSUBSCRIPT italic_θ , italic_η end_POSTSUBSCRIPT ( bold_italic_X start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ) = ( italic_g start_POSTSUBSCRIPT italic_η end_POSTSUBSCRIPT ∘ italic_f start_POSTSUBSCRIPT italic_η end_POSTSUBSCRIPT ) ( bold_italic_X start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT ) which should approximately equal 𝑿\bm{X}bold_italic_X itself.

7.6.3 Optimization

As in Section 7.3.3, we use a simple optimization setup: we sample images and masks, compute the loss on those samples and the gradients of this loss, and update the parameters using a generic optimization algorithm and the aforementioned gradients. For each timestep kkitalic_k, we:

  • Subsample BBitalic_B different samples {𝑿b(k)}b=1B\{\bm{X}_{b}^{(k)}\}_{b=1}^{B}\subseteq\mathcal{I}{ bold_italic_X start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT } start_POSTSUBSCRIPT italic_b = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_B end_POSTSUPERSCRIPT ⊆ caligraphic_I.

  • For each sample 𝑿b(k)\bm{X}_{b}^{(k)}bold_italic_X start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT, compute a different randomized resized crop and mask vb,m(k)v_{b,m}^{(k)}italic_v start_POSTSUBSCRIPT italic_b , italic_m end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT and apply it to 𝑿b(k)\bm{X}_{b}^{(k)}bold_italic_X start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT to get 𝑿b,m(k)vb,mt(𝑿b(k))\bm{X}_{b,m}^{(k)}\doteq v_{b,m}^{t}(\bm{X}_{b}^{(k)})bold_italic_X start_POSTSUBSCRIPT italic_b , italic_m end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT ≐ italic_v start_POSTSUBSCRIPT italic_b , italic_m end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ( bold_italic_X start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT ).

  • Compute the estimated autoencoding 𝑿^θ,η(𝑿b,r(k))(gηfθ)(𝑿b,r(k))\hat{\bm{X}}_{\theta,\eta}(\bm{X}_{b,r}^{(k)})\doteq(g_{\eta}\circ f_{\theta})(\bm{X}_{b,r}^{(k)})over^ start_ARG bold_italic_X end_ARG start_POSTSUBSCRIPT italic_θ , italic_η end_POSTSUBSCRIPT ( bold_italic_X start_POSTSUBSCRIPT italic_b , italic_r end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT ) ≐ ( italic_g start_POSTSUBSCRIPT italic_η end_POSTSUBSCRIPT ∘ italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ) ( bold_italic_X start_POSTSUBSCRIPT italic_b , italic_r end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT ).

  • Form the surrogate stochastic loss

    ^MAE(k)(θ,η)1Bb=1B𝑿^θ,η(𝑿b,r(k))𝑿b(k)F2.\hat{\mathcal{L}}_{\mathrm{MAE}}^{(k)}(\theta,\eta)\doteq\frac{1}{B}\sum_{b=1}^{B}\|\hat{\bm{X}}_{\theta,\eta}(\bm{X}_{b,r}^{(k)})-\bm{X}_{b}^{(k)}\|_{F}^{2}.over^ start_ARG caligraphic_L end_ARG start_POSTSUBSCRIPT roman_MAE end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT ( italic_θ , italic_η ) ≐ divide start_ARG 1 end_ARG start_ARG italic_B end_ARG ∑ start_POSTSUBSCRIPT italic_b = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_B end_POSTSUPERSCRIPT ∥ over^ start_ARG bold_italic_X end_ARG start_POSTSUBSCRIPT italic_θ , italic_η end_POSTSUBSCRIPT ( bold_italic_X start_POSTSUBSCRIPT italic_b , italic_r end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT ) - bold_italic_X start_POSTSUBSCRIPT italic_b end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT . (7.6.7)
  • Compute one step of an optimization algorithm on (θ,η)(\theta,\eta)( italic_θ , italic_η ), giving the following iteration:

    (θ(k+1),η(k+1))OptUpdate(k)(θ(k),η(k);(θ,η)^MAE(k)).(\theta^{(k+1)},\eta^{(k+1)})\doteq\textsc{OptUpdate}^{(k)}(\theta^{(k)},\eta^{(k)};\nabla_{(\theta,\eta)}\hat{\mathcal{L}}_{\mathrm{MAE}}^{(k)}).( italic_θ start_POSTSUPERSCRIPT ( italic_k + 1 ) end_POSTSUPERSCRIPT , italic_η start_POSTSUPERSCRIPT ( italic_k + 1 ) end_POSTSUPERSCRIPT ) ≐ OptUpdate start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT ( italic_θ start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT , italic_η start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT ; ∇ start_POSTSUBSCRIPT ( italic_θ , italic_η ) end_POSTSUBSCRIPT over^ start_ARG caligraphic_L end_ARG start_POSTSUBSCRIPT roman_MAE end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_k ) end_POSTSUPERSCRIPT ) . (7.6.8)

7.6.4 Evaluation

This is the first autoencoder network we discuss in this chapter. We use the same center crop view vccv_{\mathrm{cc}}italic_v start_POSTSUBSCRIPT roman_cc end_POSTSUBSCRIPT as in Sections 7.2.5 and 7.3.4, resizing the final image to a square with side length Scc=SmaskS_{\mathrm{cc}}=S_{\mathrm{mask}}italic_S start_POSTSUBSCRIPT roman_cc end_POSTSUBSCRIPT = italic_S start_POSTSUBSCRIPT roman_mask end_POSTSUBSCRIPT pixels to match the shapes of the input images seen during training.

In addition to evaluating the masked autoencoding loss itself, it is also possible to evaluate the features 𝒁θ(𝑿cc)\bm{Z}_{\theta}(\bm{X}_{\mathrm{cc}})bold_italic_Z start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_X start_POSTSUBSCRIPT roman_cc end_POSTSUBSCRIPT ) of the view 𝑿ccvcc(𝑿)\bm{X}_{\mathrm{cc}}\doteq v_{\mathrm{cc}}(\bm{X})bold_italic_X start_POSTSUBSCRIPT roman_cc end_POSTSUBSCRIPT ≐ italic_v start_POSTSUBSCRIPT roman_cc end_POSTSUBSCRIPT ( bold_italic_X ) of the data 𝑿\bm{X}bold_italic_X directly. For attention map fidelity evaluation, obtaining 𝒁θ(𝑿cc)\bm{Z}_{\theta}(\bm{X}_{\mathrm{cc}})bold_italic_Z start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_X start_POSTSUBSCRIPT roman_cc end_POSTSUBSCRIPT ) is sufficient, but for linear probing we need to extract a summarized or aggregate feature from 𝒁θ\bm{Z}_{\theta}bold_italic_Z start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT. To do this, we can use a (parameter-free) feature extraction map that returns the feature corresponding to the class token, i.e.,

fθext(𝒁)fθext([𝒛1,,𝒛n])=𝒛1,f_{\theta}^{\mathrm{ext}}(\bm{Z})\doteq f_{\theta}^{\mathrm{ext}}([\bm{z}^{1},\dots,\bm{z}^{n}])=\bm{z}^{1},italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ext end_POSTSUPERSCRIPT ( bold_italic_Z ) ≐ italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ext end_POSTSUPERSCRIPT ( [ bold_italic_z start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT , … , bold_italic_z start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ] ) = bold_italic_z start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT , (7.6.9)

as in (for example) Sections 7.3.1 and 7.3.2. With this, we have a way to obtain aggregate features 𝒛θ(𝑿cc)(fθextfθ)(𝑿cc)\bm{z}_{\theta}(\bm{X}_{\mathrm{cc}})\doteq(f_{\theta}^{\mathrm{ext}}\circ f_{\theta})(\bm{X}_{\mathrm{cc}})bold_italic_z start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_X start_POSTSUBSCRIPT roman_cc end_POSTSUBSCRIPT ) ≐ ( italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ext end_POSTSUPERSCRIPT ∘ italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ) ( bold_italic_X start_POSTSUBSCRIPT roman_cc end_POSTSUBSCRIPT ), at which point we can perform linear probing, segmentation evaluations, and so on.

7.6.5 Experiments

Since CRATE-MAE is directly based on ViT-MAE, we compare the optimal settings for ViT-MAE as given by [HCX+22] with the same settings applied to CRATE-MAE for a fair comparison.

Model architecture.

During training, the masked crop vmv_{m}italic_v start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT resizes the whole image so that the shorter edge is of size 256256256 (i.e., Srsz=256S_{\mathrm{rsz}}=256italic_S start_POSTSUBSCRIPT roman_rsz end_POSTSUBSCRIPT = 256) before taking a random crop of size 224×224224\times 224224 × 224 (i.e., Smask=224S_{\mathrm{mask}}=224italic_S start_POSTSUBSCRIPT roman_mask end_POSTSUBSCRIPT = 224), and masking pmask=34p_{\mathrm{mask}}=\frac{3}{4}italic_p start_POSTSUBSCRIPT roman_mask end_POSTSUBSCRIPT = divide start_ARG 3 end_ARG start_ARG 4 end_ARG of the patches. We take patch size 161616 (i.e., PH=PW=16P_{H}=P_{W}=16italic_P start_POSTSUBSCRIPT italic_H end_POSTSUBSCRIPT = italic_P start_POSTSUBSCRIPT italic_W end_POSTSUBSCRIPT = 16). We use the small and base variants of the ViT-MAE architecture as the embedding and backbone for both the encoder and decoder, swapping out the MHSA and MLP components for MSSA, ISTA, and linear layers, respectively. We use the same number of heads and head dimension in the case of MSSA. However, the original ViT-MAE uses an encoder which uses nearly all the total layers and a decoder which only uses a few layers; we allocate half the total number of layers (which stays the same from ViT-MAE to CRATE-MAE) to our encoder and decoder, as suggested by the conceptual and theoretical framework in Chapter 5. For CRATE-MAE we set (β,λ)=(1,0.1)(\beta,\lambda)=(1,0.1)( italic_β , italic_λ ) = ( 1 , 0.1 ).

Datasets and optimization.

For pre-training, we use the ImageNet-1K dataset. We use the AdamW optimizer to pre-train both our ViT-MAE replication as well as CRATE-MAE. We set the base learning rate as 3×1053\times 10^{-5}3 × 10 start_POSTSUPERSCRIPT - 5 end_POSTSUPERSCRIPT, the weight decay as 0.10.10.1, and batch size as B=4096B=4096italic_B = 4096. Our learning rate schedule increases the learning rate linearly to the base learning rate over the first 404040 epochs, and decreases to 0 using a cosine schedule over the next 760760760 epochs (training all models for 800800800 epochs each). For pre-training, we apply the usual regime of data augmentations (flips, Gaussian blurs, solarization, etc) to the image data.

For linear probing, we use several evaluation datasets such as CIFAR10, CIFAR100, Oxford-Flowers, and Oxford-IIT-Pets. For linear probing, we precompute the features of all samples in the target dataset and apply a fast linear regression solver, e.g., from a standard package such as Scikit-Learn.

Experiment results.

Table 7.12 demonstrates that CRATE-MAE models achieve, roughly speaking, parity with the popular ViT-MAE architecture at similar parameter counts, and also that the feature learning performance (as measured by performance on downstream classification tasks) increases with scale. Meanwhile, Figure 7.20 demonstrates that the encoder saliency maps (and therefore the fine-grained features learned by the encoder) indeed isolate and highlight the key parts of the input image.

Model CRATE-MAE-S(mall) CRATE-MAE-B(ase) ViT-MAE-S ViT-MAE-B
# parameters 25.4M 44.6M 47.6M 143.8M
CIFAR10 79.4 80.9 79.9 87.9
CIFAR100 56.6 60.1 62.3 68.0
Oxford Flowers-102 57.7 61.8 66.8 66.4
Oxford-IIIT-Pets 40.6 46.2 51.8 80.1
Table 7.12: Linear probing classification accuracy of CRATE-MAE and ViT-MAE on various datasets with different model sizes when the backbone is pre-trained for masked autoencoding on ImageNet-1K. Given the same parameter count, CRATE-MAE achieves roughly similar performance while simultaneously enjoying a simpler and more principled architecture design.
Figure 7.20 : Saliency maps of CRATE-MAE. Each pair of images consists of the original image (left) and a selected saliency map (right) corresponding to an attention head in the last layer. As is usual for CRATE models, but unusual for general transformer-like models, the saliency maps correspond to the objects in the input image.
Figure 7.20: Saliency maps of CRATE-MAE. Each pair of images consists of the original image (left) and a selected saliency map (right) corresponding to an attention head in the last layer. As is usual for CRATE models, but unusual for general transformer-like models, the saliency maps correspond to the objects in the input image.

7.7 Summary and Notes

All work in this chapter is downstream of the Transformer architecture, which was introduced by [VSP+17]. The Transformer architecture is formally described in Section 7.2. A main empirical innovation in recent years, spurred by the prevalence and performance of the Transformer architecture, is to formulate a given learning problem as a sequence-to-sequence problem and apply the Transformer architecture. This has enabled the Transformer architecture to be essentially ubiquitous in (almost) all deep learning applications. As such, direct improvements to the Transformer can propagate to become solutions to many problems and have considerable impact; similarly, we can apply our white-box understanding of transformer-like architectures to many modern problems. The material covered in this chapter is merely a subset of the work that has already been done; other work includes masked completion for text data (i.e., BERT) [DCL+19, YBP+24], (mechanistic) interpretability of language and vision models [BM24], and error correcting coding [ZLG+]. There is much more to do.

There is also much more theory specifically about the practice of scaling neural networks, which is enormously practically viable, and we at least remark on it here. This line of work was popularized by the “Tensor Programs” line of work [YHB+22]. The basic prescription is that we want the initial gradient updates in a transformer to be constant size, and by working through the backpropagation equations (Appendix A) carefully, we can determine the scale of the initialization and learning rates (chosen layer-wise) that are required to achieve this. In practice, such prescriptions greatly increase the stability and convergence of training at large scales; they also prescribe a way to find the “optimal”171717The word “optimal” is used in quotes because the work on this merely uses some desiderata about the weight size, feature size, and gradient size at initialization to determine “optimality”, as opposed to, say, the test loss at convergence. hyperparameters for large-scale training using only small-scale training. Follow-ups to this work attempt to accommodate the feature geometry [BN24], which could be informed by the work in this book about representation learning. Other follow-ups incorporate this weight-wise information into the optimizer itself to obtain these scaling benefits automatically, obtaining optimizers such as Muon [JJB+], which have recently been used for training trillion-parameter models very stably [Tea]. Overall, the two approaches to deep learning theory are orthogonal or complementary.

7.8 Exercises and Extensions

Exercise 7.1.

Read the DINO paper [CTM+21].

Exercise 7.2.

DINO v2 [ODM+23] uses everything from DINO v1 but also, during the data augmentation phase, randomly masks out patches within each view. This kind of augmentation should enforce that the features of images with similar local information are similar. Formulate an optimization problem that promotes this in the encoder, and implement it.

Exercise 7.3.

This exercise considers the implementation of stochastic optimization algorithms to minimize losses involving expectations.

  1. (a)

    Propose an alternative to the term involving RεR_{\varepsilon}italic_R start_POSTSUBSCRIPT italic_ε end_POSTSUBSCRIPT in (7.2.34) for approximating the covariance regularization term in (7.2.33). Evaluate the time complexity required to compute your proposed term and its gradient. Include an analysis for computing it on a single compute node vs. multiple nodes.

  2. (b)

    Evaluate the time complexity required to compute the existing term in (7.2.34) and its gradient.

Exercise 7.4.

Prove that (7.2.37) and (7.2.38) are convex optimization problems.

Exercise 7.5.

  1. (a)

    Implement the CRATE and CRATE-α\alphaitalic_α models.

  2. (b)

    Compare their performance and efficiency on the CIFAR-10 dataset.

  3. (c)

    Compare their interpretability in two ways:

    • The sparsity 𝒁0\|\bm{Z}\|_{0}∥ bold_italic_Z ∥ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT of the representation 𝒁\bm{Z}bold_italic_Z

    • The attention maps 𝒂θk,\bm{a}_{\theta}^{k,\ell}bold_italic_a start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k , roman_ℓ end_POSTSUPERSCRIPT