Chapter 4 Deep Representations as Unrolled Optimization

What I cannot create, I do not understand.”

  — Richard Feynman

In previous chapters, we have shown how to identify low-dimensional structures in high-dimensional spaces, mainly focusing on linear structures. For example, we introduced principal component analysis (PCA) to learn the linear denoiser 𝑼^\hat{\bm{U}}^{\top}over^ start_ARG bold_italic_U end_ARG start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT when the observed data 𝒙\bm{x}bold_italic_x follow the statistical model 𝒙=𝑼𝒛+𝜺\bm{x}=\bm{U}\bm{z}+\bm{\varepsilon}bold_italic_x = bold_italic_U bold_italic_z + bold_italic_ε. In this setting, the learned representations are linearly transformed input data 𝑼^𝒙\hat{\bm{U}}^{\top}\bm{x}over^ start_ARG bold_italic_U end_ARG start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_x. Under the linear model assumption, one can learn the low-dimensional linear structure with efficient optimization algorithms and strong theoretical guarantees. Moreover, the linear model assumption covers a wide range of applications and problems, including face recognition, magnetic resonance image recovery, and structure texture recovery [WM22].

On the other hand, the linear model can be limited when dealing with real-world applications, especially when the input data 𝒙\bm{x}bold_italic_x is complex, such as speech and natural languages, images and videos, and robotic motions. The low-dimensional distributions of such data are typically nonlinear. How to deal with nonlinearity has a long history across different disciplines such as control theory, signal processing, and pattern recognition. There have been considerable efforts that try to extend methods and solutions for linear models to handle nonlinearity, including early effort to extend PCA to nonlinear PCA (as we will study in more detail in Chapter 5). In most cases, the methods are designed based on certain assumptions about the data distributions and tailored to specific problems.

More recently, deep neural networks have achieved remarkable success across a wide range of data and applications. A neural network

f(,𝜽):𝒙f0𝒛0𝒛f𝒛+1𝒛L=𝒛.f(\cdot,\bm{\theta})\colon\bm{x}\xrightarrow{\hskip 2.84526ptf^{0}\hskip 2.84526pt}\bm{z}^{0}\rightarrow\cdots\rightarrow\bm{z}^{\ell}\xrightarrow{\hskip 2.84526ptf^{\ell}\hskip 2.84526pt}\bm{z}^{\ell+1}\rightarrow\cdots\to\bm{z}^{L}=\bm{z}.italic_f ( ⋅ , bold_italic_θ ) : bold_italic_x start_ARROW start_OVERACCENT italic_f start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT end_OVERACCENT → end_ARROW bold_italic_z start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT → ⋯ → bold_italic_z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT start_ARROW start_OVERACCENT italic_f start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT end_OVERACCENT → end_ARROW bold_italic_z start_POSTSUPERSCRIPT roman_ℓ + 1 end_POSTSUPERSCRIPT → ⋯ → bold_italic_z start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT = bold_italic_z . (4.0.1)

can learn effective features/representations for downstream applications. For example, a trained deep neural network f(,𝜽)f(\cdot,\bm{\theta})italic_f ( ⋅ , bold_italic_θ ) can be applied to map images to feature vectors, that is, 𝒛i=f(𝒙i,𝜽)\bm{z}_{i}=f(\bm{x}_{i},\bm{\theta})bold_italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = italic_f ( bold_italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , bold_italic_θ ), while a linear classifier can be learned on top of such representations {𝒛i}\{\bm{z}_{i}\}{ bold_italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT }. One notable breakthrough is AlexNet [KSH12], a deep convolutional neural network trained with more than a million natural images, outperforming all previous approaches that were based on hand-crafted features. One of the key differences between AlexNet and previous approaches is that the former learns parameters of the nonlinear transformation from massive amounts of data trained with back-propagation (BP) [RHW86a], as detailed in Section A.2.3 of Appendix A.

Subsequent popular practice models the mapping ffitalic_f with other empirically designed artificial deep neural networks and learns the parameters 𝜽\bm{\theta}bold_italic_θ from random initialization via BP. Starting with the AlexNet [KSH12], the architectures of modern deep networks continue to be empirically revised and improved. Network architectures such as VGG [SZ14], ResNet [HZR+16a], DenseNet [HLV+17], CNN, RNN or LSTM [HS97], Transformer [VSP+17], and a mixture of experts (MoE) [SMM+17, FZS22], etc. have continued to push the performance envelope. As part of the effort to improve the performance of deep networks, almost every component of the networks has been empirically scrutinized, and various revisions and improvements have been proposed. They are not limited to nonlinear activation functions [MHN13, KUM+17, XWC+15, NIG+18], skip connections [RFB15, HZR+16a], normalizations [IS15, BKH16, UVL16, WH18, MKK+18], up/down sampling or pooling [SMB10], convolutions [LBB+98a, KSH12], etc. However, almost all such modifications have been developed through years of empirical trial and error or ablation studies. Some recent practices even take to the extreme by searching for effective network structures and training strategies through extensive random search techniques, such as Neural Architecture Search [ZL17, BGN+17], AutoML [HKV19], and Learning to Learn [ADG+16].

Despite the wide application of deep neural networks, it is not clear what the underlying design principles of such a constructed network are. In particular, it is not clear what mathematical function each layer of the network performs. In this chapter, based on the results from previous chapters, we develop a principled framework that will provide a fully rigorous mathematical interpretation of the role of a deep network, including its individual layers and the network as a whole.

To understand deep networks and how they should be better designed, we must start with the objective of representation learning. In previous chapters, we have argued that the objective is to identify the intrinsically low-dimensional data distribution and then transform it to a compact and structured (say piecewise linear) representation. As we have seen in the previous chapter, the general approach to identifying a low-dimensional data distribution is through a compression process that progressively minimizes the entropy or coding rate of the distribution. However, up to this point, we have been using empirically designed deep networks to model or approximate the operations that aim to optimize these objectives, such as the score function for denoising (in Section 1.3.1) or the transformation that maximizes the rate reduction (in Section 3.4.3).

As we have argued in the previous chapter, Section 3.4 in particular, one can measure the goodness of the resulting representation by the information of the representation gained from a “lazy” representation which models all data as one big Gaussian.111that we have seen in the previous chapter as one particular choice of interpretation of the sampled dataset. In particular, if we use a mixture of Gaussians (subspaces)222which we have studied thoroughly in the previous chapter. as prototypical distributions to approximate the non-linear distribution of interest, then we can efficiently measure the coding rate of such a representation using the (sum of) rate distortion functions of the associated Gaussians. Then the amount of information gained or (relative) entropy reduced with such a modeling can be measured by the difference between the coding rate for the lazy representation and that for the more refined representation. Then, the objective of representation learning is to maximize this information gain, also known as the rate reduction objective.

As we will see in this chapter, once the objective of representation learning is clear, the role of a deep neural network is precisely to help optimize the objective iteratively. Each layer of a deep neural network can be naturally derived as an iterative optimization step to incrementally maximize the information gain, including the popular architectures of ResNet, CNN, and Transformer, and other more advanced variants. In particular, this chapter aims to answer the following questions about deep networks:

  • Section 4.1 — given a measure of goodness for a learned representation, how to construct the nonlinear mapping from the data to the optimal representation via unrolled optimization for the objective?

  • Section 4.2 — how would the above unrolling approach provide a principled interpretation of the popular transformer architectures; if so, what are the associated objective and optimization mechanisms?

  • Section 4.3 — how would this framework guide us to design more efficient or more parsimonious deep architectures?

4.1 White-Box Deep Networks via Unrolled Optimization

Now, if we agree that maximizing the rate reduction or information gain leads to the desired representation as discussed in Section 3.4, the remaining question is how to construct and learn a (nonlinear) mapping from the data 𝑿\bm{X}bold_italic_X to the optimal representation 𝒁\bm{Z}^{*}bold_italic_Z start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT. This involves designing a network architecture and learning algorithm that can effectively capture the underlying structures in the data and faithfully realize the optimal representation.

4.1.1 Deep Networks from Unrolled Gradient Descent

In the previous chapter, we presented the rate reduction objective (3.4.12) as a principled objective for learning linear discriminative representations of the data. We have, however, not specified the architecture of the feature mapping 𝒛=f(𝒙,𝜽)\bm{z}=f(\bm{x},\bm{\theta})bold_italic_z = italic_f ( bold_italic_x , bold_italic_θ ) for extracting such representations from input data 𝒙\bm{x}bold_italic_x. A straightforward choice is to use a conventional deep network, such as ResNet, for implementing f(𝒙,𝜽)f(\bm{x},\bm{\theta})italic_f ( bold_italic_x , bold_italic_θ ). As we have seen in Example Figure 3.24, such a choice often leads to decent performance empirically. Nonetheless, there remain several unanswered problems with adopting an arbitrary deep network. Although the learned feature representation is now more interpretable, the network itself is still not. It is unclear why any chosen “black-box” network is able to optimize the desired MCR2 objective at all. The good empirical results (say with a ResNet) do not necessarily justify the particular choice in architectures and operators of the network: Why is a deep layered model even necessary; what do additional layers try to improve or simplify; how wide and deep is adequate; or is there any rigorous justification for the convolutions (in a popular multi-channel form) and nonlinear operators (e.g. ReLU or softmax) used?

In this chapter, we show that using gradient ascent to maximize the rate reduction ΔRϵ(𝒁𝚷)\Delta R_{\epsilon}(\bm{Z}\mid\bm{\Pi})roman_Δ italic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( bold_italic_Z ∣ bold_Π ) as defined in (3.4.12) naturally leads to a “white-box” deep network that realizes the desired mapping. All network layers, linear/nonlinear operators, and parameters are explicitly constructed in a purely forward propagation fashion. Moreover, such network architectures resemble existing empirically-designed deep networks, providing principled justifications for their design.

Gradient Ascent for Coding Rate Reduction.

From the previous chapter, we see that to seek a linear discriminative representation (LDR), mathematically, we are essentially seeking a continuous mapping f():𝒙𝒛f(\cdot):\bm{x}\mapsto\bm{z}italic_f ( ⋅ ) : bold_italic_x ↦ bold_italic_z from the data 𝑿=[𝒙1,,𝒙N]D×N\bm{X}=[\bm{x}_{1},\ldots,\bm{x}_{N}]\in\mathbb{R}^{D\times N}bold_italic_X = [ bold_italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , bold_italic_x start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT ] ∈ blackboard_R start_POSTSUPERSCRIPT italic_D × italic_N end_POSTSUPERSCRIPT (or initial features extracted from the data333As we will see the necessity of such a feature extraction in the next section.) to an optimal representation 𝒁=[𝒛1,,𝒛N]d×N\bm{Z}=[\bm{z}_{1},\ldots,\bm{z}_{N}]\in\mathbb{R}^{d\times N}bold_italic_Z = [ bold_italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , bold_italic_z start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT ] ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_N end_POSTSUPERSCRIPT that maximizes the following coding rate reduction objective:

ΔRϵ(𝒁𝚷)12logdet(𝑰+α𝒁𝒁)Rϵ(𝒁)k=1Kγk2logdet(𝑰+αk𝒁𝚷k𝒁)Rϵc(𝒁𝚷),\begin{split}\Delta R_{\epsilon}(\bm{Z}\mid\bm{\Pi})\doteq\underbrace{\frac{1}{2}\log\det\Big{(}\bm{I}+{\alpha}\bm{Z}\bm{Z}^{\top}\Big{)}}_{R_{\epsilon}(\bm{Z})}\;-\;\underbrace{\sum_{k=1}^{K}\frac{\gamma_{k}}{2}\log\det\Big{(}\bm{I}+{\alpha_{k}}\bm{Z}\bm{\Pi}_{k}\bm{Z}^{\top}\Big{)}}_{R_{\epsilon}^{c}(\bm{Z}\mid\bm{\Pi})},\end{split}start_ROW start_CELL roman_Δ italic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( bold_italic_Z ∣ bold_Π ) ≐ under⏟ start_ARG divide start_ARG 1 end_ARG start_ARG 2 end_ARG roman_log roman_det ( bold_italic_I + italic_α bold_italic_Z bold_italic_Z start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) end_ARG start_POSTSUBSCRIPT italic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( bold_italic_Z ) end_POSTSUBSCRIPT - under⏟ start_ARG ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT divide start_ARG italic_γ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_ARG start_ARG 2 end_ARG roman_log roman_det ( bold_italic_I + italic_α start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT bold_italic_Z bold_Π start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT bold_italic_Z start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) end_ARG start_POSTSUBSCRIPT italic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT ( bold_italic_Z ∣ bold_Π ) end_POSTSUBSCRIPT , end_CELL end_ROW (4.1.1)

where ϵ>0\epsilon>0italic_ϵ > 0 is a prescribed quantization error and for simplicity we denote444Notice our use of slightly simplified notation compared to Chapter 3.

αdNϵ2,αkdtr(𝚷k)ϵ2,γktr(𝚷k)N,fork=1,,K.\displaystyle\alpha\doteq\frac{d}{N\epsilon^{2}},\qquad\alpha_{k}\doteq\frac{d}{\mathrm{tr}(\bm{\Pi}_{k})\epsilon^{2}},\qquad\gamma_{k}\doteq\frac{\mathrm{tr}(\bm{\Pi}_{k})}{N},\qquad\text{for}\ k=1,\ldots,K.italic_α ≐ divide start_ARG italic_d end_ARG start_ARG italic_N italic_ϵ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG , italic_α start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ≐ divide start_ARG italic_d end_ARG start_ARG roman_tr ( bold_Π start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) italic_ϵ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG , italic_γ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ≐ divide start_ARG roman_tr ( bold_Π start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) end_ARG start_ARG italic_N end_ARG , for italic_k = 1 , … , italic_K .

The question really boils down to whether there is a constructive way of finding such a continuous mapping f(,𝜽)f(\cdot,\bm{\theta})italic_f ( ⋅ , bold_italic_θ ) from 𝒙\bm{x}bold_italic_x to 𝒛\bm{z}bold_italic_z? To this end, let us consider incrementally maximizing the objective ΔRϵ\Delta R_{\epsilon}roman_Δ italic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT as a function of 𝒁𝕊d1\bm{Z}\subseteq\mathbb{S}^{d-1}bold_italic_Z ⊆ blackboard_S start_POSTSUPERSCRIPT italic_d - 1 end_POSTSUPERSCRIPT. Although there might be many optimization schemes to choose from, for simplicity we first consider the arguably simplest projected gradient ascent (PGA) scheme:555Notice that we use subscript jjitalic_j on 𝒁j\bm{Z}_{j}bold_italic_Z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT to indicate features in the jjitalic_j-th class and superscript \ellroman_ℓ on 𝒁\bm{Z}^{\ell}bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT to indicate all features at \ellroman_ℓ-th iteration or layer.

𝒁+1𝒁+ηΔRϵ𝒁(𝒁)s.t.𝒁+1𝕊d1,=1,2,,\bm{Z}^{\ell+1}\;\propto\;\bm{Z}^{\ell}+\eta\cdot\frac{\partial\Delta R_{\epsilon}}{\partial\bm{Z}}(\bm{Z}^{\ell})\quad\mbox{s.t.}\quad\bm{Z}^{\ell+1}\subseteq\mathbb{S}^{d-1},\quad\ell=1,2,\ldots,bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ + 1 end_POSTSUPERSCRIPT ∝ bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT + italic_η ⋅ divide start_ARG ∂ roman_Δ italic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT end_ARG start_ARG ∂ bold_italic_Z end_ARG ( bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ) s.t. bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ + 1 end_POSTSUPERSCRIPT ⊆ blackboard_S start_POSTSUPERSCRIPT italic_d - 1 end_POSTSUPERSCRIPT , roman_ℓ = 1 , 2 , … , (4.1.2)

for some step size η>0\eta>0italic_η > 0 and the iterate starts with the given data 𝒁0=𝑿\bm{Z}^{0}=\bm{X}bold_italic_Z start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT = bold_italic_X.666Again, for simplicity, we here first assume the initial features 𝒁1\bm{Z}^{1}bold_italic_Z start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT are the data themselves. Note that here \ellroman_ℓ denotes the number of iterations. Hence, the data and the features have the same dimension dditalic_d. This needs not to be the case though. As we will see in the next section, the initial features can be some (lifted) features of the data to begin with and could in principle have a different (much higher) dimension. All subsequent iterates have the same dimension. This scheme can be interpreted as how one should incrementally adjust locations of the current features 𝒁\bm{Z}^{\ell}bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT, initialized as the input data 𝑿\bm{X}bold_italic_X, in order for the resulting 𝒁+1\bm{Z}^{\ell+1}bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ + 1 end_POSTSUPERSCRIPT to improve the rate reduction ΔRϵ\Delta R_{\epsilon}roman_Δ italic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT, as illustrated in Figure 4.1.

Figure 4.1 : Incremental deformation via gradient flow to both flatten data of each class into a subspace and push different classes apart.
Figure 4.1: Incremental deformation via gradient flow to both flatten data of each class into a subspace and push different classes apart.

Simple calculation shows that the gradient ΔRϵ/𝒁{\partial\Delta R_{\epsilon}}/{\partial\bm{Z}}∂ roman_Δ italic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT / ∂ bold_italic_Z entails evaluating the following derivatives of the two terms in ΔRϵ\Delta R_{\epsilon}roman_Δ italic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT:

