“What I cannot create, I do not understand.”
— Richard Feynman
In previous chapters, we have shown how to identify low-dimensional structures in high-dimensional spaces, mainly focusing on linear structures. For example, we introduced principal component analysis (PCA) to learn the linear denoiser when the observed data follow the statistical model . In this setting, the learned representations are linearly transformed input data . Under the linear model assumption, one can learn the low-dimensional linear structure with efficient optimization algorithms and strong theoretical guarantees. Moreover, the linear model assumption covers a wide range of applications and problems, including face recognition, magnetic resonance image recovery, and structure texture recovery [WM22].
On the other hand, the linear model can be limited when dealing with real-world applications, especially when the input data is complex, such as speech and natural languages, images and videos, and robotic motions. The low-dimensional distributions of such data are typically nonlinear. How to deal with nonlinearity has a long history across different disciplines such as control theory, signal processing, and pattern recognition. There have been considerable efforts that try to extend methods and solutions for linear models to handle nonlinearity, including early effort to extend PCA to nonlinear PCA (as we will study in more detail in Chapter 5). In most cases, the methods are designed based on certain assumptions about the data distributions and tailored to specific problems.
More recently, deep neural networks have achieved remarkable success across a wide range of data and applications. A neural network
(4.0.1) |
can learn effective features/representations for downstream applications. For example, a trained deep neural network can be applied to map images to feature vectors, that is, , while a linear classifier can be learned on top of such representations . One notable breakthrough is AlexNet [KSH12], a deep convolutional neural network trained with more than a million natural images, outperforming all previous approaches that were based on hand-crafted features. One of the key differences between AlexNet and previous approaches is that the former learns parameters of the nonlinear transformation from massive amounts of data trained with back-propagation (BP) [RHW86a], as detailed in Section A.2.3 of Appendix A.
Subsequent popular practice models the mapping with other empirically designed artificial deep neural networks and learns the parameters from random initialization via BP. Starting with the AlexNet [KSH12], the architectures of modern deep networks continue to be empirically revised and improved. Network architectures such as VGG [SZ14], ResNet [HZR+16a], DenseNet [HLV+17], CNN, RNN or LSTM [HS97], Transformer [VSP+17], and a mixture of experts (MoE) [SMM+17, FZS22], etc. have continued to push the performance envelope. As part of the effort to improve the performance of deep networks, almost every component of the networks has been empirically scrutinized, and various revisions and improvements have been proposed. They are not limited to nonlinear activation functions [MHN13, KUM+17, XWC+15, NIG+18], skip connections [RFB15, HZR+16a], normalizations [IS15, BKH16, UVL16, WH18, MKK+18], up/down sampling or pooling [SMB10], convolutions [LBB+98a, KSH12], etc. However, almost all such modifications have been developed through years of empirical trial and error or ablation studies. Some recent practices even take to the extreme by searching for effective network structures and training strategies through extensive random search techniques, such as Neural Architecture Search [ZL17, BGN+17], AutoML [HKV19], and Learning to Learn [ADG+16].
Despite the wide application of deep neural networks, it is not clear what the underlying design principles of such a constructed network are. In particular, it is not clear what mathematical function each layer of the network performs. In this chapter, based on the results from previous chapters, we develop a principled framework that will provide a fully rigorous mathematical interpretation of the role of a deep network, including its individual layers and the network as a whole.
To understand deep networks and how they should be better designed, we must start with the objective of representation learning. In previous chapters, we have argued that the objective is to identify the intrinsically low-dimensional data distribution and then transform it to a compact and structured (say piecewise linear) representation. As we have seen in the previous chapter, the general approach to identifying a low-dimensional data distribution is through a compression process that progressively minimizes the entropy or coding rate of the distribution. However, up to this point, we have been using empirically designed deep networks to model or approximate the operations that aim to optimize these objectives, such as the score function for denoising (in Section 1.3.1) or the transformation that maximizes the rate reduction (in Section 3.4.3).
As we have argued in the previous chapter, Section 3.4 in particular, one can measure the goodness of the resulting representation by the information of the representation gained from a “lazy” representation which models all data as one big Gaussian.111that we have seen in the previous chapter as one particular choice of interpretation of the sampled dataset. In particular, if we use a mixture of Gaussians (subspaces)222which we have studied thoroughly in the previous chapter. as prototypical distributions to approximate the non-linear distribution of interest, then we can efficiently measure the coding rate of such a representation using the (sum of) rate distortion functions of the associated Gaussians. Then the amount of information gained or (relative) entropy reduced with such a modeling can be measured by the difference between the coding rate for the lazy representation and that for the more refined representation. Then, the objective of representation learning is to maximize this information gain, also known as the rate reduction objective.
As we will see in this chapter, once the objective of representation learning is clear, the role of a deep neural network is precisely to help optimize the objective iteratively. Each layer of a deep neural network can be naturally derived as an iterative optimization step to incrementally maximize the information gain, including the popular architectures of ResNet, CNN, and Transformer, and other more advanced variants. In particular, this chapter aims to answer the following questions about deep networks:
Section 4.1 — given a measure of goodness for a learned representation, how to construct the nonlinear mapping from the data to the optimal representation via unrolled optimization for the objective?
Section 4.2 — how would the above unrolling approach provide a principled interpretation of the popular transformer architectures; if so, what are the associated objective and optimization mechanisms?
Section 4.3 — how would this framework guide us to design more efficient or more parsimonious deep architectures?
Now, if we agree that maximizing the rate reduction or information gain leads to the desired representation as discussed in Section 3.4, the remaining question is how to construct and learn a (nonlinear) mapping from the data to the optimal representation . This involves designing a network architecture and learning algorithm that can effectively capture the underlying structures in the data and faithfully realize the optimal representation.
In the previous chapter, we presented the rate reduction objective (3.4.12) as a principled objective for learning linear discriminative representations of the data. We have, however, not specified the architecture of the feature mapping for extracting such representations from input data . A straightforward choice is to use a conventional deep network, such as ResNet, for implementing . As we have seen in Example Figure 3.24, such a choice often leads to decent performance empirically. Nonetheless, there remain several unanswered problems with adopting an arbitrary deep network. Although the learned feature representation is now more interpretable, the network itself is still not. It is unclear why any chosen “black-box” network is able to optimize the desired MCR2 objective at all. The good empirical results (say with a ResNet) do not necessarily justify the particular choice in architectures and operators of the network: Why is a deep layered model even necessary; what do additional layers try to improve or simplify; how wide and deep is adequate; or is there any rigorous justification for the convolutions (in a popular multi-channel form) and nonlinear operators (e.g. ReLU or softmax) used?
In this chapter, we show that using gradient ascent to maximize the rate reduction as defined in (3.4.12) naturally leads to a “white-box” deep network that realizes the desired mapping. All network layers, linear/nonlinear operators, and parameters are explicitly constructed in a purely forward propagation fashion. Moreover, such network architectures resemble existing empirically-designed deep networks, providing principled justifications for their design.
From the previous chapter, we see that to seek a linear discriminative representation (LDR), mathematically, we are essentially seeking a continuous mapping from the data (or initial features extracted from the data333As we will see the necessity of such a feature extraction in the next section.) to an optimal representation that maximizes the following coding rate reduction objective:
(4.1.1) |
where is a prescribed quantization error and for simplicity we denote444Notice our use of slightly simplified notation compared to Chapter 3.
The question really boils down to whether there is a constructive way of finding such a continuous mapping from to ? To this end, let us consider incrementally maximizing the objective as a function of . Although there might be many optimization schemes to choose from, for simplicity we first consider the arguably simplest projected gradient ascent (PGA) scheme:555Notice that we use subscript on to indicate features in the -th class and superscript on to indicate all features at -th iteration or layer.
(4.1.2) |
for some step size and the iterate starts with the given data .666Again, for simplicity, we here first assume the initial features are the data themselves. Note that here denotes the number of iterations. Hence, the data and the features have the same dimension . This needs not to be the case though. As we will see in the next section, the initial features can be some (lifted) features of the data to begin with and could in principle have a different (much higher) dimension. All subsequent iterates have the same dimension. This scheme can be interpreted as how one should incrementally adjust locations of the current features , initialized as the input data , in order for the resulting to improve the rate reduction , as illustrated in Figure 4.1.
Simple calculation shows that the gradient entails evaluating the following derivatives of the two terms in :
(4.1.3) |
(4.1.4) |
Notice that in the above, the matrix only depends on and it aims to expand all the features to increase the overall coding rate; the matrix depends on features from the -class and aims to compress them to reduce the coding rate of each class. Then the complete gradient is of the form:
(4.1.5) |
For any ,
(4.1.6) |
Notice that is exactly the solution to the ridge regression by all the data points concerned. Therefore, (similarly for ) is approximately (i.e., when is large enough) the projection onto the orthogonal complement of the subspace spanned by columns of . Another way to interpret the matrix is through eigenvalue decomposition of the covariance matrix . Assuming that where , we have
(4.1.7) |
Therefore, the matrix operates on a vector by stretching in a way that directions of large variance are shrunk while directions of vanishing variance are kept. These are exactly the directions (4.1.3) in which we move the features so that the overall volume expands and the coding rate will increase, hence the positive sign. To the opposite effect, the directions associated with (4.1.4) are “residuals” of features of each class that deviate from the subspace to which they are supposed to belong. These are exactly the directions in which the features need to be compressed back onto their respective subspace, hence the negative sign (see Figure 4.2).
Essentially, the linear operations and in gradient ascent for rate reduction are determined by training data conducting “auto-regressions”. The recent renewed understanding about ridge regression in an over-parameterized setting [YYY+20, WX20] indicates that using seemingly redundantly sampled data (from each subspace) as regressors does not lead to overfitting.
Notice that in the above, the gradient ascent considers all the features as free variables. The increment does not yet give a transformation on the entire feature domain . According to equation (4.1.5), the gradient cannot be evaluated at a point whose membership is not known, as illustrated in Figure 4.1. Hence, in order to find the optimal explicitly, we may consider constructing a small increment transform on the -th layer feature to emulate the above (projected) gradient scheme:
(4.1.8) |
such that That is, we need to approximate the gradient flow that locally deforms all (training) features with a continuous mapping defined on the entire feature space . Notice that one may interpret the increment (4.1.8) as a discretized version of a continuous differential equation:
(4.1.9) |
Hence the (deep) network so constructed can be interpreted as a certain neural ODE [CRB+18]. Nevertheless, unlike neural ODE where the flow is chosen to be some generic structures, here our is to emulate the gradient flow of the rate reduction on the feature set (as shown in Figure 4.1):
and its structure is entirely derived and fully determined from this objective, without any other priors or heuristics.
By inspecting the structure of the gradient (4.1.5), it suggests that a natural candidate for the increment transform is of the form:
(4.1.10) |
where indicates the probability of belonging to the -th class. The increment map parameters depend on: First, a set of linear maps represented by and that depend only on statistics of features of the training ; Second, the membership of any feature . Notice that on the training samples , for which the memberships are known, the so defined gives exactly the values for the gradient .
Since we only have the membership for the training samples, the function defined in (4.1.10) can only be evaluated on the training. To extrapolate to the entire feature space, we need to estimate in its second term. In conventional deep learning, this map is typically modeled as a deep network and learned from the training data, say via back propagation. Nevertheless, our goal here is not to learn a precise classifier already. Instead, we only need a good enough estimate of the class information in order for to approximate the gradient well.
From the geometric interpretation of the linear maps and given by Remark 4.1, the term can be viewed as (approximately) the projection of onto the orthogonal complement of each class . Therefore, is small if is in class and large otherwise. This motivates us to estimate its membership based on the following softmax function:
(4.1.11) |
Hence, the second term of (4.1.10) can be approximated by this estimated membership:
(4.1.12) |
which is denoted as a nonlinear operator on outputs of the feature through groups of filters: . Notice that the nonlinearality arises due to a “soft” assignment of class membership based on the feature responses from those filters.
Overall, combining (4.1.8), (4.1.10), and (4.1.12), the increment feature transform from to now becomes
(4.1.13) | ||||
with the nonlinear function defined above and collecting all the layer-wise parameters. That is . Note features at each layer are always “normalized” by projecting onto the unit sphere , denoted as . The form of increment in (4.1.13) can be illustrated by a diagram in Figure 4.3(a).
Notice that the increment is constructed to emulate the gradient ascent for the rate reduction . Hence by transforming the features iteratively via the above process, we expect the rate reduction to increase, as we will see in the experimental section. This iterative process, once converged say after iterations, gives the desired feature map on the input , precisely in the form of a deep network, in which each layer has the structure shown in Figure 4.3 left:
(4.1.14) | ||||
As this deep network is derived from maximizing the rate reducation, we call it the ReduNet. By comparing the architecture of ReduNet with those of popular empirically designed networks, ResNet and ResNeXt shown in Figure 4.3, the similarity is somewhat uncanny. Conceptually, ReduNet could also be used to justify the popular mixture of experts (MoE) architecture [SMM+17] as each parallel channel, , can be viewed as an “expert” trained for each class of objects.
We summarize the training and evaluation of ReduNet in Algorithm 4.1 and Algorithm 4.2, respectively. Notice that all parameters of the network are explicitly constructed layer by layer in a forward propagation fashion. The construction does not need any back propagation! The so-learned features can be directly used for classification, say via a nearest subspace classifier.
To provide some intuition on how ReduNet transforms the features, we provide a simple example with mixed 3D Gaussians and visualize how the features are transformed in Figure 4.5. Consider a mixture of three Gaussian distributions in that is projected onto . We first generate data points for 3 classes: for , , , and . We set , and . Then we project all the data points onto , i.e., . To construct the network (computing for the -th layer), we set the number of iterations/layers , step size , and precision . We do this only to demonstrate that our framework leads to stable deep networks even with thousands of layers. In practice, thousands of layers may not be necessary and one can stop whenever adding new layers gives diminishing returns. For this example, a couple of hundred layers is sufficient. Hence, the clear optimization objective gives a natural criterion for the depth of the network needed.
As shown in Figure 4.5, we can observe that after the mapping , samples from the same class are highly compressed and converge to a single cluster and the angle between two different clusters is approximately , which is well aligned with the optimal solution of the MCR2 loss in . MCR2 loss of features on different layers can be found in Figure 4.5(c). Empirically, we find that the constructed ReduNet is able to maximize MCR2 loss and converges stably and samples from the same class converge to one cluster and different clusters are orthogonal to each other. Moreover, when sampling new data points from the same distributions, we find that new samples from the same class consistently converge to the same cluster center as the training samples.
In the previous section, we derived the layer-wise architecture of a deep network, the ReduNet, using unrolled optimization for the rate reduction objective. Specifically, the compression term in (4.1.1) is designed to compress representations from the same class. However, this formulation does not account for possible domain transformation or deformation of the input data. For instance, shifting an object slightly to the right does not change the semantic label of an image. In this section, we will demonstrate how convolutional layers can be derived by maximizing a rate reduction objective that is invariant to certain domain deformations, such as image rotations and translations.
For many clustering or classification tasks (such as object detection in images), we consider two samples as equivalent if they differ by certain classes of domain deformations or augmentations . Hence, we are only interested in low-dimensional structures that are invariant to such deformations (i.e., iff for all ), which are known to have sophisticated geometric and topological structures and can be difficult to learn precisely in practice, even with rigorously designed CNNs [CW16]. In this framework, this can be formulated in a very natural way: all equivariant instances are to be embedded into the same subspace, so that the subspace itself is invariant to the transformations under consideration.
In many applications, such as serial data or imagery data, the semantic meaning (labels) of the data are invariant to certain transformations (for some group ) [CW16b, ZKR+17]. For example, the meaning of an audio signal is invariant to shift in time; and the identity of an object in an image is invariant to translation in the image plane. Hence, we prefer the feature mapping is rigorously invariant to such transformations:
(4.1.15) |
where “” indicates two features belonging to the same equivalent class. Although to ensure invariance or equivarience, convolutional operators have been common practice in deep networks [CW16b], it remains challenging in practice to train an (empirically designed) convolution network from scratch that can guarantee invariance even to simple transformations such as translation and rotation [AW18, ETT+17]. An alternative approach is to carefully design convolution filters of each layer so as to ensure translational invariance for a wide range of signals, say using wavelets as in ScatteringNet [BM13] and followup works [WB18]. However, in order to ensure invariance to generic signals, the number of convolutions needed usually grows exponentially with network depth. That is the reason why this type of network cannot be constructed so deep, usually only several layers.
Now, we show that the MCR2 principle is compatible with invariance in a natural and precise way: we only need to assign all transformed versions into the same class as the data and map their features all to the same subspace . Hence, all group equivariant information is encoded only inside the subspace, and any classifier defined on the resulting set of subspaces will be automatically invariant to such group transformations. See Figure 4.6 for an illustration of the examples of 1D rotation and 2D translation. Next, we will rigorously show that when the group is circular 1D shifting, the resulting deep network naturally becomes a multi-channel convolution network. Because the so-constructed network only needs to ensure invariance for the given data or their features , the number of convolutions needed actually remains constant through a very deep network, as opposed to the ScatteringNet.
To classify one-dimensional data invariant under shifting, we take to be the group of all circular shifts. Each observation generates a family of shifted copies, which are the columns of the circulant matrix given by
(4.1.16) |
We refer the reader to [KS12] for properties of circulant matrices. For simplicity, let .777Again, to simplify discussion, we assume for now that the initial features are themselves hence have the same dimension , i.e., . But that does not need to be the case as we will soon see that we need to lift to a higher dimension. Then what happens if we construct the ReduNet from their circulant families ? That is, we want to compress and map all these into the same subspace by the ReduNet.
Notice that now the data covariance matrix:
(4.1.17) | |||||
associated with this family of samples is automatically a (symmetric) circulant matrix. Moreover, because the circulant property is preserved under sums, inverses, and products, the matrices and are also automatically circulant matrices, whose application to a feature vector can be implemented using circular convolution “”. Specifically, we have the following proposition.
The matrix
(4.1.18) |
is a circulant matrix and represents a circular convolution:
where is the first column vector of and “” is circular convolution defined as
Similarly, the matrices associated with any subsets of are also circular convolutions.
Not only do the first-layer parameters and of the ReduNet become circulant convolutions but also the next-layer features remain circulant matrices. That is, the incremental feature transform in (4.1.13) applied to all shifted versions of a , given by
(4.1.19) |
is a circulant matrix. This implies that there is no need to construct circulant families from the second layer features as we did for the first layer. By denoting
(4.1.20) |
the features at the next level can be written as
Continuing inductively, we see that all matrices and based on such are circulant, and so are all features. By virtue of the properties of the data, ReduNet has taken the form of a convolutional network, with no need to explicitly choose this structure!
There is one problem though: In general, the set of all circular permutations of a vector gives a full-rank matrix. That is, the “augmented” features associated with each sample (hence each class) typically already span the entire space . For instance, all shifted versions of a delta function can generate any other signal as their (dense) weighted superposition. The MCR2 objective (3.4.12) will not be able to distinguish classes as different subspaces.
One natural remedy is to improve the separability of the data by “lifting” the original signal to a higher dimensional space, e.g., by taking their responses to multiple, filters :
(4.1.21) |
The filters can be pre-designed invariance-promoting filters,888For 1D signals like audio, one may consider the conventional short-time Fourier transform (STFT); for 2D images, one may consider 2D wavelets as in the ScatteringNet [BM13]. or adaptively learned from the data,999For learned filters, one can learn filters as the principal components of samples as in the PCANet [CJG+15] or from convolution dictionary learning [LB19, QLZ19]. or randomly selected as we do in our experiments. This operation lifts each original signal to a -channel feature, denoted as . Then, we may construct the ReduNet on vector representations of , denoted as . The associated circulant version and its data covariance matrix, denoted as , for all its shifted versions are given as:
(4.1.22) |
where with is the circulant version of the -th channel of the feature . Then the columns of will only span at most a -dimensional proper subspace in . However, this simple lifting operation (if linear) is not sufficient to render the classes separable yet—features associated with other classes will span the same -dimensional subspace. This reflects a fundamental conflict between invariance and linear (subspace) modeling: one cannot hope for arbitrarily shifted and superposed signals to belong to the same class.
One way of resolving this conflict is to leverage additional structure within each class, in the form of sparsity: signals within each class are not generated as an arbitrary linear superposition of some base atoms (or motifs), but only sparse combinations of them and their shifted versions, as shown in Figure 4.7. More precisely, let denote a matrix with a collection of atoms associated for class , also known as a dictionary, then each signal in this class is sparsely generated as:
(4.1.23) |
for some sparse vector . Signals in different classes are then generated by different dictionaries whose atoms (or motifs) are incoherent from one another. Due to incoherence, signals in one class are unlikely to be sparsely represented by atoms in any other class. Hence all signals can be represented as
(4.1.24) |
where is sparse.101010Notice that similar sparse representation models have long been proposed and used for classification purposes in applications such a face recognition, demonstrating excellent effectiveness [WYG+09, WWG+12]. Recently, the convolution sparse coding model has been proposed by [PRE17] as a framework for interpreting the structures of deep convolution networks. There is a vast literature on how to learn the most compact and optimal sparsifying dictionaries from sample data, e.g. [LB19, QLZ19] and subsequently solve the inverse problem and compute the associated sparse code or . Recent studies of [QLZ20, QZL+20] even show that under broad conditions the convolution dictionary learning problem can be solved effectively and efficiently.
Nevertheless, for tasks such as classification, we are not necessarily interested in the precise optimal dictionary nor the precise sparse code for each individual signal. We are mainly interested if collectively the set of sparse codes for each class are adequately separable from those of other classes. Under the assumption of the sparse generative model, if the convolution kernels match well with the “transpose” or “inverse” of the above sparsifying dictionaries , also known as the analysis filters [NDE+13, RE14], signals in one class will only have high responses to a small subset of those filters and low responses to others (due to the incoherence assumption). Nevertheless, in practice, often a sufficiently large number of, say , random filters suffices to ensure that the extracted -channel features
(4.1.25) |
for different classes have different response patterns to different filters hence make different classes separable [CJG+15].
Therefore, in our framework, to a large extent the number of channels (or the width of the network) truly plays the role as the statistical resource whereas the number of layers (the depth of the network) plays the role as the computational resource. The theory of compressive sensing precisely characterizes how many measurements are needed in order to preserve the intrinsic low-dimensional structures (including separability) of the data [WM21].
The multi-channel responses should be sparse. So to approximate the sparse code , we may take an entry-wise sparsity-promoting nonlinear thresholding, say , on the above filter outputs by setting low (say absolute value below ) or negative responses to be zero:
(4.1.26) |
Figure 4.8 illustrates the basic ideas. One may refer to [RE14] for a more systematical study on the design of the sparsifying thresholding operator. Nevertheless, here we are not so interested in obtaining the best sparse codes as long as the codes are sufficiently separable. Hence the nonlinear operator can be simply chosen to be a soft thresholding or a ReLU. These presumably sparse features can be assumed to lie on a lower-dimensional (nonlinear) submanifold of , which can be linearized and separated from the other classes by subsequent ReduNet layers, as illustrated later in Figure 4.9.
The ReduNet constructed from circulant version of these multi-channel features , i.e., , retains the good invariance properties described above: the linear operators, now denoted as and , remain block circulant, and represent multi-channel 1D circular convolutions. Specifically, we have the following result.
The matrix
(4.1.27) |
is block circulant, i.e.,
where each is a circulant matrix. Moreover, represents a multi-channel circular convolution, i.e., for any multi-channel signal we have
In above, is a multi-channel convolutional kernel with being the first column vector of , and is the multi-channel circular convolution defined as
Similarly, the matrices associated with any subsets of are also multi-channel circular convolutions.
From Proposition 4.2, shift invariant ReduNet is a deep convolutional network for multi-channel 1D signals by construction. Notice that even if the initial lifting kernels are separated (4.1.26), the matrix inverse in (4.1.27) for computing (similarly for ) introduces “cross talk” among all channels. Hence, these multi-channel convolutions in general are not depth-wise separable, unlike the Xception nets [Cho17] that were once suggested to simplify multi-channel convolutional neural networks.111111It remains open what additional structures on the data would lead to depth-wise separable convolutions.
The calculation of in (4.1.27) requires inverting a matrix of size , which in general has complexity . Nevertheless, by using the fact that a circulant matrix can be diagonalized by the Discrete Fourier Transform (DFT) matrix, the complexity can be significantly reduced. As shown in [CYY+22], to compute and , we only need to compute in the frequency domain the inverse of blocks for times hence the overall complexity becomes .
Following the above derivation, we see that in order to find a linear discriminative representation (LDR) for multiple classes of signals/images that is invariant to translation, sparse coding, a multi-layer architecture with multi-channel convolutions, different nonlinear activation, and spectrum computing all become necessary components for achieving the objective effectively and efficiently. Figure 4.9 illustrates the overall process of learning such a representation via invariant rate reduction on the input sparse codes.
We next provide an empirical performance of the ReduNet on learning rotation invariant features on the real 10-class MNIST dataset. We impose a polar grid on the image , with its geometric center being the center of the 2D polar grid (as illustrated in Figure 4.10). For each radius , , we can sample pixels with respect to each angle with . Then given a sample image from the dataset, we represent the image in the (sampled) polar coordinate as a multi-channel signal . The goal here is to learn a rotation invariant representation, i.e., we expect to learn such that lie in the same subspace, where is the cyclic-shift in polar angle. We use training samples ( from each class) and set , for polar sampling. By performing the above sampling in polar coordinate, we can obtain the data matrix . For the ReduNet, we set the number of layers/iterations , precision , step size . Before the first layer, we perform lifting of the input by 1D circulant-convolution with 20 random Gaussian kernels of size 5.
To evaluate the learned representation, each training sample is augmented by 20 of its rotated version, each shifted with stride=10. We compute the cosine similarities among the augmented training inputs and the results are shown in Figure 4.11 (a). We compare the cosine similarities among the learned features of all the augmented versions, i.e., and summarize the results in Figure 4.11 (b). As we see, the so-constructed rotation-invariant ReduNet is able to map the training data (as well as all its rotated versions) from the 10 different classes into 10 nearly orthogonal subspaces. That is, the learned subspaces are truly invariant to shift transformation in polar angle. Next, we randomly draw another test samples followed by the same augmentation procedure. In Figure 4.11 (c), we visualize the MCR2 loss on the -th layer representation of the ReduNet on the training and test dataset. From these results, we can find that the constructed ReduNet is indeed able to maximize the MCR2 loss as well as generalize to the test data.
As we have seen in the previous section, we use the problem of classification to provide a rigorous interpretation for main architectural characteristics of popular deep networks such as the ResNet and the CNN: each layer of such networks can be viewed as to imitate a gradient step which increases the rate reduction (or information gain) objective. This perspective also leads to a somewhat surprising fact: the the parameters and operators of the layers of such a deep network, the ReduNet, can be computed in a purely forward fashion.
Despite the theoretical and conceptual importance of the ReduNet, several factors limit it from being very practical. First, as we have discussed in the above, the computational cost of computing the matrix operators in each layer in a forward fashion can be very high. Second, the so-computed operators may not be so effective in optimizing the objective and it might take thousands of iterations (hence layers). As we have seen in Section 2.3.3 for LISTA, these two issues can be addressed by allowing to optimize those operators and make them learnable via back-propagation.121212Or, perhaps, by a mixture of both forward and backward optimization.
The supervised classification setting in which the ReduNet was derived is also somewhat limiting. In practice, an image might not belong to a single class as it may contain multiple objects. Hence it would be more general to assume that different regions of the image belong to different low-dimensional models (say a Gaussian or a subspace). As we will see, such a generalization would lead to a both simple and general architecture which unifies the rate reduction and the denoising operations that we have seen in the previous chapter. Moreover, the so-obtained architecture resembles the popular Transformer architecture.
We consider a general learning setup associated with real-world signals. Let denote random variables representing our data source. In vision tasks, each is interpreted as a token, typically corresponding to an image patch. In language tasks, each is interpreted as an token embedding, i.e., a continuous vector representation of a discrete token such as a word or subword.131313With a slight abuse of terminology, we refer to both the discrete tokens and their associated embeddings simply as tokens throughout this chapter for convenience. The ’s may have arbitrary correlation structures. We use to denote the random variables that defines our representations, where is the representation of the corresponding token .
In transformers, each input sample is typically converted into a sequence of tokens. A token is a basic unit of information derived from the raw input: in natural language processing, tokens are typically words or subwords; in computer vision, they correspond to image patches; and in other modalities, they may represent time steps, spatial locations, or other domain-specific units. A token embedding is a continuous vector representation of a token that serves as the input to a transformer. It maps each token to a point in a high-dimensional space, enabling the model to process symbolic inputs using numerical computation. A token representation is a vector that encodes the semantic or structural information of a token, typically produced by the intermediate or final layers of a transformer. These representations are designed to capture meaningful features of the input that are useful for downstream tasks such as classification, generation, or regression. Please refer to Section 7.2 for more details about these concepts in implementations.
Following the framework of rate reduction Section 4.1, we contend that the goal of representation learning is to find a feature mapping which transforms input tokens with a potentially nonlinear and multi-modal distribution to a (piecewise) linearized and compact token representations . While the joint distribution of tokens representations may be sophisticated (and task-specific), we further contend that it is reasonable and practical to require that the target marginal distribution of individual token representations should be highly compressed and structured, amenable for compact coding. Particularly, we require the distribution to be a mixture of low-dimensional (say ) Gaussian distributions, such that the -th Gaussian has mean , covariance , and support spanned by the orthonormal basis . We denote to be the set of bases of all Gaussians. Hence, to maximize the information gain [MTS22] for the final token representations, we wish to maximize their rate reduction (see Section 3.4.2), i.e.,
(4.2.1) |
Here, the first term is an estimate of the lossy coding rate for the whole set of token representations. More specifically, if we view the token representations as i.i.d. samples from a single zero-mean Gaussian, their lossy coding rate subject to a quantization precision is given as
(4.2.2) |
The second term is an estimate of the lossy coding rate under the codebook , which is given as
(4.2.3) |
The expression (4.2.3) for the coding rate can be viewed as a generalization of the coding rate used in the original rate reduction objective (3.4.13). In particular, the original objective is defined with respect to a set of known membership labels specific to the particular data realization . In contrast, the current objective is defined with respect to subspaces , which are independent of any particular realization but are assumed to support the distribution of token representations. Suppose that a token representation belongs to a subspace and these subspaces are approximately orthogonal to each other, i.e., for all . Then, one can verify that the projections and for all . These orthogonal projections effectively serve as implicit membership labels, identifying the subspace to which each token representation belongs.
Note that the rate reduction objective (4.2.1) is invariant to arbitrary joint rotations of the representations and subspaces. In particular, optimizing the rate reduction objective may not naturally lead to axis-aligned (i.e., sparse) representations. For instance, consider the three sets of learned representations in Figure 4.12. The coding rate reduction increases from (a) to (b), but because it is invariant under rotations, remains the same from (b) to (c). Therefore, we would like to transform the representations (and their supporting subspaces) so that the representations eventually become sparse141414Concretely, having few nonzero entries. with respect to the standard coordinates of the resulting representation space as in Figure 4.12(c). Therefore, to ensure the final representations are amenable to more compact coding, we would like to transform the representations (and their supporting subspaces) so that they become sparse with respect to the standard coordinates of the resulting representation space.151515That is, having the fewest nonzero entries. Computationally, we may combine the above two goals into a unified objective for optimization:
(4.2.4) |
where denotes a general function class and the norm promotes the sparsity of the final token representations .
In practice, the norm is often relaxed to the norm to improve computational traceability and enable convex optimization techniques [WM22]. Motivated by this, we relax Problem (4.2.4) accordingly, leading to a formulation that remains faithful to the original sparsity objective while being more amenable to efficient algorithms as follow:
(4.2.5) |
With a slight abuse of terminology, we often refer to this objective function also as the sparse rate reduction.
Although easy to state, each term in the above objective is computationally challenging to optimize [WM22]. Hence it is natural to adopt an approximation approach that realizes the global transformation to optimize (4.2.4) through a concatenation of multiple, say , simple incremental and local operations that push the representation distribution towards the desired parsimonious model distribution:
(4.2.6) |
where is the pre-processing mapping that transforms each input token to the initial token representations . Each incremental forward mapping , or a “layer”, transforms the token distribution to optimize the above sparse rate reduction objective (4.2.4), conditioned on the distribution of its input .
In contrast to other unrolled optimization approaches such as the ReduNet (see Section 4.1), we explicitly model the distribution of at each layer, say as a mixture of linear subspaces or sparsely generated from a dictionary. The model parameters are learned from data (say via backward propagation with end-to-end training). This separation between forward “optimization” and backward “learning” clarifies the mathematical role of each layer as an operator that transforms the distribution of its input, whereas the input distribution is in turn modeled (and subsequently learned) by the parameters of the layer.
Now, we show how to derive these incremental and local operations through an unrolled optimization perspective to solve Problem (4.2.5). Once we decide on using an incremental approach to optimizing Problem (4.2.5), there are a variety of possible choices to achieve the optimization. Given a model for , say a mixture of subspaces , we opt for a two-step alternating minimization method with a strong conceptual basis. First, we compress the tokens via a gradient descent to minimize the coding rate term . Specifically, we take a gradient step on with a learning rate as follows:
(4.2.7) |
Next, we sparsify the compressed tokens, generating via a suitably-relaxed proximal gradient step to minimize the remaining term . As we will argue in detail later, we can find such a by solving a sparse presentation problem with respect to a dictionary :
(4.2.8) |
In the following, we provide technical details for each of the two steps above and derive efficient updates for their implementation.
For the first step (4.2.7), the gradient of the coding rate is costly to compute, as it involves separate matrix inverses, one for each of the subspaces with basis :
(4.2.9) |
Now, we demonstrate that this gradient can be naturally approximated using a so-called multi-head subspace self-attention (MSSA) operator, which has a similar functional form to the multi-head self-attention operator [VSP+17] with heads (i.e., one for each subspace, coming from each matrix inverse). Here, we approximate the gradient (4.2.9) using the first-order Neumann series (see Exercise 4.2):
(4.2.10) |
In this approximation, we compute the similarity between projected token representations through an auto-correlation among the projected features as and convert it to a distribution of membership with a softmax, namely . Suppose that a union of subspaces spans the whole space. Then, we have . Hence, (4.2.1) becomes
(4.2.11) |
where MSSA is defined through an SSA operator as follows:
(4.2.12) | |||
(4.2.13) |
Substituting (4.2.11) into (4.2.7) yields that it can naturally approximated by
(4.2.14) |
The SSA operator in (4.2.12) resembles the attention operator in a typical transformer [VSP+17], except that here the linear operators of value, key, and query are all set to be the same as the subspace basis, i.e., . Hence, we name the Subspace Self-Attention (SSA) operator. Then, the whole MSSA operator in (4.2.13), formally defined as and called the Multi-Head Subspace Self-Attention (MSSA) operator, aggregates the attention head outputs by averaging using model-dependent weights, similar in concept to the popular multi-head self-attention operator in existing transformer networks. The overall gradient step (4.2.14) resembles the multi-head self-attention implemented with a skip connection in transformers.
For the second step of alternating minimization, we need to minimize . Note that the gradient involves a matrix inverse, and thus naive proximal gradient (see Section A.1.3) to optimize this problem becomes intractable on large-scale problems. We therefore take a different, simplifying approach to trading off between representational diversity and sparsification: we posit a (complete) incoherent or orthogonal dictionary , and ask to sparsify the intermediate iterates with respect to . That is, where is more sparse; that is, it is a sparse encoding of . The dictionary is used to sparsify all tokens simultaneously. By the incoherence assumption, we have . Thus from (4.2.2) we have
(4.2.15) |
To solve , we optimize the following problem
The above sparse representation program is usually solved by relaxing it to an unconstrained convex program, known as LASSO [WM22]:
(4.2.16) |
In our implementation, we also add a non-negative constraint to , and solve the corresponding non-negative LASSO:
(4.2.17) |
Then, we incrementally optimize Equation 4.2.17 by performing an unrolled proximal gradient descent step, known as an ISTA step [BT09], to give the update:
(4.2.18) | ||||
(4.2.19) |
We now design a white-box transformer architecture, named the Coding RATE Transformer (crate), by unrolling the above updates. By combining the above two steps (4.2.14) and (4.2.18):
Local compression of tokens within a sample towards a mixture-of-subspace structure, leading to the multi-head subspace self-attention block – MSSA;
Global sparsification of token sets across all samples through sparse coding, leading to the sparsification block – ISTA;
we can get the following rate-reduction-based transformer layer, illustrated in Figure 4.13,
(4.2.20) |
Composing multiple such layers following the incremental construction of our representation in (4.2.6), we obtain a white-box transformer architecture that transforms the data tokens towards a compact and sparse union of incoherent subspaces, where is the pre-processing mapping that transforms the input tokens to first-layer representations . An overall flow of this architecture was shown in Figure 4.14.
In contrast to other unrolled optimization approaches such as the ReduNet [CYY+22], we explicitly model the distribution of each and at each layer, either by a mixture of linear subspaces or sparsely generated from a dictionary. We introduced the interpretation that at each layer , the learned bases for the subspaces and the learned dictionaries together serve as a codebook or analysis filter that encodes and transforms the intermediate representations at each layer . Since the input distribution to layer is first modeled by then transformed by , the input distribution to each layer is different, and so we require a separate codebook at each layer to obtain the most parsimonious encoding. Parameters of these codebooks (i.e., the subspace bases and the dictionaries), heretofore assumed as being perfectly known, are actually learned from data (say via backward propagation within end-to-end training).
Hence, our methodology features a clear conceptual separation between forward “optimization” and backward “learning” for the so-derived white-box deep neural network. Namely, in its forward pass, we interpret each layer as an operator which, conditioned on a learned model (i.e., a codebook) for the distribution of its input, transforms this distribution towards a more parsimonious representation. In its backward propagation, the codebook of this model, for the distribution of the input to each layer, is updated to better fit a certain (supervised) input-output relationship, as illustrated in Figure 4.15. This conceptual interpretation implies a certain agnosticism of the model representations towards the particular task and loss; in particular, many types of tasks and losses will ensure that the models at each layer are fit, which ensures that the model produces parsimonious representations.
We now present the empirical performance of the proposed networks crate by measuring their top-1 classification accuracy on ImageNet-1K as well as transfer learning performance on several widely used downstream datasets. We summarize the results in Table 4.1. The transfer learning methodology is to fine-tune using cross-entropy loss initializing from the pre-trained networks. As the designed white-box transformer architecture leverages parameter sharing in both the attention block (MSSA) and the nonlinearity block (ISTA), the crate-Base model (22.80 million) has a similar number of parameters to the ViT-Small (22.05 million) [DBK+21], and less than 30% of the parameters of an identically configured ViT-Base (86.54 million). From Table 4.1, we find that with a similar number of model parameters, our proposed network achieves similar ImageNet-1K and transfer learning performance as ViT, while having a simple and principled design. Moreover, with the same set of training hyperparameters, we observe promising scaling behavior in crate—we consistently improve the performance by scaling up the model size. To summarize, crate achieves promising performance on real-world large-scale datasets by directly implementing our principled architecture. We will provide more details of the implementation and analysis of the experimental results on image classification in the final application Chapter 7.
Model | crate-T | crate-S | crate-B | crate-L | ViT-T | ViT-S |
---|---|---|---|---|---|---|
# parameters | 6.09M | 13.12M | 22.80M | 77.64M | 5.72M | 22.05M |
ImageNet-1K | 66.7 | 69.2 | 70.8 | 71.3 | 71.5 | 72.4 |
ImageNet-1K ReaL | 74.0 | 76.0 | 76.5 | 77.4 | 78.3 | 78.4 |
CIFAR10 | 95.5 | 96.0 | 96.8 | 97.2 | 96.6 | 97.2 |
CIFAR100 | 78.9 | 81.0 | 82.7 | 83.6 | 81.8 | 83.2 |
Oxford Flowers-102 | 84.6 | 87.1 | 88.7 | 88.3 | 85.1 | 88.5 |
Oxford-IIIT-Pets | 81.4 | 84.9 | 85.3 | 87.4 | 88.5 | 88.6 |
So far, we wish that we have provided compelling evidence that the role of (popular) deep networks is to realize certain optimization algorithms for minimizing the coding rate (or maximizing the information gain) of the learned representations. However, readers who are familiar with optimization methods might have noticed that the above architectures (the ReduNet or the CRATE) correspond to rather basic optimization techniques. They may have plenty of room for improvement in efficiency or effectiveness. Moreover, if we believe the proposed theoretical framework for interpreting deep networks is correct, it should not only help explain existing architectures, it should guide us develop more efficient and effective architectures. In this section, we show this could be the case: the resulting new architectures are not only fully interpretable but also with guaranteed correctness and improved efficiency.
In this subsection, we propose a minimalistic transformer architecture consisting of interpretable layers based on the MSSA operator. To derive a fully interpretable transformer architecture with only necessary components, we contend that the goal of representation learning is to compress a set of noisy initial token representations towards a mixture of low-dimensional subspaces. Here, we assume that the initial token representations are sampled from a mixture of low-rank Gaussians perturbed by noise as follows:
Let be a partition of the index set and denote the orthonormal basis of the -th subspace for each . We say that the token representations are sampled from a mixture of noisy low-rank Gaussian distributions if for each ,
(4.3.1) |
where and for all and , and are respectively mutually independent, and is independent of .
This model serves as an idealized framework for approximating token representations in real-world pretrained LLMs. It assumes that the token representations are sampled from a mixture of multiple low-rank Gaussian distributions with noise. Under this model, the goal of representation learning is to compress a set of noisy initial token presentations into the corresponding subspace. In addition, this model aligns well with two well-established hypotheses about the structure of token representations in pretrained large language models: the “linear representation hypothesis” [JRR+24, PCV24] and the “superposition hypothesis” [EHO+22, YCO+21].
The linear representation hypothesis posits that token representations in LLMs lie in low-dimensional linear subspaces that encode semantic features. Similarly, the superposition hypothesis suggests that these representations can be approximately expressed as a sparse linear combination of these feature vectors. In Definition 4.1, each basis of the subspaces can be interpreted as a set of semantic features, where each feature corresponds to a specific aspect of the token’s meaning. Token representations are then approximately expressed as sparse linear combinations of these subspace bases, capturing the essential semantic components of the token while ignoring irrelevant dimensions.
Now, we show that the MSSA operator (see (4.2.13)) can incrementally denoise token representations generated from the above model. Specifically, we consider for each ,
(4.3.2) |
where is defined in Definition 4.1, is the step size, and is an element-wise operator, such as softmax, ReLU, or other functions. To simplify our development, we assume that the subspaces in Definition 4.1 are orthogonal to each other, i.e., for all . Note that this assumption is not restrictive, as in high-dimensional spaces, random low-dimensional subspaces are incoherent to each other with high probability, i.e., [WM21].161616One may straightforwardly generalize our results to non-orthogonal subspaces, with slightly more sophisticated analysis.
Now, let the columns of denote the token representations from the -th subspace at the -th layer. To quantify the denoising capability, we define the signal-to-noise ratio (SNR) for each block of the token representations at the -th layer as follows:
(4.3.3) |
To simplify our analysis, we assume that , , and
(4.3.4) |
With the above setup, we now characterize the denoising performance of the MSSA operator.
Let be defined in Definition 4.1 and in (4.3.2) be , where is the softmax function and is an element-wise thresholding function with for each . Suppose that , , and
For sufficiently large , it holds with probability at least that for each ,
(4.3.5) |
This theorem demonstrates that when the initial token representations are sampled from a mixture of low-rank Gaussian distributions with a noise level , we show that each layer of the proposed transformer denoises token representations at a linear rate. This indicates the MSSA operator’s efficiency in reducing noise across layers. Notably, our theoretical results are well-supported by experimental observations in Figure 4.16. This theorem provides a theoretical foundation for the practical denoising capability of the transformer architecture derived by unrolling (4.3.2).
Under this model, the goal of representation learning is to compress a set of noisy initial token presentations into the corresponding subspace. However, we should point out that in real-world applications, where token representations exhibit more complicated structures, the goal of representation learning is to find a compact and structured representation by compressing token sets.
Now, we formally propose an attention-only transformer architecture. Specifically, by unrolling the iterative optimization steps (4.3.2) as layers of a deep network, we construct a transformer architecture in Figure 4.17. Each layer of the proposed architecture only consists of the MSSA operator and a skip connection. For language tasks, we additionally incorporate LayerNorm before the MSSA operator to improve performance. The complete architecture is built by stacking such layers, along with essential task-specific pre-processing and post-processing steps, such as positional encoding, token embedding, and final task-specific head to adapt to different applications.
Generally speaking, the standard decoder-only transformer architecture is composed of the following key components [VSP+17]: (1) positional encoding, (2) multi-head QKV self-attention mechanisms, (3) feed-forward MLP networks, (4) layer normalization, and (5) residual connections. In contrast, our proposed transformer architecture adopts a streamlined design by incorporating several key simplications. Specifically, it employs shared-QKV subspace self-attention mechanisms, excludes MLP layers, and reduces the frequency of LayerNorm.
In this subsection, we propose a new transformer attention operator whose computational complexity scales linearly with the number of tokens based on the coding rate reduction objective. Specifically, we derive a novel variational form of the MCR2 objective and show that the architecture that results from unrolled gradient descent of this variational objective leads to a new attention module called Token Statistics Self-Attention (TSSA). TSSA has linear computational and memory complexity and radically departs from the typical attention architecture that computes pairwise similarities between tokens. Recall from (3.4.2) that denotes a stochastic “group assignment” matrix (i.e., and ), where denotes the probability of assigning the -th token to the -th group.
To begin, we consider a general form of MCR2-like objectives based on concave functions of the spectrum of a matrix. Namely, for a given PSD matrix and any scalar we have that , where is the -th largest eigenvalue of . Further, note that is a concave non-decreasing function of . Thus, we describe our results in terms of a more general form of MCR2 based on general spectral functions of PSD matrices of the form , where is concave and non-decreasing. In particular, recall from our above discussion that the attention mechanism arises from unrolling the compression component of MCR2, so we consider a more general MCR2-style compression function:
(4.3.6) |
For the above objective, we now note the following result:
Let be non-decreasing, concave, and obey , and let have the form . Then for each and , we have
(4.3.7) |
Further, the inequality in (4.3.7) is achieved with equality for any which diagonalizes , and if is strictly concave then the inequality in (4.3.7) is achieved with equality if and only if diagonalizes .
Using the above result, we can replace (4.3.6) with an equivalent variational objective with form
(4.3.8) |
where the equivalence is in the sense that for an optimal choice of matrices as described in Theorem 4.2 (i.e., orthogonal matrices which diagonalize each ) we will achieve a tight bound with . Note that in general, achieving this bound would require selecting, for each sampled instance of , a new optimal set of parameter matrices which diagonalize each , which is clearly impractical for network architecture. Instead, as an alternative viewpoint, rather than considering the data () as fixed and trying to optimize the parameters to achieve the tight variational bound, we can instead take the algorithmic unrolling design principle described above and design an operator to perturb to incrementally minimize . To make this point explicit, each variational bound becomes tight when the eigenspaces of align with the columns of , so by rotating the appropriate columns of (namely, those which correspond to large entries in ) to align with we can approach a tight variational bound. That is, instead of rotating to align with the data for each instance of , we can instead rotate the token features in each to align with .
Following this approach, we compute a gradient descent step on w.r.t. . To begin this computation, first let be any element-wise non-negative vector. Then we have
(4.3.9) |
where is the gradient of , and (recall) applies to each element of the vector in the bracket. In particular, for , is simply a non-linear activation. Also, (recall) . Thus, the gradient of w.r.t. is:
(4.3.10) |
(Note that the constant arises from a constant in each term of the sum.) If we now consider a gradient step w.r.t. the -th token , we arrive at our proposed incremental compression operator, i.e., our surrogate for a self attention + residual operator:
(4.3.11) |
for each , where is a step size parameter for the incremental optimization. Then, we can construct a layer of TOST in Figure 4.18.
Given the proposed attention operator in (4.3.11), first recall that the rows of are non-negative and sum to 1 , so our operator takes a weighted average of “attention head”-esque operators and then adds a residual connection. Using that , we can rewrite (4.3.11) as:
(4.3.12) |
That is, we can view each attention head as first projecting the token features onto the basis via multiplying by , multiplying by the diagonal matrix (abbreviated as ), projecting back into the standard basis via multiplying by , and subtracting this from the original token features via the residual connection. The core aspect of our attention layer is the computation of . Namely, , so forms a probability distribution over which tokens belong to the group. As a result, estimates the second moment of under the distribution given by . Further, since is a concave non-decreasing function, monotonically decreases towards as increases, so the entries of (which have form ) achieve their maximum at and decay monotonically to as increases.
From this, we arrive at the core interpretation of our attention head + residual operators . Namely, this operator does an approximate low-rank data-dependent projection, where directions which have a large amount of “power” after the projection (i.e., directions which have a large second moment ) are preserved, while directions which do not are suppressed. To see this, recall that the entries of decrease monotonically to 0 as the second moment increases, so for directions with large second moments the attention + residual operator acts largely as the identity operator. Conversely, for directions with a small second moment, our operator subtracts a projection of the tokens along those directions, resulting in those directions being suppressed. Compared to the standard self-attention operator, our method clearly does not compute any pairwise similarities between tokens. Rather, the interactions between the tokens in impact the operator solely through their contribution to the second moment statistic used to construct the ’s. Nevertheless, similar to the standard self-attention operator, our method still has a clear interpretation as performing a form of compression towards a data-dependent low-rank structure, in the sense that it performs an approximate low-rank projection, where the specific directions that are suppressed are those which are not strongly aligned with other tokens in the group.
Having introduced our proposed attention operator, we now discuss further practical considerations. First, until this point in the presentation, we have avoided discussion of how tokens are “grouped” into various attention heads via the matrix, but clearly a means of constructing is needed to implement our method. Additionally, our variational form in Theorem 4.2 requires the matrices to be square and orthogonal, but one would ideally like to use smaller matrices (i.e., reduce the number of columns in ) for efficiency as well as drop the orthogonality constraints.
In practice, we do not enforce the orthogonality constraints. To reduce the number of columns in the matrices, we note that similar to CRATE [YBP+23], if we assume the features within group are (approximately) clustered around a low-dimensional subspace — say of dimension — then the within-group- covariance is low-rank, where recall that [YCY+20] shows that the optimal geometry of will be for each group to be a low-rank subspace, orthogonal to the other groups. We can thus explicitly find a low-dimensional orthonormal basis for the image of this covariance, i.e., the linear span of the data in group . If the dimension is , the basis can be represented by a orthogonal matrix . In this case, we can more efficiently upper-bound using these low-rank orthogonal basis matrices. To show this, we use a more general version of Theorem 4.2 to yield the following corollary.
Let be non-decreasing, concave, and obey , and let have the form . Let , be fixed. Then, for all such that for all , we have
(4.3.13) |
where is formally defined in (4.3.8). Equality holds if diagonalizes for each , and if is strongly concave then this equality condition becomes an “if and only if.”
The final step to define our attention operator is to estimate the group membership . For this we posit a simple model of how each feature deviates from its supporting subspace and then find the optimal subspace assignment. [YBP+23] show that if we independently model each as belonging to a low-dimensional Gaussian mixture model, where each Gaussian has a covariance matrix with identical spectrum and is supported on a subspace with orthonormal basis , plus independent Gaussian noise with covariance , then the posterior probability that each token belongs to each subspace is given by the assignment matrix as follows:
(4.3.14) |
where becomes a learnable temperature parameter. Thus, given an input feature , we estimate using (4.3.14) and then compute the attention operator. Combining the construction of in (4.3.14) with (4.3.11), we obtain the Token Statistics Self-Attention operator:
(4.3.15) |
where are the columns of defined in (4.3.14) and is defined in (4.3.10).
The materials presented in this chapter are based on a series of recent works on this topic, including [CYY+22, WLP+24, WLY+25, WDL+25, YBP+23]. These contributions encompass both theoretical advances and practical methodologies for constructing interpretable deep networks through unrolled optimization. Many of the key results and proofs discussed in this chapter are derived directly from, or inspired by, these foundational works.
The idea of unrolling an optimization algorithm to construct a neural network traces back to the seminal work [GL10]. In this work, the authors demonstrated that sparse coding algorithms—such as the Iterative Shrinkage-Thresholding Algorithm (ISTA)—can be unrolled to form multilayer perceptrons (MLPs), effectively bridging iterative optimization and neural network design. Notably, [MLE19] demonstrated that such unrolled networks are more interpretable, parameter-efficient, and effective compared to generic networks. In this chapter, we build on this perspective to develop principled, white-box deep network architectures by unrolling optimization algorithms that are designed to minimize well-motivated objectives—such as the (sparse) rate reduction objective introduced earlier. This approach not only clarifies the role of each layer in the network but also offers theoretical grounding for architectural choices, moving beyond empirical trial-and-error toward interpretable and goal-driven design. In the following, we compare conventional DNNs, which are typically constructed through empirical design and heuristic tuning, with our mathematically grounded ReduNet architectures:
Conventional DNNs | ReduNets | |
---|---|---|
Objectives | input/output fitting | information gain |
Deep architectures | trial & error | iterative optimization |
Layer operators | empirical | projected gradient |
Shift invariance | CNNs+augmentation | invariant ReduNets |
Initializations | random/pre-design | forward unrolled |
Training/fine-tuning | back prop | forward/back prop |
Interpretability | black box | white box |
Representations | hidden/latent | incoherent subspaces |
Let with for each . For some , let
1. Given any direction , please show that and
where . Hint: Note that
2. Please show that
where the equality holds if and only if for all .
3. Given some , let for each . Please derive the closed-form for the first-order critical point of the following function:
Hint: Let . Consider the following singular value decomposition of :
where with and , with being a diagonal matrix, and with and .
Let . If , please show
(4.5.1) |
Hint: The proof consists of two steps.
(i) Step 1: Please show that the infinite series converges when using .
(ii) Step 2: Compute the matrix product .
Please show Corollary 4.1 when .