“The best theory is inspired by practice, and the best practice is inspired by theory.”
— Donald Knuth
The previous chapters have presented a systematic introduction to mathematical problems, computational frameworks, and practical algorithms associated with learning low-dimensional distributions from high-dimensional data. Although most theoretical justifications of these methods have been established for idealistic models of data such as (mixtures of) subspaces and/or Gaussians, the principles and ideas behind these computational methods are nevertheless powerful and general, and they are in fact meant to be applicable for real-world datasets and tasks.
To help readers understand the material in this book better and learn how to apply what you have learned so far to real-world data, toward the end of the book, we provide some demonstrations and vignettes of several representative applications. Each application proposes a solution to a real-world task with a real-world dataset (such as visual data and text data), using the methods we have introduced in this book. The results presented in this chapter are meant to serve the following two purposes:
firstly, to provide additional experimental details and empirical evidence which validate the methods presented earlier in the book, and demonstrate their significant potential in real-world contexts;
secondly, to introduce the reader to certain modern empirical methods and tasks in deep learning which are not well-documented outside of research or production codebases.
However, in our honest opinion, the solutions and results given here are designed simply to verify that the methodology works. As such, there is great room for future improvement, both in engineering and theoretical understanding, to potentially improve the state-of-the-art. We will discuss some future directions in Chapter 8.
In previous chapters, we alluded to different setups in which we used representation-learning techniques to process real data at scale. In this chapter, we will describe such setups in great detail. The objective of this section is to get you, the reader, to be able to reproduce any experiment discussed in this section (or indeed the book) using just the description we will give in the book, the principles introduced in previous chapters and expanded on in this chapter, and hyperparameters taken from a smattering of papers whose results are discussed in this chapter. To this end, we will precisely describe all procedures in a detailed language, pseudocode, or mathematical notation that can be directly implemented in code. Wherever possible, we will discuss how the concrete implementations connect to the principles presented earlier in the book.
Let us define the set of possible data as (eventually this will be the set of images , for example, or the set of text ), and the set of finite sequences of tokens in (i.e., the set of matrices with rows) as . In order to discuss the wealth of applications we introduce in this chapter, we first recall that in the rest of the book, we discuss two different types of model architectures.
An encoder architecture, parameterized by , which is composed of several components:
An embedding , which converts the input data into a series of tokens which are mapped into, or embedded in, -dimensional space. In the rest of the chapter, we will often identify tokens and embeddings with each other.
An encoder backbone , which processes the series of embeddings using a sequence-to-sequence operation. This backbone is implemented by the network architectures discussed in the previous chapters, but we will give a more formal description as we go along.
An aggregate feature extractor , which extracts an aggregate representation of the whole sequence. This is used to define a single feature for the entire data sample.
A task-specific head , which extracts an -dimensional output for prediction.
We also define . Given an input , we write and . The overall pipeline is depicted in Figure 7.1.
An autoencoder architecture, which is composed of several components:
An embedding , which converts the input data into a series of tokens which are embedded in -dimensional space.
An encoder backbone , which processes the series of embeddings using a sequence-to-sequence operation.
A decoder backbone , which conceptually undoes the operation of the encoder backbone.
An unembedding , which acts as an inverse of the embedding.
We also define and . Given an input , we write and . The overall pipeline is depicted in Figure 7.2.
We will repeatedly use this notation many times in this chapter, so please feel free to refer back to it if something doesn’t make sense. This decomposition of our networks also closely mirrors most code implementations, and you can start your coding projects by defining these networks.
In this chapter, we will discuss applications of the book’s principles to contrastive learning in Section 7.2. This will serve as both an introduction to image data, data augmentation techniques, and the common architecture known as the transformer, as well as a first demonstration of the drastic kinds of simplifications we can make using the demonstrated principles. We will continue with modifications to the network architecture in Sections 7.3 and 7.4, which demonstrate the capabilities of simplified architectures for encoding within the image and text domains. We then demonstrate simplified architectures for autoencoding in Section 7.6.
Learning high-quality and faithful representations of data is a fundamental problem in deep learning, known as self-supervised learning. There have been many approaches proposed for this task, many of which do not evidently use the techniques and principles outlined in this manuscript. One such approach is called contrastive learning, so named because the learning objective is (roughly speaking) about ensuring that features of “similar” data are similar, and features of “dissimilar” data are far apart. Contrastive learning solutions are often highly-engineered, empirically designed approaches. In this section, we will describe one such approach named DINO [CTM+21], and use the principles described in the previous chapters to drastically simplify their design decisions while improving the learned representations.
The data that we will use to explore and simplify the DINO methodology are all 2-dimensional image data. For training, we will use the ImageNet-1K and ImageNet-21K datasets. Each sample in the dataset is an RGB image, of varying resolution, and a label indicating the object or scene that the image contains (i.e., the class of the image). The ImageNet-1K dataset contains 1.28M training images and 50K validation images partitioned into 1K classes. The ImageNet-21K dataset contains 14.2M training images and 21.8K classes, but the classes are not disjoint (i.e., some classes are subsets of others). Since we are doing self-supervised learning, the labels will not be used during training, only during evaluation. For evaluation, we will use a litany of datasets. Of these, the most common is CIFAR-10. CIFAR-10 is a dataset of 60K RGB 32 32 natural images partitioned into 10 classes, with a pre-established training set of 50K samples and a validation set of 10K samples. The purpose of using CIFAR-10 is to ensure that models which train on one distribution of images (ImageNet) can generalize to another distribution of images (CIFAR-10). We also refer to other similar datasets, such as CIFAR-100 (disjoint from CIFAR-10), Oxford Flowers, and Oxford Pets. Exemplars of ImageNet-1K and CIFAR-10 data are shown in Figure 7.3. In order to increase the robustness of our model, we often apply small data augmentations to each image during processing, such as flips, added small random noise, or random large crops; we do not include this in our notation, as each augmentation of a natural image is itself (very close to) a natural image in our dataset.
On a slightly more formal level, our data will be images; we let be the set of all images. Since an image is a rectangular array of pixels, and each pixel has a color given by RGB, CMYK, or another color format, we say that an image is an element of — here is the number of channels (i.e., for RGB and for CMYK), is the image height, and is the image width. Consequently, the set of all images is the set of all possible such data. Again, we will use this notation repeatedly.
Our task is to learn a good representation of the data. Contrastive learning, by and large, does this by defining what properties of the input image we wish the features to reflect, constructing images which share these properties but vary others, and setting up a loss which promotes that the features of images with shared properties are close and images with different properties are different. The naturally optimal solution to this learning problem is that the learned features preserve the desired properties of the input. However, there are many practical and empirical complications that arise in the course of training contrastive models.
In the case of DINO, the authors propose to use a methodology which produces a single feature vector for the whole image and desires the feature vector to contain “global” (i.e., image-level) information. Accordingly, the loss will promote that images with similar global information have similar features and images with different global information have different features.
This seems intuitive, but as previously mentioned, there are several empirical considerations, even while setting up the loss. First and foremost, how should we promote similarities and differences? The answer from DINO [CTM+21] is111In the author’s view, inexplicably… to convert the output features into “logits” corresponding to some probability distribution and take their cross-entropy. More specifically, let be the space of probability vectors in and define the function by
(7.2.1) |
where is the cross-entropy, defined as
(7.2.2) |
Before we continue our discussion, let us build some intuition about this distance function. We have, in particular,
(7.2.3) |
where is the KL divergence, defined as
(7.2.4) |
and is the entropy of a random variable. Note that is minimized if and only if . So minimizing does two things: it makes , and it makes and have minimal entropy (i.e., vectors with in one component and elsewhere — these are called one-hot vectors). Overall, the goal of this objective is not just to match and but also to shape them in a certain way to make them low-entropy. Keep this in mind when we discuss the formulation.
The next question is, how should we obtain samples with similar global information? The answer from DINO (as well as nearly all contrastive learning) is data augmentation — from each sample, make several correlated samples which share the desired properties. In the DINO case, we use different crops or views of the input image. Recall that we model an image as an element of the set . In this notation, a view is a function . In the DINO case, the view is a random resized crop: it takes a randomly chosen rectangular crop of the image (which has a fixed percentage of the total area of the image), resizes it proportionally so that the shorter edge is pixels long, then resizes it to a fixed shape where is the size of the view and is the number of channels in the original image.
There are two types of views we want to use, depicted in Figure 7.4:
global views, which are random resized crops with area percentage parameter and output shape ;
local views, which are random resized crops with area percentage parameter and output shape . Here and .
DINO desires that the aggregate features of all views of an input image be consistent with each other. DINO does this by using a “DINO head”222Note that is the task-specific head, which in Section 7.1 is parameterized only by as opposed to any specific parameters, but since we use two invocations of with different values of the second parameter, we keep the specified notation. , parameterized by a matrix and a vector , to extract a probability vector from the aggregate feature , using the following simple recipe:
(7.2.5) |
where the function is defined by
(7.2.6) |
and is a “temperature” parameter which controls the entropy of the softmax’s output.
In particular, DINO minimizes the difference between the probability vector for each global view and the probability vector for each view . Here, can either be a local view or a global view. We will discuss the implementation of and shortly in Section 7.2.3. Overall, DINO solves the problem
(7.2.7) |
where the expectation is over data , global views , and other views .
In this specific case, however, if you try to implement (7.2.7) and optimize it on a real network, it is very likely that you will run into a problem: after running a few iterations of the learning algorithm, the feature mapping will become the constant function! This certainly optimizes the above loss since it minimizes the distance between features of different views of the same image. But we obviously do not want to learn this solution.
Actually avoiding collapse is a very common consideration in contrastive learning. So how do we do it in this case? The solution from DINO, again, is empirically designed, and carefully tunes the optimization of the parameter (which is updated using all samples in the batch) and a “temperature” hyperparameter which is part of the implementation of and discussed in Section 7.2.3. Given a certain special set of hyperparameters that work well, this is indeed enough to ensure non-collapse of the representation. However, outside of this special configuration, training models to converge is difficult, and the training is highly unstable.
To amend this state of affairs, let us discuss simplifications to the formulation. First, instead of computing a probability vector using a learned transformation of the aggregate features , we can directly use the aggregate representation, ignoring the task-specific head (or equivalently, setting it to the identity mapping). But now we need a way to compare the vectors directly. Using our hypothesis from Chapter 4 that good representations should have Euclidean (subspace) geometry, a much more natural measure of difference is the squared distance , defined as
(7.2.8) |
This distance-based score is even more efficient to compute than the cross-entropy score. Thus, takes the place of in our simplification.
Before, collapse was avoided by using tricks to update and . In our simplification, if we compare the features within the representation space instead of converting them to probabilities, we do not have either of these parameters and so must consider a different way to avoid collapse. To do this, we return to the fundamentals. The basic idea of avoiding collapse is that in order to make sure that all samples do not return the same exact same features, we need different samples to have different features. In other words, we would like the covariance of the features to be large in some sense. But from Chapters 3 and 4, we already have a quantity which measures the size of the covariance matrix. Namely, we use the straightforward (population-level) Gaussian coding rate to ensure that the features of global views of different images, which have different global information, are well-separated and not collapsed (hence expanded). The overall modified loss becomes:
(7.2.9) |
where is fixed and the appropriate expectations are, as before, taken over data , global view , and other (local or global) view . The loss in (7.2.9) is the loss used for the simplified DINO (“SimDINO”). As we will see, when properly implemented, it works at least as well as the original DINO.
For the architecture, we use a standard vision transformer. Here is how such an architecture works formally in the context of image data. Recall from Section 7.1 that there are four components to an encoder architecture, namely an embedding, a backbone, a feature extractor, and a task-specific head. We discuss these four parts presently.
Given image data , we embed it as a sequence of tokens in using the map , as follows. The first two steps are depicted in Figure 7.5, and the latter two are depicted in Figure 7.6.
First, we turn the image data into a sequence of patches of shape where and are the patch dimensions. We assume that and evenly divide the height and width of , respectively (in the notation of Section 7.2.2 we assume that and evenly divide and ). Let the resulting grid of patches have rows and columns.
We unroll each patch into a vector of length . There are patch vectors, which we place in “raster order” (top left top right bottom left bottom right) into a matrix , where . Notice that depends only on the patch size and number of channels. Since the latter quantity is normally constant among samples in the same dataset, is the same for all images in the dataset, while is different for larger and smaller images.
We then perform the following operation on to project it to where :
(7.2.10) |
Here we have three trainable parameters , , and whose purpose is as follows:
is a matrix which projects each patch vector to a token feature.
is a so-called class token or register token. The class token heuristically holds global information of the whole data and is used for downstream tasks. In the framework of compressive deep networks from Chapter 4, we expect that the class token is projected onto the same subspaces as the salient or semantically relevant tokens during the progression of the forward pass.
is a so-called positional encoding which distinguishes tokens of different patches from each other. That is, token features should have positional information, so that the overall map is not invariant to permutations of the patches, and inserts this positional information.
In this DINO case, where the transformer receives differently-sized images, we learn a positional encoding for the largest size received during training, and interpolate to get the positional encodings for smaller-sized inputs.
Thus, in the end we have
(7.2.11) |
All parameters are contained in the parameter set .
Given a sequence of embeddings , we process it using the backbone map as follows and as depicted in Figure 7.7. The function is composed of layers , i.e.,
(7.2.12) |
The layer has the following implementation:
(7.2.13) | ||||
(7.2.14) |
and is defined such that . Here we have used some operators, such as and that are defined as follows:
The operator is multi-head-self-attention, the predecessor of the multi-head subspace self-attention (cf Chapter 4). The formulation is as follows:
(7.2.15) | ||||
(7.2.16) |
where is a positive integer, , , and are trainable parameters contained in the parameter set , and the is defined column-wise as
(7.2.17) | ||||
(7.2.18) |
In practice, the dimensions are usually picked such that . The terms
(7.2.19) |
are also known as the th attention map and th attention head output at layer , respectively. Furthermore, the operation can be computed extremely efficiently using specialized software such as FlashAttention [SBZ+25].
The is a two-layer perceptron, a regular nonlinearity used in deep networks, and has the form
(7.2.20) |
where are trainable parameters also contained in the parameter set , and is the element-wise ReLU nonlinearity, i.e., .
Each layer-norm for is a standard normalization, which applies column-wise to each token feature independently:
(7.2.21) |
and has the form
(7.2.22) |
where denotes element-wise multiplication, and are trainable parameters contained in the parameter set . The layer-norm operator serves as a sort of normalization on each token, where the scale of each token afterwards is learnable and shared amongst all tokens.
The transformer is one of the most popular neural network architectures in history, powering applications in almost all fields of deep learning.
We use a post-processing step which extracts the class token feature, which (recall) is the feature meant to contain aggregate information about the input image, and applies an MLP and normalization to it. Namely, we have
(7.2.23) |
For DINO, we use the task-specific DINO head . For SimDINO, we use no task-specific head at all, as previously described.
We have a loss function and an architecture, so we now discuss the optimization strategy. The optimization strategy for DINO uses two sets of weights for the same architecture: student weights and teacher weights . These correspond to two different neural networks, called the teacher network and student network, with the same architecture. The teacher network encodes all global views, while the student network encodes all “other” views. The goal of the loss is to distill teacher outputs into the student model. Namely, we train on the loss :
(7.2.24) |
Now, we can fully describe the overall pipeline of DINO, depicted in Figure 7.8.
While it is easy to reason about (7.2.24), it is impossible in practice to implement optimization algorithms such as gradient descent with a loss given by . This is because the expectations in the loss are impossible to evaluate, much less to take the gradient of. In this extremely frequent case, we approximate the expectation via finite samples. That is, at each timestep we:
Subsample data points from our dataset .
For each data point , sample global views and local views . Apply the views to to obtain and .
For each local view , compute the following quantities:
(7.2.25) |
and for each global view , compute the following quantities (by an abuse of notation):
(7.2.26) | |||
(7.2.27) |
Compute the surrogate, approximate loss , defined as follows:
(7.2.28) | ||||
as well as its gradients with respect to and , which should be computed under the assumption that , , and are constants — namely that they are detached from the computational graph and not dependent on and .
Update the student parameters and via an iterative gradient-based optimization algorithm, and update , , and via exponential moving averages with decay parameters , , and respectively, i.e.,
(7.2.29) | ||||
(7.2.30) | ||||
(7.2.31) | ||||
(7.2.32) |
For example, if the chosen optimization algorithm were stochastic gradient descent, we would have the update , and so on.
Notice that the optimization procedure is rather irregular: although all four parameters change at each iteration, only two of them are directly updated from a gradient-based method. The other two are updated from exponential moving averages, and indeed treated as constants when computing any gradients. After training, we discard the student weights and use the teacher weights for our trained network , as this exponentially moving average has been empirically shown to stabilize the resulting model (this idea is known as Polyak averaging or iterate averaging).
The way that and change over the optimization trajectory (i.e., the functions and ) are hyperparameters or design decisions, with and usually, and similar for . The temperature hyperparameter , used in the DINO head , also changes over the optimization trajectory (though this dependence is not explicitly notated).
Using the surrogate (“empirical”) loss transforms our intractable optimization problem, as in optimizing the loss in (7.2.24), into a tractable stochastic optimization problem which is run to train essentially every deep learning model in the world. This conversion is extremely natural once you have seen some examples, and we will hopefully give these examples throughout the chapter.
The simplified DINO population-level objective is very similar in spirit but much simpler in execution, i.e.,
(7.2.33) |
Thus, as elaborated in Figure 7.9, the SimDINO pipeline is strictly simpler than the DINO pipeline. We can use a simpler version of the DINO training pipeline to optimize SimDINO. At each timestep , we:
Subsample data points from our dataset .
For each data point , sample global views and local views . Apply the views to to obtain and .
For each local view compute . For each global view compute and .
Compute the surrogate, approximate loss , defined as follows:
(7.2.34) | ||||
where is the Gaussian coding rate estimated on finite samples, described in Chapter 4. The gradient of with respect to should (again) be computed, under the assumption that is constant.
Update the student parameters via an iterative gradient-based optimization algorithm, and update via an exponential moving average with decay parameter , i.e.,
(7.2.35) | ||||
(7.2.36) |
Again, we re-iterate that the gradient is only taken w.r.t. , treating as a constant. Here, note that while the choice of is still a design decision, the hyperparameters and are removed.
There are several ways to evaluate a trained transformer model. We highlight two in this section. Let us define the center crop view which is a deterministic resized crop:
it resizes the image so that the shortest edge is of size (similar to random resized crops with area percentage parameter );
then it takes the center crop;
so that the final shape is . Notice that the view is completely deterministic given an input. For an input , we write . Here .
The first, and most architecture-agnostic, way to evaluate an encoder model is to employ linear probing. Linear probing is, in a sentence, running logistic regression on the aggregate features computed by the encoder. This tells us how much semantic information exists in the representations, as well as how easily this information can be extracted. (That is: to what extent do the features of images with different semantics live on different subspaces of the feature space?)
More formally, let us suppose that we want to evaluate the quality and faithfulness of the features of the encoder on image-label data , where there are classes and is a “one-hot encoding” (namely, zeros in all positions except a in the th position if is in the th class). One way to do this is to solve the logistic regression problem
(7.2.37) |
More practically, if we have labeled data , we can solve the empirical logistic regression problem (akin to (7.2.24) vs. (7.2.28)) given by
(7.2.38) |
This problem is a convex optimization problem in , and thus can be solved efficiently via (stochastic) gradient descent or a litany of other algorithms. This linear probe, together with the encoder, may be used as a classifier, and we can evaluate the classification accuracy. The usual practice is to train the model first on a large dataset (such as ImageNet-1K), then train the linear probe on a dataset (such as the training dataset of CIFAR-10), and evaluate it on a third (“holdout”) dataset which is drawn from the same distribution as the second one (such as the evaluation dataset of CIFAR-10).
We can also evaluate the performance of the features on classification tasks without needing to explicitly train a classifier by using the -nearest neighbor algorithm to get an average predicted label. Namely, given a dataset , define the -nearest neighbors of another point as . Using this notation, we can compute the predicted label as
(7.2.39) |
Here, is (by an abuse of notation, cf. indicator variables) the one-hot probability vector supported at , i.e., in the th coordinate and elsewhere. That is, this procedure takes the most common label among the nearest points in feature space. The -nearest neighbor classification accuracy is just the accuracy of this predicted label, namely,
(7.2.40) |
or more commonly its corresponding empirical version, where ranges over a finite dataset (not the existing samples which are used for the neighbors).
Another way to check the performance of the representations, for a transformer-based encoder, is to examine the fidelity of the attention maps as defined in Equation 7.2.19, at the last layer , and given by the following pipeline:
(7.2.41) |
In particular, we examine what the attention maps for a given input reveal about the salient objects in the input image, i.e., which parts of the image provide the most globally-relevant information to the class token. One particular way to do this is to examine the component of the attention map where the class token is extracted as the query and removed from the value matrix, i.e., or its transpose . Notice that this vector , which we label as the “saliency vector at the th attention head at layer ,” has a value for every patch, , and we use this value to describe how relevant each patch is toward the global information. In particular for visualization’s sake we create a new image where each patch is replaced by its corresponding value in the saliency vector, showcasing the contribution of each patch; we call this image the “saliency map at the th attention head at layer ”. To visualize the total relevance of each patch toward the global information across all heads, we can average the saliency vector, i.e., and expand into the average saliency map. The average saliency maps should highlight the relevant parts of the input image.
We can evaluate how the representations capture the fine-grained (i.e., smaller or more detailed) properties of the input by using them for semantic segmentation. Roughly, this means that we use the features to construct bounding boxes for all objects in the input. There are several ways to do this, and several ways to score the resulting bounding boxes compared to ground truth. Each combination of methods corresponds to a particular segmentation metric. We do not formally describe them here as they are not particularly insightful, but the DINO paper [CTM+21] and DINOv2 paper [ODM+23] contain references to all metrics that are used in practice.
Since SimDINO is directly built upon DINO, we compare the optimal settings for DINO as given by their original paper [CTM+21] with the same settings applied to SimDINO for a fair comparison.
We use local views (i.e., ) of resolution (i.e., ) and global views (i.e., ) of resolution (i.e., ) for all experiments. The corresponding portions of the original images cropped for local and global views are and (chosen randomly per-view). The smaller edge size within the resized crops is , and the center crop (evaluation) view edge size is . All of these settings apply to both DINO and SimDINO.
For all inputs, we set the patch size to be (namely, ). We use the small, base, and large models of the ViT [DBK+21] architecture as the embedding and backbone. The feature extractor is a three-layer MLP with a hidden size of and an output dimension of , followed by an -normalization, as specified in Section 7.2.3. For DINO architectures (i.e., not SimDINO architectures), the DINO head is a matrix in , and the parameter is a vector in .
For pre-training, both our DINO reproduction and SimDINO use the ImageNet-1K dataset across all methods. We use AdamW [LH17] as the optimizer, which is a very standard choice. We follow the following hyperparameter recommendations:
The batch size is .
The learning rate (for AdamW and the student model) has “base” value . In the first epochs the learning rate linearly increases from to the base value (i.e., at the th epoch the learning rate is , for ). Then over the next epochs the learning rate decays via a so-called cosine schedule back down to . The definition of a cosine schedule is given in many places, including PyTorch documentation, and it is commonly used when training deep vision models.
The weight decay (the W in AdamW) follows a cosine schedule from 0.04 to over training.
The EMA rate follows a cosine schedule from to over training. Specifically for DINO, the centering EMA rate is fixed at .
Specifically for DINO, the teacher temperature is fixed at , while the student temperature linearly increases from to during the first epochs and is fixed at thereafter.
We use some (essentially information-preserving) data augmentations, such as flips, color jittering, Gaussian blur, and solarization, for each seen image during training, before taking the local and global views. The exact hyperparameters governing these are not listed here, but are referenced in the DINO paper [CTM+21].
For linear probing, the linear probe is usually trained using the AdamW optimizer with learning rate , weight decay , and batch size , but these are often modified on a case-by-case basis to minimize the loss.
Method | Model | Epochs | 20-NN | Linear Probing |
DINO | ViT-B | 100 | 72.9 | 76.3 |
SimDINO | ViT-B | 100 | 74.9 | 77.3 |
DINO | ViT-L | 100 | – | – |
SimDINO | ViT-L | 100 | 75.6 | 77.4 |
SwAV | ViT-S | 800 | 66.3 | 73.5 |
MoCov3 | ViT-B | 300 | – | 76.7 |
Detection | Segmentation | ||||||
---|---|---|---|---|---|---|---|
Method | Model | AP50 | AP75 | AP | AP50 | AP75 | AP |
SimDINO | ViT-L/16 | 5.4 | 1.9 | 2.4 | 4.5 | 1.4 | 1.9 |
SimDINO | ViT-B/16 | 5.2 | 2.0 | 2.5 | 4.7 | 1.5 | 2.0 |
DINO | ViT-B/16 | 3.9 | 1.5 | 1.8 | 3.1 | 1.0 | 1.4 |
DINO | ViT-B/8 | 5.1 | 2.3 | 2.5 | 4.1 | 1.3 | 1.8 |
In terms of downstream classification performance, we obtain the performance in Table 7.1. We observe that the performance of SimDINO is much higher than that of DINO under fair comparison. Also, it is much more stable: the prescribed settings of DINO cannot train a ViT-L(arge) model. On the other hand, Figure 7.10 shows visualizations of the average saliency maps in DINO and our simplified DINO, observing that the saliency maps look quite similar across models, indicating that the models learn features which are at least as good at capturing fine-grained details. The segmentation and object detection performances in Table 7.2 confirm this claim quantitatively, where SimDINO features show substantive improvement over those of DINO.
In the previous section, we simplified an overly complex learning objective using our intuition about representation learning through the lens of compression. However, many of the most popular learning procedures are incredibly simple. In these cases, it is difficult to further simplify the objective. Thus, in this and future sections, we will focus on principled ways to modify the deep network architectures for a variety of tasks.
Let us first start with arguably the most classical task in machine learning: image classification, which is often used as a standard task to evaluate pattern recognition algorithms or deep network architectures. From our discussion of white-box architectures in Chapter 4, we only need a semantically meaningful task to learn good representations with white-box architectures. We will validate this idea in this section.
First, the dataset stays largely the same as Section 7.2.1. Both the training and test data consist of labeled images, i.e., image-label pairs . We still apply various data augmentations (e.g., flips, Gaussian blurring, solarization, etc.) to each sample in each new batch.
Unlike before, our task is not just to learn a good representation of the data, but also to simultaneously build a classifier. Formally, we have labeled data pairs , where is a one-hot vector denoting the class membership of . We consider a deterministic center crop view of the input data (cf Section 7.2.2). We want to jointly train a feature mapping and a classification head , defined as follows:
(7.3.1) |
where are trainable parameters in the parameter set , such that the map predicts a smoothed label for the view of the input . The learning problem attempts to minimize the distance between and measured through cross-entropy:
(7.3.2) |
The architecture that we use is the CRATE architecture, described in some detail in Chapter 4. The overall setup is similar to that of the regular transformer in Section 7.2.3, with a few changes. While the embedding step is the same as both DINO and SimDINO in Section 7.2.3, the feature extraction step is the same as SimDINO in Section 7.2.3 as it just extracts the feature corresponding to the class token, and the classification head is described in Section 7.3.1, the backbone architecture is different. Each layer takes the form
(7.3.3) | ||||
(7.3.4) |
where the and blocks are as described in Chapter 4, namely:
The operator is multi-head-subspace-self-attention, defined as follows:
(7.3.5) |
where , , and are trainable parameters belonging to the parameter set , and (recall) the self-attention operator is defined in (7.2.16).
The operator is the iterative-shrinkage-thresholding-algorithm operator, defined as follows:
(7.3.6) |
so named because the map is one step of the well-established ISTA algorithm to find an element-wise non-negative sparse representation for with respect to the complete dictionary (cf Section 2.3).
We call this architecture CRATE, and a layer of the backbone is depicted in Figure 4.13. CRATE models, on top of being interpretable, are generally also highly performant and parameter-efficient.
We train our classifier using a simple end-to-end stochastic optimization procedure, where we subsample data and views, compute the average loss and its gradient over these samples, and use an optimization algorithm to change the parameters. At each timestep , we:
Subsample different labeled samples .
For each labeled sample , compute the central crop view and apply it to to get .
Compute the predictions .
Form the surrogate stochastic loss
(7.3.7) |
Compute one step of an optimization algorithm on , giving the following iteration:
(7.3.8) |
We use the same evaluation procedure as Section 7.2.5. To summarize, for all evaluations (as well as training) we use a center crop view which reshapes the input image and takes a large central crop of size where is the number of channels in the input image. We can then do linear probing, attention map visualization, and detection/segmentation benchmarks, given the output of this view.
Since CRATE is directly based on the transformer, we compare the optimal settings for ViT as given by [DBK+21, TCD+20] with the same settings applied to CRATE for a fair comparison.
The center crop resizes the whole image so that the shorter edge is of size (i.e., ) before taking a center crop of size (i.e., ), both in evaluation and training. We take patch size (i.e., ). We use the tiny, small, base, and large models of the ViT [DBK+21] architecture as the embedding and backbone, swapping out the MHSA and MLP components for MSSA and ISTA, respectively, using the same number of heads and head dimension in the case of MSSA, and therefore reducing the number of training parameters drastically. For CRATE, we set .
For pre-training, we use the ImageNet-1K dataset. We use the LION optimizer [CLH+24] to pre-train both our ViT replication as well as CRATE. We set the base learning rate as , the weight decay as , and batch size as . Our learning rate schedule increases the learning rate linearly to the base learning rate over the first epochs, and decreases to using a cosine schedule over the next epochs (training all models for epochs each). For pre-training, we apply a usual regime of data augmentations (flips, Gaussian blurs, solarization, etc.) to the image data, and also add small noise to the labels (this is called label smoothing [MKH19]).
For linear probing, we use several evaluation datasets such as CIFAR10, Oxford-Flowers, and Oxford-IIT-Pets. We use the AdamW optimizer to train the linear probe, using learning rate , weight decay , and batch size . We also apply the aforementioned data augmentations to the image data.
Model | CRATE-T | CRATE-S | CRATE-B | CRATE-L | ViT-T | ViT-S |
---|---|---|---|---|---|---|
# parameters | 6.09M | 13.12M | 22.80M | 77.64M | 5.72M | 22.05M |
ImageNet-1K | 66.7 | 69.2 | 70.8 | 71.3 | 71.5 | 72.4 |
ImageNet-1K ReaL | 74.0 | 76.0 | 76.5 | 77.4 | 78.3 | 78.4 |
CIFAR10 | 95.5 | 96.0 | 96.8 | 97.2 | 96.6 | 97.2 |
CIFAR100 | 78.9 | 81.0 | 82.7 | 83.6 | 81.8 | 83.2 |
Oxford Flowers-102 | 84.6 | 87.1 | 88.7 | 88.3 | 85.1 | 88.5 |
Oxford-IIIT-Pets | 81.4 | 84.9 | 85.3 | 87.4 | 88.5 | 88.6 |
Detection () | Segmentation () | |||||
---|---|---|---|---|---|---|
Model | AP50 | AP75 | AP | AP50 | AP75 | AP |
CRATE-S/8 | 2.9 | 1.0 | 1.1 | 1.8 | 0.7 | 0.8 |
CRATE-B/8 | 2.9 | 1.0 | 1.3 | 2.2 | 0.7 | 1.0 |
ViT-S/8 | 0.1 | 0.1 | 0.0 | 0.0 | 0.0 | 0.0 |
ViT-B/8 | 0.8 | 0.2 | 0.4 | 0.7 | 0.5 | 0.4 |
Table 7.3 demonstrates that CRATE models achieve parity or improvement compared to the popular Vision Transformer (ViT) architecture at similar parameter counts, at least in terms of the linear separability of their features with respect to different classes. In terms of attention map fidelity, Figure 7.11 demonstrates a truly extraordinary result: without needing to train on any segmentation or object detection data, not only do the saliency maps effectively capture all relevant parts of the input image, the saliency maps self-organize to each correspond to a discrete set of concepts, even across samples and classes! This is the first system to do this, to the authors’ knowledge, and it can do this without using any extra data except for the image classification data. Table 7.4 confirms these qualitative insights quantitatively, showing significant improvement over ViTs trained in the same supervised classification setup.
We now study causal language modeling, a method for training large language models (LLMs). This is the same setup used to train, among many others, GPT-2 and many other language models.
The data we will use to investigate the performance of CRATE for language tasks will be OpenWebText (OWT) [GC19], an open-source reproduction of the unreleased WebText dataset used by OpenAI to train GPT2. Each sample in OWT is a web document, typically sourced from high-quality web pages, blogs, articles, or online discussions, that is written in well-formed natural language. The OpenWebText dataset contains around 8.01M documents of varying lengths, totaling around 41.70GB of text. For evaluation, we will use several datasets, such as WikiText [MXB+16]333For WikiText2 and WikiText103 [MXB+16], the test splits are the same, so we merge them as a single dataset referred to as WikiText., LAMBADA [PKL+16]444To obtain the accuracy on the LAMBADA dataset, we use greedy decoding., and PTB [MSM93]. PTB and OWT are generally easier compared to other datasets. PTB focuses on simpler journalistic text, ideal for traditional language modeling, while OWT is diverse and informal, covering various topics but with less complexity in language structure or long-range dependencies. WikiText, with its formal structure and domain-specific content, requires a more complex understanding than OWT but remains manageable. LAMBADA is the most challenging, as it involves long-range dependencies, requiring the model to grasp broader contextual information to complete sentences accurately.
On a more formal level, our data will be text, or strings of characters; we let be the set of all strings.
For causal language modeling pre-training, the idea is that we want to train the model to output human-like text. The most popular way to do this by far is to use a two-stage training process:555Modern language model training has several additional training steps which demand different data distributions and algorithm approaches. However, training a model to merely mimic human writing only requires these few presented steps.
First, we wish to learn a way to optimally encode documents as a sequence of basic (“building block”) strings, called tokens. This is called tokenization, and we build a tokenizer.
Second, we wish to learn a way to predict the distribution of a token given all previous tokens. This is called next-token prediction, and we build a language model.
This procedure actually dates back to Markov, who first noticed that natural language could be modeled by the eponymous Markov chain structure [Mar06] given an appropriate tokenization, and then to Shannon, who proposed doing this exact language modeling setup with a character-level tokenizer (i.e., each character is a token) and so-called “-gram” (i.e., an explicit look-up table, calculated from training data, for the distribution of a token given the previous tokens) in place of the language model [Sha48].666A recent study [LMZ+24] scaling up -gram models has shown that they are able to model text reasonably well for large , but of course the memory required to store such a lookup table is of order and hence completely intractable.
To build a tokenizer amounts to building a vocabulary , which is a set of tokens and has some pre-specified size . There are several methods to do this. One popular algorithm is known as Byte Pair Encoding (BPE), which can be described as:
Start with a list of all unique characters in your training data, and their frequencies. Ensure that there are fewer than such characters, and add each character as a separate string (“token”) to the vocabulary along with its frequency.
Until there are tokens in the vocabulary:
Construct a token by taking the two most frequent existing tokens and merging them.
Compute this token’s frequency in the dataset.
Add it to the vocabulary (along with its frequency).
At this point, the frequency information is no longer needed and can be discarded.
The overall process of BPE is in Figure 7.12. Note that this procedure is a modification of a classical information-theoretic compression procedure for learning a lossless encoding of bytestream data (such as text), and as such, one can interpret it as finding an optimal lossless compression of the data. Notice that this is possible because (unlike images), the data here are fundamentally discrete and noise-free.
After such a vocabulary is built, a tokenizer can break down a document into tokens (i.e., “tokenize” it). BPE uses a similar procedure to tokenize data as in training:
Separate the document into a long list of one-character-long tokens. That is, if the document is “Hello” then the initial list is ‘H’, ‘e’, ‘l’, ‘l’, ‘o’.
While any two adjacent tokens can be concatenated and their concatenation is another token, we do it, i.e., we replace this pair of tokens with the merged token. Namely, if ‘He’ is a token in the vocabulary, ‘H’, ‘e’, ‘l’, ‘l’, ‘o’ would become ‘He’, ‘l’, ‘l’, ‘o’.
Repeat the above process until no more merges can be done. At this point, the document is partitioned into the final list (sequence) of tokens.
There are many practical and efficiency-based considerations to take into account during tokenization. The above algorithm, as presented, is very far from optimal if naively implemented, for instance. We do not cover this topic in great detail; there are many resources online to learn more, such as HuggingFace tutorials.
For instance, each token has a corresponding index which is just its index in the vocabulary (which after all is just a list of length ). Thus, the output of most tokenizers is a list of indices, say an element of . Keep in mind that they correspond to substrings of the original document, as shown above.
Once a tokenizer is learned, it can be used as a black box by any language model. For instance, many models have the same (OpenAI-based) tokenizer based on the tiktoken library. In the remainder of this section, we will use such a fixed and pre-built tokenizer for everything, and thus identify each text document with its tokenized version in . Therefore, we may as well consider the text space as equal to the space of token sequences (and lose nothing essential).
Once we have each document as a sequence of tokens , we wish to perform next-token prediction. That is, given a context (i.e., the first tokens in the document)777Note the incongruity with Python notation: here the notation includes index ., we wish to predict the token at position . To do this, we compute the aggregate feature of via , and use a classification head (implemented as either a linear layer, MLP, or something slightly more complicated) to project this feature into the -dimensional probability simplex . This projection serves as an estimated probability distribution of the next token. Then, using the notation to be in the th component and elsewhere, the causal language modeling loss is
(7.4.1) |
Note how similar this is to a classification loss (say for images); one uses the cross-entropy and tries to align a predicted probability vector with the ground truth. The major difference between these two losses is that in this one, we compute the loss on a whole sequence, where each prediction is correlated with the others (unlike in the i.i.d. classification case).
Optimizing this loss is usually called “pre-training” in the language model community (contrast with “post-training” and, more recently, “mid-training”, which are methodologies to modify a next-token-predictor for useful tasks).
Side note: Why does the first term of (7.4.1) predict , and there is no term which measures the loss to predict the first token? It’s because if we wanted to predict the first token, we would have the empty sequence as context, and therefore make this first token prediction using a qualitatively different mechanism than that which applies to the other tokens. So actually this model is not trained to predict the very first token of any document. The reason this is OK is due to an implementation detail of the tokenizer: often, after building the tokenizer, we insert a special token into its vocabulary, called the beginning-of-string (or document) token and labeled <|bos|>.888There are usually several special tokens for different purposes. Existing text containing the special tokens are specially processed. Then, while processing each document, we add the <|bos|> token at the beginning of the document’s token sequence, increasing the length of the tokenized sequence by . Thus the above causal language modeling objective has a term which involves trying to predict the first token of the document given only the <|bos|> token as context, and so it is a conceptually correct loss.
For the architecture, we use a standard GPT-2-style transformer, substituting CRATE layers for the transformer layers.999In direct contravention of the conventions in this book and those of many other communities, the NLP community calls such GPT-2-style transformers (encompassing nearly all current LLMs) “decoder-only” transformers. “Encoder-only” transformers have a different architecture, and “encoder-decoder” transformers concatenate an “encoder-only” transformer with a “decoder-only” transformer. This despite the fact that “decoder-only” transformers also compute an encoding of the data! For completeness, we specify the architecture here.
We first embed the token sequence to Euclidean space. This is often done by associating each index in with a vector in using a massive101010By “massive” we mean that such a structure is often a large fraction of the language model’s total size. array , and directly forming the sequence . The full embedding map also applies a positional encoding where is the maximum number of tokens which are possible to process,111111Modern positional encoding methods have since taken care of this issue and allowed for (in theory) infinite extrapolation, but such methods are more complex to develop, and for simplicity we only introduce the absolute additive positional encoding here. which yields the embedding map
(7.4.2) |
The parameters and are directly trainable. Since is so large (and the gradient update is very sparse w.r.t. it since only a small fraction of the vocabulary is used in each sample), specialized software is used to make sure the memory updates are not too onerous. Notice also that we do not use a class token like in the other sections; more on this later.
We process the embeddings using a CRATE-like backbone which uses causal masking. To motivate causal masking, consider the causal language modeling loss defined in (7.4.1). The most naive implementation would require us to compute hte forward pass times in order to backpropagate once. Obviously this is extremely inefficient, since can often be in the thousands. In order to scale training with this loss efficiently, we impose a causal constraint, i.e.,
(7.4.3) |
i.e., the columns of the token features should be the same as the first columns of the token features regardless of the positive values of and such that . In effect, this means we can apply the backbone once to the whole sequence and compute , then apply to each increasing subset as grows to the sequence length . Then we can use all of those to compute the loss.
So now that we want a causal architecture for the backbone, how can we get it? Since the MLP and layer normalizations inside each transformer layer affect each token individually, the only thing that matters for causality is the attention block (or in the case of CRATE). In order to make causal, we define the block as
(7.4.4) | ||||
where | (7.4.5) | |||
where | (7.4.6) |
Here, practitioners say that the causal mask allows future tokens to attend to past tokens but not vice versa. To see why, let us write out the expression for the th column of :
(7.4.7) |
(where here the non-colon subscript denotes the column). This expression for the th token uses no information about any token beyond index . Therefore , hence , hence the whole causal CRATE backbone is causal in terms of the definition in (7.4.3), and we unlock the considerable efficiency gains that we were promised.
We use a post-processing step which extracts the feature vector of the last known token so as to predict the next token. In theory, this means that each token should contain rich information about all tokens that come before or at index , i.e., , as all of this information should be available for predicting the next token at index . In practice, only a few of these tokens are really needed for each prediction task. Anyways, the equation for is
(7.4.8) |
where (again) the non-colon subscript is the column. In this case, as promised, we just directly extract the feature vector of the last token in the sequence.
For our classification head , the GPT-2 architecture uses a simple linear layer and a softmax to get the desired probability vectors:
(7.4.9) |
where . Some other more modern architectures use small MLPs and layer normalizations, but the idea is very much the same. Note that this linear layers also have large memory usage (because is very large), and form a bottleneck in training; there has been significant effort attempting to circumvent it.
All these architectural choices mean that causal training is extremely efficient relative to non-causal training:
We only need one forward pass through the backbone to compute the loss for the whole sequence.
The feature extraction is basically free.
All tokens can be pushed through the task-specific head in parallel.
We train our language model using end-to-end stochastic optimization. One remaining issue is that, in practice, different documents in a batch have different lengths (in terms of the number of tokens required for each sequence), but as of the time of writing this book, the main deep learning frameworks by and large allow only “rectangular” tensors, which do not accommodate this behavior. To try to resolve this issue, we just insert a special padding token <|pad|> for all shorter samples in the batch, so that we can batch-process everything using rectangular tensors. At each timestep , we:
Subsample different tokenized documents , each with length .
Compute and pad each to length using a special padding token.
Compute the features .
Compute the predicted distributions .
Form the surrogate stochastic loss
(7.4.10) |
Compute one step of an optimization algorithm on , giving the following iteration:
(7.4.11) |
There are several ways to evaluate a trained transformer language model.
On a holdout dataset of arbitrary text, we can evaluate on it; lower losses are better since they mean the model’s sampling yields better performance.
On a multiple choice question dataset, for each question we can put it as the context and check the estimated probability of the correct answer being generated.
We can also test the text generation capabilities. Namely, we can repeatedly sample from the model’s probability distribution over the next token given the context. Each time we sample we generate a new token, which we print and add to the context. This allows us to sample from the LLM, and judge the generated samples however we please.121212Having to re-run the model on each token every time can become prohibitively expensive. Clever storages of different internal features of the language model (such as the so-called - cache), along with the causality of the architecture, can dramatically reduce the cost of sampling.
In this section, we perform the first kind of evaluation exclusively.
Since our causal CRATE architecture is directly built upon GPT-2, we compare the optimal settings for GPT-2 as given by the NanoGPT repository [Kar22] with the same settings applied to CRATE for a fair comparison.
We use the GPT-2 tokenizer, which has vocabulary size , including a special token for <|pad|>.131313The <|bos|> token is not included in this setup, although it is very common in modern language models. The context length is . The backbone model follows the GPT2-Base architecture [RWC+19] with the appropriate alterations to have causal CRATE layers, and we compare against GPT2-Small and GPT2-Base.
For training causal CRATE, we follow the implementations in the NanoGPT repository [Kar22]. Specifically, we use a batch size of 384 and train for 600,000 steps with the Adam optimizer [KB14]. For the Adam optimizer, we use and weight decay of . For the learning rate schedule, we apply a linear warm-up and cosine decay, with a peak value of at the th iteration, and minimum value . The training and validation losses over iterations are shown in Figure 7.13. The training/validation loss converges around after training with a batch size of and iterations. In comparison, the open GPT-2 implementation is pre-trained on OpenWebText with a batch size of and steps and converges to a validation loss of [Kar22].
Table 7.5 demonstrates that CRATE models achieve reasonable performance on the causal language modeling loss across a variety of datasets compared to GPT-2 models with similar parameter counts and similar architectures.
#parameters | OWT | LAMBADA | WikiText | PTB | Avg | |
---|---|---|---|---|---|---|
GPT2-Base | 124M | 2.85 | 4.12 | 3.89 | 4.63 | 3.87 |
GPT2-Small | 64M | 3.04 | 4.49 | 4.31 | 5.15 | 4.25 |
Causal-CRATE-Base | 60M | 3.37 | 4.91 | 4.61 | 5.53 | 4.61 |
In this section, we will discuss three ways in which various parts of CRATE-type models can be scaled up or made more efficient while still remaining white-box. These developments mix both conceptual and empirical insights, and can be viewed as case studies about how to use white-box understanding to improve deep learning models in practice. The tasks that we use to evaluate the methods will be image classification and next-token-prediction, the data will be ImageNet and OpenWebText respectively, the optimization procedure will be the same backpropagation, and the only thing that changes is the architecture.
One design decision enforced by the CRATE framework is the width of the nonlinearity in the network. In a regular transformer, the width is usually set to , , or times the feature dimension. However, CRATE enforces that the width is exactly equal to the feature dimension, i.e., the dictionaries are square, which could lead to reduced performance. The fundamental reason that the CRATE framework constrains us to this choice is as follows:
The ISTA block takes a single step of optimization for dictionary learning.
Usually one step of any iterative optimization algorithm cannot effectively optimize the objective. So then why does this work?
Optimization algorithms usually converge very quickly if and only if they have good initializations, or warm starts. The ISTA block has a warm start — it treats the input features as an initialization to the resulting sparse codes.
This enforces that the input features and sparse codes have the same dimension. Namely, ISTA learns a complete sparsifying dictionary (cf Chapter 2).
Thus if we want to use a wide dictionary, we need ISTA to perform overcomplete dictionary learning. This means we cannot have the same warm start (as our sparse codes have a larger dimension than our features), and need more iterations to converge to a sparse code. Hence the step from features to sparse codes would no longer be
(7.5.1) |
where the function is defined as (by an abuse of notation from earlier sections)
(7.5.2) |
but rather the following iteration:
(7.5.3) |
i.e., running proximal gradient on the LASSO objective for steps in the forward pass at each layer, initialized at . In this circumstance, the dictionary can be as wide as needed, i.e., where (usually one takes in practice).
Detection | Segmentation | |||||
---|---|---|---|---|---|---|
Model | AP | AP | AP | AP | AP | AP |
CRATE--B/8 | 3.5 | 1.1 | 1.5 | 2.2 | 1.0 | 1.1 |
CRATE--L/8 | 4.0 | 1.7 | 2.0 | 2.7 | 1.1 | 1.4 |
CRATE-B/8 | 2.9 | 1.0 | 1.3 | 2.2 | 0.7 | 1.0 |
ViT-B/8 | 0.8 | 0.2 | 0.4 | 0.7 | 0.5 | 0.4 |
However, this presents an empirical problem. Using the above configuration, if , then , which can have an arbitrarily large feature dimension. In practice, we want the feature dimension at each layer to be the same. So this sets up a practical trichotomy for designing the network, namely, we cannot have all of the following desiderata:
The feature dimension at each layer is the same.
The dictionary is wide, i.e., overcomplete.
The output of the nonlinearity is the sparse codes of the input with respect to the dictionary.
In practice, giving up (1) is less tractable for efficiency reasons. Giving up (2) leads to the usual CRATE framework. Giving up (3) leads to a wide version of CRATE, i.e., CRATE-, which has the following nonlinearity to get from to :
(7.5.4) |
i.e., it takes the sparse codes obtained via proximal gradient descent and multiplies by the dictionary to get the denoised version of the input. Thus CRATE-’s nonlinearity computes a denoised version of the input which is amenable to sparse coding, not the actual sparse codes themselves. The map from to here is called the Overcomplete Dictionary Learning (ODL) block and denoted , i.e.,
(7.5.5) |
Model | GPT-2-B(ase) | CRATE-B | CRATE--S(mall) | CRATE--B |
---|---|---|---|---|
# parameters | 124M | 60M | 57M | 120M |
OWT val. loss | 2.85 | 3.37 | 3.28 | 3.14 |
The CRATE- layer is shown in Figure 7.14. In practice this modification of CRATE performs very well at larger scales. For example, when we pre-train CRATE- models on ImageNet-21K, unsupervised tasks like segmentation (see Figure 7.15 and Table 7.6) generally have significantly improved performance compared to CRATE. Similar trends are present in language model training using causal self-attention (see Table 7.7). Overall, it is a promising avenue to scaling up the performance to match black-box models such as transformers.141414Note that the experimental results in this section use a slightly different model architecture, which add very slight empirical gains. The changes are: (1) an additional residual connection on the ODL block, (2) modifying to use two separate dictionaries instead of and .
In practice, deep learning models suffer from bottlenecks in space and time complexity, representing problem sizes beyond which they cannot scale given fixed resources. One such bottleneck, particularly meaningful when dealing with data where each sample is itself high-dimensional and rich (such as long streams of text or videos), is the time complexity of processing long sequences of data. In order to alleviate the time complexity of processing data using transformers, in Section 4.3.2 we proposed a token statistics self-attention operator . We now build a token statistics transformer, called ToST, around it, which we can use for long-context tasks. In particular, we can use the following layer (depicted in Figure 7.16) as a drop-in replacement for a backbone layer in CRATE:
(7.5.6) | ||||
(7.5.7) |
where the block is defined as in Section 4.3.2. Notice that this is exactly the same as the vision transformer architecture discussed in Section 7.2.3, except that replaces the conventional multi-head self-attention block . Regardless, the computational complexity of the forward pass of this layer is linear in all problem variables — sequence length, feature dimension, number of heads, and head dimension.
Datasets | ToST-T(iny) | ToST-S(mall) | ToST-M(edium) | XCiT-S | XCiT-M | ViT-S | ViT-B(ase) |
---|---|---|---|---|---|---|---|
# parameters | 5.8M | 22.6M | 68.1M | 24.9M | 80.2M | 22.1M | 86.6 M |
ImageNet | 67.3 | 77.9 | 80.3 | 80.5 | 81.5 | 79.8 | 81.8 |
ImageNet ReaL | 72.2 | 84.1 | 85.6 | 85.6 | 85.9 | 85.6 | 86.7 |
CIFAR10 | 95.5 | 96.5 | 97.5 | 98.1 | 98.3 | 98.6 | 98.8 |
CIFAR100 | 78.3 | 82.7 | 84.5 | 86.1 | 87.6 | 88.8 | 89.3 |
Oxford Flowers-102 | 88.6 | 92.8 | 94.2 | 93.9 | 94.0 | 94.0 | 95.7 |
Oxford-IIIT-Pets | 85.6 | 91.1 | 92.8 | 92.9 | 94.0 | 92.8 | 94.1 |
Model | # params | OWT | Lambada | Wikitext | PTB | Avg |
---|---|---|---|---|---|---|
GPT-2-Base | 124M | 2.84 | 4.32 | 4.13 | 5.75 | 4.26 |
ToST-Base | 110M | 3.20 | 4.98 | 4.77 | 6.39 | 4.84 |
ToST-Medium | 304M | 2.88 | 4.45 | 4.30 | 5.64 | 4.32 |
ToST-Large | 655M | 2.72 | 4.32 | 3.99 | 5.03 | 4.02 |
Model | ListOps | Text | Retrieval | Image | Pathfinder | Avg |
---|---|---|---|---|---|---|
Reformer | 37.27 | 56.10 | 53.40 | 38.07 | 68.50 | 50.56 |
BigBird | 36.05 | 64.02 | 59.29 | 40.83 | 74.87 | 54.17 |
LinFormer | 16.13 | 65.90 | 53.09 | 42.34 | 75.30 | 50.46 |
Performer | 18.01 | 65.40 | 53.82 | 42.77 | 77.05 | 51.18 |
Transformer | 37.11 | 65.21 | 79.14 | 42.94 | 71.83 | 59.24 |
ToST | 37.25 | 66.75 | 79.46 | 46.62 | 69.41 | 59.90 |
Moreover, the proposed architecture, named ToST (for “Token Statistics Transformer”) performs well at vision tasks (i.e., Table 7.8) and language tasks (i.e., Table 7.9). This is especially true for long-sequence-length tasks (cf Table 7.10), where it is both more performant and much more efficient than conventional transformers and all other transformer-like architectures.
Another bottleneck to remove from deep learning models, specifically transformer-like architectures, is the memory bottleneck that comes from massive matrix multiplications in MLPs, where the internal dimension is far greater than the feature dimension . It thus is an interesting and important question to ask: do we really need the MLP inside a transformer, and how good can the performance get without it? To explore this question, we use the attention-only-transformer (AoT) architecture (see Section 4.3.1), depicted in Figure 7.17. Namely, each layer is simply of the form
(7.5.8) |
In our implementation, we also experimented with using multi-head self-attention (MHSA) in place of MSSA. It turns out that this architecture is also viable, though the depth of the network needs to be much greater in order to achieve equivalent performance to the usual CRATE or transformer architecture.
We conduct experiments using the proposed AoT architecture and demonstrate its potential. We pre-train the AoT-MSSA and AoT-MHSA models of different sizes, along with GPT-2, on OpenWebText [GC19]. We plot the training loss and validation loss against the number of training iterations in Figure 7.18(a) and (b), respectively. It is observed that medium- and large-sized AoT-based models achieve training and validation losses comparable to those of the GPT-2 base model. In addition, compared to the GPT-2 base model, the AoT-MHSA model is identical to the GPT-2 base model, except for the absence of MLP layers in the architecture. As shown in Figure 7.18, incorporating MLP layers can accelerate the training process. Using the above pre-trained models, we compute the cross-entropy validation loss without training on different datasets in Table 7.11. It is observed that the AoT models with medium and large parameter sizes can achieve comparable performance to the GPT-2 base model. Moreover, we found that adding MLP layers to AoT does not improve the zero-shot performance. These results highlight the potential of attention-only models to achieve competitive results while maintaining interpretability.
Models | LAMBADA | PTB | WikiText | LAMBADA | CBT CN | CBT NE |
---|---|---|---|---|---|---|
# of parameters | (val loss) | (val loss) | (val loss) | (acc) | (acc) | (acc) |
AoT-MSSA Base (102M) | 4.70 | 6.03 | 4.65 | 0.25 | 0.80 | 0.74 |
AoT-MSSA Medium (182M) | 4.47 | 5.08 | 4.22 | 0.29 | 0.84 | 0.77 |
AoT-MHSA Base (122M) | 4.42 | 5.52 | 4.19 | 0.38 | 0.86 | 0.82 |
GPT-2 Base (124M) | 4.32 | 5.75 | 4.13 | 0.40 | 0.87 | 0.84 |
The second application we discuss is nonlinear image completion, also known as masked autoencoding (MAE), which is a direct generalization of the low-rank matrix completion problem discussed in Chapter 2. Masked autoencoding, since its introduction in the deep learning context by [HCX+22], has been a staple and simple self-supervised representation learning method, which aims to endow each patch feature within with aggregate information as well as information about its neighbors, such that both the patch feature and aggregate features are rich sources of information for the whole sample.
The dataset is kept the same as the image datasets discussed in Section 7.2.1. As usual, we still apply data augmentations to each sample in each new batch.
As the name suggests, masked autoencoding involves a view which, given an input, performs a random resized crop (cf Section 7.2.2) to turn the input image into a square image of size , then masks (i.e., sets to zero) a fixed percentage of pixels in the input. For efficiency reasons151515The original implementation of MAE by [HCX+22] embeds the whole image, removes the tokens that would be masked, feeds the resulting token set through the encoder, adds back learned placeholder tokens in the masked spots and adds back the appropriate positional encoding, and feeds the resulting token set through the decoder to get the autoencoding prediction. This is more efficient since the encoder has fewer tokens to go through, but conceptually is the same as the method discussed in the text, and the resulting models’ performance in the masked autoencoding task and downstream evaluations is very similar., the masking is done patch-wise, i.e., after embedding the whole image, percentage of patches are set to zero. The goal of MAE is to train an encoder and a decoder that can reconstruct an input from its masking, i.e., writing , we have
(7.6.1) |
Essentially this means that the features of the view must contain information about the masked patches as well as the existing patches. From the perspective of the compression-based white-box models in Chapter 4, if a white-box autoencoder succeeds at this task, it means that the learned subspaces and dictionaries perform a redundant encoding of the data such that it can reconstruct missing parts of the data from encoded other parts of the data. This means that information about each patch is stored in other patches. Therefore, each patch feature should contain both information about the patch and information about the statistics of the whole image. Thus, again, we expect that the representations should contain both local and global semantically relevant information, and therefore representations of different patches with similar local and global information should be related (i.e., on the same subspace or encoded together by a dictionary).
We use a CRATE encoder and decoder, depicted in Figure 7.7, though of course it is possible to use a regular transformer encoder and decoder. Details follow.
The encoder is the same as the CRATE encoder in Section 7.3.2, with the caveat that there is no feature extractor . However, both the embedding and the backbone are the same.
The decoder backbone is the CRATE decoder described in Chapter 5. For completeness, we describe it now. Given a feature sequence , we can process it using the decoder backbone as follows. The function is composed of layers , i.e.,
(7.6.2) |
The layer has the following implementation. First, define . Then, we obtain
(7.6.3) | ||||
(7.6.4) |
and is defined such that . Here, the relevant concept is that should learn an approximate inverse of , as discretizations of a forward- and reverse-time diffusion process, respectively. In particular, should approximate , and similarly, the parameters should be similar to the parameters of . The output is .
To transform back into an estimate for , we need to undo the effect of the embedding module using the unembedding module . As such, harkening back to the functional form of the embedding module in (7.2.11), i.e.,
(7.6.5) |
it implies that our inverse operation looks like the following:
(7.6.6) |
where does the inverse operation of the unrolling and flattening operation that does.161616Again, the “inverse positional encoding” is learned for a large input, and for smaller inputs may be interpolated. It is even possible to directly set equal to the positional encoding and use the same interpolated positional encodings for each input in both the encoder and decoder.
This architecture is a white-box autoencoder where (recall) and . In particular, we can use it to compute an estimate for a masked view which should approximately equal itself.
As in Section 7.3.3, we use a simple optimization setup: we sample images and masks, compute the loss on those samples and the gradients of this loss, and update the parameters using a generic optimization algorithm and the aforementioned gradients. For each timestep , we:
Subsample different samples .
For each sample , compute a different randomized resized crop and mask and apply it to to get .
Compute the estimated autoencoding .
Form the surrogate stochastic loss
(7.6.7) |
Compute one step of an optimization algorithm on , giving the following iteration:
(7.6.8) |
This is the first autoencoder network we discuss in this chapter. We use the same center crop view as in Sections 7.2.5 and 7.3.4, resizing the final image to a square with side length pixels to match the shapes of the input images seen during training.
In addition to evaluating the masked autoencoding loss itself, it is also possible to evaluate the features of the view of the data directly. For attention map fidelity evaluation, obtaining is sufficient, but for linear probing we need to extract a summarized or aggregate feature from . To do this, we can use a (parameter-free) feature extraction map that returns the feature corresponding to the class token, i.e.,
(7.6.9) |
as in (for example) Sections 7.3.1 and 7.3.2. With this, we have a way to obtain aggregate features , at which point we can perform linear probing, segmentation evaluations, and so on.
Since CRATE-MAE is directly based on ViT-MAE, we compare the optimal settings for ViT-MAE as given by [HCX+22] with the same settings applied to CRATE-MAE for a fair comparison.
During training, the masked crop resizes the whole image so that the shorter edge is of size (i.e., ) before taking a random crop of size (i.e., ), and masking of the patches. We take patch size (i.e., ). We use the small and base variants of the ViT-MAE architecture as the embedding and backbone for both the encoder and decoder, swapping out the MHSA and MLP components for MSSA, ISTA, and linear layers, respectively. We use the same number of heads and head dimension in the case of MSSA. However, the original ViT-MAE uses an encoder which uses nearly all the total layers and a decoder which only uses a few layers; we allocate half the total number of layers (which stays the same from ViT-MAE to CRATE-MAE) to our encoder and decoder, as suggested by the conceptual and theoretical framework in Chapter 5. For CRATE-MAE we set .
For pre-training, we use the ImageNet-1K dataset. We use the AdamW optimizer to pre-train both our ViT-MAE replication as well as CRATE-MAE. We set the base learning rate as , the weight decay as , and batch size as . Our learning rate schedule increases the learning rate linearly to the base learning rate over the first epochs, and decreases to using a cosine schedule over the next epochs (training all models for epochs each). For pre-training, we apply the usual regime of data augmentations (flips, Gaussian blurs, solarization, etc) to the image data.
For linear probing, we use several evaluation datasets such as CIFAR10, CIFAR100, Oxford-Flowers, and Oxford-IIT-Pets. For linear probing, we precompute the features of all samples in the target dataset and apply a fast linear regression solver, e.g., from a standard package such as Scikit-Learn.
Table 7.12 demonstrates that CRATE-MAE models achieve, roughly speaking, parity with the popular ViT-MAE architecture at similar parameter counts, and also that the feature learning performance (as measured by performance on downstream classification tasks) increases with scale. Meanwhile, Figure 7.20 demonstrates that the encoder saliency maps (and therefore the fine-grained features learned by the encoder) indeed isolate and highlight the key parts of the input image.
Model | CRATE-MAE-S(mall) | CRATE-MAE-B(ase) | ViT-MAE-S | ViT-MAE-B |
---|---|---|---|---|
# parameters | 25.4M | 44.6M | 47.6M | 143.8M |
CIFAR10 | 79.4 | 80.9 | 79.9 | 87.9 |
CIFAR100 | 56.6 | 60.1 | 62.3 | 68.0 |
Oxford Flowers-102 | 57.7 | 61.8 | 66.8 | 66.4 |
Oxford-IIIT-Pets | 40.6 | 46.2 | 51.8 | 80.1 |
All work in this chapter is downstream of the Transformer architecture, which was introduced by [VSP+17]. The Transformer architecture is formally described in Section 7.2. A main empirical innovation in recent years, spurred by the prevalence and performance of the Transformer architecture, is to formulate a given learning problem as a sequence-to-sequence problem and apply the Transformer architecture. This has enabled the Transformer architecture to be essentially ubiquitous in (almost) all deep learning applications. As such, direct improvements to the Transformer can propagate to become solutions to many problems and have considerable impact; similarly, we can apply our white-box understanding of transformer-like architectures to many modern problems. The material covered in this chapter is merely a subset of the work that has already been done; other work includes masked completion for text data (i.e., BERT) [DCL+19, YBP+24], (mechanistic) interpretability of language and vision models [BM24], and error correcting coding [ZLG+]. There is much more to do.
There is also much more theory specifically about the practice of scaling neural networks, which is enormously practically viable, and we at least remark on it here. This line of work was popularized by the “Tensor Programs” line of work [YHB+22]. The basic prescription is that we want the initial gradient updates in a transformer to be constant size, and by working through the backpropagation equations (Appendix A) carefully, we can determine the scale of the initialization and learning rates (chosen layer-wise) that are required to achieve this. In practice, such prescriptions greatly increase the stability and convergence of training at large scales; they also prescribe a way to find the “optimal”171717The word “optimal” is used in quotes because the work on this merely uses some desiderata about the weight size, feature size, and gradient size at initialization to determine “optimality”, as opposed to, say, the test loss at convergence. hyperparameters for large-scale training using only small-scale training. Follow-ups to this work attempt to accommodate the feature geometry [BN24], which could be informed by the work in this book about representation learning. Other follow-ups incorporate this weight-wise information into the optimizer itself to obtain these scaling benefits automatically, obtaining optimizers such as Muon [JJB+], which have recently been used for training trillion-parameter models very stably [Tea]. Overall, the two approaches to deep learning theory are orthogonal or complementary.
Read the DINO paper [CTM+21].
DINO v2 [ODM+23] uses everything from DINO v1 but also, during the data augmentation phase, randomly masks out patches within each view. This kind of augmentation should enforce that the features of images with similar local information are similar. Formulate an optimization problem that promotes this in the encoder, and implement it.
This exercise considers the implementation of stochastic optimization algorithms to minimize losses involving expectations.
Evaluate the time complexity required to compute the existing term in (7.2.34) and its gradient.
Implement the CRATE and CRATE- models.
Compare their performance and efficiency on the CIFAR-10 dataset.
Compare their interpretability in two ways:
The sparsity of the representation
The attention maps