12logdet(𝑰+α𝒁𝒁)𝒁(𝒁)=α(𝑰+α𝒁(𝒁))1𝑬d×d𝒁,\frac{1}{2}\frac{\partial\log\det(\bm{I}\!+\!\alpha\bm{Z}\bm{Z}^{\top})}{\partial\bm{Z}}(\bm{Z}^{\ell})=\underbrace{\alpha(\bm{I}\!+\!\alpha\bm{Z}^{\ell}(\bm{Z}^{\ell})^{\top})^{-1}}_{\bm{E}^{\ell}\;\in\mathbb{R}^{d\times d}}\bm{Z}^{\ell},divide start_ARG 1 end_ARG start_ARG 2 end_ARG divide start_ARG ∂ roman_log roman_det ( bold_italic_I + italic_α bold_italic_Z bold_italic_Z start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) end_ARG start_ARG ∂ bold_italic_Z end_ARG ( bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ) = under⏟ start_ARG italic_α ( bold_italic_I + italic_α bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ( bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT end_ARG start_POSTSUBSCRIPT bold_italic_E start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_d end_POSTSUPERSCRIPT end_POSTSUBSCRIPT bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT , (4.1.3)
12(γklogdet(𝑰+αk𝒁𝚷k𝒁))𝒁(𝒁)=γkαk(𝑰+αk𝒁𝚷k(𝒁))1𝑪kd×d𝒁𝚷k.\frac{1}{2}\frac{\partial\left(\gamma_{k}\log\det(\bm{I}+\alpha_{k}\bm{Z}\bm{\Pi}_{k}\bm{Z}^{\top})\right)}{\partial\bm{Z}}(\bm{Z}^{\ell})=\gamma_{k}\underbrace{\alpha_{k}(\bm{I}+\alpha_{k}\bm{Z}^{\ell}\bm{\Pi}_{k}(\bm{Z}^{\ell})^{\top})^{-1}}_{\bm{C}^{\ell}_{k}\;\in\mathbb{R}^{d\times d}}\bm{Z}^{\ell}\bm{\Pi}_{k}.divide start_ARG 1 end_ARG start_ARG 2 end_ARG divide start_ARG ∂ ( italic_γ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT roman_log roman_det ( bold_italic_I + italic_α start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT bold_italic_Z bold_Π start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT bold_italic_Z start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) ) end_ARG start_ARG ∂ bold_italic_Z end_ARG ( bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ) = italic_γ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT under⏟ start_ARG italic_α start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( bold_italic_I + italic_α start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT bold_Π start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT end_ARG start_POSTSUBSCRIPT bold_italic_C start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_d end_POSTSUPERSCRIPT end_POSTSUBSCRIPT bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT bold_Π start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT . (4.1.4)

Notice that in the above, the matrix 𝑬\bm{E}^{\ell}bold_italic_E start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT only depends on 𝒁\bm{Z}^{\ell}bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT and it aims to expand all the features to increase the overall coding rate; the matrix 𝑪k\bm{C}^{\ell}_{k}bold_italic_C start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT depends on features from the kkitalic_k-class and aims to compress them to reduce the coding rate of each class. Then the complete gradient ΔRϵ𝒁(𝒁)d×N\frac{\partial\Delta R_{\epsilon}}{\partial\bm{Z}}(\bm{Z}^{\ell})\in\mathbb{R}^{d\times N}divide start_ARG ∂ roman_Δ italic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT end_ARG start_ARG ∂ bold_italic_Z end_ARG ( bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ) ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_N end_POSTSUPERSCRIPT is of the form:

ΔRϵ𝒁(𝒁)=𝑬Expansion𝒁k=1Kγk𝑪kCompression𝒁𝚷k.\frac{\partial\Delta R_{\epsilon}}{\partial\bm{Z}}(\bm{Z}^{\ell})=\underbrace{\bm{E}^{\ell}}_{\text{Expansion}}\bm{Z}^{\ell}\;-\;\sum_{k=1}^{K}\gamma_{k}\underbrace{\bm{C}_{k}^{\ell}}_{\text{Compression}}\bm{Z}^{\ell}\bm{\Pi}_{k}.divide start_ARG ∂ roman_Δ italic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT end_ARG start_ARG ∂ bold_italic_Z end_ARG ( bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ) = under⏟ start_ARG bold_italic_E start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT end_ARG start_POSTSUBSCRIPT Expansion end_POSTSUBSCRIPT bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT - ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT italic_γ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT under⏟ start_ARG bold_italic_C start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT end_ARG start_POSTSUBSCRIPT Compression end_POSTSUBSCRIPT bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT bold_Π start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT . (4.1.5)
Remark 4.1 (Interpretation of 𝑬\bm{E}^{\ell}bold_italic_E start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT and 𝑪j\bm{C}_{j}^{\ell}bold_italic_C start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT as linear operators).

For any 𝒛d\bm{z}^{\ell}\in\mathbb{R}^{d}bold_italic_z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT,

𝑬𝒛=α(𝒛𝒁𝒒),where𝒒argmin𝒒{α𝒛𝒁𝒒22+𝒒22}.\displaystyle\bm{E}^{\ell}\bm{z}^{\ell}=\alpha(\bm{z}^{\ell}-\bm{Z}^{\ell}\bm{q}^{\ell}_{\star}),\qquad\mbox{where}\qquad\bm{q}^{\ell}_{\star}\doteq\operatorname*{arg\ min}_{\bm{q}^{\ell}}\big{\{}\alpha\|\bm{z}^{\ell}-\bm{Z}^{\ell}\bm{q}^{\ell}\|_{2}^{2}+\|\bm{q}^{\ell}\|_{2}^{2}\big{\}}.bold_italic_E start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT bold_italic_z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT = italic_α ( bold_italic_z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT - bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT bold_italic_q start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT ⋆ end_POSTSUBSCRIPT ) , where bold_italic_q start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT ⋆ end_POSTSUBSCRIPT ≐ start_OPERATOR roman_arg roman_min end_OPERATOR start_POSTSUBSCRIPT bold_italic_q start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT { italic_α ∥ bold_italic_z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT - bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT bold_italic_q start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + ∥ bold_italic_q start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT } . (4.1.6)

Notice that 𝒒\bm{q}^{\ell}_{\star}bold_italic_q start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT ⋆ end_POSTSUBSCRIPT is exactly the solution to the ridge regression by all the data points 𝒁\bm{Z}^{\ell}bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT concerned. Therefore, 𝑬\bm{E}^{\ell}bold_italic_E start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT (similarly for 𝑪k\bm{C}^{\ell}_{k}bold_italic_C start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT) is approximately (i.e., when NNitalic_N is large enough) the projection onto the orthogonal complement of the subspace spanned by columns of 𝒁\bm{Z}^{\ell}bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT. Another way to interpret the matrix 𝑬\bm{E}^{\ell}bold_italic_E start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT is through eigenvalue decomposition of the covariance matrix 𝒁(𝒁)\bm{Z}^{\ell}(\bm{Z}^{\ell})^{\top}bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ( bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT. Assuming that 𝒁(𝒁)𝑼𝚲(𝑼)\bm{Z}^{\ell}(\bm{Z}^{\ell})^{\top}\doteq\bm{U}^{\ell}\bm{\Lambda}^{\ell}(\bm{U}^{\ell})^{\top}bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ( bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ≐ bold_italic_U start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT bold_Λ start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ( bold_italic_U start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT where 𝚲diag(λ1,,λd)\bm{\Lambda}^{\ell}\doteq\operatorname{diag}\left(\lambda^{\ell}_{1},\ldots,\lambda^{\ell}_{d}\right)bold_Λ start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ≐ roman_diag ( italic_λ start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_λ start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT ), we have

𝑬=α𝑼diag(11+αλ1,,11+αλd)(𝑼).\bm{E}^{\ell}=\alpha\bm{U}^{\ell}\,\operatorname{diag}\left(\frac{1}{1+\alpha\lambda^{\ell}_{1}},\ldots,\frac{1}{1+\alpha\lambda^{\ell}_{d}}\right)\left(\bm{U}^{\ell}\right)^{\top}.bold_italic_E start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT = italic_α bold_italic_U start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT roman_diag ( divide start_ARG 1 end_ARG start_ARG 1 + italic_α italic_λ start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_ARG , … , divide start_ARG 1 end_ARG start_ARG 1 + italic_α italic_λ start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT end_ARG ) ( bold_italic_U start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT . (4.1.7)

Therefore, the matrix 𝑬\bm{E}^{\ell}bold_italic_E start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT operates on a vector 𝒛\bm{z}^{\ell}bold_italic_z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT by stretching in a way that directions of large variance are shrunk while directions of vanishing variance are kept. These are exactly the directions (4.1.3) in which we move the features so that the overall volume expands and the coding rate will increase, hence the positive sign. To the opposite effect, the directions associated with (4.1.4) are “residuals” of features of each class that deviate from the subspace to which they are supposed to belong. These are exactly the directions in which the features need to be compressed back onto their respective subspace, hence the negative sign (see Figure 4.2).

Figure 4.2 : Interpretation of 𝑪 k ℓ \bm{C}^{\ell}_{k} bold_italic_C start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT and 𝑬 ℓ \bm{E}^{\ell} bold_italic_E start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT : 𝑪 k ℓ \bm{C}^{\ell}_{k} bold_italic_C start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT compresses each class by contracting the features to a low-dimensional subspace; 𝑬 ℓ \bm{E}^{\ell} bold_italic_E start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT expands all features by contrasting and repelling features across different classes.
Figure 4.2: Interpretation of 𝑪k\bm{C}^{\ell}_{k}bold_italic_C start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT and 𝑬\bm{E}^{\ell}bold_italic_E start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT: 𝑪k\bm{C}^{\ell}_{k}bold_italic_C start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT compresses each class by contracting the features to a low-dimensional subspace; 𝑬\bm{E}^{\ell}bold_italic_E start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT expands all features by contrasting and repelling features across different classes.

Essentially, the linear operations 𝑬\bm{E}^{\ell}bold_italic_E start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT and 𝑪k\bm{C}_{k}^{\ell}bold_italic_C start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT in gradient ascent for rate reduction are determined by training data conducting “auto-regressions”. The recent renewed understanding about ridge regression in an over-parameterized setting [YYY+20, WX20] indicates that using seemingly redundantly sampled data (from each subspace) as regressors does not lead to overfitting.

Gradient-Guided Feature Map Increment.

Notice that in the above, the gradient ascent considers all the features 𝒁=[𝒛1,,𝒛N]\bm{Z}^{\ell}=[\bm{z}^{\ell}_{1},\dots,\bm{z}^{\ell}_{N}]bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT = [ bold_italic_z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , bold_italic_z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT ] as free variables. The increment 𝒁+1𝒁=ηΔRϵ𝒁(𝒁)\bm{Z}^{\ell+1}-\bm{Z}^{\ell}=\eta\frac{\partial\Delta R_{\epsilon}}{\partial\bm{Z}}(\bm{Z}^{\ell})bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ + 1 end_POSTSUPERSCRIPT - bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT = italic_η divide start_ARG ∂ roman_Δ italic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT end_ARG start_ARG ∂ bold_italic_Z end_ARG ( bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ) does not yet give a transformation on the entire feature domain 𝒛d\bm{z}^{\ell}\in\mathbb{R}^{d}bold_italic_z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT. According to equation (4.1.5), the gradient cannot be evaluated at a point whose membership is not known, as illustrated in Figure 4.1. Hence, in order to find the optimal f(𝒙,𝜽)f(\bm{x},\bm{\theta})italic_f ( bold_italic_x , bold_italic_θ ) explicitly, we may consider constructing a small increment transform g(,𝜽)g(\cdot,\bm{\theta}^{\ell})italic_g ( ⋅ , bold_italic_θ start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ) on the \ellroman_ℓ-th layer feature 𝒛\bm{z}^{\ell}bold_italic_z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT to emulate the above (projected) gradient scheme:

𝒛+1𝒛+ηg(𝒛,𝜽)subject to𝒛+1𝕊d1\bm{z}^{\ell+1}\;\propto\;\bm{z}^{\ell}+\eta\cdot g(\bm{z}^{\ell},\bm{\theta}^{\ell})\quad\mbox{subject to}\quad\bm{z}^{\ell+1}\in\mathbb{S}^{d-1}bold_italic_z start_POSTSUPERSCRIPT roman_ℓ + 1 end_POSTSUPERSCRIPT ∝ bold_italic_z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT + italic_η ⋅ italic_g ( bold_italic_z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT , bold_italic_θ start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ) subject to bold_italic_z start_POSTSUPERSCRIPT roman_ℓ + 1 end_POSTSUPERSCRIPT ∈ blackboard_S start_POSTSUPERSCRIPT italic_d - 1 end_POSTSUPERSCRIPT (4.1.8)

such that [g(𝒛1,𝜽),,g(𝒛N,𝜽)]ΔRϵ𝒁(𝒁).\big{[}g(\bm{z}_{1}^{\ell},\bm{\theta}^{\ell}),\ldots,g(\bm{z}_{N}^{\ell},\bm{\theta}^{\ell})\big{]}\approx\frac{\partial\Delta R_{\epsilon}}{\partial\bm{Z}}(\bm{Z}^{\ell}).[ italic_g ( bold_italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT , bold_italic_θ start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ) , … , italic_g ( bold_italic_z start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT , bold_italic_θ start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ) ] ≈ divide start_ARG ∂ roman_Δ italic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT end_ARG start_ARG ∂ bold_italic_Z end_ARG ( bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ) . That is, we need to approximate the gradient flow ΔRϵ𝒁\frac{\partial\Delta R_{\epsilon}}{\partial\bm{Z}}divide start_ARG ∂ roman_Δ italic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT end_ARG start_ARG ∂ bold_italic_Z end_ARG that locally deforms all (training) features {𝒛i}i=1N\{\bm{z}_{i}^{\ell}\}_{i=1}^{N}{ bold_italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT with a continuous mapping g(𝒛,𝜽)g(\bm{z},\bm{\theta})italic_g ( bold_italic_z , bold_italic_θ ) defined on the entire feature space 𝒛d\bm{z}^{\ell}\in\mathbb{R}^{d}bold_italic_z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT. Notice that one may interpret the increment (4.1.8) as a discretized version of a continuous differential equation:

𝒛˙=g(𝒛,θ).\dot{\bm{z}}=g(\bm{z},\theta).over˙ start_ARG bold_italic_z end_ARG = italic_g ( bold_italic_z , italic_θ ) . (4.1.9)

Hence the (deep) network so constructed can be interpreted as a certain neural ODE [CRB+18]. Nevertheless, unlike neural ODE where the flow ggitalic_g is chosen to be some generic structures, here our g(𝒛,𝜽)g(\bm{z},\bm{\theta})italic_g ( bold_italic_z , bold_italic_θ ) is to emulate the gradient flow of the rate reduction on the feature set (as shown in Figure 4.1):

𝒁˙=ΔRϵ𝒁,\dot{\bm{Z}}=\frac{\partial\Delta R_{\epsilon}}{\partial\bm{Z}},over˙ start_ARG bold_italic_Z end_ARG = divide start_ARG ∂ roman_Δ italic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT end_ARG start_ARG ∂ bold_italic_Z end_ARG ,

and its structure is entirely derived and fully determined from this objective, without any other priors or heuristics.

By inspecting the structure of the gradient (4.1.5), it suggests that a natural candidate for the increment transform g(𝒛,𝜽)g(\bm{z}^{\ell},\bm{\theta}^{\ell})italic_g ( bold_italic_z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT , bold_italic_θ start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ) is of the form:

g(𝒛,𝜽)𝑬𝒛k=1Kγkπk(𝒛)𝑪k𝒛d,g(\bm{z}^{\ell},\bm{\theta}^{\ell})\;\doteq\;\bm{E}^{\ell}\bm{z}^{\ell}-\sum_{k=1}^{K}\gamma_{k}\pi_{k}(\bm{z}^{\ell})\bm{C}_{k}^{\ell}\bm{z}^{\ell}\in\mathbb{R}^{d},italic_g ( bold_italic_z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT , bold_italic_θ start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ) ≐ bold_italic_E start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT bold_italic_z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT - ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT italic_γ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT italic_π start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( bold_italic_z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ) bold_italic_C start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT bold_italic_z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT , (4.1.10)

where πk(𝒛)[0,1]\pi_{k}(\bm{z}^{\ell})\in[0,1]italic_π start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( bold_italic_z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ) ∈ [ 0 , 1 ] indicates the probability of 𝒛\bm{z}^{\ell}bold_italic_z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT belonging to the kkitalic_k-th class. The increment map parameters 𝜽\bm{\theta}^{\ell}bold_italic_θ start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT depend on: First, a set of linear maps represented by 𝑬\bm{E}^{\ell}bold_italic_E start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT and {𝑪k}k=1K\{\bm{C}^{\ell}_{k}\}_{k=1}^{K}{ bold_italic_C start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT that depend only on statistics of features of the training 𝒁\bm{Z}^{\ell}bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT; Second, the membership {πk(𝒛)}k=1K\{\pi_{k}(\bm{z}^{\ell})\}_{k=1}^{K}{ italic_π start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( bold_italic_z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ) } start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT of any feature 𝒛\bm{z}^{\ell}bold_italic_z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT. Notice that on the training samples 𝒁\bm{Z}^{\ell}bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT, for which the memberships 𝚷k\bm{\Pi}_{k}bold_Π start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT are known, the so defined g(𝒛,𝜽)g(\bm{z}^{\ell},\bm{\theta})italic_g ( bold_italic_z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT , bold_italic_θ ) gives exactly the values for the gradient ΔRϵ𝒁(𝒁)\frac{\partial\Delta R_{\epsilon}}{\partial\bm{Z}}(\bm{Z}^{\ell})divide start_ARG ∂ roman_Δ italic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT end_ARG start_ARG ∂ bold_italic_Z end_ARG ( bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ).

Since we only have the membership for the training samples, the function g()g(\cdot)italic_g ( ⋅ ) defined in (4.1.10) can only be evaluated on the training. To extrapolate g()g(\cdot)italic_g ( ⋅ ) to the entire feature space, we need to estimate πk(𝒛)\pi_{k}(\bm{z}^{\ell})italic_π start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( bold_italic_z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ) in its second term. In conventional deep learning, this map is typically modeled as a deep network and learned from the training data, say via back propagation. Nevertheless, our goal here is not to learn a precise classifier πk(𝒛)\pi_{k}(\bm{z}^{\ell})italic_π start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( bold_italic_z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ) already. Instead, we only need a good enough estimate of the class information in order for g()g(\cdot)italic_g ( ⋅ ) to approximate the gradient ΔRϵ𝒁\frac{\partial\Delta R_{\epsilon}}{\partial\bm{Z}}divide start_ARG ∂ roman_Δ italic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT end_ARG start_ARG ∂ bold_italic_Z end_ARG well.

From the geometric interpretation of the linear maps 𝑬\bm{E}^{\ell}bold_italic_E start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT and 𝑪k\bm{C}_{k}^{\ell}bold_italic_C start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT given by Remark 4.1, the term 𝒑k𝑪k𝒛\bm{p}_{k}^{\ell}\doteq\bm{C}^{\ell}_{k}\bm{z}^{\ell}bold_italic_p start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ≐ bold_italic_C start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT bold_italic_z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT can be viewed as (approximately) the projection of 𝒛\bm{z}^{\ell}bold_italic_z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT onto the orthogonal complement of each class jjitalic_j. Therefore, 𝒑j2\|\bm{p}_{j}^{\ell}\|_{2}∥ bold_italic_p start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT is small if 𝒛\bm{z}^{\ell}bold_italic_z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT is in class jjitalic_j and large otherwise. This motivates us to estimate its membership based on the following softmax function:

𝝅^(𝒛)softmax(λ[𝑪1𝒛2𝑪K𝒛2])=1k=1Kexp(λ𝑪k𝒛2)[exp(λ𝑪1𝒛2)exp(λ𝑪K𝒛2)][0,1]K.\widehat{\bm{\pi}}(\bm{z}^{\ell})\doteq\operatorname{\mathrm{softmax}}\left(-\lambda\begin{bmatrix}\|\bm{C}^{\ell}_{1}\bm{z}^{\ell}\|_{2}\\ \vdots\\ \|\bm{C}^{\ell}_{K}\bm{z}^{\ell}\|_{2}\end{bmatrix}\right)=\frac{1}{\sum_{k=1}^{K}\exp(-\lambda\|\bm{C}^{\ell}_{k}\bm{z}^{\ell}\|_{2})}\begin{bmatrix}\exp(-\lambda\|\bm{C}^{\ell}_{1}\bm{z}^{\ell}\|_{2})\\ \vdots\\ \exp(-\lambda\|\bm{C}^{\ell}_{K}\bm{z}^{\ell}\|_{2})\end{bmatrix}\in[0,1]^{K}.over^ start_ARG bold_italic_π end_ARG ( bold_italic_z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ) ≐ roman_softmax ( - italic_λ [ start_ARG start_ROW start_CELL ∥ bold_italic_C start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT bold_italic_z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_CELL end_ROW start_ROW start_CELL ⋮ end_CELL end_ROW start_ROW start_CELL ∥ bold_italic_C start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT bold_italic_z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_CELL end_ROW end_ARG ] ) = divide start_ARG 1 end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT roman_exp ( - italic_λ ∥ bold_italic_C start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT bold_italic_z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) end_ARG [ start_ARG start_ROW start_CELL roman_exp ( - italic_λ ∥ bold_italic_C start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT bold_italic_z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) end_CELL end_ROW start_ROW start_CELL ⋮ end_CELL end_ROW start_ROW start_CELL roman_exp ( - italic_λ ∥ bold_italic_C start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT bold_italic_z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) end_CELL end_ROW end_ARG ] ∈ [ 0 , 1 ] start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT . (4.1.11)

Hence, the second term of (4.1.10) can be approximated by this estimated membership:

k=1Kγkπk(𝒛)𝑪k𝒛k=1Kγkπ^k(𝒛)𝑪k𝒛𝝈([𝑪1𝒛,,𝑪K𝒛]),\displaystyle\sum_{k=1}^{K}\gamma_{k}\pi_{k}(\bm{z}^{\ell})\bm{C}_{k}^{\ell}\bm{z}^{\ell}\;\approx\;\sum_{k=1}^{K}\gamma_{k}\widehat{\pi}_{k}(\bm{z}^{\ell})\bm{C}^{\ell}_{k}\bm{z}^{\ell}\;\doteq\;\bm{\sigma}\Big{(}[\bm{C}^{\ell}_{1}\bm{z}^{\ell},\dots,\bm{C}^{\ell}_{K}\bm{z}^{\ell}]\Big{)},∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT italic_γ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT italic_π start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( bold_italic_z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ) bold_italic_C start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT bold_italic_z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ≈ ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT italic_γ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT over^ start_ARG italic_π end_ARG start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( bold_italic_z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ) bold_italic_C start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT bold_italic_z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ≐ bold_italic_σ ( [ bold_italic_C start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT bold_italic_z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT , … , bold_italic_C start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT bold_italic_z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ] ) , (4.1.12)

which is denoted as a nonlinear operator 𝝈()\bm{\sigma}(\cdot)bold_italic_σ ( ⋅ ) on outputs of the feature 𝒛\bm{z}^{\ell}bold_italic_z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT through KKitalic_K groups of filters: [𝑪1,,𝑪K][\bm{C}^{\ell}_{1},\dots,\bm{C}^{\ell}_{K}][ bold_italic_C start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , bold_italic_C start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ]. Notice that the nonlinearality arises due to a “soft” assignment of class membership based on the feature responses from those filters.

Overall, combining (4.1.8), (4.1.10), and (4.1.12), the increment feature transform from 𝒛\bm{z}^{\ell}bold_italic_z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT to 𝒛+1\bm{z}^{\ell+1}bold_italic_z start_POSTSUPERSCRIPT roman_ℓ + 1 end_POSTSUPERSCRIPT now becomes

𝒛+1\displaystyle\bm{z}^{\ell+1}bold_italic_z start_POSTSUPERSCRIPT roman_ℓ + 1 end_POSTSUPERSCRIPT 𝒛+η𝑬𝒛η𝝈([𝑪1𝒛,,𝑪K𝒛])\displaystyle\propto\;\bm{z}^{\ell}+\eta\cdot\bm{E}^{\ell}\bm{z}^{\ell}-\eta\cdot\bm{\sigma}\big{(}[\bm{C}^{\ell}_{1}\bm{z}^{\ell},\dots,\bm{C}^{\ell}_{K}\bm{z}^{\ell}]\big{)}∝ bold_italic_z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT + italic_η ⋅ bold_italic_E start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT bold_italic_z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT - italic_η ⋅ bold_italic_σ ( [ bold_italic_C start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT bold_italic_z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT , … , bold_italic_C start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT bold_italic_z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ] ) (4.1.13)
=𝒛+ηg(𝒛,𝜽)s.t.𝒛+1𝕊d1,\displaystyle=\;\bm{z}^{\ell}+\eta\cdot g(\bm{z}^{\ell},\bm{\theta}^{\ell})\qquad\mbox{s.t.}\quad\bm{z}^{\ell+1}\in\mathbb{S}^{d-1},= bold_italic_z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT + italic_η ⋅ italic_g ( bold_italic_z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT , bold_italic_θ start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ) s.t. bold_italic_z start_POSTSUPERSCRIPT roman_ℓ + 1 end_POSTSUPERSCRIPT ∈ blackboard_S start_POSTSUPERSCRIPT italic_d - 1 end_POSTSUPERSCRIPT ,

with the nonlinear function 𝝈()\bm{\sigma}(\cdot)bold_italic_σ ( ⋅ ) defined above and 𝜽\bm{\theta}^{\ell}bold_italic_θ start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT collecting all the layer-wise parameters. That is 𝜽={𝑬,𝑪1,,𝑪K,γk,λ}\bm{\theta}^{\ell}=\left\{\bm{E}^{\ell},\bm{C}^{\ell}_{1},\dots,\bm{C}^{\ell}_{K},\gamma_{k},\lambda\right\}bold_italic_θ start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT = { bold_italic_E start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT , bold_italic_C start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , bold_italic_C start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT , italic_γ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT , italic_λ }. Note features at each layer are always “normalized” by projecting onto the unit sphere 𝕊d1\mathbb{S}^{d-1}blackboard_S start_POSTSUPERSCRIPT italic_d - 1 end_POSTSUPERSCRIPT, denoted as 𝒫𝕊d1\mathcal{P}_{\mathbb{S}^{d-1}}caligraphic_P start_POSTSUBSCRIPT blackboard_S start_POSTSUPERSCRIPT italic_d - 1 end_POSTSUPERSCRIPT end_POSTSUBSCRIPT. The form of increment in (4.1.13) can be illustrated by a diagram in Figure 4.3(a).

(a) ReduNet
(a) ReduNet
(a) ReduNet
(b) ResNet and ResNeXt.
Figure 4.3: Network Architectures of the ReduNet and comparison with others. (a): Layer structure of the ReduNet derived from one iteration of gradient ascent for optimizing rate reduction. (b) (left): A layer of ResNet [HZR+16a]; and (b) (right): A layer of ResNeXt [XGD+17]. As we will see in Section 4.1.2, the linear operators 𝑬\bm{E}^{\ell}bold_italic_E start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT and 𝑪k\bm{C}_{k}^{\ell}bold_italic_C start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT of the ReduNet naturally become (multi-channel) convolutions when shift-invariance is imposed.
Algorithm 4.1 Training algorithm for ReduNet
1:𝑿=[𝒙1,,𝒙N]D×N\bm{X}=[\bm{x}_{1},\ldots,\bm{x}_{N}]\in\mathbb{R}^{D\times N}bold_italic_X = [ bold_italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , bold_italic_x start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT ] ∈ blackboard_R start_POSTSUPERSCRIPT italic_D × italic_N end_POSTSUPERSCRIPT, 𝚷={𝚷k}k=1K\bm{\Pi}=\{\bm{\Pi}_{k}\}_{k=1}^{K}bold_Π = { bold_Π start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT, ϵ>0\epsilon>0italic_ϵ > 0, λ\lambdaitalic_λ, and a learning rate η\etaitalic_η.
2:The learned parameters {𝑬}=1L,{{𝑪k}k=1K}=1L,{γk}k=1k\{\bm{E}^{\ell}\}_{\ell=1}^{L},\{\{\bm{C}^{\ell}_{k}\}_{k=1}^{K}\}_{\ell=1}^{L},\{\gamma_{k}\}_{k=1}^{k}{ bold_italic_E start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT } start_POSTSUBSCRIPT roman_ℓ = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT , { { bold_italic_C start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT } start_POSTSUBSCRIPT roman_ℓ = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT , { italic_γ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT
3:procedure ReduNetTraining(𝑿,𝚷,ϵ,λ,η\bm{X},\bm{\Pi},\epsilon,\lambda,\etabold_italic_X , bold_Π , italic_ϵ , italic_λ , italic_η)
4:  # Define constants
5:     αd/(Nϵ2)\alpha\leftarrow d/(N\epsilon^{2})italic_α ← italic_d / ( italic_N italic_ϵ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT )
6:     for k{1,,K}k\in\{1,\dots,K\}italic_k ∈ { 1 , … , italic_K } do
7:         αkD/(tr(𝚷k)ϵ2)\alpha_{k}\leftarrow D/(\operatorname{tr}(\bm{\Pi}_{k})\epsilon^{2})italic_α start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ← italic_D / ( roman_tr ( bold_Π start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) italic_ϵ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT )
8:         γktr(𝚷k)/D\gamma_{k}\leftarrow\operatorname{tr}(\bm{\Pi}_{k})/Ditalic_γ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ← roman_tr ( bold_Π start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) / italic_D
9:     end for
10:
11:  # ReduNet layer-by-layer iteration
12:     𝒁1=[𝒛11,,𝒛N1]𝑿\bm{Z}^{1}=\begin{bmatrix}\bm{z}_{1}^{1},\dots,\bm{z}_{N}^{1}\end{bmatrix}\leftarrow\bm{X}bold_italic_Z start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT = [ start_ARG start_ROW start_CELL bold_italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT , … , bold_italic_z start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT end_CELL end_ROW end_ARG ] ← bold_italic_X \triangleright Initialize the ReduNet per-layer iteration
13:     for {1,,L}\ell\in\{1,\dots,L\}roman_ℓ ∈ { 1 , … , italic_L } do
14:  # Step 1: Compute network parameters 𝑬,{𝑪k}k=1K\bm{E}^{\ell},\{\bm{C}^{\ell}_{k}\}_{k=1}^{K}bold_italic_E start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT , { bold_italic_C start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT
15:         𝑬α(𝑰+α𝒁(𝒁))1d×d\bm{E}^{\ell}\leftarrow\alpha\left(\bm{I}+\alpha\bm{Z}^{\ell}(\bm{Z}^{\ell})^{\top}\right)^{-1}\in\mathbb{R}^{d\times d}bold_italic_E start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ← italic_α ( bold_italic_I + italic_α bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ( bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_d end_POSTSUPERSCRIPT
16:         for k{1,,K}k\in\{1,\dots,K\}italic_k ∈ { 1 , … , italic_K } do
17:              𝑪kαk(𝑰+αk𝒁𝚷k(𝒁))1d×d\bm{C}^{\ell}_{k}\leftarrow\alpha_{k}\left(\bm{I}+\alpha_{k}\bm{Z}^{\ell}\bm{\Pi}_{k}(\bm{Z}^{\ell})^{\top}\right)^{-1}\in\mathbb{R}^{d\times d}bold_italic_C start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ← italic_α start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( bold_italic_I + italic_α start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT bold_Π start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_d end_POSTSUPERSCRIPT
18:         end for
19:
20:   # Step 2: Update features 𝒁\bm{Z}^{\ell}bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT
21:         for i{1,,N}i\in\{1,\dots,N\}italic_i ∈ { 1 , … , italic_N } do
22:              𝝅^(𝒛i)softmax(λ[𝑪1𝒛i2,,𝑪K𝒛i2])[0,1]K\hat{\bm{\pi}}(\bm{z}^{\ell}_{i})\leftarrow\displaystyle\operatorname{\mathrm{softmax}}(-\lambda[\|\bm{C}^{\ell}_{1}\bm{z}^{\ell}_{i}\|_{2},\dots,\|\bm{C}^{\ell}_{K}\bm{z}^{\ell}_{i}\|_{2}])\in[0,1]^{K}over^ start_ARG bold_italic_π end_ARG ( bold_italic_z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ← roman_softmax ( - italic_λ [ ∥ bold_italic_C start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT bold_italic_z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , … , ∥ bold_italic_C start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT bold_italic_z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ] ) ∈ [ 0 , 1 ] start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT \triangleright Compute soft assignments 𝝅^(𝒛i)\hat{\bm{\pi}}(\bm{z}^{\ell}_{i})over^ start_ARG bold_italic_π end_ARG ( bold_italic_z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT )
23:              𝒛i+1𝒫𝕊d1(𝒛i+η(𝑬𝒛ik=1Kγkπ^k(𝒛i)𝑪k𝒛i))d\displaystyle\bm{z}^{\ell+1}_{i}\leftarrow\mathcal{P}_{\mathbb{S}^{d-1}}\left(\bm{z}^{\ell}_{i}+\eta\left(\bm{E}^{\ell}\bm{z}^{\ell}_{i}-\sum_{k=1}^{K}\gamma_{k}\hat{\pi}_{k}(\bm{z}^{\ell}_{i})\bm{C}^{\ell}_{k}\bm{z}^{\ell}_{i}\right)\right)\in\mathbb{R}^{d}bold_italic_z start_POSTSUPERSCRIPT roman_ℓ + 1 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ← caligraphic_P start_POSTSUBSCRIPT blackboard_S start_POSTSUPERSCRIPT italic_d - 1 end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( bold_italic_z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT + italic_η ( bold_italic_E start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT bold_italic_z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT - ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT italic_γ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT over^ start_ARG italic_π end_ARG start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( bold_italic_z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) bold_italic_C start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT bold_italic_z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ) ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT \triangleright Update features 𝒛i+1\bm{z}^{\ell+1}_{i}bold_italic_z start_POSTSUPERSCRIPT roman_ℓ + 1 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT from 𝒛i\bm{z}^{\ell}_{i}bold_italic_z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT
24:         end for
25:     end for
26:     return {𝑬}=1L,{{𝑪k}k=1K}=1L,{γk}k=1K\{\bm{E}^{\ell}\}_{\ell=1}^{L},\{\{\bm{C}^{\ell}_{k}\}_{k=1}^{K}\}_{\ell=1}^{L},\{\gamma_{k}\}_{k=1}^{K}{ bold_italic_E start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT } start_POSTSUBSCRIPT roman_ℓ = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT , { { bold_italic_C start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT } start_POSTSUBSCRIPT roman_ℓ = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT , { italic_γ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT \triangleright Return all network parameters.
27:end procedure

Deep Network for Optimizing Rate Reduction.

Notice that the increment is constructed to emulate the gradient ascent for the rate reduction ΔRϵ\Delta R_{\epsilon}roman_Δ italic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT. Hence by transforming the features iteratively via the above process, we expect the rate reduction to increase, as we will see in the experimental section. This iterative process, once converged say after LLitalic_L iterations, gives the desired feature map f(𝒙,𝜽)f(\bm{x},\bm{\theta})italic_f ( bold_italic_x , bold_italic_θ ) on the input 𝒙=𝒛0\bm{x}=\bm{z}^{0}bold_italic_x = bold_italic_z start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT, precisely in the form of a deep network, in which each layer has the structure shown in Figure 4.3 left:

f(𝒙,𝜽)=\displaystyle f(\bm{x},\bm{\theta})\;=italic_f ( bold_italic_x , bold_italic_θ ) = fLfL1f1f0(𝒛0),\displaystyle\;\;f^{L}\circ f^{L-1}\circ\cdots\circ f^{1}\circ f^{0}(\bm{z}^{0}),italic_f start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT ∘ italic_f start_POSTSUPERSCRIPT italic_L - 1 end_POSTSUPERSCRIPT ∘ ⋯ ∘ italic_f start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT ∘ italic_f start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT ( bold_italic_z start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT ) , (4.1.14)
f(𝒛,𝜽)\displaystyle f^{\ell}(\bm{z}^{\ell},\bm{\theta}^{\ell})\;\doteqitalic_f start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ( bold_italic_z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT , bold_italic_θ start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ) ≐ 𝒛+1=𝒫𝕊n1[𝒛+ηg(𝒛,𝜽)],\displaystyle\;\;\bm{z}^{\ell+1}=\mathcal{P}_{\mathbb{S}^{n-1}}[\bm{z}^{\ell}+\eta\cdot g(\bm{z}^{\ell},\bm{\theta}^{\ell})],bold_italic_z start_POSTSUPERSCRIPT roman_ℓ + 1 end_POSTSUPERSCRIPT = caligraphic_P start_POSTSUBSCRIPT blackboard_S start_POSTSUPERSCRIPT italic_n - 1 end_POSTSUPERSCRIPT end_POSTSUBSCRIPT [ bold_italic_z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT + italic_η ⋅ italic_g ( bold_italic_z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT , bold_italic_θ start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ) ] ,
g(𝒛,𝜽)=\displaystyle g(\bm{z}^{\ell},\bm{\theta}^{\ell})\;=italic_g ( bold_italic_z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT , bold_italic_θ start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ) = 𝑬𝒛𝝈([𝑪1𝒛,,𝑪K𝒛]).\displaystyle\;\;\bm{E}^{\ell}\bm{z}^{\ell}-\bm{\sigma}\big{(}[\bm{C}^{\ell}_{1}\bm{z}^{\ell},\dots,\bm{C}^{\ell}_{K}\bm{z}^{\ell}]\big{)}.bold_italic_E start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT bold_italic_z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT - bold_italic_σ ( [ bold_italic_C start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT bold_italic_z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT , … , bold_italic_C start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT bold_italic_z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ] ) .

As this deep network is derived from maximizing the rate reducation, we call it the ReduNet. By comparing the architecture of ReduNet with those of popular empirically designed networks, ResNet and ResNeXt shown in Figure 4.3, the similarity is somewhat uncanny. Conceptually, ReduNet could also be used to justify the popular mixture of experts (MoE) architecture [SMM+17] as each parallel channel, 𝑪k\bm{C}^{\ell}_{k}bold_italic_C start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT, can be viewed as an “expert” trained for each class of objects.

Figure 4.4 : Left: a mixture of experts (MoE) deep network [ SMM+17 ] . Right: a sparsity-promoting Switch Transformer [ FZS22 ] , used to implement MoE with 1.7 trillion parameters.
Figure 4.4 : Left: a mixture of experts (MoE) deep network [ SMM+17 ] . Right: a sparsity-promoting Switch Transformer [ FZS22 ] , used to implement MoE with 1.7 trillion parameters.
Figure 4.4: Left: a mixture of experts (MoE) deep network [SMM+17]. Right: a sparsity-promoting Switch Transformer [FZS22], used to implement MoE with 1.7 trillion parameters.

We summarize the training and evaluation of ReduNet in Algorithm 4.1 and Algorithm 4.2, respectively. Notice that all parameters of the network are explicitly constructed layer by layer in a forward propagation fashion. The construction does not need any back propagation! The so-learned features can be directly used for classification, say via a nearest subspace classifier.

Algorithm 4.2 Evaluation algorithm for ReduNet
1:Input 𝒙D\bm{x}\in\mathbb{R}^{D}bold_italic_x ∈ blackboard_R start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT, network parameters {𝑬}=1L,{{𝑪k}k=1K}=1L,{γk}k=1K\{\bm{E}^{\ell}\}_{\ell=1}^{L},\{\{\bm{C}^{\ell}_{k}\}_{k=1}^{K}\}_{\ell=1}^{L},\{\gamma_{k}\}_{k=1}^{K}{ bold_italic_E start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT } start_POSTSUBSCRIPT roman_ℓ = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT , { { bold_italic_C start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT } start_POSTSUBSCRIPT roman_ℓ = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT , { italic_γ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT, learning rate λ\lambdaitalic_λ
2:feature 𝒛L+1\bm{z}^{L+1}bold_italic_z start_POSTSUPERSCRIPT italic_L + 1 end_POSTSUPERSCRIPT
3:procedure ReduNetEvaluation(𝒙\bm{x}bold_italic_x)
4:     𝒛1𝒙D\bm{z}^{1}\leftarrow\bm{x}\in\mathbb{R}^{D}bold_italic_z start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT ← bold_italic_x ∈ blackboard_R start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT \triangleright Initialize the ReduNet per-layer iteration
5:     for {1,,L}\ell\in\{1,\dots,L\}roman_ℓ ∈ { 1 , … , italic_L } do
6:         𝝅^(𝒛)softmax(λ[𝑪1𝒛2,,𝑪K𝒛2])[0,1]K\hat{\bm{\pi}}(\bm{z}^{\ell})\leftarrow\operatorname{\mathrm{softmax}}(-\lambda\begin{bmatrix}\|\bm{C}^{\ell}_{1}\bm{z}^{\ell}\|_{2},\dots,\|\bm{C}^{\ell}_{K}\bm{z}^{\ell}\|_{2}\end{bmatrix})\in[0,1]^{K}over^ start_ARG bold_italic_π end_ARG ( bold_italic_z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ) ← roman_softmax ( - italic_λ [ start_ARG start_ROW start_CELL ∥ bold_italic_C start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT bold_italic_z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , … , ∥ bold_italic_C start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT bold_italic_z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_CELL end_ROW end_ARG ] ) ∈ [ 0 , 1 ] start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT \triangleright Compute soft assignments 𝝅^(𝒛)\hat{\bm{\pi}}(\bm{z}^{\ell})over^ start_ARG bold_italic_π end_ARG ( bold_italic_z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT )
7:         𝒛+1𝒫𝕊d1(𝒛+η(𝑬𝒛k=1Kγkπ^k(𝒛)𝑪k𝒛))d\bm{z}^{\ell+1}\leftarrow\mathcal{P}_{\mathbb{S}^{d-1}}\left(\bm{z}^{\ell}+\eta\left(\bm{E}^{\ell}\bm{z}^{\ell}-\sum_{k=1}^{K}\gamma_{k}\hat{\pi}_{k}(\bm{z}^{\ell})\bm{C}^{\ell}_{k}\bm{z}^{\ell}\right)\right)\in\mathbb{R}^{d}bold_italic_z start_POSTSUPERSCRIPT roman_ℓ + 1 end_POSTSUPERSCRIPT ← caligraphic_P start_POSTSUBSCRIPT blackboard_S start_POSTSUPERSCRIPT italic_d - 1 end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( bold_italic_z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT + italic_η ( bold_italic_E start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT bold_italic_z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT - ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT italic_γ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT over^ start_ARG italic_π end_ARG start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( bold_italic_z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ) bold_italic_C start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT bold_italic_z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ) ) ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT \triangleright Update feature 𝒛+1\bm{z}^{\ell+1}bold_italic_z start_POSTSUPERSCRIPT roman_ℓ + 1 end_POSTSUPERSCRIPT using 𝒛\bm{z}^{\ell}bold_italic_z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT
8:     end for
9:     return 𝒛L+1\bm{z}^{L+1}bold_italic_z start_POSTSUPERSCRIPT italic_L + 1 end_POSTSUPERSCRIPT \triangleright Return the output features
10:end procedure
Example 4.1.

To provide some intuition on how ReduNet transforms the features, we provide a simple example with mixed 3D Gaussians and visualize how the features are transformed in Figure 4.5. Consider a mixture of three Gaussian distributions in 3\mathbb{R}^{3}blackboard_R start_POSTSUPERSCRIPT 3 end_POSTSUPERSCRIPT that is projected onto 𝕊2\mathbb{S}^{2}blackboard_S start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT. We first generate data points for 3 classes: for k=1,2,3k=1,2,3italic_k = 1 , 2 , 3, 𝑿k=[𝒙k,1,,𝒙k,m]3×m\bm{X}_{k}=[\bm{x}_{k,1},\ldots,\bm{x}_{k,m}]\in\mathbb{R}^{3\times m}bold_italic_X start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT = [ bold_italic_x start_POSTSUBSCRIPT italic_k , 1 end_POSTSUBSCRIPT , … , bold_italic_x start_POSTSUBSCRIPT italic_k , italic_m end_POSTSUBSCRIPT ] ∈ blackboard_R start_POSTSUPERSCRIPT 3 × italic_m end_POSTSUPERSCRIPT, 𝒙k,i𝒩(𝝁k,σk2𝑰)\bm{x}_{k,i}\sim\mathcal{N}(\bm{\mu}_{k},\sigma_{k}^{2}\bm{I})bold_italic_x start_POSTSUBSCRIPT italic_k , italic_i end_POSTSUBSCRIPT ∼ caligraphic_N ( bold_italic_μ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT , italic_σ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT bold_italic_I ), and π(𝒙k,i)=k{\pi}(\bm{x}_{k,i})=kitalic_π ( bold_italic_x start_POSTSUBSCRIPT italic_k , italic_i end_POSTSUBSCRIPT ) = italic_k. We set m=500,σ1=σ2=σ3=0.1m=500,\sigma_{1}=\sigma_{2}=\sigma_{3}=0.1italic_m = 500 , italic_σ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT = italic_σ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT = italic_σ start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT = 0.1, and 𝝁1,𝝁2,𝝁3𝕊2\bm{\mu}_{1},\bm{\mu}_{2},\bm{\mu}_{3}\in\mathbb{S}^{2}bold_italic_μ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , bold_italic_μ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , bold_italic_μ start_POSTSUBSCRIPT 3 end_POSTSUBSCRIPT ∈ blackboard_S start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT. Then we project all the data points onto 𝕊2\mathbb{S}^{2}blackboard_S start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT, i.e., 𝒙k,i/𝒙k,i2\bm{x}_{k,i}/\|\bm{x}_{k,i}\|_{2}bold_italic_x start_POSTSUBSCRIPT italic_k , italic_i end_POSTSUBSCRIPT / ∥ bold_italic_x start_POSTSUBSCRIPT italic_k , italic_i end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT. To construct the network (computing 𝑬,𝑪k\bm{E}^{\ell},\bm{C}^{\ell}_{k}bold_italic_E start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT , bold_italic_C start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT for the \ellroman_ℓ-th layer), we set the number of iterations/layers L=2,000L=2,000italic_L = 2 , 000, step size η=0.5\eta=0.5italic_η = 0.5, and precision ϵ=0.1\epsilon=0.1italic_ϵ = 0.1. We do this only to demonstrate that our framework leads to stable deep networks even with thousands of layers. In practice, thousands of layers may not be necessary and one can stop whenever adding new layers gives diminishing returns. For this example, a couple of hundred layers is sufficient. Hence, the clear optimization objective gives a natural criterion for the depth of the network needed.

As shown in Figure 4.5, we can observe that after the mapping f(,𝜽)f(\cdot,\bm{\theta})italic_f ( ⋅ , bold_italic_θ ), samples from the same class are highly compressed and converge to a single cluster and the angle between two different clusters is approximately π/2\pi/2italic_π / 2, which is well aligned with the optimal solution 𝒁\bm{Z}^{\star}bold_italic_Z start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT of the MCR2 loss in 𝕊2\mathbb{S}^{2}blackboard_S start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT. MCR2 loss of features on different layers can be found in Figure 4.5(c). Empirically, we find that the constructed ReduNet is able to maximize MCR2 loss and converges stably and samples from the same class converge to one cluster and different clusters are orthogonal to each other. Moreover, when sampling new data points from the same distributions, we find that new samples from the same class consistently converge to the same cluster center as the training samples.

(a) 𝑿 train \bm{X}_{\text{train}} bold_italic_X start_POSTSUBSCRIPT train end_POSTSUBSCRIPT
(a) 𝑿train\bm{X}_{\text{train}}bold_italic_X start_POSTSUBSCRIPT train end_POSTSUBSCRIPT
(a) 𝑿 train \bm{X}_{\text{train}} bold_italic_X start_POSTSUBSCRIPT train end_POSTSUBSCRIPT
(b) 𝒁train\bm{Z}_{\text{train}}bold_italic_Z start_POSTSUBSCRIPT train end_POSTSUBSCRIPT
(a) 𝑿 train \bm{X}_{\text{train}} bold_italic_X start_POSTSUBSCRIPT train end_POSTSUBSCRIPT
(c) Loss (train/val)
Figure 4.5: Original samples and learned representations for 3D Mixture of Gaussians. We visualize data points 𝑿\bm{X}bold_italic_X (before mapping f(,𝜽)f(\cdot,\bm{\theta})italic_f ( ⋅ , bold_italic_θ )) in (a) and learned features 𝒁\bm{Z}bold_italic_Z (after mapping f(,𝜽)f(\cdot,\bm{\theta})italic_f ( ⋅ , bold_italic_θ )) in (b) by scatter plot. In each scatter plot, each color represents one class of samples. In (c), we also show the plots for the progression of values of the objective functions.

\blacksquare

4.1.2 Convolutional Networks from Invariant Rate Reduction

In the previous section, we derived the layer-wise architecture of a deep network, the ReduNet, using unrolled optimization for the rate reduction objective. Specifically, the compression term Rϵc(𝒁𝚷)R^{c}_{\epsilon}(\bm{Z}\mid\bm{\Pi})italic_R start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( bold_italic_Z ∣ bold_Π ) in (4.1.1) is designed to compress representations from the same class. However, this formulation does not account for possible domain transformation or deformation of the input data. For instance, shifting an object slightly to the right does not change the semantic label of an image. In this section, we will demonstrate how convolutional layers can be derived by maximizing a rate reduction objective that is invariant to certain domain deformations, such as image rotations and translations.

For many clustering or classification tasks (such as object detection in images), we consider two samples as equivalent if they differ by certain classes of domain deformations or augmentations 𝒯={τ}\mathcal{T}=\{\tau\}caligraphic_T = { italic_τ }. Hence, we are only interested in low-dimensional structures that are invariant to such deformations (i.e., 𝒙\bm{x}\in\mathcal{M}bold_italic_x ∈ caligraphic_M iff τ(𝒙)\tau(\bm{x})\in\mathcal{M}italic_τ ( bold_italic_x ) ∈ caligraphic_M for all τ𝒯\tau\in\mathcal{T}italic_τ ∈ caligraphic_T ), which are known to have sophisticated geometric and topological structures and can be difficult to learn precisely in practice, even with rigorously designed CNNs [CW16]. In this framework, this can be formulated in a very natural way: all equivariant instances are to be embedded into the same subspace, so that the subspace itself is invariant to the transformations under consideration.

In many applications, such as serial data or imagery data, the semantic meaning (labels) of the data are invariant to certain transformations 𝔤𝔾\mathfrak{g}\in\mathbb{G}fraktur_g ∈ blackboard_G (for some group 𝔾\mathbb{G}blackboard_G) [CW16b, ZKR+17]. For example, the meaning of an audio signal is invariant to shift in time; and the identity of an object in an image is invariant to translation in the image plane. Hence, we prefer the feature mapping f(𝒙,𝜽)f(\bm{x},\bm{\theta})italic_f ( bold_italic_x , bold_italic_θ ) is rigorously invariant to such transformations:

Group Invariance:f(𝒙𝔤,𝜽)f(𝒙,𝜽),𝔤𝔾,\mbox{\em Group Invariance:}\;f(\bm{x}\circ\mathfrak{g},\bm{\theta})\sim f(\bm{x},\bm{\theta}),\ \forall\mathfrak{g}\in\mathbb{G},Group Invariance: italic_f ( bold_italic_x ∘ fraktur_g , bold_italic_θ ) ∼ italic_f ( bold_italic_x , bold_italic_θ ) , ∀ fraktur_g ∈ blackboard_G , (4.1.15)

where “\sim” indicates two features belonging to the same equivalent class. Although to ensure invariance or equivarience, convolutional operators have been common practice in deep networks [CW16b], it remains challenging in practice to train an (empirically designed) convolution network from scratch that can guarantee invariance even to simple transformations such as translation and rotation [AW18, ETT+17]. An alternative approach is to carefully design convolution filters of each layer so as to ensure translational invariance for a wide range of signals, say using wavelets as in ScatteringNet [BM13] and followup works [WB18]. However, in order to ensure invariance to generic signals, the number of convolutions needed usually grows exponentially with network depth. That is the reason why this type of network cannot be constructed so deep, usually only several layers.

Now, we show that the MCR2 principle is compatible with invariance in a natural and precise way: we only need to assign all transformed versions {𝒙𝔤𝔤𝔾}\{\bm{x}\circ\mathfrak{g}\mid\mathfrak{g}\in\mathbb{G}\}{ bold_italic_x ∘ fraktur_g ∣ fraktur_g ∈ blackboard_G } into the same class as the data 𝒙\bm{x}bold_italic_x and map their features 𝒛\bm{z}bold_italic_z all to the same subspace 𝒮\mathcal{S}caligraphic_S. Hence, all group equivariant information is encoded only inside the subspace, and any classifier defined on the resulting set of subspaces will be automatically invariant to such group transformations. See Figure 4.6 for an illustration of the examples of 1D rotation and 2D translation. Next, we will rigorously show that when the group 𝔾\mathbb{G}blackboard_G is circular 1D shifting, the resulting deep network naturally becomes a multi-channel convolution network. Because the so-constructed network only needs to ensure invariance for the given data 𝑿\bm{X}bold_italic_X or their features 𝒁\bm{Z}bold_italic_Z, the number of convolutions needed actually remains constant through a very deep network, as opposed to the ScatteringNet.

Figure 4.6 : Illustration of the sought representation that is equivariant/invariant to image rotation (left) or translation (right): all transformed images of each class are mapped into the same subspace that is incoherent to other subspaces. The features embedded in each subspace are equivariant to the transformation group whereas each subspace is invariant to such transformations.
Figure 4.6 : Illustration of the sought representation that is equivariant/invariant to image rotation (left) or translation (right): all transformed images of each class are mapped into the same subspace that is incoherent to other subspaces. The features embedded in each subspace are equivariant to the transformation group whereas each subspace is invariant to such transformations.
Figure 4.6: Illustration of the sought representation that is equivariant/invariant to image rotation (left) or translation (right): all transformed images of each class are mapped into the same subspace that is incoherent to other subspaces. The features embedded in each subspace are equivariant to the transformation group whereas each subspace is invariant to such transformations.

1D Serial Data and Shift Invariance

To classify one-dimensional data 𝒙=[x(0),x(1),,x(D1)]D\bm{x}=[x(0),x(1),\ldots,x(D-1)]\in\mathbb{R}^{D}bold_italic_x = [ italic_x ( 0 ) , italic_x ( 1 ) , … , italic_x ( italic_D - 1 ) ] ∈ blackboard_R start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT invariant under shifting, we take 𝔾\mathbb{G}blackboard_G to be the group of all circular shifts. Each observation 𝒙i\bm{x}_{i}bold_italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT generates a family {𝒙i𝔤|𝔤𝔾}\{\bm{x}_{i}\circ\mathfrak{g}\,|\,\mathfrak{g}\in\mathbb{G}\}{ bold_italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∘ fraktur_g | fraktur_g ∈ blackboard_G } of shifted copies, which are the columns of the circulant matrix 𝖼𝗂𝗋𝖼(𝒙i)D×D\mathsf{circ}(\bm{x}_{i})\in\mathbb{R}^{D\times D}sansserif_circ ( bold_italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ∈ blackboard_R start_POSTSUPERSCRIPT italic_D × italic_D end_POSTSUPERSCRIPT given by

𝖼𝗂𝗋𝖼(𝒙)[x(0)x(D1)x(2)x(1)x(1)x(0)x(D1)x(2)x(1)x(0)x(D2)x(D1)x(D1)x(D2)x(1)x(0)]D×D.\mathsf{circ}(\bm{x})\,\doteq\,\left[\begin{array}[]{ccccc}x(0)&x(D-1)&\dots&x(2)&x(1)\\ x(1)&x(0)&x(D-1)&\cdots&x(2)\\ \vdots&x(1)&x(0)&\ddots&\vdots\\ x(D-2)&\vdots&\ddots&\ddots&x(D-1)\\ x(D-1)&x(D-2)&\dots&x(1)&x(0)\end{array}\right]\in\mathbb{R}^{D\times D}.sansserif_circ ( bold_italic_x ) ≐ [ start_ARRAY start_ROW start_CELL italic_x ( 0 ) end_CELL start_CELL italic_x ( italic_D - 1 ) end_CELL start_CELL … end_CELL start_CELL italic_x ( 2 ) end_CELL start_CELL italic_x ( 1 ) end_CELL end_ROW start_ROW start_CELL italic_x ( 1 ) end_CELL start_CELL italic_x ( 0 ) end_CELL start_CELL italic_x ( italic_D - 1 ) end_CELL start_CELL ⋯ end_CELL start_CELL italic_x ( 2 ) end_CELL end_ROW start_ROW start_CELL ⋮ end_CELL start_CELL italic_x ( 1 ) end_CELL start_CELL italic_x ( 0 ) end_CELL start_CELL ⋱ end_CELL start_CELL ⋮ end_CELL end_ROW start_ROW start_CELL italic_x ( italic_D - 2 ) end_CELL start_CELL ⋮ end_CELL start_CELL ⋱ end_CELL start_CELL ⋱ end_CELL start_CELL italic_x ( italic_D - 1 ) end_CELL end_ROW start_ROW start_CELL italic_x ( italic_D - 1 ) end_CELL start_CELL italic_x ( italic_D - 2 ) end_CELL start_CELL … end_CELL start_CELL italic_x ( 1 ) end_CELL start_CELL italic_x ( 0 ) end_CELL end_ROW end_ARRAY ] ∈ blackboard_R start_POSTSUPERSCRIPT italic_D × italic_D end_POSTSUPERSCRIPT . (4.1.16)

We refer the reader to [KS12] for properties of circulant matrices. For simplicity, let 𝒁1[𝒛11,,𝒛N1]=𝑿d×N\bm{Z}^{1}\doteq[\bm{z}_{1}^{1},\dots,\bm{z}_{N}^{1}]=\bm{X}\in\mathbb{R}^{d\times N}bold_italic_Z start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT ≐ [ bold_italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT , … , bold_italic_z start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT ] = bold_italic_X ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_N end_POSTSUPERSCRIPT.777Again, to simplify discussion, we assume for now that the initial features 𝒁1\bm{Z}^{1}bold_italic_Z start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT are 𝑿\bm{X}bold_italic_X themselves hence have the same dimension dditalic_d, i.e., D=dD=ditalic_D = italic_d. But that does not need to be the case as we will soon see that we need to lift 𝑿\bm{X}bold_italic_X to a higher dimension. Then what happens if we construct the ReduNet from their circulant families 𝖼𝗂𝗋𝖼(𝒁1)=[𝖼𝗂𝗋𝖼(𝒛11),,𝖼𝗂𝗋𝖼(𝒛N1)]d×(dN)\mathsf{circ}(\bm{Z}^{1})=\left[\mathsf{circ}(\bm{z}_{1}^{1}),\dots,\mathsf{circ}(\bm{z}_{N}^{1})\right]\in\mathbb{R}^{d\times(dN)}sansserif_circ ( bold_italic_Z start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT ) = [ sansserif_circ ( bold_italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT ) , … , sansserif_circ ( bold_italic_z start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT ) ] ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × ( italic_d italic_N ) end_POSTSUPERSCRIPT? That is, we want to compress and map all these into the same subspace by the ReduNet.

Notice that now the data covariance matrix:

𝖼𝗂𝗋𝖼(𝒁1)𝖼𝗂𝗋𝖼(𝒁1)\displaystyle\mathsf{circ}(\bm{Z}^{1})\mathsf{circ}(\bm{Z}^{1})^{\top}sansserif_circ ( bold_italic_Z start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT ) sansserif_circ ( bold_italic_Z start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT =\displaystyle== [𝖼𝗂𝗋𝖼(𝒛11),,𝖼𝗂𝗋𝖼(𝒛N1)][𝖼𝗂𝗋𝖼(𝒛11),,𝖼𝗂𝗋𝖼(𝒛N1)]\displaystyle\left[\mathsf{circ}(\bm{z}_{1}^{1}),\dots,\mathsf{circ}(\bm{z}_{N}^{1})\right]\left[\mathsf{circ}(\bm{z}_{1}^{1}),\dots,\mathsf{circ}(\bm{z}_{N}^{1})\right]^{\top}[ sansserif_circ ( bold_italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT ) , … , sansserif_circ ( bold_italic_z start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT ) ] [ sansserif_circ ( bold_italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT ) , … , sansserif_circ ( bold_italic_z start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT ) ] start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT (4.1.17)
=\displaystyle== i=1N𝖼𝗂𝗋𝖼(𝒛i1)𝖼𝗂𝗋𝖼(𝒛i1)d×d\displaystyle\sum_{i=1}^{N}\mathsf{circ}(\bm{z}_{i}^{1})\mathsf{circ}(\bm{z}_{i}^{1})^{\top}\;\in\mathbb{R}^{d\times d}∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT sansserif_circ ( bold_italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT ) sansserif_circ ( bold_italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_d end_POSTSUPERSCRIPT

associated with this family of samples is automatically a (symmetric) circulant matrix. Moreover, because the circulant property is preserved under sums, inverses, and products, the matrices 𝑬1\bm{E}^{1}bold_italic_E start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT and 𝑪k1\bm{C}^{1}_{k}bold_italic_C start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT are also automatically circulant matrices, whose application to a feature vector 𝒛d\bm{z}\in\mathbb{R}^{d}bold_italic_z ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT can be implemented using circular convolution “\circledast”. Specifically, we have the following proposition.

Proposition 4.1 (Convolution structures of 𝑬1\bm{E}^{1}bold_italic_E start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT and 𝑪k1\bm{C}^{1}_{k}bold_italic_C start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT).

The matrix

𝑬1=α(𝑰+α𝖼𝗂𝗋𝖼(𝒁1)𝖼𝗂𝗋𝖼(𝒁1))1\bm{E}^{1}=\alpha\big{(}\bm{I}+\alpha\mathsf{circ}(\bm{Z}^{1})\mathsf{circ}(\bm{Z}^{1})^{\top}\big{)}^{-1}bold_italic_E start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT = italic_α ( bold_italic_I + italic_α sansserif_circ ( bold_italic_Z start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT ) sansserif_circ ( bold_italic_Z start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT (4.1.18)

is a circulant matrix and represents a circular convolution:

𝑬1𝒛=𝒆1𝒛,\bm{E}^{1}\bm{z}=\bm{e}_{1}\circledast\bm{z},bold_italic_E start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT bold_italic_z = bold_italic_e start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ⊛ bold_italic_z ,

where 𝐞1d\bm{e}_{1}\in\mathbb{R}^{d}bold_italic_e start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT is the first column vector of 𝐄1\bm{E}^{1}bold_italic_E start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT and “\circledast” is circular convolution defined as

(𝒆1𝒛)ij=0d1e1(j)x(i+djmodd).(\bm{e}_{1}\circledast\bm{z})_{i}\doteq\sum_{j=0}^{d-1}e_{1}(j)x(i+d-j\,\,\textsf{mod}\,\,d).( bold_italic_e start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ⊛ bold_italic_z ) start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ≐ ∑ start_POSTSUBSCRIPT italic_j = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d - 1 end_POSTSUPERSCRIPT italic_e start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_j ) italic_x ( italic_i + italic_d - italic_j mod italic_d ) .

Similarly, the matrices 𝐂k1\bm{C}^{1}_{k}bold_italic_C start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT associated with any subsets of 𝐙1\bm{Z}^{1}bold_italic_Z start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT are also circular convolutions.

Not only do the first-layer parameters 𝑬1\bm{E}^{1}bold_italic_E start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT and 𝑪k1\bm{C}^{1}_{k}bold_italic_C start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT of the ReduNet become circulant convolutions but also the next-layer features remain circulant matrices. That is, the incremental feature transform in (4.1.13) applied to all shifted versions of a 𝒛1d\bm{z}^{1}\in\mathbb{R}^{d}bold_italic_z start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT, given by

𝖼𝗂𝗋𝖼(𝒛1)+η𝑬1𝖼𝗂𝗋𝖼(𝒛1)η𝝈([𝑪11𝖼𝗂𝗋𝖼(𝒛1),,𝑪K1𝖼𝗂𝗋𝖼(𝒛1)]),\mathsf{circ}(\bm{z}^{1})+\eta\cdot\bm{E}^{1}\mathsf{circ}(\bm{z}^{1})-\eta\cdot\bm{\sigma}\Big{(}[\bm{C}_{1}^{1}\mathsf{circ}(\bm{z}^{1}),\ldots,\bm{C}^{1}_{K}\mathsf{circ}(\bm{z}^{1})]\Big{)},sansserif_circ ( bold_italic_z start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT ) + italic_η ⋅ bold_italic_E start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT sansserif_circ ( bold_italic_z start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT ) - italic_η ⋅ bold_italic_σ ( [ bold_italic_C start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT sansserif_circ ( bold_italic_z start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT ) , … , bold_italic_C start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT sansserif_circ ( bold_italic_z start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT ) ] ) , (4.1.19)

is a circulant matrix. This implies that there is no need to construct circulant families from the second layer features as we did for the first layer. By denoting

𝒛2𝒛1+ηg(𝒛1,𝜽1)=𝒛1+η𝒆1𝒛1η𝝈([𝒄11𝒛1,,𝒄K1𝒛1]),\bm{z}^{2}\propto\bm{z}^{1}+\eta\cdot g(\bm{z}^{1},\bm{\theta}^{1})=\bm{z}^{1}+\eta\cdot\bm{e}_{1}\circledast\bm{z}^{1}-\eta\cdot\bm{\sigma}\Big{(}[\bm{c}_{1}^{1}\circledast\bm{z}^{1},\dots,\bm{c}^{1}_{K}\circledast\bm{z}^{1}]\Big{)},bold_italic_z start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ∝ bold_italic_z start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT + italic_η ⋅ italic_g ( bold_italic_z start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT , bold_italic_θ start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT ) = bold_italic_z start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT + italic_η ⋅ bold_italic_e start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ⊛ bold_italic_z start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT - italic_η ⋅ bold_italic_σ ( [ bold_italic_c start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT ⊛ bold_italic_z start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT , … , bold_italic_c start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ⊛ bold_italic_z start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT ] ) , (4.1.20)

the features at the next level can be written as

𝖼𝗂𝗋𝖼(𝒁2)=[𝖼𝗂𝗋𝖼(𝒛11+ηg(𝒛11,𝜽1)),,𝖼𝗂𝗋𝖼(𝒛N1+ηg(𝒛N1,𝜽1))].\mathsf{circ}(\bm{Z}^{2})=\big{[}\mathsf{circ}(\bm{z}_{1}^{1}+\eta g(\bm{z}_{1}^{1},\bm{\theta}^{1})),\dots,\mathsf{circ}(\bm{z}_{N}^{1}+\eta g(\bm{z}_{N}^{1},\bm{\theta}^{1}))\big{]}.sansserif_circ ( bold_italic_Z start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) = [ sansserif_circ ( bold_italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT + italic_η italic_g ( bold_italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT , bold_italic_θ start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT ) ) , … , sansserif_circ ( bold_italic_z start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT + italic_η italic_g ( bold_italic_z start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT , bold_italic_θ start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT ) ) ] .

Continuing inductively, we see that all matrices 𝑬\bm{E}^{\ell}bold_italic_E start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT and 𝑪k\bm{C}^{\ell}_{k}bold_italic_C start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT based on such 𝖼𝗂𝗋𝖼(𝒁)\mathsf{circ}(\bm{Z}^{\ell})sansserif_circ ( bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ) are circulant, and so are all features. By virtue of the properties of the data, ReduNet has taken the form of a convolutional network, with no need to explicitly choose this structure!

A Fundamental Trade-off between Invariance and Sparsity.

There is one problem though: In general, the set of all circular permutations of a vector 𝒛\bm{z}bold_italic_z gives a full-rank matrix. That is, the dditalic_d “augmented” features associated with each sample (hence each class) typically already span the entire space d\mathbb{R}^{d}blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT. For instance, all shifted versions of a delta function δ(d)\delta(d)italic_δ ( italic_d ) can generate any other signal as their (dense) weighted superposition. The MCR2 objective (3.4.12) will not be able to distinguish classes as different subspaces.

One natural remedy is to improve the separability of the data by “lifting” the original signal to a higher dimensional space, e.g., by taking their responses to multiple, filters 𝒌1,,𝒌Cd\bm{k}_{1},\ldots,\bm{k}_{C}\in\mathbb{R}^{d}bold_italic_k start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , bold_italic_k start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT:

𝒛[c]=𝒌c𝒙=𝖼𝗂𝗋𝖼(𝒌c)𝒙d,c=1,,C.\bm{z}[c]=\bm{k}_{c}\circledast\bm{x}=\mathsf{circ}(\bm{k}_{c})\bm{x}\in\mathbb{R}^{d},\quad c=1,\ldots,C.bold_italic_z [ italic_c ] = bold_italic_k start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT ⊛ bold_italic_x = sansserif_circ ( bold_italic_k start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT ) bold_italic_x ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT , italic_c = 1 , … , italic_C . (4.1.21)

The filters can be pre-designed invariance-promoting filters,888For 1D signals like audio, one may consider the conventional short-time Fourier transform (STFT); for 2D images, one may consider 2D wavelets as in the ScatteringNet [BM13]. or adaptively learned from the data,999For learned filters, one can learn filters as the principal components of samples as in the PCANet [CJG+15] or from convolution dictionary learning [LB19, QLZ19]. or randomly selected as we do in our experiments. This operation lifts each original signal 𝒙d\bm{x}\in\mathbb{R}^{d}bold_italic_x ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT to a CCitalic_C-channel feature, denoted as 𝒛¯[𝒛[1],,𝒛[C]]C×d\bar{\bm{z}}\doteq[\bm{z}[1],\ldots,\bm{z}[C]]^{\top}\in\mathbb{R}^{C\times d}over¯ start_ARG bold_italic_z end_ARG ≐ [ bold_italic_z [ 1 ] , … , bold_italic_z [ italic_C ] ] start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_C × italic_d end_POSTSUPERSCRIPT. Then, we may construct the ReduNet on vector representations of 𝒛¯\bar{\bm{z}}over¯ start_ARG bold_italic_z end_ARG, denoted as (𝒛¯)[𝒛[1],,𝒛[C]]dC\vec{(}\bar{\bm{z}})\doteq[\bm{z}[1]^{\top},\ldots,\bm{z}[C]^{\top}]\in\mathbb{R}^{dC}over→ start_ARG ( end_ARG over¯ start_ARG bold_italic_z end_ARG ) ≐ [ bold_italic_z [ 1 ] start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT , … , bold_italic_z [ italic_C ] start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ] ∈ blackboard_R start_POSTSUPERSCRIPT italic_d italic_C end_POSTSUPERSCRIPT. The associated circulant version 𝖼𝗂𝗋𝖼(𝒛¯)\mathsf{circ}(\bar{\bm{z}})sansserif_circ ( over¯ start_ARG bold_italic_z end_ARG ) and its data covariance matrix, denoted as 𝚺¯(𝒛¯)\bar{\bm{\Sigma}}(\bar{\bm{z}})over¯ start_ARG bold_Σ end_ARG ( over¯ start_ARG bold_italic_z end_ARG ), for all its shifted versions are given as:

𝖼𝗂𝗋𝖼(𝒛¯)[𝖼𝗂𝗋𝖼(𝒛[1])𝖼𝗂𝗋𝖼(𝒛[C])]dC×d,𝚺¯(𝒛¯)[𝖼𝗂𝗋𝖼(𝒛[1])𝖼𝗂𝗋𝖼(𝒛[C])][𝖼𝗂𝗋𝖼(𝒛[1]),,𝖼𝗂𝗋𝖼(𝒛[C])]dC×dC,\displaystyle\mathsf{circ}(\bar{\bm{z}})\doteq\left[\begin{matrix}\mathsf{circ}(\bm{z}[1])\\ \vdots\\ \mathsf{circ}(\bm{z}[C])\end{matrix}\right]\in\mathbb{R}^{dC\times d},\quad\bar{\bm{\Sigma}}(\bar{\bm{z}})\doteq\left[\begin{matrix}\mathsf{circ}(\bm{z}[1])\\ \vdots\\ \mathsf{circ}(\bm{z}[C])\end{matrix}\right]\left[\begin{matrix}\mathsf{circ}(\bm{z}[1])^{\top},\ldots,\mathsf{circ}(\bm{z}[C])^{\top}\end{matrix}\right]\in\mathbb{R}^{dC\times dC},sansserif_circ ( over¯ start_ARG bold_italic_z end_ARG ) ≐ [ start_ARG start_ROW start_CELL sansserif_circ ( bold_italic_z [ 1 ] ) end_CELL end_ROW start_ROW start_CELL ⋮ end_CELL end_ROW start_ROW start_CELL sansserif_circ ( bold_italic_z [ italic_C ] ) end_CELL end_ROW end_ARG ] ∈ blackboard_R start_POSTSUPERSCRIPT italic_d italic_C × italic_d end_POSTSUPERSCRIPT , over¯ start_ARG bold_Σ end_ARG ( over¯ start_ARG bold_italic_z end_ARG ) ≐ [ start_ARG start_ROW start_CELL sansserif_circ ( bold_italic_z [ 1 ] ) end_CELL end_ROW start_ROW start_CELL ⋮ end_CELL end_ROW start_ROW start_CELL sansserif_circ ( bold_italic_z [ italic_C ] ) end_CELL end_ROW end_ARG ] [ start_ARG start_ROW start_CELL sansserif_circ ( bold_italic_z [ 1 ] ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT , … , sansserif_circ ( bold_italic_z [ italic_C ] ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT end_CELL end_ROW end_ARG ] ∈ blackboard_R start_POSTSUPERSCRIPT italic_d italic_C × italic_d italic_C end_POSTSUPERSCRIPT , (4.1.22)

where 𝖼𝗂𝗋𝖼(𝒛[c])d×d\mathsf{circ}(\bm{z}[c])\in\mathbb{R}^{d\times d}sansserif_circ ( bold_italic_z [ italic_c ] ) ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_d end_POSTSUPERSCRIPT with c[C]c\in[C]italic_c ∈ [ italic_C ] is the circulant version of the ccitalic_c-th channel of the feature 𝒛¯\bar{\bm{z}}over¯ start_ARG bold_italic_z end_ARG. Then the columns of 𝖼𝗂𝗋𝖼(𝒛¯)\mathsf{circ}(\bar{\bm{z}})sansserif_circ ( over¯ start_ARG bold_italic_z end_ARG ) will only span at most a dditalic_d-dimensional proper subspace in dC\mathbb{R}^{dC}blackboard_R start_POSTSUPERSCRIPT italic_d italic_C end_POSTSUPERSCRIPT. However, this simple lifting operation (if linear) is not sufficient to render the classes separable yet—features associated with other classes will span the same dditalic_d-dimensional subspace. This reflects a fundamental conflict between invariance and linear (subspace) modeling: one cannot hope for arbitrarily shifted and superposed signals to belong to the same class.

Figure 4.7 : Each input signal 𝒙 \bm{x} bold_italic_x (an image here) can be represented as a superposition of sparse convolutions with multiple kernels 𝒅 c \bm{d}_{c} bold_italic_d start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT in a dictionary 𝑫 \bm{D} bold_italic_D .

Figure 4.7: Each input signal 𝒙\bm{x}bold_italic_x (an image here) can be represented as a superposition of sparse convolutions with multiple kernels 𝒅c\bm{d}_{c}bold_italic_d start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT in a dictionary 𝑫\bm{D}bold_italic_D.

One way of resolving this conflict is to leverage additional structure within each class, in the form of sparsity: signals within each class are not generated as an arbitrary linear superposition of some base atoms (or motifs), but only sparse combinations of them and their shifted versions, as shown in Figure 4.7. More precisely, let 𝑫k=[𝒅k,1,,𝒅k,c]\bm{D}_{k}=[\bm{d}_{k,1},\ldots,\bm{d}_{k,c}]bold_italic_D start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT = [ bold_italic_d start_POSTSUBSCRIPT italic_k , 1 end_POSTSUBSCRIPT , … , bold_italic_d start_POSTSUBSCRIPT italic_k , italic_c end_POSTSUBSCRIPT ] denote a matrix with a collection of atoms associated for class kkitalic_k, also known as a dictionary, then each signal 𝒙\bm{x}bold_italic_x in this class is sparsely generated as:

𝒙=𝒅k,1z1++𝒅k,czc=𝖼𝗂𝗋𝖼(𝑫k)𝒛,\bm{x}=\bm{d}_{k,1}\circledast z_{1}+\ldots+\bm{d}_{k,c}\circledast z_{c}=\mathsf{circ}(\bm{D}_{k})\bm{z},bold_italic_x = bold_italic_d start_POSTSUBSCRIPT italic_k , 1 end_POSTSUBSCRIPT ⊛ italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + … + bold_italic_d start_POSTSUBSCRIPT italic_k , italic_c end_POSTSUBSCRIPT ⊛ italic_z start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT = sansserif_circ ( bold_italic_D start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) bold_italic_z , (4.1.23)

for some sparse vector 𝒛\bm{z}bold_italic_z. Signals in different classes are then generated by different dictionaries whose atoms (or motifs) are incoherent from one another. Due to incoherence, signals in one class are unlikely to be sparsely represented by atoms in any other class. Hence all signals can be represented as

𝒙=[𝖼𝗂𝗋𝖼(𝑫1),𝖼𝗂𝗋𝖼(𝑫2),,𝖼𝗂𝗋𝖼(𝑫K)]𝒛¯,\bm{x}=\big{[}\mathsf{circ}(\bm{D}_{1}),\mathsf{circ}(\bm{D}_{2}),\ldots,\mathsf{circ}(\bm{D}_{K})\big{]}\bar{\bm{z}},bold_italic_x = [ sansserif_circ ( bold_italic_D start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) , sansserif_circ ( bold_italic_D start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) , … , sansserif_circ ( bold_italic_D start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ) ] over¯ start_ARG bold_italic_z end_ARG , (4.1.24)

where 𝒛¯\bar{\bm{z}}over¯ start_ARG bold_italic_z end_ARG is sparse.101010Notice that similar sparse representation models have long been proposed and used for classification purposes in applications such a face recognition, demonstrating excellent effectiveness [WYG+09, WWG+12]. Recently, the convolution sparse coding model has been proposed by [PRE17] as a framework for interpreting the structures of deep convolution networks. There is a vast literature on how to learn the most compact and optimal sparsifying dictionaries from sample data, e.g. [LB19, QLZ19] and subsequently solve the inverse problem and compute the associated sparse code 𝒛\bm{z}bold_italic_z or 𝒛¯\bar{\bm{z}}over¯ start_ARG bold_italic_z end_ARG. Recent studies of [QLZ20, QZL+20] even show that under broad conditions the convolution dictionary learning problem can be solved effectively and efficiently.

Nevertheless, for tasks such as classification, we are not necessarily interested in the precise optimal dictionary nor the precise sparse code for each individual signal. We are mainly interested if collectively the set of sparse codes for each class are adequately separable from those of other classes. Under the assumption of the sparse generative model, if the convolution kernels {𝒌c}c=1C\{\bm{k}_{c}\}_{c=1}^{C}{ bold_italic_k start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_c = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_C end_POSTSUPERSCRIPT match well with the “transpose” or “inverse” of the above sparsifying dictionaries 𝑫=[𝑫1,,𝑫K]\bm{D}=[\bm{D}_{1},\ldots,\bm{D}_{K}]bold_italic_D = [ bold_italic_D start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , bold_italic_D start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ], also known as the analysis filters [NDE+13, RE14], signals in one class will only have high responses to a small subset of those filters and low responses to others (due to the incoherence assumption). Nevertheless, in practice, often a sufficiently large number of, say CCitalic_C, random filters {𝒌c}c=1C\{\bm{k}_{c}\}_{c=1}^{C}{ bold_italic_k start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_c = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_C end_POSTSUPERSCRIPT suffices to ensure that the extracted CCitalic_C-channel features

[𝒌1𝒙,𝒌2𝒙,,𝒌C𝒙]=[𝖼𝗂𝗋𝖼(𝒌1)𝒙,,𝖼𝗂𝗋𝖼(𝒌C)𝒙]C×d\big{[}\bm{k}_{1}\circledast\bm{x},\bm{k}_{2}\circledast\bm{x},\ldots,\bm{k}_{C}\circledast\bm{x}\big{]}^{\top}=\big{[}\mathsf{circ}(\bm{k}_{1})\bm{x},\ldots,\mathsf{circ}(\bm{k}_{C})\bm{x}\big{]}^{\top}\in\mathbb{R}^{C\times d}[ bold_italic_k start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ⊛ bold_italic_x , bold_italic_k start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ⊛ bold_italic_x , … , bold_italic_k start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT ⊛ bold_italic_x ] start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT = [ sansserif_circ ( bold_italic_k start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) bold_italic_x , … , sansserif_circ ( bold_italic_k start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT ) bold_italic_x ] start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_C × italic_d end_POSTSUPERSCRIPT (4.1.25)

for different classes have different response patterns to different filters hence make different classes separable [CJG+15].

Therefore, in our framework, to a large extent the number of channels (or the width of the network) truly plays the role as the statistical resource whereas the number of layers (the depth of the network) plays the role as the computational resource. The theory of compressive sensing precisely characterizes how many measurements are needed in order to preserve the intrinsic low-dimensional structures (including separability) of the data [WM21].

Figure 4.8 : Estimate the sparse code 𝒛 ¯ \bar{\bm{z}} over¯ start_ARG bold_italic_z end_ARG of an input signal 𝒙 \bm{x} bold_italic_x (an image here) by taking convolutions with multiple kernels 𝒌 c \bm{k}_{c} bold_italic_k start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT and then sparsifying.

Figure 4.8: Estimate the sparse code 𝒛¯\bar{\bm{z}}over¯ start_ARG bold_italic_z end_ARG of an input signal 𝒙\bm{x}bold_italic_x (an image here) by taking convolutions with multiple kernels 𝒌c\bm{k}_{c}bold_italic_k start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT and then sparsifying.

The multi-channel responses 𝒛¯\bar{\bm{z}}over¯ start_ARG bold_italic_z end_ARG should be sparse. So to approximate the sparse code 𝒛¯\bar{\bm{z}}over¯ start_ARG bold_italic_z end_ARG, we may take an entry-wise sparsity-promoting nonlinear thresholding, say 𝝉()\bm{\tau}(\cdot)bold_italic_τ ( ⋅ ), on the above filter outputs by setting low (say absolute value below ϵ\epsilonitalic_ϵ) or negative responses to be zero:

𝒛¯𝝉([𝖼𝗂𝗋𝖼(𝒌1)𝒙,,𝖼𝗂𝗋𝖼(𝒌C)𝒙])C×d.\bar{\bm{z}}\doteq\bm{\tau}\left(\big{[}\mathsf{circ}(\bm{k}_{1})\bm{x},\ldots,\mathsf{circ}(\bm{k}_{C})\bm{x}\big{]}^{\top}\right)\in\mathbb{R}^{C\times d}.over¯ start_ARG bold_italic_z end_ARG ≐ bold_italic_τ ( [ sansserif_circ ( bold_italic_k start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) bold_italic_x , … , sansserif_circ ( bold_italic_k start_POSTSUBSCRIPT italic_C end_POSTSUBSCRIPT ) bold_italic_x ] start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) ∈ blackboard_R start_POSTSUPERSCRIPT italic_C × italic_d end_POSTSUPERSCRIPT . (4.1.26)

Figure 4.8 illustrates the basic ideas. One may refer to [RE14] for a more systematical study on the design of the sparsifying thresholding operator. Nevertheless, here we are not so interested in obtaining the best sparse codes as long as the codes are sufficiently separable. Hence the nonlinear operator 𝝉\bm{\tau}bold_italic_τ can be simply chosen to be a soft thresholding or a ReLU. These presumably sparse features 𝒛¯\bar{\bm{z}}over¯ start_ARG bold_italic_z end_ARG can be assumed to lie on a lower-dimensional (nonlinear) submanifold of dC\mathbb{R}^{dC}blackboard_R start_POSTSUPERSCRIPT italic_d italic_C end_POSTSUPERSCRIPT, which can be linearized and separated from the other classes by subsequent ReduNet layers, as illustrated later in Figure 4.9.

The ReduNet constructed from circulant version of these multi-channel features 𝒁¯[𝒛¯1,,𝒛¯N]C×d×N\bar{\bm{Z}}\doteq[\bar{\bm{z}}_{1},\ldots,\bar{\bm{z}}_{N}]\in\mathbb{R}^{C\times d\times N}over¯ start_ARG bold_italic_Z end_ARG ≐ [ over¯ start_ARG bold_italic_z end_ARG start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , over¯ start_ARG bold_italic_z end_ARG start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT ] ∈ blackboard_R start_POSTSUPERSCRIPT italic_C × italic_d × italic_N end_POSTSUPERSCRIPT, i.e., 𝖼𝗂𝗋𝖼(𝒁¯)[𝖼𝗂𝗋𝖼(𝒛¯1),,𝖼𝗂𝗋𝖼(𝒛¯N)]dC×dN\mathsf{circ}(\bar{\bm{Z}})\doteq[\mathsf{circ}(\bar{\bm{z}}_{1}),\dots,\mathsf{circ}(\bar{\bm{z}}_{N})]\in\mathbb{R}^{dC\times dN}sansserif_circ ( over¯ start_ARG bold_italic_Z end_ARG ) ≐ [ sansserif_circ ( over¯ start_ARG bold_italic_z end_ARG start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) , … , sansserif_circ ( over¯ start_ARG bold_italic_z end_ARG start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT ) ] ∈ blackboard_R start_POSTSUPERSCRIPT italic_d italic_C × italic_d italic_N end_POSTSUPERSCRIPT, retains the good invariance properties described above: the linear operators, now denoted as 𝑬¯\bar{\bm{E}}over¯ start_ARG bold_italic_E end_ARG and 𝑪¯k\bar{\bm{C}}_{k}over¯ start_ARG bold_italic_C end_ARG start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT, remain block circulant, and represent multi-channel 1D circular convolutions. Specifically, we have the following result.

Proposition 4.2 (Multi-channel convolution structures of 𝑬¯\bar{\bm{E}}over¯ start_ARG bold_italic_E end_ARG and 𝑪¯k\bar{\bm{C}}_{k}over¯ start_ARG bold_italic_C end_ARG start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT).

The matrix

𝑬¯α(𝑰+α𝖼𝗂𝗋𝖼(𝒁¯)𝖼𝗂𝗋𝖼(𝒁¯))1\bar{\bm{E}}\doteq\alpha\left(\bm{I}+\alpha\,\mathsf{circ}(\bar{\bm{Z}})\mathsf{circ}(\bar{\bm{Z}})^{\top}\right)^{-1}over¯ start_ARG bold_italic_E end_ARG ≐ italic_α ( bold_italic_I + italic_α sansserif_circ ( over¯ start_ARG bold_italic_Z end_ARG ) sansserif_circ ( over¯ start_ARG bold_italic_Z end_ARG ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT (4.1.27)

is block circulant, i.e.,

𝑬¯=[𝑬¯1,1𝑬¯1,C𝑬¯C,1𝑬¯C,C]dC×dC,\bar{\bm{E}}=\left[\begin{matrix}\bar{\bm{E}}_{1,1}&\cdots&\bar{\bm{E}}_{1,C}\\ \vdots&\ddots&\vdots\\ \bar{\bm{E}}_{C,1}&\cdots&\bar{\bm{E}}_{C,C}\\ \end{matrix}\right]\in\mathbb{R}^{dC\times dC},over¯ start_ARG bold_italic_E end_ARG = [ start_ARG start_ROW start_CELL over¯ start_ARG bold_italic_E end_ARG start_POSTSUBSCRIPT 1 , 1 end_POSTSUBSCRIPT end_CELL start_CELL ⋯ end_CELL start_CELL over¯ start_ARG bold_italic_E end_ARG start_POSTSUBSCRIPT 1 , italic_C end_POSTSUBSCRIPT end_CELL end_ROW start_ROW start_CELL ⋮ end_CELL start_CELL ⋱ end_CELL start_CELL ⋮ end_CELL end_ROW start_ROW start_CELL over¯ start_ARG bold_italic_E end_ARG start_POSTSUBSCRIPT italic_C , 1 end_POSTSUBSCRIPT end_CELL start_CELL ⋯ end_CELL start_CELL over¯ start_ARG bold_italic_E end_ARG start_POSTSUBSCRIPT italic_C , italic_C end_POSTSUBSCRIPT end_CELL end_ROW end_ARG ] ∈ blackboard_R start_POSTSUPERSCRIPT italic_d italic_C × italic_d italic_C end_POSTSUPERSCRIPT ,

where each 𝐄¯c,cd×d\bar{\bm{E}}_{c,c^{\prime}}\in\mathbb{R}^{d\times d}over¯ start_ARG bold_italic_E end_ARG start_POSTSUBSCRIPT italic_c , italic_c start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_d end_POSTSUPERSCRIPT is a circulant matrix. Moreover, 𝐄¯\bar{\bm{E}}over¯ start_ARG bold_italic_E end_ARG represents a multi-channel circular convolution, i.e., for any multi-channel signal 𝐳¯C×n\bar{\bm{z}}\in\mathbb{R}^{C\times n}over¯ start_ARG bold_italic_z end_ARG ∈ blackboard_R start_POSTSUPERSCRIPT italic_C × italic_n end_POSTSUPERSCRIPT we have

𝑬¯vec(𝒛¯)=vec(𝒆¯𝒛¯).\bar{\bm{E}}\cdot\textsf{vec}(\bar{\bm{z}})=\textsf{vec}(\bar{\bm{e}}\circledast\bar{\bm{z}}).over¯ start_ARG bold_italic_E end_ARG ⋅ vec ( over¯ start_ARG bold_italic_z end_ARG ) = vec ( over¯ start_ARG bold_italic_e end_ARG ⊛ over¯ start_ARG bold_italic_z end_ARG ) .

In above, 𝐞¯C×C×d\bar{\bm{e}}\in\mathbb{R}^{C\times C\times d}over¯ start_ARG bold_italic_e end_ARG ∈ blackboard_R start_POSTSUPERSCRIPT italic_C × italic_C × italic_d end_POSTSUPERSCRIPT is a multi-channel convolutional kernel with 𝐞¯[c,c]d\bar{\bm{e}}[c,c^{\prime}]\in\mathbb{R}^{d}over¯ start_ARG bold_italic_e end_ARG [ italic_c , italic_c start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ] ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT being the first column vector of 𝐄¯c,c\bar{\bm{E}}_{c,c^{\prime}}over¯ start_ARG bold_italic_E end_ARG start_POSTSUBSCRIPT italic_c , italic_c start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT, and 𝐞¯𝐳¯C×d\bar{\bm{e}}\circledast\bar{\bm{z}}\in\mathbb{R}^{C\times d}over¯ start_ARG bold_italic_e end_ARG ⊛ over¯ start_ARG bold_italic_z end_ARG ∈ blackboard_R start_POSTSUPERSCRIPT italic_C × italic_d end_POSTSUPERSCRIPT is the multi-channel circular convolution defined as

(𝒆¯𝒛¯)[c]c=1C𝒆¯[c,c]𝒛¯[c],c=1,,C.(\bar{\bm{e}}\circledast\bar{\bm{z}})[c]\doteq\sum_{c^{\prime}=1}^{C}\bar{\bm{e}}[c,c^{\prime}]\circledast\bar{\bm{z}}[c^{\prime}],\quad\forall c=1,\ldots,C.( over¯ start_ARG bold_italic_e end_ARG ⊛ over¯ start_ARG bold_italic_z end_ARG ) [ italic_c ] ≐ ∑ start_POSTSUBSCRIPT italic_c start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_C end_POSTSUPERSCRIPT over¯ start_ARG bold_italic_e end_ARG [ italic_c , italic_c start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ] ⊛ over¯ start_ARG bold_italic_z end_ARG [ italic_c start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ] , ∀ italic_c = 1 , … , italic_C .

Similarly, the matrices 𝐂¯k\bar{\bm{C}}_{k}over¯ start_ARG bold_italic_C end_ARG start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT associated with any subsets of 𝐙¯\bar{\bm{Z}}over¯ start_ARG bold_italic_Z end_ARG are also multi-channel circular convolutions.

From Proposition 4.2, shift invariant ReduNet is a deep convolutional network for multi-channel 1D signals by construction. Notice that even if the initial lifting kernels are separated (4.1.26), the matrix inverse in (4.1.27) for computing 𝑬¯\bar{\bm{E}}over¯ start_ARG bold_italic_E end_ARG (similarly for 𝑪k¯\bar{\bm{C}_{k}}over¯ start_ARG bold_italic_C start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_ARG) introduces “cross talk” among all CCitalic_C channels. Hence, these multi-channel convolutions in general are not depth-wise separable, unlike the Xception nets [Cho17] that were once suggested to simplify multi-channel convolutional neural networks.111111It remains open what additional structures on the data would lead to depth-wise separable convolutions.

Remark 4.2 (Reducing Computational Complexity in the Frequency Domain).

The calculation of 𝑬¯\bar{\bm{E}}over¯ start_ARG bold_italic_E end_ARG in (4.1.27) requires inverting a matrix of size dC×dCdC\times dCitalic_d italic_C × italic_d italic_C, which in general has complexity O(d3C3)O(d^{3}C^{3})italic_O ( italic_d start_POSTSUPERSCRIPT 3 end_POSTSUPERSCRIPT italic_C start_POSTSUPERSCRIPT 3 end_POSTSUPERSCRIPT ). Nevertheless, by using the fact that a circulant matrix can be diagonalized by the Discrete Fourier Transform (DFT) matrix, the complexity can be significantly reduced. As shown in [CYY+22], to compute 𝑬¯\bar{\bm{E}}over¯ start_ARG bold_italic_E end_ARG and 𝑪¯kdC×dC\bar{\bm{C}}_{k}\in\mathbb{R}^{dC\times dC}over¯ start_ARG bold_italic_C end_ARG start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d italic_C × italic_d italic_C end_POSTSUPERSCRIPT, we only need to compute in the frequency domain the inverse of C×CC\times Citalic_C × italic_C blocks for dditalic_d times hence the overall complexity becomes O(dC3)O(dC^{3})italic_O ( italic_d italic_C start_POSTSUPERSCRIPT 3 end_POSTSUPERSCRIPT ).

Overall Network Architecture and Comparison.

Following the above derivation, we see that in order to find a linear discriminative representation (LDR) for multiple classes of signals/images that is invariant to translation, sparse coding, a multi-layer architecture with multi-channel convolutions, different nonlinear activation, and spectrum computing all become necessary components for achieving the objective effectively and efficiently. Figure 4.9 illustrates the overall process of learning such a representation via invariant rate reduction on the input sparse codes.

Figure 4.9 : The overall process for classifying multi-class signals with shift invariance: Multi-channel lifting, sparse coding, followed by a multi-channel convolution ReduNet for invariant rate reduction. These components are necessary in order to map shift-invariant multi-class signals to incoherent (linear) subspaces as an LDR. Note that the architectures of most modern deep neural networks resemble this process. The so-learned LDR facilitates subsequent tasks such as classification.
Figure 4.9: The overall process for classifying multi-class signals with shift invariance: Multi-channel lifting, sparse coding, followed by a multi-channel convolution ReduNet for invariant rate reduction. These components are necessary in order to map shift-invariant multi-class signals to incoherent (linear) subspaces as an LDR. Note that the architectures of most modern deep neural networks resemble this process. The so-learned LDR facilitates subsequent tasks such as classification.
Example 4.2 (Invariant Classification of Digits).

We next provide an empirical performance of the ReduNet on learning rotation invariant features on the real 10-class MNIST dataset. We impose a polar grid on the image 𝒙H×W\bm{x}\in\mathbb{R}^{H\times W}bold_italic_x ∈ blackboard_R start_POSTSUPERSCRIPT italic_H × italic_W end_POSTSUPERSCRIPT, with its geometric center being the center of the 2D polar grid (as illustrated in Figure 4.10). For each radius rir_{i}italic_r start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT, i[C]i\in[C]italic_i ∈ [ italic_C ], we can sample Γ\Gammaroman_Γ pixels with respect to each angle γl=l(2π/Γ)\gamma_{l}=l\cdot({2\pi}/\Gamma)italic_γ start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT = italic_l ⋅ ( 2 italic_π / roman_Γ ) with l[Γ]l\in[\Gamma]italic_l ∈ [ roman_Γ ]. Then given a sample image 𝒙\bm{x}bold_italic_x from the dataset, we represent the image in the (sampled) polar coordinate as a multi-channel signal 𝒙pΓ×C\bm{x}_{p}\in\mathbb{R}^{\Gamma\times C}bold_italic_x start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT roman_Γ × italic_C end_POSTSUPERSCRIPT. The goal here is to learn a rotation invariant representation, i.e., we expect to learn f(,𝜽)f(\cdot,\bm{\theta})italic_f ( ⋅ , bold_italic_θ ) such that {f(𝒙p𝔤,𝜽)}𝔤𝔾\{f(\bm{x}_{p}\circ\mathfrak{g},\bm{\theta})\}_{\mathfrak{g}\in\mathbb{G}}{ italic_f ( bold_italic_x start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT ∘ fraktur_g , bold_italic_θ ) } start_POSTSUBSCRIPT fraktur_g ∈ blackboard_G end_POSTSUBSCRIPT lie in the same subspace, where 𝔤\mathfrak{g}fraktur_g is the cyclic-shift in polar angle. We use N=100N=100italic_N = 100 training samples (101010 from each class) and set Γ=200\Gamma=200roman_Γ = 200, C=15C=15italic_C = 15 for polar sampling. By performing the above sampling in polar coordinate, we can obtain the data matrix 𝑿p(ΓC)×N\bm{X}_{p}\in\mathbb{R}^{(\Gamma\cdot C)\times N}bold_italic_X start_POSTSUBSCRIPT italic_p end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT ( roman_Γ ⋅ italic_C ) × italic_N end_POSTSUPERSCRIPT. For the ReduNet, we set the number of layers/iterations L=40L=40italic_L = 40, precision ϵ=0.1\epsilon=0.1italic_ϵ = 0.1, step size η=0.5\eta=0.5italic_η = 0.5. Before the first layer, we perform lifting of the input by 1D circulant-convolution with 20 random Gaussian kernels of size 5.

(a) 𝑿 rotation \bm{X}_{\text{rotation}} bold_italic_X start_POSTSUBSCRIPT rotation end_POSTSUBSCRIPT
(a) 𝑿rotation\bm{X}_{\text{rotation}}bold_italic_X start_POSTSUBSCRIPT rotation end_POSTSUBSCRIPT
(a) 𝑿 rotation \bm{X}_{\text{rotation}} bold_italic_X start_POSTSUBSCRIPT rotation end_POSTSUBSCRIPT
(b) 𝒁rotation\bm{Z}_{\text{rotation}}bold_italic_Z start_POSTSUBSCRIPT rotation end_POSTSUBSCRIPT
(a) 𝑿 rotation \bm{X}_{\text{rotation}} bold_italic_X start_POSTSUBSCRIPT rotation end_POSTSUBSCRIPT
(c) Loss
Figure 4.10: Examples of rotated images of MNIST digits, each by 18. (Left) Diagram for polar coordinate representation; (Right) Rotated images of digit ‘0’ and ‘1’.

To evaluate the learned representation, each training sample is augmented by 20 of its rotated version, each shifted with stride=10. We compute the cosine similarities among the m×20m\times 20italic_m × 20 augmented training inputs 𝑿rotation\bm{X}_{\text{rotation}}bold_italic_X start_POSTSUBSCRIPT rotation end_POSTSUBSCRIPT and the results are shown in Figure 4.11 (a). We compare the cosine similarities among the learned features of all the augmented versions, i.e., 𝒁¯rotation\bar{\bm{Z}}_{\text{rotation}}over¯ start_ARG bold_italic_Z end_ARG start_POSTSUBSCRIPT rotation end_POSTSUBSCRIPT and summarize the results in Figure 4.11 (b). As we see, the so-constructed rotation-invariant ReduNet is able to map the training data (as well as all its rotated versions) from the 10 different classes into 10 nearly orthogonal subspaces. That is, the learned subspaces are truly invariant to shift transformation in polar angle. Next, we randomly draw another 100100100 test samples followed by the same augmentation procedure. In Figure 4.11 (c), we visualize the MCR2 loss on the \ellroman_ℓ-th layer representation of the ReduNet on the training and test dataset. From these results, we can find that the constructed ReduNet is indeed able to maximize the MCR2 loss as well as generalize to the test data.

(a) 𝑿 rotation \bm{X}_{\text{rotation}} bold_italic_X start_POSTSUBSCRIPT rotation end_POSTSUBSCRIPT
(a) 𝑿rotation\bm{X}_{\text{rotation}}bold_italic_X start_POSTSUBSCRIPT rotation end_POSTSUBSCRIPT
(a) 𝑿 rotation \bm{X}_{\text{rotation}} bold_italic_X start_POSTSUBSCRIPT rotation end_POSTSUBSCRIPT
(b) 𝒁rotation\bm{Z}_{\text{rotation}}bold_italic_Z start_POSTSUBSCRIPT rotation end_POSTSUBSCRIPT
(a) 𝑿 rotation \bm{X}_{\text{rotation}} bold_italic_X start_POSTSUBSCRIPT rotation end_POSTSUBSCRIPT
(c) Loss
Figure 4.11: (a)(b) are heatmaps of cosine similarity among rotated training data 𝑿rotation\bm{X}_{\text{rotation}}bold_italic_X start_POSTSUBSCRIPT rotation end_POSTSUBSCRIPT and learned features 𝒁¯rotation\bar{\bm{Z}}_{\text{rotation}}over¯ start_ARG bold_italic_Z end_ARG start_POSTSUBSCRIPT rotation end_POSTSUBSCRIPT for rotation invariance. (d) visualizes the training/val MCR2 losses across layers.

\blacksquare

4.2 White-Box Transformers from Unrolled Optimization

As we have seen in the previous section, we use the problem of classification to provide a rigorous interpretation for main architectural characteristics of popular deep networks such as the ResNet and the CNN: each layer of such networks can be viewed as to imitate a gradient step which increases the rate reduction (or information gain) objective. This perspective also leads to a somewhat surprising fact: the the parameters and operators of the layers of such a deep network, the ReduNet, can be computed in a purely forward fashion.

Despite the theoretical and conceptual importance of the ReduNet, several factors limit it from being very practical. First, as we have discussed in the above, the computational cost of computing the matrix operators in each layer in a forward fashion can be very high. Second, the so-computed operators may not be so effective in optimizing the objective and it might take thousands of iterations (hence layers). As we have seen in Section 2.3.3 for LISTA, these two issues can be addressed by allowing to optimize those operators and make them learnable via back-propagation.121212Or, perhaps, by a mixture of both forward and backward optimization.

The supervised classification setting in which the ReduNet was derived is also somewhat limiting. In practice, an image might not belong to a single class as it may contain multiple objects. Hence it would be more general to assume that different regions of the image belong to different low-dimensional models (say a Gaussian or a subspace). As we will see, such a generalization would lead to a both simple and general architecture which unifies the rate reduction and the denoising operations that we have seen in the previous chapter. Moreover, the so-obtained architecture resembles the popular Transformer architecture.

4.2.1 Unrolled Optimization for Sparse Rate Reduction

We consider a general learning setup associated with real-world signals. Let 𝑿=[𝒙1,,𝒙N]D×N\bm{X}=\begin{bmatrix}\bm{x}_{1},\dots,\bm{x}_{N}\end{bmatrix}\in\mathbb{R}^{D\times N}bold_italic_X = [ start_ARG start_ROW start_CELL bold_italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , bold_italic_x start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT end_CELL end_ROW end_ARG ] ∈ blackboard_R start_POSTSUPERSCRIPT italic_D × italic_N end_POSTSUPERSCRIPT denote random variables representing our data source. In vision tasks, each 𝒙iD\bm{x}_{i}\in\mathbb{R}^{D}bold_italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT is interpreted as a token, typically corresponding to an image patch. In language tasks, each 𝒙iD\bm{x}_{i}\in\mathbb{R}^{D}bold_italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT is interpreted as an token embedding, i.e., a continuous vector representation of a discrete token such as a word or subword.131313With a slight abuse of terminology, we refer to both the discrete tokens and their associated embeddings simply as tokens throughout this chapter for convenience. The 𝒙i\bm{x}_{i}bold_italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT’s may have arbitrary correlation structures. We use 𝒁=[𝒛1,,𝒛N]d×N\bm{Z}=\begin{bmatrix}\bm{z}_{1},\dots,\bm{z}_{N}\end{bmatrix}\in\mathbb{R}^{d\times N}bold_italic_Z = [ start_ARG start_ROW start_CELL bold_italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , bold_italic_z start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT end_CELL end_ROW end_ARG ] ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_N end_POSTSUPERSCRIPT to denote the random variables that defines our representations, where 𝒛id\bm{z}_{i}\in\mathbb{R}^{d}bold_italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT is the representation of the corresponding token 𝒙iD\bm{x}_{i}\in\mathbb{R}^{D}bold_italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT.

Remark 4.3.

In transformers, each input sample is typically converted into a sequence of tokens. A token is a basic unit of information derived from the raw input: in natural language processing, tokens are typically words or subwords; in computer vision, they correspond to image patches; and in other modalities, they may represent time steps, spatial locations, or other domain-specific units. A token embedding is a continuous vector representation of a token that serves as the input to a transformer. It maps each token to a point in a high-dimensional space, enabling the model to process symbolic inputs using numerical computation. A token representation is a vector that encodes the semantic or structural information of a token, typically produced by the intermediate or final layers of a transformer. These representations are designed to capture meaningful features of the input that are useful for downstream tasks such as classification, generation, or regression. Please refer to Section 7.2 for more details about these concepts in implementations.

Objective for Learning a Structured and Compact Representation.

Following the framework of rate reduction Section 4.1, we contend that the goal of representation learning is to find a feature mapping f:𝑿D×N𝒁d×Nf\colon\bm{X}\in\mathbb{R}^{D\times N}\to\bm{Z}\in\mathbb{R}^{d\times N}italic_f : bold_italic_X ∈ blackboard_R start_POSTSUPERSCRIPT italic_D × italic_N end_POSTSUPERSCRIPT → bold_italic_Z ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_N end_POSTSUPERSCRIPT which transforms input tokens {𝒙i}i=1ND\{\bm{x}_{i}\}_{i=1}^{N}\subset\mathbb{R}^{D}{ bold_italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT ⊂ blackboard_R start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT with a potentially nonlinear and multi-modal distribution to a (piecewise) linearized and compact token representations {𝒛i}i=1Nd\{\bm{z}_{i}\}_{i=1}^{N}\subset\mathbb{R}^{d}{ bold_italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT ⊂ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT. While the joint distribution of tokens representations {𝒛i}i=1N\{\bm{z}_{i}\}_{i=1}^{N}{ bold_italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT may be sophisticated (and task-specific), we further contend that it is reasonable and practical to require that the target marginal distribution of individual token representations should be highly compressed and structured, amenable for compact coding. Particularly, we require the distribution to be a mixture of low-dimensional (say KKitalic_K) Gaussian distributions, such that the kkitalic_k-th Gaussian has mean 𝟎d\mathbf{0}\in\mathbb{R}^{d}bold_0 ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT, covariance 𝚺k𝟎d×d\bm{\Sigma}_{k}\succeq\mathbf{0}\in\mathbb{R}^{d\times d}bold_Σ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ⪰ bold_0 ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_d end_POSTSUPERSCRIPT, and support spanned by the orthonormal basis 𝑼kd×p\bm{U}_{k}\in\mathbb{R}^{d\times p}bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_p end_POSTSUPERSCRIPT. We denote 𝑼[K]={𝑼k}k=1K\bm{U}_{[K]}=\{\bm{U}_{k}\}_{k=1}^{K}bold_italic_U start_POSTSUBSCRIPT [ italic_K ] end_POSTSUBSCRIPT = { bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT to be the set of bases of all Gaussians. Hence, to maximize the information gain [MTS22] for the final token representations, we wish to maximize their rate reduction (see Section 3.4.2), i.e.,

max𝒁d×NΔRϵ(𝒁𝑼[K])Rϵ(𝒁)Rϵc(𝒁𝑼[K]).\displaystyle\mathrm{max}_{\bm{Z}\in\mathbb{R}^{d\times N}}\ \Delta R_{\epsilon}(\bm{Z}\mid\bm{U}_{[K]})\doteq R_{\epsilon}(\bm{Z})-R^{c}_{\epsilon}(\bm{Z}\mid\bm{U}_{[K]}).roman_max start_POSTSUBSCRIPT bold_italic_Z ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_N end_POSTSUPERSCRIPT end_POSTSUBSCRIPT roman_Δ italic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( bold_italic_Z ∣ bold_italic_U start_POSTSUBSCRIPT [ italic_K ] end_POSTSUBSCRIPT ) ≐ italic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( bold_italic_Z ) - italic_R start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( bold_italic_Z ∣ bold_italic_U start_POSTSUBSCRIPT [ italic_K ] end_POSTSUBSCRIPT ) . (4.2.1)

Here, the first term RϵR_{\epsilon}italic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT is an estimate of the lossy coding rate for the whole set of token representations. More specifically, if we view the token representations {𝒛i}i=1N\{\bm{z}_{i}\}_{i=1}^{N}{ bold_italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT as i.i.d. samples from a single zero-mean Gaussian, their lossy coding rate subject to a quantization precision ϵ>0\epsilon>0italic_ϵ > 0 is given as

Rϵ(𝒁)12logdet(𝑰+dNϵ2𝒁𝒁)=12logdet(𝑰+dNϵ2𝒁𝒁).R_{\epsilon}(\bm{Z})\doteq\frac{1}{2}\textrm{logdet}\left(\bm{I}+\frac{d}{N\epsilon^{2}}\bm{Z}^{\top}\bm{Z}\right)=\frac{1}{2}\textrm{logdet}\left(\bm{I}+\frac{d}{N\epsilon^{2}}\bm{Z}\bm{Z}^{\top}\right).italic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( bold_italic_Z ) ≐ divide start_ARG 1 end_ARG start_ARG 2 end_ARG logdet ( bold_italic_I + divide start_ARG italic_d end_ARG start_ARG italic_N italic_ϵ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG bold_italic_Z start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_Z ) = divide start_ARG 1 end_ARG start_ARG 2 end_ARG logdet ( bold_italic_I + divide start_ARG italic_d end_ARG start_ARG italic_N italic_ϵ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG bold_italic_Z bold_italic_Z start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) . (4.2.2)

The second term RϵcR_{\epsilon}^{c}italic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT is an estimate of the lossy coding rate under the codebook 𝑼[K]\bm{U}_{[K]}bold_italic_U start_POSTSUBSCRIPT [ italic_K ] end_POSTSUBSCRIPT, which is given as

Rϵc(𝒁𝑼[K])\displaystyle R_{\epsilon}^{c}(\bm{Z}\mid\bm{U}_{[K]})italic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT ( bold_italic_Z ∣ bold_italic_U start_POSTSUBSCRIPT [ italic_K ] end_POSTSUBSCRIPT ) k=1KRϵ(𝑼k𝒁)=12k=1Klogdet(𝑰+pNϵ2(𝑼k𝒁)(𝑼k𝒁)).\displaystyle\doteq\sum_{k=1}^{K}R_{\epsilon}(\bm{U}_{k}^{\top}\bm{Z})=\frac{1}{2}\sum_{k=1}^{K}\log\det\left(\bm{I}+\frac{p}{N\epsilon^{2}}(\bm{U}_{k}^{\top}\bm{Z})^{\top}(\bm{U}_{k}^{\top}\bm{Z})\right).≐ ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT italic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_Z ) = divide start_ARG 1 end_ARG start_ARG 2 end_ARG ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT roman_log roman_det ( bold_italic_I + divide start_ARG italic_p end_ARG start_ARG italic_N italic_ϵ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ( bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_Z ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ( bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_Z ) ) . (4.2.3)
Remark 4.4.

The expression (4.2.3) for the coding rate can be viewed as a generalization of the coding rate RϵcR_{\epsilon}^{c}italic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT used in the original rate reduction objective (3.4.13). In particular, the original objective is defined with respect to a set of known membership labels {𝚷k}\{\bm{\Pi}_{k}\}{ bold_Π start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT } specific to the particular data realization 𝑿\bm{X}bold_italic_X. In contrast, the current objective is defined with respect to subspaces 𝑼[K]\bm{U}_{[K]}bold_italic_U start_POSTSUBSCRIPT [ italic_K ] end_POSTSUBSCRIPT, which are independent of any particular realization but are assumed to support the distribution of token representations. Suppose that a token representation 𝒛i\bm{z}_{i}bold_italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT belongs to a subspace 𝑼k\bm{U}_{k}bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT and these subspaces are approximately orthogonal to each other, i.e., 𝑼k𝑼l𝟎\bm{U}_{k}^{\top}\bm{U}_{l}\approx\bm{0}bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_U start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ≈ bold_0 for all klk\neq litalic_k ≠ italic_l. Then, one can verify that the projections 𝑼k𝑼k𝒛i=𝒛i\bm{U}_{k}\bm{U}_{k}^{\top}\bm{z}_{i}=\bm{z}_{i}bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = bold_italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT and 𝑼l𝑼l𝒛i𝟎\bm{U}_{l}\bm{U}_{l}^{\top}\bm{z}_{i}\approx\bm{0}bold_italic_U start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT bold_italic_U start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ≈ bold_0 for all lkl\neq kitalic_l ≠ italic_k. These orthogonal projections effectively serve as implicit membership labels, identifying the subspace to which each token representation belongs.

Figure 4.12 : Comparison of three sets of representations via rate reduction and sparsity. Each S i S_{i} italic_S start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT represents one linear subspace, and the number of blue balls represents the difference between the coding rates Δ ​ R ϵ ​ ( 𝒁 ∣ 𝑼 [ K ] ) = R ϵ ​ ( 𝒁 ) − R ϵ c ​ ( 𝒁 ∣ 𝑼 [ K ] ) \Delta R_{\epsilon}(\bm{Z}\mid\bm{U}_{[K]})=R_{\epsilon}(\bm{Z})-R^{c}_{\epsilon}(\bm{Z}\mid\bm{U}_{[K]}) roman_Δ italic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( bold_italic_Z ∣ bold_italic_U start_POSTSUBSCRIPT [ italic_K ] end_POSTSUBSCRIPT ) = italic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( bold_italic_Z ) - italic_R start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( bold_italic_Z ∣ bold_italic_U start_POSTSUBSCRIPT [ italic_K ] end_POSTSUBSCRIPT ) .
Figure 4.12: Comparison of three sets of representations via rate reduction and sparsity. Each SiS_{i}italic_S start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT represents one linear subspace, and the number of blue balls represents the difference between the coding rates ΔRϵ(𝒁𝑼[K])=Rϵ(𝒁)Rϵc(𝒁𝑼[K])\Delta R_{\epsilon}(\bm{Z}\mid\bm{U}_{[K]})=R_{\epsilon}(\bm{Z})-R^{c}_{\epsilon}(\bm{Z}\mid\bm{U}_{[K]})roman_Δ italic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( bold_italic_Z ∣ bold_italic_U start_POSTSUBSCRIPT [ italic_K ] end_POSTSUBSCRIPT ) = italic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( bold_italic_Z ) - italic_R start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( bold_italic_Z ∣ bold_italic_U start_POSTSUBSCRIPT [ italic_K ] end_POSTSUBSCRIPT ).

Sparse Rate Reduction.

Note that the rate reduction objective (4.2.1) is invariant to arbitrary joint rotations of the representations and subspaces. In particular, optimizing the rate reduction objective may not naturally lead to axis-aligned (i.e., sparse) representations. For instance, consider the three sets of learned representations in Figure 4.12. The coding rate reduction increases from (a) to (b), but because it is invariant under rotations, remains the same from (b) to (c). Therefore, we would like to transform the representations (and their supporting subspaces) so that the representations 𝒁\bm{Z}bold_italic_Z eventually become sparse141414Concretely, having few nonzero entries. with respect to the standard coordinates of the resulting representation space as in Figure 4.12(c). Therefore, to ensure the final representations are amenable to more compact coding, we would like to transform the representations (and their supporting subspaces) so that they become sparse with respect to the standard coordinates of the resulting representation space.151515That is, having the fewest nonzero entries. Computationally, we may combine the above two goals into a unified objective for optimization:

maxf[ΔRϵ(𝒁𝑼[K])λ𝒁0]s.t.𝒁=f(𝑿),\max_{f\in\mathcal{F}}\ [\Delta R_{\epsilon}(\bm{Z}\mid\bm{U}_{[K]})-\lambda\|\bm{Z}\|_{0}]\qquad\text{s.t.}\ \bm{Z}=f(\bm{X}),roman_max start_POSTSUBSCRIPT italic_f ∈ caligraphic_F end_POSTSUBSCRIPT [ roman_Δ italic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( bold_italic_Z ∣ bold_italic_U start_POSTSUBSCRIPT [ italic_K ] end_POSTSUBSCRIPT ) - italic_λ ∥ bold_italic_Z ∥ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ] s.t. bold_italic_Z = italic_f ( bold_italic_X ) , (4.2.4)

where \mathcal{F}caligraphic_F denotes a general function class and the 0\ell_{0}roman_ℓ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT norm 𝒁0\|\bm{Z}\|_{0}∥ bold_italic_Z ∥ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT promotes the sparsity of the final token representations 𝒁=f(𝑿)\bm{Z}=f(\bm{X})bold_italic_Z = italic_f ( bold_italic_X ).

In practice, the 0\ell_{0}roman_ℓ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT norm is often relaxed to the 1\ell_{1}roman_ℓ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT norm to improve computational traceability and enable convex optimization techniques [WM22]. Motivated by this, we relax Problem (4.2.4) accordingly, leading to a formulation that remains faithful to the original sparsity objective while being more amenable to efficient algorithms as follow:

maxf[ΔRϵ(𝒁𝑼[K])λ𝒁1]s.t.𝒁=f(𝑿),\displaystyle\max_{f\in\mathcal{F}}\ [\Delta R_{\epsilon}(\bm{Z}\mid\bm{U}_{[K]})-\lambda\|\bm{Z}\|_{1}]\qquad\text{s.t.}\ \bm{Z}=f(\bm{X}),roman_max start_POSTSUBSCRIPT italic_f ∈ caligraphic_F end_POSTSUBSCRIPT [ roman_Δ italic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( bold_italic_Z ∣ bold_italic_U start_POSTSUBSCRIPT [ italic_K ] end_POSTSUBSCRIPT ) - italic_λ ∥ bold_italic_Z ∥ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ] s.t. bold_italic_Z = italic_f ( bold_italic_X ) , (4.2.5)

With a slight abuse of terminology, we often refer to this objective function also as the sparse rate reduction.

White-Box Network Architecture via Unrolled Optimization.

Although easy to state, each term in the above objective is computationally challenging to optimize [WM22]. Hence it is natural to adopt an approximation approach that realizes the global transformation ffitalic_f to optimize (4.2.4) through a concatenation of multiple, say LLitalic_L, simple incremental and local operations ff^{\ell}italic_f start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT that push the representation distribution towards the desired parsimonious model distribution:

f:𝑿=𝒁0f0𝒁1𝒁f𝒁+1fL1𝒁L=𝒁,f\colon\bm{X}=\bm{Z}^{0}\xrightarrow{\hskip 2.84526ptf^{0}\hskip 2.84526pt}\bm{Z}^{1}\rightarrow\cdots\rightarrow\bm{Z}^{\ell}\xrightarrow{\hskip 2.84526ptf^{\ell}\hskip 2.84526pt}\bm{Z}^{\ell+1}\rightarrow\cdots\xrightarrow{\hskip 2.84526ptf^{L-1}}\bm{Z}^{L}=\bm{Z},italic_f : bold_italic_X = bold_italic_Z start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT start_ARROW start_OVERACCENT italic_f start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT end_OVERACCENT → end_ARROW bold_italic_Z start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT → ⋯ → bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT start_ARROW start_OVERACCENT italic_f start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT end_OVERACCENT → end_ARROW bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ + 1 end_POSTSUPERSCRIPT → ⋯ start_ARROW start_OVERACCENT italic_f start_POSTSUPERSCRIPT italic_L - 1 end_POSTSUPERSCRIPT end_OVERACCENT → end_ARROW bold_italic_Z start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT = bold_italic_Z , (4.2.6)

where f0:Ddf^{0}:\mathbb{R}^{D}\rightarrow\mathbb{R}^{d}italic_f start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT : blackboard_R start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT → blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT is the pre-processing mapping that transforms each input token 𝒙iD\bm{x}_{i}\in\mathbb{R}^{D}bold_italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT to the initial token representations 𝒛i1d\bm{z}_{i}^{1}\in\mathbb{R}^{d}bold_italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT. Each incremental forward mapping 𝒁+1=f(𝒁)\bm{Z}^{\ell+1}=f^{\ell}(\bm{Z}^{\ell})bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ + 1 end_POSTSUPERSCRIPT = italic_f start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ( bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ), or a “layer”, transforms the token distribution to optimize the above sparse rate reduction objective (4.2.4), conditioned on the distribution of its input 𝒁\bm{Z}^{\ell}bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT.

Remark 4.5.

In contrast to other unrolled optimization approaches such as the ReduNet (see Section 4.1), we explicitly model the distribution of 𝒁\bm{Z}^{\ell}bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT at each layer, say as a mixture of linear subspaces or sparsely generated from a dictionary. The model parameters are learned from data (say via backward propagation with end-to-end training). This separation between forward “optimization” and backward “learning” clarifies the mathematical role of each layer as an operator that transforms the distribution of its input, whereas the input distribution is in turn modeled (and subsequently learned) by the parameters of the layer.

Now, we show how to derive these incremental and local operations through an unrolled optimization perspective to solve Problem (4.2.5). Once we decide on using an incremental approach to optimizing Problem (4.2.5), there are a variety of possible choices to achieve the optimization. Given a model for 𝒁\bm{Z}^{\ell}bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT, say a mixture of subspaces 𝑼[K]\bm{U}_{[K]}bold_italic_U start_POSTSUBSCRIPT [ italic_K ] end_POSTSUBSCRIPT, we opt for a two-step alternating minimization method with a strong conceptual basis. First, we compress the tokens 𝒁\bm{Z}^{\ell}bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT via a gradient descent to minimize the coding rate term Rϵc(𝒁𝑼[K])R^{c}_{\epsilon}(\bm{Z}\mid\bm{U}_{[K]}^{\ell})italic_R start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( bold_italic_Z ∣ bold_italic_U start_POSTSUBSCRIPT [ italic_K ] end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ). Specifically, we take a gradient step on RϵcR^{c}_{\epsilon}italic_R start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT with a learning rate κ\kappaitalic_κ as follows:

𝒁+1/2=𝒁κ𝒁Rϵc(𝒁𝑼[K]).\displaystyle\bm{Z}^{\ell+1/2}=\bm{Z}^{\ell}-\kappa\nabla_{\bm{Z}}R^{c}_{\epsilon}(\bm{Z}\mid\bm{U}_{[K]}^{\ell}).bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ + 1 / 2 end_POSTSUPERSCRIPT = bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT - italic_κ ∇ start_POSTSUBSCRIPT bold_italic_Z end_POSTSUBSCRIPT italic_R start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( bold_italic_Z ∣ bold_italic_U start_POSTSUBSCRIPT [ italic_K ] end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ) . (4.2.7)

Next, we sparsify the compressed tokens, generating 𝒁+1\bm{Z}^{\ell+1}bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ + 1 end_POSTSUPERSCRIPT via a suitably-relaxed proximal gradient step to minimize the remaining term λ𝒁1Rϵ(𝒁)\lambda\|\bm{Z}\|_{1}-R_{\epsilon}(\bm{Z})italic_λ ∥ bold_italic_Z ∥ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT - italic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( bold_italic_Z ). As we will argue in detail later, we can find such a 𝒁+1\bm{Z}^{\ell+1}bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ + 1 end_POSTSUPERSCRIPT by solving a sparse presentation problem with respect to a dictionary 𝑫\bm{D}^{\ell}bold_italic_D start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT:

𝒁+1=argmin𝒁{λ𝒁1+12𝒁+1/2𝑫𝒁F2}.\bm{Z}^{\ell+1}=\operatorname*{arg\ min}_{{\bm{Z}}}\bigg{\{}\lambda\|\bm{Z}\|_{1}+\frac{1}{2}\|\bm{Z}^{\ell+1/2}-\bm{D}^{\ell}{\bm{Z}}\|_{F}^{2}\bigg{\}}.bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ + 1 end_POSTSUPERSCRIPT = start_OPERATOR roman_arg roman_min end_OPERATOR start_POSTSUBSCRIPT bold_italic_Z end_POSTSUBSCRIPT { italic_λ ∥ bold_italic_Z ∥ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + divide start_ARG 1 end_ARG start_ARG 2 end_ARG ∥ bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ + 1 / 2 end_POSTSUPERSCRIPT - bold_italic_D start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT bold_italic_Z ∥ start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT } . (4.2.8)

In the following, we provide technical details for each of the two steps above and derive efficient updates for their implementation.

Self-Attention as Gradient Descent on Coding Rate of Token Representations.

For the first step (4.2.7), the gradient of the coding rate 𝒁Rϵc\nabla_{\bm{Z}}R^{c}_{\epsilon}∇ start_POSTSUBSCRIPT bold_italic_Z end_POSTSUBSCRIPT italic_R start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT is costly to compute, as it involves KKitalic_K separate matrix inverses, one for each of the KKitalic_K subspaces with basis 𝑼k\bm{U}_{k}^{\ell}bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT:

𝒁Rϵc(𝒁𝑼[K])=pNϵ2k=1K𝑼k𝑼k𝒁(𝑰+pNϵ2(𝑼k𝒁)(𝑼k𝒁))1.\nabla_{\bm{Z}}R_{\epsilon}^{c}(\bm{Z}\mid\bm{U}_{[K]})=\frac{p}{N\epsilon^{2}}\sum_{k=1}^{K}\bm{U}_{k}\bm{U}_{k}^{\top}\bm{Z}\Big{(}\bm{I}+\frac{p}{N\epsilon^{2}}(\bm{U}_{k}^{\top}\bm{Z})^{\top}(\bm{U}_{k}^{\top}\bm{Z})\Big{)}^{-1}.∇ start_POSTSUBSCRIPT bold_italic_Z end_POSTSUBSCRIPT italic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT ( bold_italic_Z ∣ bold_italic_U start_POSTSUBSCRIPT [ italic_K ] end_POSTSUBSCRIPT ) = divide start_ARG italic_p end_ARG start_ARG italic_N italic_ϵ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_Z ( bold_italic_I + divide start_ARG italic_p end_ARG start_ARG italic_N italic_ϵ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ( bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_Z ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ( bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_Z ) ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT . (4.2.9)

Now, we demonstrate that this gradient can be naturally approximated using a so-called multi-head subspace self-attention (MSSA) operator, which has a similar functional form to the multi-head self-attention operator [VSP+17] with KKitalic_K heads (i.e., one for each subspace, coming from each matrix inverse). Here, we approximate the gradient (4.2.9) using the first-order Neumann series (see Exercise 4.2):

𝒁Rϵc(𝒁𝑼[K])\displaystyle\nabla_{\bm{Z}}R_{\epsilon}^{c}(\bm{Z}\mid\bm{U}_{[K]})∇ start_POSTSUBSCRIPT bold_italic_Z end_POSTSUBSCRIPT italic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT ( bold_italic_Z ∣ bold_italic_U start_POSTSUBSCRIPT [ italic_K ] end_POSTSUBSCRIPT ) pNϵ2k=1K𝑼k𝑼k𝒁(𝑰pNϵ2(𝑼k𝒁)(𝑼k𝒁))\displaystyle\approx\frac{p}{N\epsilon^{2}}\sum_{k=1}^{K}\bm{U}_{k}\bm{U}_{k}^{\top}\bm{Z}\left(\bm{I}-\frac{p}{N\epsilon^{2}}(\bm{U}_{k}^{\top}\bm{Z})^{\top}(\bm{U}_{k}^{\top}\bm{Z})\right)≈ divide start_ARG italic_p end_ARG start_ARG italic_N italic_ϵ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_Z ( bold_italic_I - divide start_ARG italic_p end_ARG start_ARG italic_N italic_ϵ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ( bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_Z ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ( bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_Z ) )
=pNϵ2(k=1K𝑼k𝑼k)𝒁(pNϵ2)2k=1K𝑼k(𝑼k𝒁)(𝑼k𝒁)(𝑼k𝒁).\displaystyle{=\frac{p}{N\epsilon^{2}}\left(\sum_{k=1}^{K}\bm{U}_{k}\bm{U}_{k}^{\top}\right)\bm{Z}-\left(\frac{p}{N\epsilon^{2}}\right)^{2}\sum_{k=1}^{K}\bm{U}_{k}(\bm{U}_{k}^{\top}\bm{Z})(\bm{U}_{k}^{\top}\bm{Z})^{\top}(\bm{U}_{k}^{\top}\bm{Z})}.= divide start_ARG italic_p end_ARG start_ARG italic_N italic_ϵ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ( ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) bold_italic_Z - ( divide start_ARG italic_p end_ARG start_ARG italic_N italic_ϵ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_Z ) ( bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_Z ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ( bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_Z ) . (4.2.10)

In this approximation, we compute the similarity between projected token representations {𝑼k𝒛i}\{\bm{U}_{k}^{\top}\bm{z}_{i}\}{ bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } through an auto-correlation among the projected features as (𝑼k𝒁)(𝑼k𝒁)(\bm{U}_{k}^{\top}\bm{Z})^{\top}(\bm{U}_{k}^{\top}\bm{Z})( bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_Z ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ( bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_Z ) and convert it to a distribution of membership with a softmax, namely softmax(𝑼k𝒁)(𝑼k𝒁)\operatorname{\mathrm{softmax}}{(\bm{U}_{k}^{\top}\bm{Z})^{\top}(\bm{U}_{k}^{\top}\bm{Z})}roman_softmax ( bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_Z ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ( bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_Z ). Suppose that a union of subspaces 𝑼[K]\bm{U}_{[K]}bold_italic_U start_POSTSUBSCRIPT [ italic_K ] end_POSTSUBSCRIPT spans the whole space. Then, we have k=1K𝑼k𝑼k=𝑰\sum_{k=1}^{K}\bm{U}_{k}\bm{U}_{k}^{\top}=\bm{I}∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT = bold_italic_I. Hence, (4.2.1) becomes

𝒁Rϵc(𝒁𝑼[K])pNϵ2𝒁(pNϵ2)2MSSA(𝒁𝑼[K]),\displaystyle\nabla_{\bm{Z}}R_{\epsilon}^{c}(\bm{Z}\mid\bm{U}_{[K]})\approx\frac{p}{N\epsilon^{2}}\bm{Z}-\left(\frac{p}{N\epsilon^{2}}\right)^{2}\operatorname{MSSA}\left(\bm{Z}^{\ell}\mid\bm{U}_{[K]}^{\ell}\right),∇ start_POSTSUBSCRIPT bold_italic_Z end_POSTSUBSCRIPT italic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT ( bold_italic_Z ∣ bold_italic_U start_POSTSUBSCRIPT [ italic_K ] end_POSTSUBSCRIPT ) ≈ divide start_ARG italic_p end_ARG start_ARG italic_N italic_ϵ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG bold_italic_Z - ( divide start_ARG italic_p end_ARG start_ARG italic_N italic_ϵ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT roman_MSSA ( bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ∣ bold_italic_U start_POSTSUBSCRIPT [ italic_K ] end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ) , (4.2.11)

where MSSA is defined through an SSA operator as follows:

SSA(𝒁𝑼k)(𝑼k𝒁)softmax((𝑼k𝒁)(𝑼k𝒁)),k[K],\displaystyle\mathrm{SSA}\left(\bm{Z}\mid\bm{U}_{k}\right)\doteq(\bm{U}_{k}^{\top}\bm{Z})\mathrm{softmax}\left((\bm{U}_{k}^{\top}\bm{Z})^{\top}(\bm{U}_{k}^{\top}\bm{Z})\right),\ \forall k\in[K],roman_SSA ( bold_italic_Z ∣ bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) ≐ ( bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_Z ) roman_softmax ( ( bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_Z ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ( bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_Z ) ) , ∀ italic_k ∈ [ italic_K ] , (4.2.12)
MSSA(𝒁𝑼[K])pNϵ2[𝑼1,,𝑼K][SSA(𝒁𝑼1)SSA(𝒁𝑼K)].\displaystyle\mathrm{MSSA}\left(\bm{Z}\mid\bm{U}_{[K]}\right)\doteq\frac{p}{N\epsilon^{2}}\begin{bmatrix}\bm{U}_{1},\dots,\bm{U}_{K}\end{bmatrix}\begin{bmatrix}\mathrm{SSA}({\bm{Z}\mid\bm{U}_{1}})\\ \vdots\\ \mathrm{SSA}({\bm{Z}\mid\bm{U}_{K}})\end{bmatrix}.roman_MSSA ( bold_italic_Z ∣ bold_italic_U start_POSTSUBSCRIPT [ italic_K ] end_POSTSUBSCRIPT ) ≐ divide start_ARG italic_p end_ARG start_ARG italic_N italic_ϵ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG [ start_ARG start_ROW start_CELL bold_italic_U start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , bold_italic_U start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT end_CELL end_ROW end_ARG ] [ start_ARG start_ROW start_CELL roman_SSA ( bold_italic_Z ∣ bold_italic_U start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) end_CELL end_ROW start_ROW start_CELL ⋮ end_CELL end_ROW start_ROW start_CELL roman_SSA ( bold_italic_Z ∣ bold_italic_U start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ) end_CELL end_ROW end_ARG ] . (4.2.13)

Substituting (4.2.11) into (4.2.7) yields that it can naturally approximated by

𝒁+1/2=(1κpNϵ2)𝒁+κpNϵ2MSSA(𝒁|𝑼[K]).\bm{Z}^{\ell+1/2}=\left(1-\frac{\kappa p}{N\epsilon^{2}}\right)\bm{Z}^{\ell}+\frac{\kappa p}{N\epsilon^{2}}\mathrm{MSSA}\left(\bm{Z}^{\ell}\ \middle|\ \bm{U}_{[K]}^{\ell}\right).bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ + 1 / 2 end_POSTSUPERSCRIPT = ( 1 - divide start_ARG italic_κ italic_p end_ARG start_ARG italic_N italic_ϵ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ) bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT + divide start_ARG italic_κ italic_p end_ARG start_ARG italic_N italic_ϵ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG roman_MSSA ( bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT | bold_italic_U start_POSTSUBSCRIPT [ italic_K ] end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ) . (4.2.14)
Remark 4.6.

The SSA operator in (4.2.12) resembles the attention operator in a typical transformer [VSP+17], except that here the linear operators of value, key, and query are all set to be the same as the subspace basis, i.e., 𝑽k=𝑲k=𝑸k=𝑼k\bm{V}_{k}=\bm{K}_{k}=\bm{Q}_{k}=\bm{U}_{k}^{*}bold_italic_V start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT = bold_italic_K start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT = bold_italic_Q start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT = bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT. Hence, we name SSA(𝑼k):d×np×n\mathrm{SSA}({\,\cdot\,\mid\bm{U}_{k}}):\mathbb{R}^{d\times n}\rightarrow\mathbb{R}^{p\times n}roman_SSA ( ⋅ ∣ bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) : blackboard_R start_POSTSUPERSCRIPT italic_d × italic_n end_POSTSUPERSCRIPT → blackboard_R start_POSTSUPERSCRIPT italic_p × italic_n end_POSTSUPERSCRIPT the Subspace Self-Attention (SSA) operator. Then, the whole MSSA operator in (4.2.13), formally defined as MSSA(𝑼[K]):d×nd×n\mathrm{MSSA}({\,\cdot\,\mid\bm{U}_{[K]}})\colon\mathbb{R}^{d\times n}\to\mathbb{R}^{d\times n}roman_MSSA ( ⋅ ∣ bold_italic_U start_POSTSUBSCRIPT [ italic_K ] end_POSTSUBSCRIPT ) : blackboard_R start_POSTSUPERSCRIPT italic_d × italic_n end_POSTSUPERSCRIPT → blackboard_R start_POSTSUPERSCRIPT italic_d × italic_n end_POSTSUPERSCRIPT and called the Multi-Head Subspace Self-Attention (MSSA) operator, aggregates the attention head outputs by averaging using model-dependent weights, similar in concept to the popular multi-head self-attention operator in existing transformer networks. The overall gradient step (4.2.14) resembles the multi-head self-attention implemented with a skip connection in transformers.

MLP as Proximal Gradient Descent for Sparse Coding of Token Representations.

For the second step of alternating minimization, we need to minimize λ𝒁1Rϵ(𝒁)\lambda\|\bm{Z}\|_{1}-R_{\epsilon}(\bm{Z})italic_λ ∥ bold_italic_Z ∥ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT - italic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( bold_italic_Z ). Note that the gradient Rϵ(𝒁)\nabla R_{\epsilon}(\bm{Z})∇ italic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( bold_italic_Z ) involves a matrix inverse, and thus naive proximal gradient (see Section A.1.3) to optimize this problem becomes intractable on large-scale problems. We therefore take a different, simplifying approach to trading off between representational diversity and sparsification: we posit a (complete) incoherent or orthogonal dictionary 𝑫d×d\bm{D}^{\ell}\in\mathbb{R}^{d\times d}bold_italic_D start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_d end_POSTSUPERSCRIPT, and ask to sparsify the intermediate iterates 𝒁+1/2\bm{Z}^{\ell+1/2}bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ + 1 / 2 end_POSTSUPERSCRIPT with respect to 𝑫\bm{D}^{\ell}bold_italic_D start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT. That is, 𝒁+1/2𝑫𝒁+1\bm{Z}^{\ell+1/2}\approx\bm{D}^{\ell}\bm{Z}^{\ell+1}bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ + 1 / 2 end_POSTSUPERSCRIPT ≈ bold_italic_D start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ + 1 end_POSTSUPERSCRIPT where 𝒁+1\bm{Z}^{\ell+1}bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ + 1 end_POSTSUPERSCRIPT is more sparse; that is, it is a sparse encoding of 𝒁+1/2\bm{Z}^{\ell+1/2}bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ + 1 / 2 end_POSTSUPERSCRIPT. The dictionary 𝑫\bm{D}^{\ell}bold_italic_D start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT is used to sparsify all tokens simultaneously. By the incoherence assumption, we have (𝑫)(𝑫)𝑰(\bm{D}^{\ell})^{\top}(\bm{D}^{\ell})\approx\bm{I}( bold_italic_D start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ( bold_italic_D start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ) ≈ bold_italic_I. Thus from (4.2.2) we have

Rϵ(𝒁+1/2)Rϵ(𝑫𝒁+1)Rϵ(𝒁+1).R_{\epsilon}(\bm{Z}^{\ell+1/2})\approx R_{\epsilon}(\bm{D}^{\ell}\bm{Z}^{\ell+1})\approx R_{\epsilon}(\bm{Z}^{\ell+1}).italic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ + 1 / 2 end_POSTSUPERSCRIPT ) ≈ italic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( bold_italic_D start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ + 1 end_POSTSUPERSCRIPT ) ≈ italic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ + 1 end_POSTSUPERSCRIPT ) . (4.2.15)

To solve λ𝒁1Rϵ(𝒁)\lambda\|\bm{Z}\|_{1}-R_{\epsilon}(\bm{Z})italic_λ ∥ bold_italic_Z ∥ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT - italic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( bold_italic_Z ), we optimize the following problem

𝒁+1argmin𝒁𝒁1subject to𝒁+1/2=𝑫𝒁.\displaystyle\bm{Z}^{\ell+1}\approx\operatorname*{arg\ min}_{\bm{Z}}\|\bm{Z}\|_{1}\quad\mbox{subject to}\quad\bm{Z}^{\ell+1/2}=\bm{D}^{\ell}\bm{Z}.bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ + 1 end_POSTSUPERSCRIPT ≈ start_OPERATOR roman_arg roman_min end_OPERATOR start_POSTSUBSCRIPT bold_italic_Z end_POSTSUBSCRIPT ∥ bold_italic_Z ∥ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT subject to bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ + 1 / 2 end_POSTSUPERSCRIPT = bold_italic_D start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT bold_italic_Z .

The above sparse representation program is usually solved by relaxing it to an unconstrained convex program, known as LASSO [WM22]:

𝒁+1argmin𝒁[λ𝒁1+12𝒁+1/2𝑫𝒁F2].\bm{Z}^{\ell+1}\approx\operatorname*{arg\ min}_{\bm{Z}}\left[\lambda\|\bm{Z}\|_{1}+\frac{1}{2}\|\bm{Z}^{\ell+1/2}-\bm{D}^{\ell}\bm{Z}\|_{F}^{2}\right].bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ + 1 end_POSTSUPERSCRIPT ≈ start_OPERATOR roman_arg roman_min end_OPERATOR start_POSTSUBSCRIPT bold_italic_Z end_POSTSUBSCRIPT [ italic_λ ∥ bold_italic_Z ∥ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + divide start_ARG 1 end_ARG start_ARG 2 end_ARG ∥ bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ + 1 / 2 end_POSTSUPERSCRIPT - bold_italic_D start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT bold_italic_Z ∥ start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ] . (4.2.16)

In our implementation, we also add a non-negative constraint to 𝒁+1\bm{Z}^{\ell+1}bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ + 1 end_POSTSUPERSCRIPT, and solve the corresponding non-negative LASSO:

𝒁+1argmin𝒁𝟎[λ𝒁1+12𝒁+1/2𝑫𝒁F2].\bm{Z}^{\ell+1}\approx\operatorname*{arg\ min}_{\bm{Z}\geq\bm{0}}\left[\lambda\|\bm{Z}\|_{1}+\frac{1}{2}\|\bm{Z}^{\ell+1/2}-\bm{D}^{\ell}\bm{Z}\|_{F}^{2}\right].bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ + 1 end_POSTSUPERSCRIPT ≈ start_OPERATOR roman_arg roman_min end_OPERATOR start_POSTSUBSCRIPT bold_italic_Z ≥ bold_0 end_POSTSUBSCRIPT [ italic_λ ∥ bold_italic_Z ∥ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + divide start_ARG 1 end_ARG start_ARG 2 end_ARG ∥ bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ + 1 / 2 end_POSTSUPERSCRIPT - bold_italic_D start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT bold_italic_Z ∥ start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ] . (4.2.17)

Then, we incrementally optimize Equation 4.2.17 by performing an unrolled proximal gradient descent step, known as an ISTA step [BT09], to give the update:

𝒁+1\displaystyle\bm{Z}^{\ell+1}bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ + 1 end_POSTSUPERSCRIPT =ISTA(𝒁+1/2𝑫),\displaystyle=\mathrm{ISTA}({\bm{Z}^{\ell+1/2}\mid\bm{D}^{\ell}}),= roman_ISTA ( bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ + 1 / 2 end_POSTSUPERSCRIPT ∣ bold_italic_D start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ) , (4.2.18)
whereISTA(𝒁𝑫)\displaystyle\text{where}\quad\mathrm{ISTA}({\bm{Z}\mid\bm{D}})where roman_ISTA ( bold_italic_Z ∣ bold_italic_D ) ReLU(𝒁η𝑫(𝑫𝒁𝒁)ηλ𝟏).\displaystyle\doteq\operatorname{ReLU}(\bm{Z}-\eta\bm{D}^{\top}(\bm{D}\bm{Z}-\bm{Z})-\eta\lambda\bm{1}).≐ roman_ReLU ( bold_italic_Z - italic_η bold_italic_D start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ( bold_italic_D bold_italic_Z - bold_italic_Z ) - italic_η italic_λ bold_1 ) . (4.2.19)
Figure 4.13 : One layer of the CRATE encoder architecture. The full architecture is simply a concatenation of such layers, with some initial tokenizer, pre-processing head, and final task-specific head (i.e., a classification head).
Figure 4.13: One layer of the CRATE encoder architecture. The full architecture is simply a concatenation of such layers, with some initial tokenizer, pre-processing head, and final task-specific head (i.e., a classification head).

4.2.2 Overall White-Box Transformer Architecture: CRATE

We now design a white-box transformer architecture, named the Coding RATE Transformer (crate), by unrolling the above updates. By combining the above two steps (4.2.14) and (4.2.18):

  1. 1.

    Local compression of tokens within a sample towards a mixture-of-subspace structure, leading to the multi-head subspace self-attention block – MSSA;

  2. 2.

    Global sparsification of token sets across all samples through sparse coding, leading to the sparsification block – ISTA;

we can get the following rate-reduction-based transformer layer, illustrated in Figure 4.13,

𝒁+1/2𝒁+MSSA(𝒁𝑼[K]),𝒁+1ISTA(𝒁+1/2𝑫).\bm{Z}^{\ell+1/2}\doteq\bm{Z}^{\ell}+\texttt{MSSA}(\bm{Z}^{\ell}\mid\bm{U}_{[K]}^{\ell}),\qquad\bm{Z}^{\ell+1}\doteq\texttt{ISTA}(\bm{Z}^{\ell+1/2}\mid\bm{D}^{\ell}).bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ + 1 / 2 end_POSTSUPERSCRIPT ≐ bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT + MSSA ( bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ∣ bold_italic_U start_POSTSUBSCRIPT [ italic_K ] end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ) , bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ + 1 end_POSTSUPERSCRIPT ≐ ISTA ( bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ + 1 / 2 end_POSTSUPERSCRIPT ∣ bold_italic_D start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ) . (4.2.20)

Composing multiple such layers following the incremental construction of our representation in (4.2.6), we obtain a white-box transformer architecture that transforms the data tokens towards a compact and sparse union of incoherent subspaces, where fpre:D×Nd×Nf^{\mathrm{pre}}:\mathbb{R}^{D\times N}\rightarrow\mathbb{R}^{d\times N}italic_f start_POSTSUPERSCRIPT roman_pre end_POSTSUPERSCRIPT : blackboard_R start_POSTSUPERSCRIPT italic_D × italic_N end_POSTSUPERSCRIPT → blackboard_R start_POSTSUPERSCRIPT italic_d × italic_N end_POSTSUPERSCRIPT is the pre-processing mapping that transforms the input tokens 𝑿D×N\bm{X}\in\mathbb{R}^{D\times N}bold_italic_X ∈ blackboard_R start_POSTSUPERSCRIPT italic_D × italic_N end_POSTSUPERSCRIPT to first-layer representations 𝒁1d×N\bm{Z}^{1}\in\mathbb{R}^{d\times N}bold_italic_Z start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_N end_POSTSUPERSCRIPT. An overall flow of this architecture was shown in Figure 4.14.

Figure 4.14 : The ‘main loop’ of the crate white-box deep network design. After encoding input data as a sequence of tokens 𝒁 0 \bm{Z}^{0} bold_italic_Z start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT , crate constructs a deep network that transforms the data to a canonical configuration of low-dimensional subspaces by successive compression against a local model for the distribution, generating 𝒁 ℓ + 1 / 2 \bm{Z}^{\ell+1/2} bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ + 1 / 2 end_POSTSUPERSCRIPT , and sparsification against a global dictionary, generating 𝒁 ℓ + 1 \bm{Z}^{\ell+1} bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ + 1 end_POSTSUPERSCRIPT . Repeatedly stacking these blocks and training the model parameters via backpropagation yields a powerful and interpretable representation of the data.
Figure 4.14: The ‘main loop’ of the crate white-box deep network design. After encoding input data as a sequence of tokens 𝒁0\bm{Z}^{0}bold_italic_Z start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT, crate constructs a deep network that transforms the data to a canonical configuration of low-dimensional subspaces by successive compression against a local model for the distribution, generating 𝒁+1/2\bm{Z}^{\ell+1/2}bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ + 1 / 2 end_POSTSUPERSCRIPT, and sparsification against a global dictionary, generating 𝒁+1\bm{Z}^{\ell+1}bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ + 1 end_POSTSUPERSCRIPT. Repeatedly stacking these blocks and training the model parameters via backpropagation yields a powerful and interpretable representation of the data.
Remark 4.7 (The roles of the forward pass and backward propagation).

In contrast to other unrolled optimization approaches such as the ReduNet [CYY+22], we explicitly model the distribution of each 𝒁\bm{Z}^{\ell}bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT and 𝒁+1/2\bm{Z}^{\ell+1/2}bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ + 1 / 2 end_POSTSUPERSCRIPT at each layer, either by a mixture of linear subspaces or sparsely generated from a dictionary. We introduced the interpretation that at each layer \ellroman_ℓ, the learned bases for the subspaces 𝑼[K]\bm{U}_{[K]}^{\ell}bold_italic_U start_POSTSUBSCRIPT [ italic_K ] end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT and the learned dictionaries 𝑫\bm{D}^{\ell}bold_italic_D start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT together serve as a codebook or analysis filter that encodes and transforms the intermediate representations at each layer \ellroman_ℓ. Since the input distribution to layer \ellroman_ℓ is first modeled by 𝑼[K]\bm{U}_{[K]}^{\ell}bold_italic_U start_POSTSUBSCRIPT [ italic_K ] end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT then transformed by 𝑫\bm{D}^{\ell}bold_italic_D start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT, the input distribution to each layer is different, and so we require a separate codebook at each layer to obtain the most parsimonious encoding. Parameters of these codebooks (i.e., the subspace bases and the dictionaries), heretofore assumed as being perfectly known, are actually learned from data (say via backward propagation within end-to-end training).

Hence, our methodology features a clear conceptual separation between forward “optimization” and backward “learning” for the so-derived white-box deep neural network. Namely, in its forward pass, we interpret each layer as an operator which, conditioned on a learned model (i.e., a codebook) for the distribution of its input, transforms this distribution towards a more parsimonious representation. In its backward propagation, the codebook of this model, for the distribution of the input to each layer, is updated to better fit a certain (supervised) input-output relationship, as illustrated in Figure 4.15. This conceptual interpretation implies a certain agnosticism of the model representations towards the particular task and loss; in particular, many types of tasks and losses will ensure that the models at each layer are fit, which ensures that the model produces parsimonious representations.

We now present the empirical performance of the proposed networks crate by measuring their top-1 classification accuracy on ImageNet-1K as well as transfer learning performance on several widely used downstream datasets. We summarize the results in Table 4.1. The transfer learning methodology is to fine-tune using cross-entropy loss initializing from the pre-trained networks. As the designed white-box transformer architecture leverages parameter sharing in both the attention block (MSSA) and the nonlinearity block (ISTA), the crate-Base model (22.80 million) has a similar number of parameters to the ViT-Small (22.05 million) [DBK+21], and less than 30% of the parameters of an identically configured ViT-Base (86.54 million). From Table 4.1, we find that with a similar number of model parameters, our proposed network achieves similar ImageNet-1K and transfer learning performance as ViT, while having a simple and principled design. Moreover, with the same set of training hyperparameters, we observe promising scaling behavior in crate—we consistently improve the performance by scaling up the model size. To summarize, crate achieves promising performance on real-world large-scale datasets by directly implementing our principled architecture. We will provide more details of the implementation and analysis of the experimental results on image classification in the final application Chapter 7.

(a) Forward pass
(a) Forward pass
(a) Forward pass
(b) Backward propagation
Figure 4.15: The roles of forward pass and backward propagation in deep networks. (a) Given fixed subspaces and dictionaries {(𝑼[K],𝑫)}=1L\{(\bm{U}_{[K]}^{\ell},\bm{D}^{\ell})\}_{\ell=1}^{L}{ ( bold_italic_U start_POSTSUBSCRIPT [ italic_K ] end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT , bold_italic_D start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ) } start_POSTSUBSCRIPT roman_ℓ = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT, each layer performs compression and sparsification on representations in the forward pass; (b) Backpropagation learn subspaces and dictionaries {(𝑼[K],𝑫)}=1L\{(\bm{U}_{[K]}^{\ell},\bm{D}^{\ell})\}_{\ell=1}^{L}{ ( bold_italic_U start_POSTSUBSCRIPT [ italic_K ] end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT , bold_italic_D start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ) } start_POSTSUBSCRIPT roman_ℓ = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT from training data.
Table 4.1: Top-1 classification accuracy of crate on various datasets with different model scales when pre-trained on ImageNet-1K. For ImageNet-1K/ImageNet-1K ReaL, we directly evaluate the top-1 accuracy. For other datasets, we use models that are pre-trained on ImageNet as initialization and the evaluate the transfer learning performance via fine-tuning.
Model crate-T crate-S crate-B crate-L ViT-T ViT-S
# parameters 6.09M 13.12M 22.80M 77.64M 5.72M 22.05M
ImageNet-1K 66.7 69.2 70.8 71.3 71.5 72.4
ImageNet-1K ReaL 74.0 76.0 76.5 77.4 78.3 78.4
CIFAR10 95.5 96.0 96.8 97.2 96.6 97.2
CIFAR100 78.9 81.0 82.7 83.6 81.8 83.2
Oxford Flowers-102 84.6 87.1 88.7 88.3 85.1 88.5
Oxford-IIIT-Pets 81.4 84.9 85.3 87.4 88.5 88.6

4.3 Variants of Deep Architectures by Design

So far, we wish that we have provided compelling evidence that the role of (popular) deep networks is to realize certain optimization algorithms for minimizing the coding rate (or maximizing the information gain) of the learned representations. However, readers who are familiar with optimization methods might have noticed that the above architectures (the ReduNet or the CRATE) correspond to rather basic optimization techniques. They may have plenty of room for improvement in efficiency or effectiveness. Moreover, if we believe the proposed theoretical framework for interpreting deep networks is correct, it should not only help explain existing architectures, it should guide us develop more efficient and effective architectures. In this section, we show this could be the case: the resulting new architectures are not only fully interpretable but also with guaranteed correctness and improved efficiency.

4.3.1 Attention-Only Transformer Architecture

In this subsection, we propose a minimalistic transformer architecture consisting of interpretable layers based on the MSSA operator. To derive a fully interpretable transformer architecture with only necessary components, we contend that the goal of representation learning is to compress a set of noisy initial token representations towards a mixture of low-dimensional subspaces. Here, we assume that the initial token representations 𝒁(1)\bm{Z}^{(1)}bold_italic_Z start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT are sampled from a mixture of low-rank Gaussians perturbed by noise as follows:

Definition 4.1.

Let C1,,CKC_{1},\dots,C_{K}italic_C start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_C start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT be a partition of the index set [N][N][ italic_N ] and 𝑼k𝒪d×pk\bm{U}_{k}\in\mathcal{O}^{d\times p_{k}}bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ∈ caligraphic_O start_POSTSUPERSCRIPT italic_d × italic_p start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUPERSCRIPT denote the orthonormal basis of the kkitalic_k-th subspace for each k[K]k\in[K]italic_k ∈ [ italic_K ]. We say that the token representations {𝒛i}i=1Nd\{\bm{z}_{i}\}_{i=1}^{N}\subseteq\mathbb{R}^{d}{ bold_italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT ⊆ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT are sampled from a mixture of noisy low-rank Gaussian distributions if for each k[K]k\in[K]italic_k ∈ [ italic_K ],

𝒛i=𝑼k𝒂i𝐬𝐢𝐠𝐧𝐚𝐥+jkK𝑼j𝒆i,j𝐧𝐨𝐢𝐬𝐞,iCk,\displaystyle\bm{z}_{i}=\underbrace{\bm{U}_{k}\bm{a}_{i}}_{\bf signal}+\underbrace{\sum_{j\neq k}^{K}\bm{U}_{j}\bm{e}_{i,j}}_{\bf noise},\ \forall i\in C_{k},bold_italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = under⏟ start_ARG bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT bold_italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_ARG start_POSTSUBSCRIPT bold_signal end_POSTSUBSCRIPT + under⏟ start_ARG ∑ start_POSTSUBSCRIPT italic_j ≠ italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT bold_italic_U start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT bold_italic_e start_POSTSUBSCRIPT italic_i , italic_j end_POSTSUBSCRIPT end_ARG start_POSTSUBSCRIPT bold_noise end_POSTSUBSCRIPT , ∀ italic_i ∈ italic_C start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT , (4.3.1)

where 𝒂ii.i.d.𝒩(𝟎,𝑰pk)\bm{a}_{i}\overset{i.i.d.}{\sim}\mathcal{N}(\bm{0},\bm{I}_{p_{k}})bold_italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_OVERACCENT italic_i . italic_i . italic_d . end_OVERACCENT start_ARG ∼ end_ARG caligraphic_N ( bold_0 , bold_italic_I start_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) and 𝒆i,ji.i.d.𝒩(𝟎,δ2𝑰pj)\bm{e}_{i,j}\overset{i.i.d.}{\sim}\mathcal{N}(\bm{0},\delta^{2}\bm{I}_{p_{j}})bold_italic_e start_POSTSUBSCRIPT italic_i , italic_j end_POSTSUBSCRIPT start_OVERACCENT italic_i . italic_i . italic_d . end_OVERACCENT start_ARG ∼ end_ARG caligraphic_N ( bold_0 , italic_δ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT bold_italic_I start_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) for all iCki\in C_{k}italic_i ∈ italic_C start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT and k[K]k\in[K]italic_k ∈ [ italic_K ], {𝒂i}\{\bm{a}_{i}\}{ bold_italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } and {𝒆i,j}\{\bm{e}_{i,j}\}{ bold_italic_e start_POSTSUBSCRIPT italic_i , italic_j end_POSTSUBSCRIPT } are respectively mutually independent, and {𝒂i}\{\bm{a}_{i}\}{ bold_italic_a start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } is independent of {𝒆i,j}\{\bm{e}_{i,j}\}{ bold_italic_e start_POSTSUBSCRIPT italic_i , italic_j end_POSTSUBSCRIPT }.

This model serves as an idealized framework for approximating token representations in real-world pretrained LLMs. It assumes that the token representations are sampled from a mixture of multiple low-rank Gaussian distributions with noise. Under this model, the goal of representation learning is to compress a set of noisy initial token presentations into the corresponding subspace. In addition, this model aligns well with two well-established hypotheses about the structure of token representations in pretrained large language models: the “linear representation hypothesis” [JRR+24, PCV24] and the “superposition hypothesis” [EHO+22, YCO+21].

Remark 4.8.

The linear representation hypothesis posits that token representations in LLMs lie in low-dimensional linear subspaces that encode semantic features. Similarly, the superposition hypothesis suggests that these representations can be approximately expressed as a sparse linear combination of these feature vectors. In Definition 4.1, each basis 𝑼k\bm{U}_{k}bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT of the subspaces can be interpreted as a set of semantic features, where each feature corresponds to a specific aspect of the token’s meaning. Token representations are then approximately expressed as sparse linear combinations of these subspace bases, capturing the essential semantic components of the token while ignoring irrelevant dimensions.

Denoising Operator for Token Representations.

Now, we show that the MSSA operator (see (4.2.13)) can incrementally denoise token representations generated from the above model. Specifically, we consider for each =1,,L\ell=1,\dots,Lroman_ℓ = 1 , … , italic_L,

𝒁(+1)=𝒁()+ηk=1K𝑼k𝑼kT𝒁()φ(𝒁()T𝑼k𝑼kT𝒁()),\displaystyle\bm{Z}^{(\ell+1)}=\bm{Z}^{(\ell)}+\eta\sum_{k=1}^{K}\bm{U}_{k}\bm{U}_{k}^{T}\bm{Z}^{(\ell)}\varphi\left(\bm{Z}^{(\ell)^{T}}\bm{U}_{k}\bm{U}_{k}^{T}\bm{Z}^{(\ell)}\right),bold_italic_Z start_POSTSUPERSCRIPT ( roman_ℓ + 1 ) end_POSTSUPERSCRIPT = bold_italic_Z start_POSTSUPERSCRIPT ( roman_ℓ ) end_POSTSUPERSCRIPT + italic_η ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT bold_italic_Z start_POSTSUPERSCRIPT ( roman_ℓ ) end_POSTSUPERSCRIPT italic_φ ( bold_italic_Z start_POSTSUPERSCRIPT ( roman_ℓ ) start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT bold_italic_Z start_POSTSUPERSCRIPT ( roman_ℓ ) end_POSTSUPERSCRIPT ) , (4.3.2)

where {𝑼k}k=1K\{\bm{U}_{k}\}_{k=1}^{K}{ bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT is defined in Definition 4.1, η>0\eta>0italic_η > 0 is the step size, and φ\varphiitalic_φ is an element-wise operator, such as softmax, ReLU, or other functions. To simplify our development, we assume that the subspaces in Definition 4.1 are orthogonal to each other, i.e., 𝑼kT𝑼j=𝟎\bm{U}_{k}^{T}\bm{U}_{j}=\bm{0}bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT bold_italic_U start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT = bold_0 for all kjk\neq jitalic_k ≠ italic_j. Note that this assumption is not restrictive, as in high-dimensional spaces, random low-dimensional subspaces are incoherent to each other with high probability, i.e., 𝑼kT𝑼j𝟎\bm{U}_{k}^{T}\bm{U}_{j}\approx\bm{0}bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT bold_italic_U start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ≈ bold_0 [WM21].161616One may straightforwardly generalize our results to non-orthogonal subspaces, with slightly more sophisticated analysis.

(a) Noise level δ = 0.2 \delta=0.2 italic_δ = 0.2
(a) Noise level δ=0.2\delta=0.2italic_δ = 0.2
(a) Noise level δ = 0.2 \delta=0.2 italic_δ = 0.2
(b) Noise level δ=0.5\delta=0.5italic_δ = 0.5
Figure 4.16: Denoising performance of the attention-only transformer. Here, we sample initial token representations from a mixture of low-rank Gaussians in Definition 4.1. Then, we apply (4.3.2) to update token representations and report the SNR at each layer.

Now, let the columns of 𝒁k()\bm{Z}_{k}^{(\ell)}bold_italic_Z start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( roman_ℓ ) end_POSTSUPERSCRIPT denote the token representations from the kkitalic_k-th subspace at the \ellroman_ℓ-th layer. To quantify the denoising capability, we define the signal-to-noise ratio (SNR) for each block of the token representations at the \ellroman_ℓ-th layer as follows:

SNR(𝒁k())𝑼k𝑼kT𝒁k()F(𝑰𝑼k𝑼kT)𝒁k()F,k[K].\displaystyle\mathrm{SNR}(\bm{Z}_{k}^{(\ell)})\doteq\frac{\|\bm{U}_{k}\bm{U}_{k}^{T}\bm{Z}_{k}^{(\ell)}\|_{F}}{\|(\bm{I}-\bm{U}_{k}\bm{U}_{k}^{T})\bm{Z}_{k}^{(\ell)}\|_{F}},\quad\forall k\in[K].roman_SNR ( bold_italic_Z start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( roman_ℓ ) end_POSTSUPERSCRIPT ) ≐ divide start_ARG ∥ bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT bold_italic_Z start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( roman_ℓ ) end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT end_ARG start_ARG ∥ ( bold_italic_I - bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ) bold_italic_Z start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( roman_ℓ ) end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT end_ARG , ∀ italic_k ∈ [ italic_K ] . (4.3.3)

To simplify our analysis, we assume that p=p1==pKp=p_{1}=\dots=p_{K}italic_p = italic_p start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT = ⋯ = italic_p start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT, N1==NK=N/KN_{1}=\dots=N_{K}=N/Kitalic_N start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT = ⋯ = italic_N start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT = italic_N / italic_K, and

[𝑼1𝑼K]𝒪d×Kp.\displaystyle\begin{bmatrix}\bm{U}_{1}&\dots&\bm{U}_{K}\end{bmatrix}\in\mathcal{O}^{d\times Kp}.[ start_ARG start_ROW start_CELL bold_italic_U start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_CELL start_CELL … end_CELL start_CELL bold_italic_U start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT end_CELL end_ROW end_ARG ] ∈ caligraphic_O start_POSTSUPERSCRIPT italic_d × italic_K italic_p end_POSTSUPERSCRIPT . (4.3.4)

With the above setup, we now characterize the denoising performance of the MSSA operator.

Figure 4.17 : Details of the attention-only transformer architecture. Each layer consists of the MSSA operator and a skip connection. In addition, LayerNorm is included only for language tasks. In practice, backpropagation is applied to train the model parameters using training samples.
Figure 4.17: Details of the attention-only transformer architecture. Each layer consists of the MSSA operator and a skip connection. In addition, LayerNorm is included only for language tasks. In practice, backpropagation is applied to train the model parameters using training samples.
Theorem 4.1.

Let 𝐙(1)\bm{Z}^{(1)}bold_italic_Z start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT be defined in Definition 4.1 and φ()\varphi(\cdot)italic_φ ( ⋅ ) in (4.3.2) be φ(𝐱)=h(σ(𝐱))\varphi(\bm{x})=h\left(\sigma(\bm{x})\right)italic_φ ( bold_italic_x ) = italic_h ( italic_σ ( bold_italic_x ) ), where σ:NN\sigma:\mathbb{R}^{N}\to\mathbb{R}^{N}italic_σ : blackboard_R start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT → blackboard_R start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT is the softmax function and h:NNh:\mathbb{R}^{N}\to\mathbb{R}^{N}italic_h : blackboard_R start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT → blackboard_R start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT is an element-wise thresholding function with h(x)=τ𝕀{x>τ}h(x)=\tau\mathbb{I}\left\{x>\tau\right\}italic_h ( italic_x ) = italic_τ blackboard_I { italic_x > italic_τ } for each i[N]i\in[N]italic_i ∈ [ italic_N ]. Suppose that plogNp\gtrsim\log Nitalic_p ≳ roman_log italic_N, δlogN/p\delta\lesssim\sqrt{\log N}/\sqrt{p}italic_δ ≲ square-root start_ARG roman_log italic_N end_ARG / square-root start_ARG italic_p end_ARG, and

τ(12,11+Nexp(9p/32)].\displaystyle\tau\in\left(\frac{1}{2},\frac{1}{1+N\exp(-9p/32)}\right].italic_τ ∈ ( divide start_ARG 1 end_ARG start_ARG 2 end_ARG , divide start_ARG 1 end_ARG start_ARG 1 + italic_N roman_exp ( - 9 italic_p / 32 ) end_ARG ] .

For sufficiently large NNitalic_N, it holds with probability at least 1KLNΩ(1)1-KLN^{-\Omega(1)}1 - italic_K italic_L italic_N start_POSTSUPERSCRIPT - roman_Ω ( 1 ) end_POSTSUPERSCRIPT that for each [L]\ell\in[L]roman_ℓ ∈ [ italic_L ],

SNR(𝒁k(+1))=(1+ητ)SNR(𝒁k()),k[K].\displaystyle\mathrm{SNR}(\bm{Z}_{k}^{(\ell+1)})=(1+\eta\tau)\mathrm{SNR}(\bm{Z}_{k}^{(\ell)}),\ \forall k\in[K].roman_SNR ( bold_italic_Z start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( roman_ℓ + 1 ) end_POSTSUPERSCRIPT ) = ( 1 + italic_η italic_τ ) roman_SNR ( bold_italic_Z start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( roman_ℓ ) end_POSTSUPERSCRIPT ) , ∀ italic_k ∈ [ italic_K ] . (4.3.5)

This theorem demonstrates that when the initial token representations are sampled from a mixture of low-rank Gaussian distributions with a noise level O(logN/p)O(\sqrt{\log N}/\sqrt{p})italic_O ( square-root start_ARG roman_log italic_N end_ARG / square-root start_ARG italic_p end_ARG ), we show that each layer of the proposed transformer denoises token representations at a linear rate. This indicates the MSSA operator’s efficiency in reducing noise across layers. Notably, our theoretical results are well-supported by experimental observations in Figure 4.16. This theorem provides a theoretical foundation for the practical denoising capability of the transformer architecture derived by unrolling (4.3.2).

Remark 4.9.

Under this model, the goal of representation learning is to compress a set of noisy initial token presentations into the corresponding subspace. However, we should point out that in real-world applications, where token representations exhibit more complicated structures, the goal of representation learning is to find a compact and structured representation by compressing token sets.

Attention-Only Transformer.

Now, we formally propose an attention-only transformer architecture. Specifically, by unrolling the iterative optimization steps (4.3.2) as layers of a deep network, we construct a transformer architecture in Figure 4.17. Each layer of the proposed architecture only consists of the MSSA operator and a skip connection. For language tasks, we additionally incorporate LayerNorm before the MSSA operator to improve performance. The complete architecture is built by stacking such layers, along with essential task-specific pre-processing and post-processing steps, such as positional encoding, token embedding, and final task-specific head to adapt to different applications.

Generally speaking, the standard decoder-only transformer architecture is composed of the following key components [VSP+17]: (1) positional encoding, (2) multi-head QKV self-attention mechanisms, (3) feed-forward MLP networks, (4) layer normalization, and (5) residual connections. In contrast, our proposed transformer architecture adopts a streamlined design by incorporating several key simplications. Specifically, it employs shared-QKV subspace self-attention mechanisms, excludes MLP layers, and reduces the frequency of LayerNorm.

4.3.2 Linear-Time Attention: Token Statistics Transformer

In this subsection, we propose a new transformer attention operator whose computational complexity scales linearly with the number of tokens based on the coding rate reduction objective. Specifically, we derive a novel variational form of the MCR2 objective and show that the architecture that results from unrolled gradient descent of this variational objective leads to a new attention module called Token Statistics Self-Attention (TSSA). TSSA has linear computational and memory complexity and radically departs from the typical attention architecture that computes pairwise similarities between tokens. Recall from (3.4.2) that 𝚷=[𝝅1,,𝝅K]N×K\bm{\Pi}=[\bm{\pi}_{1},\ldots,\bm{\pi}_{K}]\in\mathbb{R}^{N\times K}bold_Π = [ bold_italic_π start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , bold_italic_π start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ] ∈ blackboard_R start_POSTSUPERSCRIPT italic_N × italic_K end_POSTSUPERSCRIPT denotes a stochastic “group assignment” matrix (i.e., 𝚷𝟏=𝟏\bm{\Pi}\bm{1}=\bm{1}bold_Π bold_1 = bold_1 and Πik0,(i,k)[N]×[K]\Pi_{ik}\geq 0,\ \forall(i,k)\in[N]\times[K]roman_Π start_POSTSUBSCRIPT italic_i italic_k end_POSTSUBSCRIPT ≥ 0 , ∀ ( italic_i , italic_k ) ∈ [ italic_N ] × [ italic_K ] ), where Πik\Pi_{ik}roman_Π start_POSTSUBSCRIPT italic_i italic_k end_POSTSUBSCRIPT denotes the probability of assigning the iiitalic_i-th token to the kkitalic_k-th group.

A New Variational Form for Coding Rates.

To begin, we consider a general form of MCR2-like objectives based on concave functions of the spectrum of a matrix. Namely, for a given PSD matrix 𝑴𝖯𝖲𝖣(d)\bm{M}\in\mathsf{PSD}(d)bold_italic_M ∈ sansserif_PSD ( italic_d ) and any scalar c0c\geq 0italic_c ≥ 0 we have that logdet(𝑰+c𝑴)=i=1dlog(1+cλi(𝑴))\log\det(\bm{I}+c\bm{M})=\sum_{i=1}^{d}\log(1+c\lambda_{i}(\bm{M}))roman_log roman_det ( bold_italic_I + italic_c bold_italic_M ) = ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT roman_log ( 1 + italic_c italic_λ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_italic_M ) ), where λi(𝑴)\lambda_{i}(\bm{M})italic_λ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_italic_M ) is the iiitalic_i-th largest eigenvalue of 𝑴\bm{M}bold_italic_M. Further, note that log(1+cσ)\log(1+c\sigma)roman_log ( 1 + italic_c italic_σ ) is a concave non-decreasing function of σ\sigmaitalic_σ. Thus, we describe our results in terms of a more general form of MCR2 based on general spectral functions of PSD matrices of the form F(𝑴)=i=1df(λi(𝑴))F(\bm{M})=\sum_{i=1}^{d}f(\lambda_{i}(\bm{M}))italic_F ( bold_italic_M ) = ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT italic_f ( italic_λ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_italic_M ) ), where ffitalic_f is concave and non-decreasing. In particular, recall from our above discussion that the attention mechanism arises from unrolling the compression component of MCR2, so we consider a more general MCR2-style compression function:

Rc,f(𝒁,𝚷)12k=1KNkNF(1Nk𝒁Diag(𝝅k)𝒁).R_{c,f}(\bm{Z},\bm{\Pi})\doteq\frac{1}{2}\sum_{k=1}^{K}\frac{N_{k}}{N}F\left(\frac{1}{N_{k}}\bm{Z}\mathrm{Diag}(\bm{\pi}_{k})\bm{Z}^{\top}\right).italic_R start_POSTSUBSCRIPT italic_c , italic_f end_POSTSUBSCRIPT ( bold_italic_Z , bold_Π ) ≐ divide start_ARG 1 end_ARG start_ARG 2 end_ARG ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT divide start_ARG italic_N start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_ARG start_ARG italic_N end_ARG italic_F ( divide start_ARG 1 end_ARG start_ARG italic_N start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_ARG bold_italic_Z roman_Diag ( bold_italic_π start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) bold_italic_Z start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) . (4.3.6)

For the above objective, we now note the following result:

Theorem 4.2.

Let f:[0,)f\colon[0,\infty)\to\mathbb{R}italic_f : [ 0 , ∞ ) → blackboard_R be non-decreasing, concave, and obey f(0)=0f(0)=0italic_f ( 0 ) = 0, and let F:𝖯𝖲𝖣(d)F\colon\mathsf{PSD}(d)\to\mathbb{R}italic_F : sansserif_PSD ( italic_d ) → blackboard_R have the form F(𝐌)=i=1df(λi(𝐌))F(\bm{M})=\sum_{i=1}^{d}f(\lambda_{i}(\bm{M}))italic_F ( bold_italic_M ) = ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT italic_f ( italic_λ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_italic_M ) ). Then for each 𝐌𝖯𝖲𝖣(d)\bm{M}\in\mathsf{PSD}(d)bold_italic_M ∈ sansserif_PSD ( italic_d ) and 𝐐𝖮(d)\bm{Q}\in\mathsf{O}(d)bold_italic_Q ∈ sansserif_O ( italic_d ), we have

F(𝑴)i=1df((𝑸𝑴𝑸)ii).F(\bm{M})\leq\sum_{i=1}^{d}f\left((\bm{Q}^{\top}\bm{M}\bm{Q})_{ii}\right).italic_F ( bold_italic_M ) ≤ ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT italic_f ( ( bold_italic_Q start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_M bold_italic_Q ) start_POSTSUBSCRIPT italic_i italic_i end_POSTSUBSCRIPT ) . (4.3.7)

Further, the inequality in (4.3.7) is achieved with equality for any 𝐐\bm{Q}bold_italic_Q which diagonalizes 𝐌\bm{M}bold_italic_M, and if ffitalic_f is strictly concave then the inequality in (4.3.7) is achieved with equality if and only if 𝐐\bm{Q}bold_italic_Q diagonalizes 𝐌\bm{M}bold_italic_M.

Using the above result, we can replace (4.3.6) with an equivalent variational objective with form

Rc,fvar(𝒁,𝚷𝑼[K])12k=1KNkNi=1df(1Nk(𝑼k𝒁Diag(𝝅k)𝒁𝑼k)ii),R^{\rm var}_{c,f}(\bm{Z},\bm{\Pi}\mid\bm{U}_{[K]})\doteq\frac{1}{2}\sum_{k=1}^{K}\frac{N_{k}}{N}\sum_{i=1}^{d}f\left(\frac{1}{N_{k}}(\bm{U}_{k}^{\top}\bm{Z}\mathrm{Diag}(\bm{\pi}_{k})\bm{Z}^{\top}\bm{U}_{k})_{ii}\right),italic_R start_POSTSUPERSCRIPT roman_var end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_c , italic_f end_POSTSUBSCRIPT ( bold_italic_Z , bold_Π ∣ bold_italic_U start_POSTSUBSCRIPT [ italic_K ] end_POSTSUBSCRIPT ) ≐ divide start_ARG 1 end_ARG start_ARG 2 end_ARG ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT divide start_ARG italic_N start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_ARG start_ARG italic_N end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT italic_f ( divide start_ARG 1 end_ARG start_ARG italic_N start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_ARG ( bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_Z roman_Diag ( bold_italic_π start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) bold_italic_Z start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT italic_i italic_i end_POSTSUBSCRIPT ) , (4.3.8)

where the equivalence is in the sense that for an optimal choice of {𝑼k𝖮(d)}k=1K\{\bm{U}_{k}\in\mathsf{O}(d)\}_{k=1}^{K}{ bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ∈ sansserif_O ( italic_d ) } start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT matrices as described in Theorem 4.2 (i.e., orthogonal matrices which diagonalize each 𝒁Diag(𝝅k)𝒁\bm{Z}\mathrm{Diag}(\bm{\pi}_{k})\bm{Z}^{\top}bold_italic_Z roman_Diag ( bold_italic_π start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) bold_italic_Z start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT) we will achieve a tight bound with Rc,fvar(𝒁,𝚷𝑼[K])=Rc,f(𝒁,𝚷)R^{\rm var}_{c,f}(\bm{Z},\bm{\Pi}\mid\bm{U}_{[K]})=R_{c,f}(\bm{Z},\bm{\Pi})italic_R start_POSTSUPERSCRIPT roman_var end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_c , italic_f end_POSTSUBSCRIPT ( bold_italic_Z , bold_Π ∣ bold_italic_U start_POSTSUBSCRIPT [ italic_K ] end_POSTSUBSCRIPT ) = italic_R start_POSTSUBSCRIPT italic_c , italic_f end_POSTSUBSCRIPT ( bold_italic_Z , bold_Π ). Note that in general, achieving this bound would require selecting, for each sampled instance of 𝒁\bm{Z}bold_italic_Z, a new optimal set of 𝑼k\bm{U}_{k}bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT parameter matrices which diagonalize each 𝒁Diag(𝝅k)𝒁\bm{Z}\mathrm{Diag}(\bm{\pi}_{k})\bm{Z}^{\top}bold_italic_Z roman_Diag ( bold_italic_π start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) bold_italic_Z start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT, which is clearly impractical for network architecture. Instead, as an alternative viewpoint, rather than considering the data (𝒁\bm{Z}bold_italic_Z) as fixed and trying to optimize the 𝑼k\bm{U}_{k}bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT parameters to achieve the tight variational bound, we can instead take the algorithmic unrolling design principle described above and design an operator to perturb 𝒁\bm{Z}bold_italic_Z to incrementally minimize Rc,fvar(𝑼[K])R_{c,f}^{\rm var}(\cdot\mid\bm{U}_{[K]})italic_R start_POSTSUBSCRIPT italic_c , italic_f end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_var end_POSTSUPERSCRIPT ( ⋅ ∣ bold_italic_U start_POSTSUBSCRIPT [ italic_K ] end_POSTSUBSCRIPT ). To make this point explicit, each variational bound becomes tight when the eigenspaces of 𝒁Diag(𝝅k)𝒁\bm{Z}\mathrm{Diag}(\bm{\pi}_{k})\bm{Z}^{\top}bold_italic_Z roman_Diag ( bold_italic_π start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) bold_italic_Z start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT align with the columns of 𝑼k\bm{U}_{k}bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT, so by rotating the appropriate columns of 𝒁\bm{Z}bold_italic_Z (namely, those which correspond to large entries in 𝝅k\bm{\pi}_{k}bold_italic_π start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT) to align with 𝑼k\bm{U}_{k}bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT we can approach a tight variational bound. That is, instead of rotating 𝑼k\bm{U}_{k}bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT to align with the data for each instance of 𝒁\bm{Z}bold_italic_Z, we can instead rotate the token features in each 𝒁\bm{Z}bold_italic_Z to align with 𝑼k\bm{U}_{k}bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT.

Following this approach, we compute a gradient descent step on Rc,fvarR_{c,f}^{\rm var}italic_R start_POSTSUBSCRIPT italic_c , italic_f end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_var end_POSTSUPERSCRIPT w.r.t. 𝒁\bm{Z}bold_italic_Z. To begin this computation, first let 𝝅N\bm{\pi}\in\mathbb{R}^{N}bold_italic_π ∈ blackboard_R start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT be any element-wise non-negative vector. Then we have

𝒁12i=1df((𝒁Diag(𝝅)𝒁)ii)=Diag(f[𝒁2𝝅])𝒁Diag(𝝅),\nabla_{\bm{Z}}\ \frac{1}{2}\sum_{i=1}^{d}f((\bm{Z}\mathrm{Diag}(\bm{\pi})\bm{Z}^{\top})_{ii})=\;\mathrm{Diag}(\nabla f[\bm{Z}^{\mathbin{\mathchoice{\raisebox{1.3pt}{$\displaystyle\mathchoice{\scalebox{0.8}{$\displaystyle\odot$}}{\scalebox{0.8}{$\textstyle\odot$}}{\scalebox{0.8}{$\scriptstyle\odot$}}{\scalebox{0.8}{$\scriptscriptstyle\odot$}}$}}{\raisebox{1.3pt}{$\mathchoice{\scalebox{0.8}{$\displaystyle\odot$}}{\scalebox{0.8}{$\textstyle\odot$}}{\scalebox{0.8}{$\scriptstyle\odot$}}{\scalebox{0.8}{$\scriptscriptstyle\odot$}}$}}{\raisebox{0.75pt}{$\scriptstyle\mathchoice{\scalebox{0.8}{$\displaystyle\odot$}}{\scalebox{0.8}{$\textstyle\odot$}}{\scalebox{0.8}{$\scriptstyle\odot$}}{\scalebox{0.8}{$\scriptscriptstyle\odot$}}$}}{\raisebox{0.6pt}{$\scriptscriptstyle\mathchoice{\scalebox{0.8}{$\displaystyle\odot$}}{\scalebox{0.8}{$\textstyle\odot$}}{\scalebox{0.8}{$\scriptstyle\odot$}}{\scalebox{0.8}{$\scriptscriptstyle\odot$}}$}}}2}\bm{\pi}])\bm{Z}\mathrm{Diag}(\bm{\pi}),∇ start_POSTSUBSCRIPT bold_italic_Z end_POSTSUBSCRIPT divide start_ARG 1 end_ARG start_ARG 2 end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT italic_f ( ( bold_italic_Z roman_Diag ( bold_italic_π ) bold_italic_Z start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) start_POSTSUBSCRIPT italic_i italic_i end_POSTSUBSCRIPT ) = roman_Diag ( ∇ italic_f [ bold_italic_Z start_POSTSUPERSCRIPT ⊙ 2 end_POSTSUPERSCRIPT bold_italic_π ] ) bold_italic_Z roman_Diag ( bold_italic_π ) , (4.3.9)

where f\nabla f∇ italic_f is the gradient of ffitalic_f, and (recall) f[]\nabla f[\cdot]∇ italic_f [ ⋅ ] applies f\nabla f∇ italic_f to each element of the vector in the bracket. In particular, for f(x)=log(1+(d/ϵ2)x)f(x)=\log(1+(d/\epsilon^{2})x)italic_f ( italic_x ) = roman_log ( 1 + ( italic_d / italic_ϵ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) italic_x ), f(x)=(d/ϵ2)(1+(d/ϵ2)x)1\nabla f(x)=(d/\epsilon^{2})(1+(d/\epsilon^{2})x)^{-1}∇ italic_f ( italic_x ) = ( italic_d / italic_ϵ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) ( 1 + ( italic_d / italic_ϵ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) italic_x ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT is simply a non-linear activation. Also, (recall) Nk=𝝅k,𝟏N_{k}=\langle\bm{\pi}_{k},\bm{1}\rangleitalic_N start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT = ⟨ bold_italic_π start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT , bold_1 ⟩. Thus, the gradient of Rc,fvarR^{\rm var}_{c,f}italic_R start_POSTSUPERSCRIPT roman_var end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_c , italic_f end_POSTSUBSCRIPT w.r.t. 𝒁\bm{Z}bold_italic_Z is:

𝒁Rc,fvar(𝒁,𝚷𝑼[K])=1nk=1K𝑼kDiag(f[(𝑼k𝒁)2𝝅k𝝅k,𝟏])𝑫(𝒁,𝝅k𝑼k)𝑼k𝒁Diag(𝝅k).\displaystyle\nabla_{\bm{Z}}R^{\rm var}_{c,f}(\bm{Z},\bm{\Pi}\mid\bm{U}_{[K]})=\frac{1}{n}\sum_{k=1}^{K}\bm{U}_{k}\underbrace{\mathrm{Diag}\left(\nabla f\left[(\bm{U}_{k}^{\top}\bm{Z})^{\mathbin{\mathchoice{\raisebox{1.3pt}{$\displaystyle\mathchoice{\scalebox{0.8}{$\displaystyle\odot$}}{\scalebox{0.8}{$\textstyle\odot$}}{\scalebox{0.8}{$\scriptstyle\odot$}}{\scalebox{0.8}{$\scriptscriptstyle\odot$}}$}}{\raisebox{1.3pt}{$\mathchoice{\scalebox{0.8}{$\displaystyle\odot$}}{\scalebox{0.8}{$\textstyle\odot$}}{\scalebox{0.8}{$\scriptstyle\odot$}}{\scalebox{0.8}{$\scriptscriptstyle\odot$}}$}}{\raisebox{0.75pt}{$\scriptstyle\mathchoice{\scalebox{0.8}{$\displaystyle\odot$}}{\scalebox{0.8}{$\textstyle\odot$}}{\scalebox{0.8}{$\scriptstyle\odot$}}{\scalebox{0.8}{$\scriptscriptstyle\odot$}}$}}{\raisebox{0.6pt}{$\scriptscriptstyle\mathchoice{\scalebox{0.8}{$\displaystyle\odot$}}{\scalebox{0.8}{$\textstyle\odot$}}{\scalebox{0.8}{$\scriptstyle\odot$}}{\scalebox{0.8}{$\scriptscriptstyle\odot$}}$}}}2}\frac{\bm{\pi}_{k}}{\langle\bm{\pi}_{k},\bm{1}\rangle}\right]\right)}_{\doteq\bm{D}(\bm{Z},\bm{\pi}_{k}\mid\bm{U}_{k})}\bm{U}_{k}^{\top}\bm{Z}\mathrm{Diag}(\bm{\pi}_{k}).∇ start_POSTSUBSCRIPT bold_italic_Z end_POSTSUBSCRIPT italic_R start_POSTSUPERSCRIPT roman_var end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_c , italic_f end_POSTSUBSCRIPT ( bold_italic_Z , bold_Π ∣ bold_italic_U start_POSTSUBSCRIPT [ italic_K ] end_POSTSUBSCRIPT ) = divide start_ARG 1 end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT under⏟ start_ARG roman_Diag ( ∇ italic_f [ ( bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_Z ) start_POSTSUPERSCRIPT ⊙ 2 end_POSTSUPERSCRIPT divide start_ARG bold_italic_π start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_ARG start_ARG ⟨ bold_italic_π start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT , bold_1 ⟩ end_ARG ] ) end_ARG start_POSTSUBSCRIPT ≐ bold_italic_D ( bold_italic_Z , bold_italic_π start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ∣ bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) end_POSTSUBSCRIPT bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_Z roman_Diag ( bold_italic_π start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) . (4.3.10)

(Note that the 1/N1/N1 / italic_N constant arises from a (Nk/N)(1/Nk)=1/N(N_{k}/N)\cdot(1/N_{k})=1/N( italic_N start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT / italic_N ) ⋅ ( 1 / italic_N start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) = 1 / italic_N constant in each term of the sum.) If we now consider a gradient step w.r.t. the jjitalic_j-th token 𝒛j\bm{z}_{j}bold_italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT , we arrive at our proposed incremental compression operator, i.e., our surrogate for a self attention + residual operator:

𝒛j+=𝒛jτ𝒛jRc,fvar(𝒁,𝚷𝑼[K])=𝒛jτNk=1KΠjk𝑼k𝑫(𝒁,𝝅k𝑼k)𝑼k𝒛j\bm{z}_{j}^{+}=\bm{z}_{j}-\tau\nabla_{\bm{z}_{j}}R_{c,f}^{\rm var}(\bm{Z},\bm{\Pi}\mid\bm{U}_{[K]})=\bm{z}_{j}-\frac{\tau}{N}\sum_{k=1}^{K}\Pi_{jk}\bm{U}_{k}\bm{D}(\bm{Z},\bm{\pi}_{k}\mid\bm{U}_{k})\bm{U}_{k}^{\top}\bm{z}_{j}bold_italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT + end_POSTSUPERSCRIPT = bold_italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT - italic_τ ∇ start_POSTSUBSCRIPT bold_italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_R start_POSTSUBSCRIPT italic_c , italic_f end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_var end_POSTSUPERSCRIPT ( bold_italic_Z , bold_Π ∣ bold_italic_U start_POSTSUBSCRIPT [ italic_K ] end_POSTSUBSCRIPT ) = bold_italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT - divide start_ARG italic_τ end_ARG start_ARG italic_N end_ARG ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT roman_Π start_POSTSUBSCRIPT italic_j italic_k end_POSTSUBSCRIPT bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT bold_italic_D ( bold_italic_Z , bold_italic_π start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ∣ bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT (4.3.11)

for each j[n]j\in[n]italic_j ∈ [ italic_n ], where τ>0\tau>0italic_τ > 0 is a step size parameter for the incremental optimization. Then, we can construct a layer of TOST in Figure 4.18.

Figure 4.18 : One layer ℓ \ell roman_ℓ of the proposed Token Statistics Transformer (ToST). Notably, the self-attention of ToST transforms tokens 𝒁 ℓ \bm{Z}^{\ell} bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT efficiently to 𝒁 ℓ + 1 \bm{Z}^{\ell+1} bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ + 1 end_POSTSUPERSCRIPT , via multiplying each row of the projected token by only a scalar . This leads to reduced complexity of the attention: it has O ​ ( p ) O(p) italic_O ( italic_p ) space and O ​ ( p ​ n ) O(pn) italic_O ( italic_p italic_n ) time complexity, where p p italic_p is the dimension of the projected tokens of each head, and n n italic_n is the number of tokens.
Figure 4.18: One layer \ellroman_ℓ of the proposed Token Statistics Transformer (ToST). Notably, the self-attention of ToST transforms tokens 𝒁\bm{Z}^{\ell}bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT efficiently to 𝒁+1\bm{Z}^{\ell+1}bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ + 1 end_POSTSUPERSCRIPT, via multiplying each row of the projected token by only a scalar. This leads to reduced complexity of the attention: it has O(p)O(p)italic_O ( italic_p ) space and O(pn)O(pn)italic_O ( italic_p italic_n ) time complexity, where ppitalic_p is the dimension of the projected tokens of each head, and nnitalic_n is the number of tokens.

Model interpretation.

Given the proposed attention operator in (4.3.11), first recall that the rows of 𝚷\bm{\Pi}bold_Π are non-negative and sum to 1 , so our operator takes a weighted average of KKitalic_K “attention head”-esque operators and then adds a residual connection. Using that k=1KΠjk=1\sum_{k=1}^{K}\Pi_{jk}=1∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT roman_Π start_POSTSUBSCRIPT italic_j italic_k end_POSTSUBSCRIPT = 1, we can rewrite (4.3.11) as:

𝒛j+=k=1KΠjk[𝒛jτn𝑼k𝑫(𝒁,𝝅k𝑼k)𝑼kaction of one attention head𝒛j].\bm{z}_{j}^{+}=\sum_{k=1}^{K}\Pi_{jk}\Big{[}\bm{z}_{j}\underbrace{-\frac{\tau}{n}\bm{U}_{k}\bm{D}(\bm{Z},\bm{\pi}_{k}\mid\bm{U}_{k})\bm{U}_{k}^{\top}}_{\text{action of one attention head}}\bm{z}_{j}\Big{]}.bold_italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT start_POSTSUPERSCRIPT + end_POSTSUPERSCRIPT = ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT roman_Π start_POSTSUBSCRIPT italic_j italic_k end_POSTSUBSCRIPT [ bold_italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT under⏟ start_ARG - divide start_ARG italic_τ end_ARG start_ARG italic_n end_ARG bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT bold_italic_D ( bold_italic_Z , bold_italic_π start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ∣ bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT end_ARG start_POSTSUBSCRIPT action of one attention head end_POSTSUBSCRIPT bold_italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ] . (4.3.12)

That is, we can view each attention head as first projecting the token features onto the basis 𝑼k\bm{U}_{k}bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT via multiplying by 𝑼k\bm{U}_{k}^{\top}bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT, multiplying by the diagonal matrix 𝑫(𝒁,𝝅k𝑼k)\bm{D}(\bm{Z},\bm{\pi}_{k}\mid\bm{U}_{k})bold_italic_D ( bold_italic_Z , bold_italic_π start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ∣ bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) (abbreviated as 𝑫k\bm{D}_{k}bold_italic_D start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT), projecting back into the standard basis via multiplying by 𝑼k\bm{U}_{k}bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT, and subtracting this from the original token features via the residual connection. The core aspect of our attention layer is the computation of 𝑫k\bm{D}_{k}bold_italic_D start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT. Namely, Πjk0\Pi_{jk}\geq 0roman_Π start_POSTSUBSCRIPT italic_j italic_k end_POSTSUBSCRIPT ≥ 0, so 𝝅k/𝝅k,𝟏N\bm{\pi}_{k}/\langle\bm{\pi}_{k},\bm{1}\rangle\in\mathbb{R}^{N}bold_italic_π start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT / ⟨ bold_italic_π start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT , bold_1 ⟩ ∈ blackboard_R start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT forms a probability distribution over which tokens belong to the kthk^{\text{th}}italic_k start_POSTSUPERSCRIPT th end_POSTSUPERSCRIPT group. As a result, (𝑼k𝒁)2𝝅k/𝝅k,𝟏(\bm{U}^{\top}_{k}\bm{Z})^{\mathbin{\mathchoice{\raisebox{1.3pt}{$\displaystyle\mathchoice{\scalebox{0.8}{$\displaystyle\odot$}}{\scalebox{0.8}{$\textstyle\odot$}}{\scalebox{0.8}{$\scriptstyle\odot$}}{\scalebox{0.8}{$\scriptscriptstyle\odot$}}$}}{\raisebox{1.3pt}{$\mathchoice{\scalebox{0.8}{$\displaystyle\odot$}}{\scalebox{0.8}{$\textstyle\odot$}}{\scalebox{0.8}{$\scriptstyle\odot$}}{\scalebox{0.8}{$\scriptscriptstyle\odot$}}$}}{\raisebox{0.75pt}{$\scriptstyle\mathchoice{\scalebox{0.8}{$\displaystyle\odot$}}{\scalebox{0.8}{$\textstyle\odot$}}{\scalebox{0.8}{$\scriptstyle\odot$}}{\scalebox{0.8}{$\scriptscriptstyle\odot$}}$}}{\raisebox{0.6pt}{$\scriptscriptstyle\mathchoice{\scalebox{0.8}{$\displaystyle\odot$}}{\scalebox{0.8}{$\textstyle\odot$}}{\scalebox{0.8}{$\scriptstyle\odot$}}{\scalebox{0.8}{$\scriptscriptstyle\odot$}}$}}}2}\bm{\pi}_{k}/\langle\bm{\pi}_{k},\bm{1}\rangle( bold_italic_U start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT bold_italic_Z ) start_POSTSUPERSCRIPT ⊙ 2 end_POSTSUPERSCRIPT bold_italic_π start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT / ⟨ bold_italic_π start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT , bold_1 ⟩ estimates the second moment of 𝑼k𝒁\bm{U}_{k}^{\top}\bm{Z}bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_Z under the distribution given by 𝝅k/𝝅k,𝟏\bm{\pi}_{k}/\langle\bm{\pi}_{k},\bm{1}\ranglebold_italic_π start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT / ⟨ bold_italic_π start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT , bold_1 ⟩. Further, since ffitalic_f is a concave non-decreasing function, f(x)\nabla f(x)∇ italic_f ( italic_x ) monotonically decreases towards 0 as xxitalic_x increases, so the entries of 𝑫k\bm{D}_{k}bold_italic_D start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT (which have form f(x)\nabla f(x)∇ italic_f ( italic_x )) achieve their maximum at x=0x=0italic_x = 0 and decay monotonically to 0 as xxitalic_x increases.

From this, we arrive at the core interpretation of our attention head + residual operators [𝑰(τ/n)𝑼k𝑫k𝑼k][\bm{I}-(\tau/n)\bm{U}_{k}\bm{D}_{k}\bm{U}_{k}^{\top}][ bold_italic_I - ( italic_τ / italic_n ) bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT bold_italic_D start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ]. Namely, this operator does an approximate low-rank data-dependent projection, where directions which have a large amount of “power” after the projection 𝑼k𝒁\bm{U}_{k}^{\top}\bm{Z}bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_Z (i.e., directions which have a large second moment (𝑼k𝒁)2𝝅k/𝝅k,𝟏(\bm{U}_{k}^{\top}\bm{Z})^{\mathbin{\mathchoice{\raisebox{1.3pt}{$\displaystyle\mathchoice{\scalebox{0.8}{$\displaystyle\odot$}}{\scalebox{0.8}{$\textstyle\odot$}}{\scalebox{0.8}{$\scriptstyle\odot$}}{\scalebox{0.8}{$\scriptscriptstyle\odot$}}$}}{\raisebox{1.3pt}{$\mathchoice{\scalebox{0.8}{$\displaystyle\odot$}}{\scalebox{0.8}{$\textstyle\odot$}}{\scalebox{0.8}{$\scriptstyle\odot$}}{\scalebox{0.8}{$\scriptscriptstyle\odot$}}$}}{\raisebox{0.75pt}{$\scriptstyle\mathchoice{\scalebox{0.8}{$\displaystyle\odot$}}{\scalebox{0.8}{$\textstyle\odot$}}{\scalebox{0.8}{$\scriptstyle\odot$}}{\scalebox{0.8}{$\scriptscriptstyle\odot$}}$}}{\raisebox{0.6pt}{$\scriptscriptstyle\mathchoice{\scalebox{0.8}{$\displaystyle\odot$}}{\scalebox{0.8}{$\textstyle\odot$}}{\scalebox{0.8}{$\scriptstyle\odot$}}{\scalebox{0.8}{$\scriptscriptstyle\odot$}}$}}}2}\bm{\pi}_{k}/\langle\bm{\pi}_{k},\bm{1}\rangle( bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_Z ) start_POSTSUPERSCRIPT ⊙ 2 end_POSTSUPERSCRIPT bold_italic_π start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT / ⟨ bold_italic_π start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT , bold_1 ⟩) are preserved, while directions which do not are suppressed. To see this, recall that the entries of 𝑫k\bm{D}_{k}bold_italic_D start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT decrease monotonically to 0 as the second moment increases, so for directions with large second moments the attention + residual operator acts largely as the identity operator. Conversely, for directions with a small second moment, our operator subtracts a projection of the tokens along those directions, resulting in those directions being suppressed. Compared to the standard self-attention operator, our method clearly does not compute any pairwise similarities between tokens. Rather, the interactions between the tokens in 𝒁\bm{Z}bold_italic_Z impact the operator solely through their contribution to the second moment statistic used to construct the 𝑫k\bm{D}_{k}bold_italic_D start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT’s. Nevertheless, similar to the standard self-attention operator, our method still has a clear interpretation as performing a form of compression towards a data-dependent low-rank structure, in the sense that it performs an approximate low-rank projection, where the specific directions that are suppressed are those which are not strongly aligned with other tokens in the group.

Practical Implementation Details.

Having introduced our proposed attention operator, we now discuss further practical considerations. First, until this point in the presentation, we have avoided discussion of how tokens are “grouped” into various attention heads via the 𝚷\bm{\Pi}bold_Π matrix, but clearly a means of constructing 𝚷\bm{\Pi}bold_Π is needed to implement our method. Additionally, our variational form in Theorem 4.2 requires the 𝑼\bm{U}bold_italic_U matrices to be square and orthogonal, but one would ideally like to use smaller matrices (i.e., reduce the number of columns in 𝑼\bm{U}bold_italic_U) for efficiency as well as drop the orthogonality constraints.

In practice, we do not enforce the orthogonality constraints. To reduce the number of columns in the 𝑼\bm{U}bold_italic_U matrices, we note that similar to CRATE [YBP+23], if we assume the features 𝒁\bm{Z}bold_italic_Z within group kkitalic_k are (approximately) clustered around a low-dimensional subspace — say of dimension ppitalic_p — then the within-group-kkitalic_k covariance 𝒁Diag(𝝅k)𝒁\bm{Z}\mathrm{Diag}(\bm{\pi}_{k})\bm{Z}^{\top}bold_italic_Z roman_Diag ( bold_italic_π start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) bold_italic_Z start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT is low-rank, where recall that [YCY+20] shows that the optimal geometry of 𝒁\bm{Z}bold_italic_Z will be for each group to be a low-rank subspace, orthogonal to the other groups. We can thus explicitly find a low-dimensional orthonormal basis for the image of this covariance, i.e., the linear span of the data in group kkitalic_k. If the dimension is pdp\leq ditalic_p ≤ italic_d, the basis can be represented by a d×pd\times pitalic_d × italic_p orthogonal matrix 𝑼k𝖮(d,p)\bm{U}_{k}\in\mathsf{O}(d,p)bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ∈ sansserif_O ( italic_d , italic_p ). In this case, we can more efficiently upper-bound Rc,fR_{c,f}italic_R start_POSTSUBSCRIPT italic_c , italic_f end_POSTSUBSCRIPT using these low-rank orthogonal basis matrices. To show this, we use a more general version of Theorem 4.2 to yield the following corollary.

Corollary 4.1.

Let f:[0,)f\colon[0,\infty)\to\mathbb{R}italic_f : [ 0 , ∞ ) → blackboard_R be non-decreasing, concave, and obey f(0)=0f(0)=0italic_f ( 0 ) = 0, and let F:𝖯𝖲𝖣(p)F\colon\mathsf{PSD}(p)\to\mathbb{R}italic_F : sansserif_PSD ( italic_p ) → blackboard_R have the form F(𝐌)=i=1pf(λi(𝐌))F(\bm{M})=\sum_{i=1}^{p}f(\lambda_{i}(\bm{M}))italic_F ( bold_italic_M ) = ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_p end_POSTSUPERSCRIPT italic_f ( italic_λ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ( bold_italic_M ) ). Let 𝐙\bm{Z}bold_italic_Z, 𝚷\bm{\Pi}bold_Π be fixed. Then, for all 𝐔1,,𝐔K𝖮(d,p)\bm{U}_{1},\dots,\bm{U}_{K}\in\mathsf{O}(d,p)bold_italic_U start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , bold_italic_U start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ∈ sansserif_O ( italic_d , italic_p ) such that image(𝐙diag(𝛑k)𝐙)image(𝐔k)\mathrm{image}(\bm{Z}\operatorname{diag}(\bm{\pi}_{k})\bm{Z}^{\top})\subset\mathrm{image}(\bm{U}_{k})roman_image ( bold_italic_Z roman_diag ( bold_italic_π start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) bold_italic_Z start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) ⊂ roman_image ( bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) for all k[K]k\in[K]italic_k ∈ [ italic_K ], we have

Rc,f(𝒁,𝚷)Rc,fvar(𝒁,𝚷𝑼[K]),R_{c,f}(\bm{Z},\bm{\Pi})\leq R_{c,f}^{\rm var}(\bm{Z},\bm{\Pi}\mid\bm{U}_{[K]}),italic_R start_POSTSUBSCRIPT italic_c , italic_f end_POSTSUBSCRIPT ( bold_italic_Z , bold_Π ) ≤ italic_R start_POSTSUBSCRIPT italic_c , italic_f end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_var end_POSTSUPERSCRIPT ( bold_italic_Z , bold_Π ∣ bold_italic_U start_POSTSUBSCRIPT [ italic_K ] end_POSTSUBSCRIPT ) , (4.3.13)

where Rc,fvarR_{c,f}^{\rm var}italic_R start_POSTSUBSCRIPT italic_c , italic_f end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_var end_POSTSUPERSCRIPT is formally defined in (4.3.8). Equality holds if 𝐔k\bm{U}_{k}bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT diagonalizes 𝐙diag(𝛑k)𝐙\bm{Z}\operatorname{diag}(\bm{\pi}_{k})\bm{Z}^{\top}bold_italic_Z roman_diag ( bold_italic_π start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) bold_italic_Z start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT for each k[K]k\in[K]italic_k ∈ [ italic_K ], and if ffitalic_f is strongly concave then this equality condition becomes an “if and only if.”

The final step to define our attention operator is to estimate the group membership 𝚷\bm{\Pi}bold_Π. For this we posit a simple model of how each feature 𝒛j\bm{z}_{j}bold_italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT deviates from its supporting subspace and then find the optimal subspace assignment. [YBP+23] show that if we independently model each 𝒛j\bm{z}_{j}bold_italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT as belonging to a low-dimensional Gaussian mixture model, where each Gaussian has a covariance matrix with identical spectrum and is supported on a subspace with orthonormal basis 𝑼k\bm{U}_{k}bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT, plus independent Gaussian noise with covariance η𝑰\eta\bm{I}italic_η bold_italic_I, then the posterior probability that each token 𝒛j\bm{z}_{j}bold_italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT belongs to each subspace is given by the assignment matrix 𝚷=𝚷(𝒁𝑼[K])\bm{\Pi}=\bm{\Pi}(\bm{Z}\mid\bm{U}_{[K]})bold_Π = bold_Π ( bold_italic_Z ∣ bold_italic_U start_POSTSUBSCRIPT [ italic_K ] end_POSTSUBSCRIPT ) as follows:

𝚷=[𝝂(𝒛1𝑼[K])𝝂(𝒛n𝑼[K])],where𝝂(𝒛j𝑼[K])softmax(12η[𝑼1𝒛j22𝑼K𝒛j22]),j[n],\displaystyle\bm{\Pi}=\begin{bmatrix}\bm{\nu}(\bm{z}_{1}\mid\bm{U}_{[K]})^{\top}\\ \vdots\\ \bm{\nu}(\bm{z}_{n}\mid\bm{U}_{[K]})^{\top}\end{bmatrix},\quad\text{where}\quad\bm{\nu}(\bm{z}_{j}\mid\bm{U}_{[K]})\doteq\operatorname{softmax}\left(\frac{1}{2\eta}\begin{bmatrix}\|\bm{U}_{1}^{\top}\bm{z}_{j}\|_{2}^{2}\\ \vdots\\ \|\bm{U}_{K}^{\top}\bm{z}_{j}\|_{2}^{2}\end{bmatrix}\right),\quad\forall j\in[n],bold_Π = [ start_ARG start_ROW start_CELL bold_italic_ν ( bold_italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ∣ bold_italic_U start_POSTSUBSCRIPT [ italic_K ] end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT end_CELL end_ROW start_ROW start_CELL ⋮ end_CELL end_ROW start_ROW start_CELL bold_italic_ν ( bold_italic_z start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ∣ bold_italic_U start_POSTSUBSCRIPT [ italic_K ] end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT end_CELL end_ROW end_ARG ] , where bold_italic_ν ( bold_italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ∣ bold_italic_U start_POSTSUBSCRIPT [ italic_K ] end_POSTSUBSCRIPT ) ≐ roman_softmax ( divide start_ARG 1 end_ARG start_ARG 2 italic_η end_ARG [ start_ARG start_ROW start_CELL ∥ bold_italic_U start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_CELL end_ROW start_ROW start_CELL ⋮ end_CELL end_ROW start_ROW start_CELL ∥ bold_italic_U start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_z start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_CELL end_ROW end_ARG ] ) , ∀ italic_j ∈ [ italic_n ] , (4.3.14)

where η\etaitalic_η becomes a learnable temperature parameter. Thus, given an input feature 𝒁\bm{Z}bold_italic_Z, we estimate 𝚷\bm{\Pi}bold_Π using (4.3.14) and then compute the attention operator. Combining the construction of 𝚷\bm{\Pi}bold_Π in (4.3.14) with (4.3.11), we obtain the Token Statistics Self-Attention operator:

TSSA(𝒁𝑼[K])τnk=1K𝑼k𝑫(𝒁,𝝅k𝑼k)𝑼k𝒁diag(𝝅k),\texttt{TSSA}(\bm{Z}\mid\bm{U}_{[K]})\doteq-\frac{\tau}{n}\sum_{k=1}^{K}\bm{U}_{k}\bm{D}(\bm{Z},\bm{\pi}_{k}\mid\bm{U}_{k})\bm{U}_{k}^{\top}\bm{Z}\operatorname{diag}(\bm{\pi}_{k}),TSSA ( bold_italic_Z ∣ bold_italic_U start_POSTSUBSCRIPT [ italic_K ] end_POSTSUBSCRIPT ) ≐ - divide start_ARG italic_τ end_ARG start_ARG italic_n end_ARG ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT bold_italic_D ( bold_italic_Z , bold_italic_π start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ∣ bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_Z roman_diag ( bold_italic_π start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) , (4.3.15)

where 𝝅k\bm{\pi}_{k}bold_italic_π start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT are the columns of 𝚷=𝚷(𝒁𝑼[K])\bm{\Pi}=\bm{\Pi}(\bm{Z}\mid\bm{U}_{[K]})bold_Π = bold_Π ( bold_italic_Z ∣ bold_italic_U start_POSTSUBSCRIPT [ italic_K ] end_POSTSUBSCRIPT ) defined in (4.3.14) and 𝑫\bm{D}bold_italic_D is defined in (4.3.10).

4.4 Summary and Notes

The materials presented in this chapter are based on a series of recent works on this topic, including [CYY+22, WLP+24, WLY+25, WDL+25, YBP+23]. These contributions encompass both theoretical advances and practical methodologies for constructing interpretable deep networks through unrolled optimization. Many of the key results and proofs discussed in this chapter are derived directly from, or inspired by, these foundational works.

The idea of unrolling an optimization algorithm to construct a neural network traces back to the seminal work [GL10]. In this work, the authors demonstrated that sparse coding algorithms—such as the Iterative Shrinkage-Thresholding Algorithm (ISTA)—can be unrolled to form multilayer perceptrons (MLPs), effectively bridging iterative optimization and neural network design. Notably, [MLE19] demonstrated that such unrolled networks are more interpretable, parameter-efficient, and effective compared to generic networks. In this chapter, we build on this perspective to develop principled, white-box deep network architectures by unrolling optimization algorithms that are designed to minimize well-motivated objectives—such as the (sparse) rate reduction objective introduced earlier. This approach not only clarifies the role of each layer in the network but also offers theoretical grounding for architectural choices, moving beyond empirical trial-and-error toward interpretable and goal-driven design. In the following, we compare conventional DNNs, which are typically constructed through empirical design and heuristic tuning, with our mathematically grounded ReduNet architectures:

Conventional DNNs ReduNets
Objectives input/output fitting information gain
Deep architectures trial & error iterative optimization
Layer operators empirical projected gradient
Shift invariance CNNs+augmentation invariant ReduNets
Initializations random/pre-design forward unrolled
Training/fine-tuning back prop forward/back prop
Interpretability black box white box
Representations hidden/latent incoherent subspaces

4.5 Exercises and Extensions

Exercise 4.1.

Let 𝒁=[𝒁1,,𝒁K]d×m\bm{Z}=[\bm{Z}_{1},\dots,\bm{Z}_{K}]\in\mathbb{R}^{d\times m}bold_italic_Z = [ bold_italic_Z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , bold_italic_Z start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ] ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_m end_POSTSUPERSCRIPT with 𝒁kd×mk\bm{Z}_{k}\in\mathbb{R}^{d\times m_{k}}bold_italic_Z start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_m start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUPERSCRIPT for each k[K]k\in[K]italic_k ∈ [ italic_K ]. For some α>0\alpha>0italic_α > 0, let

R(𝒁)=logdet(𝑰+α𝒁𝒁T).\displaystyle R(\bm{Z})=\log\det\left(\bm{I}+\alpha\bm{Z}\bm{Z}^{T}\right).italic_R ( bold_italic_Z ) = roman_log roman_det ( bold_italic_I + italic_α bold_italic_Z bold_italic_Z start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ) .

1. Given any direction 𝑫d×m\bm{D}\in\mathbb{R}^{d\times m}bold_italic_D ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_m end_POSTSUPERSCRIPT, please show that R(𝒁)=α𝑿1𝒁\nabla R(\bm{Z})=\alpha\bm{X}^{-1}\bm{Z}∇ italic_R ( bold_italic_Z ) = italic_α bold_italic_X start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT bold_italic_Z and

2R(𝒁)[𝑫,𝑫]=αTr(𝑿1𝑫𝑫T)α22Tr(𝑿1(𝒁𝑫T+𝑫𝒁T)𝑿1(𝒁𝑫T+𝑫𝒁T)),\displaystyle\nabla^{2}R(\bm{Z})[\bm{D},\bm{D}]=\alpha\mathrm{Tr}\left(\bm{X}^{-1}\bm{D}\bm{D}^{T}\right)-\frac{\alpha^{2}}{2}\mathrm{Tr}\left(\bm{X}^{-1}\left(\bm{Z}\bm{D}^{T}+\bm{D}\bm{Z}^{T}\right)\bm{X}^{-1}\left(\bm{Z}\bm{D}^{T}+\bm{D}\bm{Z}^{T}\right)\right),∇ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_R ( bold_italic_Z ) [ bold_italic_D , bold_italic_D ] = italic_α roman_Tr ( bold_italic_X start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT bold_italic_D bold_italic_D start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ) - divide start_ARG italic_α start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG 2 end_ARG roman_Tr ( bold_italic_X start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( bold_italic_Z bold_italic_D start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT + bold_italic_D bold_italic_Z start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ) bold_italic_X start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( bold_italic_Z bold_italic_D start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT + bold_italic_D bold_italic_Z start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ) ) ,

where 𝑿𝑰+α𝒁𝒁T\bm{X}\doteq\bm{I}+\alpha\bm{Z}\bm{Z}^{T}bold_italic_X ≐ bold_italic_I + italic_α bold_italic_Z bold_italic_Z start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT. Hint: Note that

2R(𝒁)[𝑫,𝑫]limt0R(𝒁+t𝑫)R(𝒁)t,𝑫.\displaystyle\nabla^{2}R(\bm{Z})[\bm{D},\bm{D}]\doteq\left\langle\lim_{t\to 0}\frac{\nabla R(\bm{Z}+t\bm{D})-\nabla R(\bm{Z})}{t},\bm{D}\right\rangle.∇ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_R ( bold_italic_Z ) [ bold_italic_D , bold_italic_D ] ≐ ⟨ roman_lim start_POSTSUBSCRIPT italic_t → 0 end_POSTSUBSCRIPT divide start_ARG ∇ italic_R ( bold_italic_Z + italic_t bold_italic_D ) - ∇ italic_R ( bold_italic_Z ) end_ARG start_ARG italic_t end_ARG , bold_italic_D ⟩ .

2. Please show that

R(𝒁)k=1Klogdet(𝑰+α𝒁k𝒁kT),\displaystyle R(\bm{Z})\leq\sum_{k=1}^{K}\log\det\left(\bm{I}+\alpha\bm{Z}_{k}\bm{Z}_{k}^{T}\right),italic_R ( bold_italic_Z ) ≤ ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT roman_log roman_det ( bold_italic_I + italic_α bold_italic_Z start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT bold_italic_Z start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ) ,

where the equality holds if and only if 𝒁kT𝒁l=𝟎\bm{Z}_{k}^{T}\bm{Z}_{l}=\bm{0}bold_italic_Z start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT bold_italic_Z start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT = bold_0 for all kl[K]k\neq l\in[K]italic_k ≠ italic_l ∈ [ italic_K ].

3. Given some α>0\alpha>0italic_α > 0, let αk=mα/mk\alpha_{k}=m\alpha/m_{k}italic_α start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT = italic_m italic_α / italic_m start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT for each k[K]k\in[K]italic_k ∈ [ italic_K ]. Please derive the closed-form for the first-order critical point of the following function:

f(𝒁k)=12logdet(𝑰+α𝒁k𝒁kT)mk2mlogdet(𝑰+αk𝒁k𝒁kT)λ2𝒁kF2.\displaystyle f(\bm{Z}_{k})=\frac{1}{2}\log\det\left(\bm{I}+\alpha\bm{Z}_{k}\bm{Z}_{k}^{T}\right)-\frac{m_{k}}{2m}\log\det\left(\bm{I}+\alpha_{k}\bm{Z}_{k}\bm{Z}_{k}^{T}\right)-\frac{\lambda}{2}\|\bm{Z}_{k}\|_{F}^{2}.italic_f ( bold_italic_Z start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) = divide start_ARG 1 end_ARG start_ARG 2 end_ARG roman_log roman_det ( bold_italic_I + italic_α bold_italic_Z start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT bold_italic_Z start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ) - divide start_ARG italic_m start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_ARG start_ARG 2 italic_m end_ARG roman_log roman_det ( bold_italic_I + italic_α start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT bold_italic_Z start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT bold_italic_Z start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT ) - divide start_ARG italic_λ end_ARG start_ARG 2 end_ARG ∥ bold_italic_Z start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT .

Hint: Let rk=rank(𝒁k)r_{k}=\mathrm{rank}(\bm{Z}_{k})italic_r start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT = roman_rank ( bold_italic_Z start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ). Consider the following singular value decomposition of 𝒁k\bm{Z}_{k}bold_italic_Z start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT:

𝒁k=𝑷k𝚺k𝑸kT=[𝑷k,1𝑷k,2][𝚺~k𝟎𝟎𝟎][𝑸k,1T𝑸k,2T],\displaystyle\bm{Z}_{k}=\bm{P}_{k}\bm{\Sigma}_{k}\bm{Q}_{k}^{T}=\begin{bmatrix}\bm{P}_{k,1}&\bm{P}_{k,2}\end{bmatrix}\begin{bmatrix}\tilde{\bm{\Sigma}}_{k}&\bm{0}\\ \bm{0}&\bm{0}\end{bmatrix}\begin{bmatrix}\bm{Q}_{k,1}^{T}\\ \bm{Q}_{k,2}^{T}\end{bmatrix},bold_italic_Z start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT = bold_italic_P start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT bold_Σ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT bold_italic_Q start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT = [ start_ARG start_ROW start_CELL bold_italic_P start_POSTSUBSCRIPT italic_k , 1 end_POSTSUBSCRIPT end_CELL start_CELL bold_italic_P start_POSTSUBSCRIPT italic_k , 2 end_POSTSUBSCRIPT end_CELL end_ROW end_ARG ] [ start_ARG start_ROW start_CELL over~ start_ARG bold_Σ end_ARG start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_CELL start_CELL bold_0 end_CELL end_ROW start_ROW start_CELL bold_0 end_CELL start_CELL bold_0 end_CELL end_ROW end_ARG ] [ start_ARG start_ROW start_CELL bold_italic_Q start_POSTSUBSCRIPT italic_k , 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT end_CELL end_ROW start_ROW start_CELL bold_italic_Q start_POSTSUBSCRIPT italic_k , 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT end_CELL end_ROW end_ARG ] ,

where 𝑷k𝒪d\bm{P}_{k}\in\mathcal{O}^{d}bold_italic_P start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ∈ caligraphic_O start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT with 𝑷k,1d×rk\bm{P}_{k,1}\in\mathbb{R}^{d\times r_{k}}bold_italic_P start_POSTSUBSCRIPT italic_k , 1 end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_r start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUPERSCRIPT and 𝑷k,2d×(drk)\bm{P}_{k,2}\in\mathbb{R}^{d\times(d-r_{k})}bold_italic_P start_POSTSUBSCRIPT italic_k , 2 end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × ( italic_d - italic_r start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) end_POSTSUPERSCRIPT, 𝚺kd×mk\bm{\Sigma}_{k}\in\mathbb{R}^{d\times m_{k}}bold_Σ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_m start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUPERSCRIPT with 𝚺~krk×rk\tilde{\bm{\Sigma}}_{k}\in\mathbb{R}^{r_{k}\times r_{k}}over~ start_ARG bold_Σ end_ARG start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_r start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT × italic_r start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUPERSCRIPT being a diagonal matrix, and 𝑸k𝒪mk\bm{Q}_{k}\in\mathcal{O}^{m_{k}}bold_italic_Q start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ∈ caligraphic_O start_POSTSUPERSCRIPT italic_m start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUPERSCRIPT with 𝑸k,1mk×rk\bm{Q}_{k,1}\in\mathbb{R}^{m_{k}\times r_{k}}bold_italic_Q start_POSTSUBSCRIPT italic_k , 1 end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_m start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT × italic_r start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUPERSCRIPT and 𝑷k,2mk×(mkrk)\bm{P}_{k,2}\in\mathbb{R}^{m_{k}\times(m_{k}-r_{k})}bold_italic_P start_POSTSUBSCRIPT italic_k , 2 end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_m start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT × ( italic_m start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT - italic_r start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) end_POSTSUPERSCRIPT.

Exercise 4.2 (Neumann series for matrix inverse).

Let 𝑨n×n\bm{A}\in\mathbb{R}^{n\times n}bold_italic_A ∈ blackboard_R start_POSTSUPERSCRIPT italic_n × italic_n end_POSTSUPERSCRIPT. If 𝑨<1\|\bm{A}\|<1∥ bold_italic_A ∥ < 1, please show

(𝑰𝑨)1=k=1𝑨k.\displaystyle\left(\bm{I}-\bm{A}\right)^{-1}=\sum_{k=1}^{\infty}\bm{A}^{k}.( bold_italic_I - bold_italic_A ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT = ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∞ end_POSTSUPERSCRIPT bold_italic_A start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT . (4.5.1)

Hint: The proof consists of two steps.
(i) Step 1: Please show that the infinite series k=1𝑨k\sum_{k=1}^{\infty}\bm{A}^{k}∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∞ end_POSTSUPERSCRIPT bold_italic_A start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT converges when 𝑨<1\bm{A}<1bold_italic_A < 1 using 𝑨k𝑨k\|\bm{A}^{k}\|\leq\|\bm{A}\|^{k}∥ bold_italic_A start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ∥ ≤ ∥ bold_italic_A ∥ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT.
(ii) Step 2: Compute the matrix product (𝑰𝑨)k=1𝑨k(\bm{I}-\bm{A})\sum_{k=1}^{\infty}\bm{A}^{k}( bold_italic_I - bold_italic_A ) ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∞ end_POSTSUPERSCRIPT bold_italic_A start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT.

Exercise 4.3.

Please compute the gradients in (4.3.9) and (4.3.10).

Exercise 4.4.

Please show Corollary 4.1 when KpdKp\leq ditalic_K italic_p ≤ italic_d.