Chapter 3 Pursuing Low-Dimensionality via Lossy Compression

We compress to learn, and we learn to compress.”
  — High-dimensional Data Analysis, Wright and Ma, 2022

In Chapter 2, we have shown how to learn simple classes of distributions whose supports are assumed to be either a single or a mixture of low-dimensional subspaces or low-rank Gaussians. For further simplicity, the different (hidden) linear or Gaussian modes are assumed to be orthogonal or independent111Or can be easily reduced to such idealistic cases., as illustrated in Figure 2.4. As we have shown, for such special distributions, one can derive rather simple and effective learning algorithms with correctness and efficiency guarantees. The geometric and statistical interpretation of operations in the associated algorithms is also very clear.

In practice, both linearity and independence are rather idealistic assumptions that distributions of real-world high-dimensional data rarely satisfy. The only thing that we may assume is that the intrinsic dimension of the distribution is very low compared to the dimension of the ambient space in which the data are embedded. Hence, in this chapter, we show how to learn a more general class of low-dimensional distributions in a high-dimensional space that is not necessarily (piecewise) linear.

It is typical that the distribution of real data often contains multiple components or modes, say corresponding to different classes of objects in the case of images. These modes might not be statistically independent and they may even have different intrinsic dimensions. It is also typical that we have access to only a finite number of samples of the distribution. Therefore, in general, we may assume our data are distributed on a mixture of (nonlinear) low-dimensional submanifolds in a high-dimensional space. Figure 3.1 illustrates an example of such a distribution.

To learn such a distribution under such conditions, there are several fundamental questions that we need to address:

  • What is a general approach to learn a general low-dimensional distribution in a high-dimensional space and represent the learned distribution?

  • How do we measure the complexity of the resulting representation so that we can effectively exploit the low dimensionality to learn?

  • How do we make the learning process computationally tractable and even scalable, as the ambient dimension is usually high and the number of samples typically large?

As we will see, the fundamental idea of compression, or dimension reduction, which has been shown to be very effective for the linear/independent case, still serves as a general principle for developing effective computational models and methods for learning general low-dimensional distributions.

Due to its theoretical and practical significance, we will study in greater depth how this general framework of learning low-dimensional distributions via compression substantiates when the distribution of interest can be well-modeled or approximated by a mixture of low-dimensional subspaces or low-rank Gaussians.

Figure 3.1 : Data distributed on a mixture of low-dimensional submanifolds ∪ j ℳ j \cup_{j}\mathcal{M}_{j} ∪ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT caligraphic_M start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT in a very high-dimensional ambient space, say ℝ D \mathbb{R}^{D} blackboard_R start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT .
Figure 3.1: Data distributed on a mixture of low-dimensional submanifolds jj\cup_{j}\mathcal{M}_{j}∪ start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT caligraphic_M start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT in a very high-dimensional ambient space, say D\mathbb{R}^{D}blackboard_R start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT.

3.1 Entropy Minimization and Compression

3.1.1 Entropy and Coding Rate

In Chapter 1, we have mentioned that the goal of learning is to find the simplest way to generate a given set of data. Conceptually, the Kolmogorov complexity was intended to provide such a measure of complexity, but it is not computable and is not associated with any implementable scheme that can actually reproduce the data. Hence, we need an alternative, computable, and realizable measure of complexity. That leads us to the notion of entropy, introduced by Shannon in 1948 [Sha48].

To illustrate the constructive nature of entropy, let us start with the simplest case. Suppose that we have a discrete random variable that takes NNitalic_N distinct values, or tokens, {𝒙1,,𝒙N}\{\bm{x}_{1},\ldots,\bm{x}_{N}\}{ bold_italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , bold_italic_x start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT } with equal probability 1/N1/N1 / italic_N. Then we could encode each token 𝒙i\bm{x}_{i}bold_italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT using the log2N\log_{2}Nroman_log start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT italic_N-bit binary representation of iiitalic_i. This coding scheme could be generalized to encoding arbitrary discrete distributions [CT91]: Given a distribution ppitalic_p such that i=1Np(𝒙i)=1\sum_{i=1}^{N}p(\bm{x}_{i})=1∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT italic_p ( bold_italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) = 1, one could assign each token 𝒙i\bm{x}_{i}bold_italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT with probability p(𝒙i)p(\bm{x}_{i})italic_p ( bold_italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) to a binary code of size log2[1/p(𝒙i)]=log2p(𝒙i)\log_{2}[1/p(\bm{x}_{i})]=-\log_{2}p(\bm{x}_{i})roman_log start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT [ 1 / italic_p ( bold_italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ] = - roman_log start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT italic_p ( bold_italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) bits. Hence the average number of bits, or the coding rate, needed to encode any sample from the distribution p()p(\cdot)italic_p ( ⋅ ) is given by the expression:222By the convention of Information Theory [CT91], the log\logroman_log here is to the base 222. Hence entropy is measured in (binary) bits.

H(𝒙)𝔼[log1/p(𝒙)]=i=1Np(𝒙i)logp(𝒙i).H(\bm{x})\doteq\mathbb{E}[\log 1/p(\bm{x})]=-\sum_{i=1}^{N}p(\bm{x}_{i})\log p(\bm{x}_{i}).italic_H ( bold_italic_x ) ≐ blackboard_E [ roman_log 1 / italic_p ( bold_italic_x ) ] = - ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT italic_p ( bold_italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) roman_log italic_p ( bold_italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) . (3.1.1)

This is known as the entropy of the (discrete) distribution p()p(\cdot)italic_p ( ⋅ ). Note that this entropy is always nonnegative and it is zero if and only if p(𝒙i)=1p(\bm{x}_{i})=1italic_p ( bold_italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) = 1 for some 𝒙i\bm{x}_{i}bold_italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT with i[N]i\in[N]italic_i ∈ [ italic_N ].333Here notice that we use the fact limp0plogp=0\lim_{p\rightarrow 0}p\log p=0roman_lim start_POSTSUBSCRIPT italic_p → 0 end_POSTSUBSCRIPT italic_p roman_log italic_p = 0.

3.1.2 Differential Entropy

When the random variable 𝒙D\bm{x}\in\mathbb{R}^{D}bold_italic_x ∈ blackboard_R start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT is continuous and has a probability density ppitalic_p, one may view that the limit of the above sum (3.1.1) is related to an integral:

h(𝒙)𝔼[log1/p(𝒙)]=Dp(𝝃)logp(𝝃)d𝝃.h(\bm{x})\doteq\operatorname{\mathbb{E}}[\log 1/p(\bm{x})]=-\int_{\mathbb{R}^{D}}p(\bm{\xi})\log p(\bm{\xi})\mathrm{d}\bm{\xi}.italic_h ( bold_italic_x ) ≐ blackboard_E [ roman_log 1 / italic_p ( bold_italic_x ) ] = - ∫ start_POSTSUBSCRIPT blackboard_R start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT end_POSTSUBSCRIPT italic_p ( bold_italic_ξ ) roman_log italic_p ( bold_italic_ξ ) roman_d bold_italic_ξ . (3.1.2)

More precisely, given a continuous variable 𝒙\bm{x}bold_italic_x, we may quantize it with a quantization size ϵ>0\epsilon>0italic_ϵ > 0. Denote the resulting discrete variable as 𝒙ϵ\bm{x}^{\epsilon}bold_italic_x start_POSTSUPERSCRIPT italic_ϵ end_POSTSUPERSCRIPT. Then one can show that H(𝒙ϵ)+log(ϵ)h(𝒙)H(\bm{x}^{\epsilon})+\log(\epsilon)\approx h(\bm{x})italic_H ( bold_italic_x start_POSTSUPERSCRIPT italic_ϵ end_POSTSUPERSCRIPT ) + roman_log ( italic_ϵ ) ≈ italic_h ( bold_italic_x ). Hence, when ϵ\epsilonitalic_ϵ is small, the differential entropy h(𝒙)h(\bm{x})italic_h ( bold_italic_x ) can be negative. Interested readers may refer to [CT91] for a more detailed explanation.

Example 3.1 (Entropy of Gaussian Distributions).

Through direct calculation, it is possible to show that the entropy of a Gaussian distribution x𝒩(μ,σ2)x\sim\mathcal{N}(\mu,\sigma^{2})italic_x ∼ caligraphic_N ( italic_μ , italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) is given by:

h(x)=12log(2πσ2)+12.h(x)=\frac{1}{2}\log(2\pi\sigma^{2})+\frac{1}{2}.italic_h ( italic_x ) = divide start_ARG 1 end_ARG start_ARG 2 end_ARG roman_log ( 2 italic_π italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) + divide start_ARG 1 end_ARG start_ARG 2 end_ARG . (3.1.3)

It is also known that the Gaussian distribution achieves the maximal entropy for all distributions with the same variance σ2\sigma^{2}italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT. The entropy of a multivariate Gaussian distribution 𝒙𝒩(𝝁,𝚺)\bm{x}\sim\mathcal{N}(\bm{\mu},\bm{\Sigma})bold_italic_x ∼ caligraphic_N ( bold_italic_μ , bold_Σ ) in D\mathbb{R}^{D}blackboard_R start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT is given by:

h(𝒙)=D2(1+log(2π))+12logdet(𝚺).h(\bm{x})=\frac{D}{2}(1+\log(2\pi))+\frac{1}{2}\log\det(\bm{\Sigma}).italic_h ( bold_italic_x ) = divide start_ARG italic_D end_ARG start_ARG 2 end_ARG ( 1 + roman_log ( 2 italic_π ) ) + divide start_ARG 1 end_ARG start_ARG 2 end_ARG roman_log roman_det ( bold_Σ ) . (3.1.4)

\blacksquare

Similar to the entropy for a discrete distribution, we would like the differential entropy to be associated with the coding rate of some realizable coding scheme. For example, as above, we may discretize the domain of the distribution with a grid of size ϵ>0\epsilon>0italic_ϵ > 0. The coding rate of the resulting discrete distribution can be viewed as an approximation to the differential entropy [CT91].

Be aware that there are some caveats associated with the definition of differential entropy. For a distribution in a high-dimensional space, when its support becomes degenerate (low-dimensional), its differential entropy diverges to -\infty- ∞. This fact is proved in Theorem B.1 (we also recall the maximum entropy characterization of the Gaussian distribution mentioned above in Theorem B.1) but even in the simple explicit case of Gaussian distributions (3.1.4), when the covariance 𝚺\bm{\Sigma}bold_Σ is singular, we can see that logdet(𝚺)=\log\det(\bm{\Sigma})=-\inftyroman_log roman_det ( bold_Σ ) = - ∞ so we have h(𝒙)=h(\bm{x})=-\inftyitalic_h ( bold_italic_x ) = - ∞. In such a situation, it is not obvious how to properly quantize or encode such a distribution. Nevertheless, degenerate (Gaussian) distributions are precisely the simplest possible, and arguably the most important, instances of low-dimensional distributions in a high-dimensional space. In this chapter, we will discuss a complete resolution to this seeming difficulty with degeneracy.

3.1.3 Minimizing Coding Rate

Remember that the learning problem entails the recovery of a (potentially continuous) distribution p(𝒙)p(\bm{x})italic_p ( bold_italic_x ) from a set of samples {𝒙1,,𝒙N}\{\bm{x}_{1},\ldots,\bm{x}_{N}\}{ bold_italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , bold_italic_x start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT } drawn from the distribution. For ease of exposition, we write 𝑿=[𝒙1,,𝒙N]D×N\bm{X}=[\bm{x}_{1},\ldots,\bm{x}_{N}]\in\mathbb{R}^{D\times N}bold_italic_X = [ bold_italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , bold_italic_x start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT ] ∈ blackboard_R start_POSTSUPERSCRIPT italic_D × italic_N end_POSTSUPERSCRIPT . Given that the distributions of interest here are (nearly) low-dimensional, we should expect that their (differential) entropy is very small. But unlike the situations that we have studied in the previous chapter, in general we do not know the family of (analytical) low-dimensional models to which the distribution p(𝒙)p(\bm{x})italic_p ( bold_italic_x ) belongs. So checking whether the entropy is small seems to be the only guideline that we can rely on to identify and model the distribution.

Now given the samples alone without knowing what p(𝒙)p(\bm{x})italic_p ( bold_italic_x ) is, in theory they could be interpreted as samples from any generic distribution. In particular, they could be interpreted as any of the following cases:

  1. 1.

    as samples from the empirical distribution p𝑿p^{\bm{X}}italic_p start_POSTSUPERSCRIPT bold_italic_X end_POSTSUPERSCRIPT itself, which assigns 1/N1/N1 / italic_N probability to each of the NNitalic_N samples 𝒙i,i=1,,N\bm{x}_{i},i=1,\ldots,Nbold_italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_i = 1 , … , italic_N.

  2. 2.

    as samples from a standard normal distribution 𝒙npn𝒩(𝟎,σ2𝑰)\bm{x}^{n}\sim p^{n}\doteq\mathcal{N}(\bm{0},\sigma^{2}\bm{I})bold_italic_x start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ∼ italic_p start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ≐ caligraphic_N ( bold_0 , italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT bold_italic_I ) with a variance σ2\sigma^{2}italic_σ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT large enough (say larger than the sample norms);

  3. 3.

    as samples from a normal distribution 𝒙epe𝒩(𝟎,𝚺^)\bm{x}^{e}\sim p^{e}\doteq\mathcal{N}(\bm{0},\hat{\bm{\Sigma}})bold_italic_x start_POSTSUPERSCRIPT italic_e end_POSTSUPERSCRIPT ∼ italic_p start_POSTSUPERSCRIPT italic_e end_POSTSUPERSCRIPT ≐ caligraphic_N ( bold_0 , over^ start_ARG bold_Σ end_ARG ) with a covariance 𝚺^=1N𝑿𝑿T\hat{\bm{\Sigma}}=\frac{1}{N}\bm{X}\bm{X}^{T}over^ start_ARG bold_Σ end_ARG = divide start_ARG 1 end_ARG start_ARG italic_N end_ARG bold_italic_X bold_italic_X start_POSTSUPERSCRIPT italic_T end_POSTSUPERSCRIPT being the empirical covariance of the samples;

  4. 4.

    as samples from a distribution 𝒙^q^(𝒙)\hat{\bm{x}}\sim\hat{q}(\bm{x})over^ start_ARG bold_italic_x end_ARG ∼ over^ start_ARG italic_q end_ARG ( bold_italic_x ) that closely approximates the ground truth distribution ppitalic_p.

Now the question is: which one is better, and in what sense? Suppose that you believe these data 𝑿\bm{X}bold_italic_X are drawn from a particular distribution q(𝒙)q(\bm{x})italic_q ( bold_italic_x ), which may be one of the above distributions considered. Then we could encode the data points with the optimal code book for the distribution q(𝒙)q(\bm{x})italic_q ( bold_italic_x ). The required average coding length (or coding rate) is given by:

1Ni=1Nlogq(𝒙i)Dp(𝝃)logq(𝝃)d𝝃\frac{1}{N}\sum_{i=1}^{N}-\log q(\bm{x}_{i})\quad\approx\quad-\int_{\mathbb{R}^{D}}p(\bm{\xi})\log q(\bm{\xi})\mathrm{d}\bm{\xi}divide start_ARG 1 end_ARG start_ARG italic_N end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT - roman_log italic_q ( bold_italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ≈ - ∫ start_POSTSUBSCRIPT blackboard_R start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT end_POSTSUBSCRIPT italic_p ( bold_italic_ξ ) roman_log italic_q ( bold_italic_ξ ) roman_d bold_italic_ξ (3.1.5)

as the number of samples NNitalic_N becomes large. If we have identified the correct distribution p(𝒙)p(\bm{x})italic_p ( bold_italic_x ), the coding rate is given by the entropy p(𝝃)logp(𝝃)d𝝃-\int p(\bm{\xi})\log p(\bm{\xi})\mathrm{d}\bm{\xi}- ∫ italic_p ( bold_italic_ξ ) roman_log italic_p ( bold_italic_ξ ) roman_d bold_italic_ξ. It turns out that the above coding length p(𝝃)logq(𝝃)d𝝃-\int p(\bm{\xi})\log q(\bm{\xi})\mathrm{d}\bm{\xi}- ∫ italic_p ( bold_italic_ξ ) roman_log italic_q ( bold_italic_ξ ) roman_d bold_italic_ξ is always larger than or equal to the entropy unless q(𝒙)=p(𝒙)q(\bm{x})=p(\bm{x})italic_q ( bold_italic_x ) = italic_p ( bold_italic_x ). Their difference, denoted as

𝖪𝖫(pq)\displaystyle\operatorname{\mathsf{KL}}(p\;\|\;q)sansserif_KL ( italic_p ∥ italic_q ) \displaystyle\doteq Dp(𝝃)logq(𝝃)d𝝃(Dp(𝝃)logp(𝝃)d𝝃)\displaystyle-\int_{\mathbb{R}^{D}}p(\bm{\xi})\log q(\bm{\xi})\mathrm{d}\bm{\xi}-\Big{(}-\int_{\mathbb{R}^{D}}p(\bm{\xi})\log p(\bm{\xi})\mathrm{d}\bm{\xi}\Big{)}- ∫ start_POSTSUBSCRIPT blackboard_R start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT end_POSTSUBSCRIPT italic_p ( bold_italic_ξ ) roman_log italic_q ( bold_italic_ξ ) roman_d bold_italic_ξ - ( - ∫ start_POSTSUBSCRIPT blackboard_R start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT end_POSTSUBSCRIPT italic_p ( bold_italic_ξ ) roman_log italic_p ( bold_italic_ξ ) roman_d bold_italic_ξ ) (3.1.6)
=\displaystyle== Dp(𝝃)logp(𝝃)q(𝝃)d𝝃\displaystyle\int_{\mathbb{R}^{D}}p(\bm{\xi})\log\frac{p(\bm{\xi})}{q(\bm{\xi})}\mathrm{d}\bm{\xi}∫ start_POSTSUBSCRIPT blackboard_R start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT end_POSTSUBSCRIPT italic_p ( bold_italic_ξ ) roman_log divide start_ARG italic_p ( bold_italic_ξ ) end_ARG start_ARG italic_q ( bold_italic_ξ ) end_ARG roman_d bold_italic_ξ (3.1.7)

is known as the Kullback-Leibler (KL) divergence, or relative entropy. This quantity is always non-negative.

Theorem 3.1 (Information Inequality).

Let p(𝐱),q(𝐱)p(\bm{x}),q(\bm{x})italic_p ( bold_italic_x ) , italic_q ( bold_italic_x ) be two probability density functions (that have the same support). Then 𝖪𝖫(pq)0\operatorname{\mathsf{KL}}(p\;\|\;q)\geq 0sansserif_KL ( italic_p ∥ italic_q ) ≥ 0, where the inequality becomes equality if and only if p=qp=qitalic_p = italic_q.444Technically, this equality should be taken to mean “almost everywhere”, i.e., except possibly on a set of zero measure (volume), since this set would not impact the value of any integral.

Proof.
𝖪𝖫(pq)\displaystyle-\operatorname{\mathsf{KL}}(p\;\|\;q)- sansserif_KL ( italic_p ∥ italic_q ) =\displaystyle== Dp(𝝃)logp(𝝃)q(𝝃)d𝝃=Dp(𝝃)logq(𝝃)p(𝝃)d𝝃\displaystyle-\int_{\mathbb{R}^{D}}p(\bm{\xi})\log\frac{p(\bm{\xi})}{q(\bm{\xi})}\mathrm{d}\bm{\xi}=\int_{\mathbb{R}^{D}}p(\bm{\xi})\log\frac{q(\bm{\xi})}{p(\bm{\xi})}\mathrm{d}\bm{\xi}- ∫ start_POSTSUBSCRIPT blackboard_R start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT end_POSTSUBSCRIPT italic_p ( bold_italic_ξ ) roman_log divide start_ARG italic_p ( bold_italic_ξ ) end_ARG start_ARG italic_q ( bold_italic_ξ ) end_ARG roman_d bold_italic_ξ = ∫ start_POSTSUBSCRIPT blackboard_R start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT end_POSTSUBSCRIPT italic_p ( bold_italic_ξ ) roman_log divide start_ARG italic_q ( bold_italic_ξ ) end_ARG start_ARG italic_p ( bold_italic_ξ ) end_ARG roman_d bold_italic_ξ
\displaystyle\leq logDp(𝝃)q(𝝃)p(𝝃)d𝝃=logDq(𝝃)d𝝃=log1=0,\displaystyle\log\int_{\mathbb{R}^{D}}p(\bm{\xi})\frac{q(\bm{\xi})}{p(\bm{\xi})}\mathrm{d}\bm{\xi}=\log\int_{\mathbb{R}^{D}}q(\bm{\xi})\mathrm{d}\bm{\xi}=\log 1=0,roman_log ∫ start_POSTSUBSCRIPT blackboard_R start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT end_POSTSUBSCRIPT italic_p ( bold_italic_ξ ) divide start_ARG italic_q ( bold_italic_ξ ) end_ARG start_ARG italic_p ( bold_italic_ξ ) end_ARG roman_d bold_italic_ξ = roman_log ∫ start_POSTSUBSCRIPT blackboard_R start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT end_POSTSUBSCRIPT italic_q ( bold_italic_ξ ) roman_d bold_italic_ξ = roman_log 1 = 0 ,

where the first inequality follows from Jensen’s inequality and the fact that the function log()\log(\cdot)roman_log ( ⋅ ) is strictly concave. The equality holds if and only if p=qp=qitalic_p = italic_q . ∎

Hence, given a set of sampled data 𝑿\bm{X}bold_italic_X, to determine which case is better among pnp^{n}italic_p start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT, pep^{e}italic_p start_POSTSUPERSCRIPT italic_e end_POSTSUPERSCRIPT, and q^\hat{q}over^ start_ARG italic_q end_ARG, we may compare their coding rates for 𝑿\bm{X}bold_italic_X and see which one gives the lowest rate. We know from the above that the (theoretically achievable) coding rate for a distribution is closely related to its entropy. In general, we have:

h(𝒙n)>h(𝒙e)>h(𝒙^).h(\bm{x}^{n})>h(\bm{x}^{e})>h(\hat{\bm{x}}).italic_h ( bold_italic_x start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ) > italic_h ( bold_italic_x start_POSTSUPERSCRIPT italic_e end_POSTSUPERSCRIPT ) > italic_h ( over^ start_ARG bold_italic_x end_ARG ) . (3.1.8)

Hence, if the data 𝑿\bm{X}bold_italic_X were encoded by the code book associated with each of these distributions, the coding rate for 𝑿\bm{X}bold_italic_X would in general decrease in the same order:

p(𝒙n)p(𝒙e)p(𝒙^).p(\bm{x}^{n})\rightarrow p(\bm{x}^{e})\rightarrow p(\hat{\bm{x}}).italic_p ( bold_italic_x start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ) → italic_p ( bold_italic_x start_POSTSUPERSCRIPT italic_e end_POSTSUPERSCRIPT ) → italic_p ( over^ start_ARG bold_italic_x end_ARG ) . (3.1.9)

This observation gives us a general guideline on how we may be able to pursue a distribution p(𝒙)p(\bm{x})italic_p ( bold_italic_x ) which has a low-dimensional structure. It suggests two possible approaches:

  1. 1.

    Starting with a general distribution (say a normal distribution) with high entropy, gradually transforming the distribution towards the (empirical) distribution of the data by reducing entropy.

  2. 2.

    Among a large family of (parametric or non-parametric) distributions with explicit coding schemes that encode the given data, progressively search for better coding schemes that give lower coding rates.

Conceptually, both approaches are essentially trying to do the same thing. For the first approach, we need to make sure such a path of transformation exists and is computable. For the second approach, it is necessary that the chosen family is rich enough and can closely approximate (or contain) the ground truth distribution. For either approach, we need to ensure that solutions with lower entropy or better coding rates can be efficiently computed and converge to the desired distribution quickly.555Say the distribution of real-world data such as images and texts. We will explore both approaches in the two remaining sections of this chapter.

3.2 Compression via Denoising

In this section, we will describe a natural and computationally tractable way to learn a distribution p(𝒙)p(\bm{x})italic_p ( bold_italic_x ) by way of learning a parametric encoding of our distribution such that the representation has the minimum entropy or coding rate, then using this encoding to transform high-entropy samples from a standard Gaussian into low-entropy samples from the target distribution, as illustrated in Figure 3.2. This presents a methodology that utilizes both approaches above in order to learn and sample from the distribution.

Figure 3.2 : Illustration of an iterative denoising process that, starting from an isotropic Gaussian distribution, converges to an arbitrary data distribution.
Figure 3.2: Illustration of an iterative denoising process that, starting from an isotropic Gaussian distribution, converges to an arbitrary data distribution.

3.2.1 Diffusion and Denoising Processes

We first want to find a procedure to decrease the entropy of a given very noisy sample into a lower-entropy sample from the data distribution. Here, we describe a potential approach—one of many, but perhaps the most natural way to attack this problem. First, we find a way to gradually increase the entropy of existing samples from the data distribution. Then, we find an approximate inverse of this process. But in general, the operation of increasing entropy does not have an inverse, as information from the original distribution may be destroyed. We will thus tackle a special case where (1) the operation of adding entropy takes on a simple, computable, and reversible form; (2) we can obtain a (parametric) encoding of the data distribution, as alluded to in the above pair of approaches. As we will see, the above two factors will ensure that our approach is possible.

We will increase the entropy in arguably the simplest possible way, i.e., adding isotropic Gaussian noise. More precisely, given the random variable 𝒙\bm{x}bold_italic_x, we can consider the stochastic process (𝒙t)t[0,T](\bm{x}_{t})_{t\in[0,T]}( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT italic_t ∈ [ 0 , italic_T ] end_POSTSUBSCRIPT which adds gradual noise to it, i.e.,

𝒙t𝒙+t𝒈,t[0,T],\bm{x}_{t}\doteq\bm{x}+t\bm{g},\qquad\forall t\in[0,T],bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ≐ bold_italic_x + italic_t bold_italic_g , ∀ italic_t ∈ [ 0 , italic_T ] , (3.2.1)

where T[0,)T\in[0,\infty)italic_T ∈ [ 0 , ∞ ) is a time horizon and 𝒈𝒩(𝟎,𝑰)\bm{g}\sim\operatorname{\mathcal{N}}(\bm{0},\bm{I})bold_italic_g ∼ caligraphic_N ( bold_0 , bold_italic_I ) is drawn independently of 𝒙\bm{x}bold_italic_x. This process is an example of a diffusion process, so-named because it spreads the probability mass out over all of D\mathbb{R}^{D}blackboard_R start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT as time goes on, increasing the entropy over time. This intuition is confirmed graphically by Figure 3.3, and rigorously via the following theorem.

Theorem 3.2 (Simplified Version of Theorem B.2).

Suppose that (𝐱t)t[0,T](\bm{x}_{t})_{t\in[0,T]}( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT italic_t ∈ [ 0 , italic_T ] end_POSTSUBSCRIPT follows the model (3.2.1). For any t(0,T]t\in(0,T]italic_t ∈ ( 0 , italic_T ], the random variable 𝐱t\bm{x}_{t}bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT has differential entropy h(𝐱t)>h(\bm{x}_{t})>-\inftyitalic_h ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) > - ∞. Moreover, under certain technical conditions on 𝐱\bm{x}bold_italic_x,

ddth(𝒙t)>0,t(0,T],\frac{\mathrm{d}}{\mathrm{d}t}h(\bm{x}_{t})>0,\qquad\forall t\in(0,T],divide start_ARG roman_d end_ARG start_ARG roman_d italic_t end_ARG italic_h ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) > 0 , ∀ italic_t ∈ ( 0 , italic_T ] , (3.2.2)

showing that the entropy of the noised 𝐱\bm{x}bold_italic_x increases over time ttitalic_t.

The proof is elementary, but it is rather long, so we postpone it to Section B.2.1. The main as-yet unstated implication of this result is that h(𝒙t)>h(𝒙)h(\bm{x}_{t})>h(\bm{x})italic_h ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) > italic_h ( bold_italic_x ) for every t>0t>0italic_t > 0. To see this, note that if h(𝒙)=h(\bm{x})=-\inftyitalic_h ( bold_italic_x ) = - ∞ then h(𝒙t)>h(\bm{x}_{t})>-\inftyitalic_h ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) > - ∞ for all t>0t>0italic_t > 0, and if h(𝒙)>h(\bm{x})>-\inftyitalic_h ( bold_italic_x ) > - ∞ then h(𝒙t)=h(𝒙)+0t[ddsh(𝒙s)]ds>h(𝒙)h(\bm{x}_{t})=h(\bm{x})+\int_{0}^{t}[\frac{\mathrm{d}}{\mathrm{d}s}h(\bm{x}_{s})]\mathrm{d}s>h(\bm{x})italic_h ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) = italic_h ( bold_italic_x ) + ∫ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT [ divide start_ARG roman_d end_ARG start_ARG roman_d italic_s end_ARG italic_h ( bold_italic_x start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT ) ] roman_d italic_s > italic_h ( bold_italic_x ) by the fundamental theorem of calculus, so in both cases h(𝒙t)>h(𝒙)h(\bm{x}_{t})>h(\bm{x})italic_h ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) > italic_h ( bold_italic_x ) for every t>0t>0italic_t > 0.

Figure 3.3 : Diffusing a mixture of Gaussians. From left to right, we observe the evolution of the density as t t italic_t grows from 0 to 10 10 10 , along with some representative samples. Each region is colored by its density ( 0.0 0.0 0.0 is completely white, > 0.01 >0.01 > 0.01 is very dark blue, every other value maps to some shade of blue in between.) We observe that the probability mass gets less concentrated as t t italic_t increases, signaling that entropy increases.
Figure 3.3: Diffusing a mixture of Gaussians. From left to right, we observe the evolution of the density as ttitalic_t grows from 0 to 101010, along with some representative samples. Each region is colored by its density (0.00.00.0 is completely white, >0.01>0.01> 0.01 is very dark blue, every other value maps to some shade of blue in between.) We observe that the probability mass gets less concentrated as ttitalic_t increases, signaling that entropy increases.

The inverse operation to adding noise is known as denoising. It is a classical and well-studied topic in signal processing and system theory, such as the Wiener filter and the Kalman filter. Several problems discussed in Chapter 2, such as PCA, ICA, and Dictionary Learning, are specific instances of the denoising problem. For a fixed ttitalic_t and the additive Gaussian noise model (3.2.1), the denoising problem can be formulated as attempting to learn a function 𝒙¯(t,)\bar{\bm{x}}^{\ast}(t,\cdot)over¯ start_ARG bold_italic_x end_ARG start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( italic_t , ⋅ ) which forms the best possible approximation (in expectation) of the true random variable 𝒙\bm{x}bold_italic_x, given both ttitalic_t and 𝒙t\bm{x}_{t}bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT:

𝒙¯(t,)argmin𝒙¯(t,)𝔼𝒙,𝒙t𝒙𝒙¯(t,𝒙t)22.\bar{\bm{x}}^{\ast}(t,\cdot)\in\operatorname*{arg\ min}_{\bar{\bm{x}}(t,\cdot)}\operatorname{\mathbb{E}}_{\bm{x},\bm{x}_{t}}\|\bm{x}-\bar{\bm{x}}(t,\bm{x}_{t})\|_{2}^{2}.over¯ start_ARG bold_italic_x end_ARG start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( italic_t , ⋅ ) ∈ start_OPERATOR roman_arg roman_min end_OPERATOR start_POSTSUBSCRIPT over¯ start_ARG bold_italic_x end_ARG ( italic_t , ⋅ ) end_POSTSUBSCRIPT blackboard_E start_POSTSUBSCRIPT bold_italic_x , bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT ∥ bold_italic_x - over¯ start_ARG bold_italic_x end_ARG ( italic_t , bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT . (3.2.3)

The solution to this problem, when optimizing 𝒙¯(t,)\bar{\bm{x}}(t,\cdot)over¯ start_ARG bold_italic_x end_ARG ( italic_t , ⋅ ) over all possible (square-integrable) functions, is the so-called Bayes optimal denoiser:

𝒙¯(t,𝝃)𝔼[𝒙𝒙t=𝝃].\bar{\bm{x}}^{\ast}(t,\bm{\xi})\doteq\operatorname{\mathbb{E}}[\bm{x}\mid\bm{x}_{t}=\bm{\xi}].over¯ start_ARG bold_italic_x end_ARG start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( italic_t , bold_italic_ξ ) ≐ blackboard_E [ bold_italic_x ∣ bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = bold_italic_ξ ] . (3.2.4)

This expression justifies the notation 𝒙¯\bar{\bm{x}}over¯ start_ARG bold_italic_x end_ARG, which is meant to compute a conditional expectation (i.e., conditional mean or conditional average). In short, it attempts to remove the noise from the noisy input, outputting the best possible guess (in expectation and w.r.t. the 2\ell^{2}roman_ℓ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT-distance) of the (de-noised) original random variable.

Example 3.2 (Denoising Gaussian Noise from a Mixture of Gaussians).

In this example we compute the Bayes optimal denoiser for an incredibly important class of distributions, the Gaussian mixture model. To start, let us fix parameters for the distribution: mixture weights 𝝅K\bm{\pi}\in\mathbb{R}^{K}bold_italic_π ∈ blackboard_R start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT, component means {𝝁k}k=1KD\{\bm{\mu}_{k}\}_{k=1}^{K}\subseteq\mathbb{R}^{D}{ bold_italic_μ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT ⊆ blackboard_R start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT, and component covariances {𝚺k}k=1K𝖯𝖲𝖣(D)\{\bm{\Sigma}_{k}\}_{k=1}^{K}\subseteq\mathsf{PSD}(D){ bold_Σ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT ⊆ sansserif_PSD ( italic_D ), where 𝖯𝖲𝖣(D)\mathsf{PSD}(D)sansserif_PSD ( italic_D ) is the set of D×DD\times Ditalic_D × italic_D symmetric positive semidefinite matrices. Now, suppose 𝒙\bm{x}bold_italic_x is generated by the following two-step procedure:

  • First, an index (or label) y[K]y\in[K]italic_y ∈ [ italic_K ] is sampled such that y=ky=kitalic_y = italic_k with probability πk\pi_{k}italic_π start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT.

  • Second, 𝒙\bm{x}bold_italic_x is sampled from the normal distribution 𝒩(𝝁y,𝚺y)\operatorname{\mathcal{N}}(\bm{\mu}_{y},\bm{\Sigma}_{y})caligraphic_N ( bold_italic_μ start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT , bold_Σ start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT ).

Then 𝒙\bm{x}bold_italic_x has distribution

𝒙k=1Kπk𝒩(𝝁k,𝚺k),\bm{x}\sim\sum_{k=1}^{K}\pi_{k}\operatorname{\mathcal{N}}(\bm{\mu}_{k},\bm{\Sigma}_{k}),bold_italic_x ∼ ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT italic_π start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT caligraphic_N ( bold_italic_μ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT , bold_Σ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) , (3.2.5)

and so

𝒙t=𝒙+t𝒈k=1Kπk𝒩(𝝁k,𝚺k+t2𝑰).\bm{x}_{t}=\bm{x}+t\bm{g}\sim\sum_{k=1}^{K}\pi_{k}\operatorname{\mathcal{N}}(\bm{\mu}_{k},\bm{\Sigma}_{k}+t^{2}\bm{I}).bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = bold_italic_x + italic_t bold_italic_g ∼ ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT italic_π start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT caligraphic_N ( bold_italic_μ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT , bold_Σ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT + italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT bold_italic_I ) . (3.2.6)

Let us define φ(𝒙;𝝁,𝚺)\varphi(\bm{x};\bm{\mu},\bm{\Sigma})italic_φ ( bold_italic_x ; bold_italic_μ , bold_Σ ) as the probability density of 𝒩(𝝁,𝚺)\operatorname{\mathcal{N}}(\bm{\mu},\bm{\Sigma})caligraphic_N ( bold_italic_μ , bold_Σ ) evaluated at 𝒙\bm{x}bold_italic_x. In this notation, the density of 𝒙t\bm{x}_{t}bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT is

pt(𝒙t)=k=1Kπkφ(𝒙t;𝝁k,𝚺k+t2𝑰).p_{t}(\bm{x}_{t})=\sum_{k=1}^{K}\pi_{k}\varphi(\bm{x}_{t};\bm{\mu}_{k},\bm{\Sigma}_{k}+t^{2}\bm{I}).italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) = ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT italic_π start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT italic_φ ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ; bold_italic_μ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT , bold_Σ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT + italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT bold_italic_I ) . (3.2.7)

Conditioned on yyitalic_y, the variables are jointly Gaussian: if we say that 𝒙=𝝁y+𝚺y1/2𝒖\bm{x}=\bm{\mu}_{y}+\bm{\Sigma}_{y}^{1/2}\bm{u}bold_italic_x = bold_italic_μ start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT + bold_Σ start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 / 2 end_POSTSUPERSCRIPT bold_italic_u where ()1/2(\cdot)^{1/2}( ⋅ ) start_POSTSUPERSCRIPT 1 / 2 end_POSTSUPERSCRIPT is the matrix square root and 𝒖𝒩(𝟎,𝑰)\bm{u}\sim\operatorname{\mathcal{N}}(\bm{0},\bm{I})bold_italic_u ∼ caligraphic_N ( bold_0 , bold_italic_I ) independently of yyitalic_y (and 𝒈\bm{g}bold_italic_g), then we have

[𝒙𝒙t]=[𝝁y𝝁y]+[𝚺y1/2𝟎𝚺y1/2t𝑰][𝒖𝒈].\begin{bmatrix}\bm{x}\\ \bm{x}_{t}\end{bmatrix}=\begin{bmatrix}\bm{\mu}_{y}\\ \bm{\mu}_{y}\end{bmatrix}+\begin{bmatrix}\bm{\Sigma}_{y}^{1/2}&\bm{0}\\ \bm{\Sigma}_{y}^{1/2}&t\bm{I}\end{bmatrix}\begin{bmatrix}\bm{u}\\ \bm{g}\end{bmatrix}.[ start_ARG start_ROW start_CELL bold_italic_x end_CELL end_ROW start_ROW start_CELL bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_CELL end_ROW end_ARG ] = [ start_ARG start_ROW start_CELL bold_italic_μ start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT end_CELL end_ROW start_ROW start_CELL bold_italic_μ start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT end_CELL end_ROW end_ARG ] + [ start_ARG start_ROW start_CELL bold_Σ start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 / 2 end_POSTSUPERSCRIPT end_CELL start_CELL bold_0 end_CELL end_ROW start_ROW start_CELL bold_Σ start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 / 2 end_POSTSUPERSCRIPT end_CELL start_CELL italic_t bold_italic_I end_CELL end_ROW end_ARG ] [ start_ARG start_ROW start_CELL bold_italic_u end_CELL end_ROW start_ROW start_CELL bold_italic_g end_CELL end_ROW end_ARG ] . (3.2.8)

This shows that 𝒙\bm{x}bold_italic_x and 𝒙t\bm{x}_{t}bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT are jointly Gaussian (conditioned on yyitalic_y) as claimed. Thus we can write

[𝒙𝒙t]𝒩([𝝁y𝝁y],[𝚺y𝚺y𝚺y𝚺y+t2𝑰]).\begin{bmatrix}\bm{x}\\ \bm{x}_{t}\end{bmatrix}\sim\operatorname{\mathcal{N}}\left(\begin{bmatrix}\bm{\mu}_{y}\\ \bm{\mu}_{y}\end{bmatrix},\begin{bmatrix}\bm{\Sigma}_{y}&\bm{\Sigma}_{y}\\ \bm{\Sigma}_{y}&\bm{\Sigma}_{y}+t^{2}\bm{I}\end{bmatrix}\right).[ start_ARG start_ROW start_CELL bold_italic_x end_CELL end_ROW start_ROW start_CELL bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_CELL end_ROW end_ARG ] ∼ caligraphic_N ( [ start_ARG start_ROW start_CELL bold_italic_μ start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT end_CELL end_ROW start_ROW start_CELL bold_italic_μ start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT end_CELL end_ROW end_ARG ] , [ start_ARG start_ROW start_CELL bold_Σ start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT end_CELL start_CELL bold_Σ start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT end_CELL end_ROW start_ROW start_CELL bold_Σ start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT end_CELL start_CELL bold_Σ start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT + italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT bold_italic_I end_CELL end_ROW end_ARG ] ) . (3.2.9)

Thus the conditional expectation of 𝒙\bm{x}bold_italic_x given 𝒙t\bm{x}_{t}bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT (i.e., the Bayes optimal denoiser conditioned on yyitalic_y) is famously (Exercise 3.2)

𝔼[𝒙𝒙t,y]=𝝁y+𝚺y(𝚺y+t2𝑰)1(𝒙t𝝁y).\operatorname{\mathbb{E}}[\bm{x}\mid\bm{x}_{t},y]=\bm{\mu}_{y}+\bm{\Sigma}_{y}(\bm{\Sigma}_{y}+t^{2}\bm{I})^{-1}(\bm{x}_{t}-\bm{\mu}_{y}).blackboard_E [ bold_italic_x ∣ bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_y ] = bold_italic_μ start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT + bold_Σ start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT ( bold_Σ start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT + italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT bold_italic_I ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - bold_italic_μ start_POSTSUBSCRIPT italic_y end_POSTSUBSCRIPT ) . (3.2.10)

To find the overall Bayes optimal denoiser, we use the law of iterated expectation, obtaining

𝒙¯(t,𝒙t)\displaystyle\bar{\bm{x}}^{\ast}(t,\bm{x}_{t})over¯ start_ARG bold_italic_x end_ARG start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( italic_t , bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) =𝔼[𝒙𝒙t]\displaystyle=\operatorname{\mathbb{E}}[\bm{x}\mid\bm{x}_{t}]= blackboard_E [ bold_italic_x ∣ bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ] (3.2.11)
=𝔼[𝔼[𝒙𝒙t,y]𝒙t]\displaystyle=\operatorname{\mathbb{E}}[\operatorname{\mathbb{E}}[\bm{x}\mid\bm{x}_{t},y]\mid\bm{x}_{t}]= blackboard_E [ blackboard_E [ bold_italic_x ∣ bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_y ] ∣ bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ] (3.2.12)
=k=1K[y=k𝒙t]𝔼[𝒙𝒙t,y=k].\displaystyle=\sum_{k=1}^{K}\operatorname{\mathbb{P}}[y=k\mid\bm{x}_{t}]\operatorname{\mathbb{E}}[\bm{x}\mid\bm{x}_{t},y=k].= ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT blackboard_P [ italic_y = italic_k ∣ bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ] blackboard_E [ bold_italic_x ∣ bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_y = italic_k ] . (3.2.13)

The probability can be dealt with as follows. Let ptyp_{t\mid y}italic_p start_POSTSUBSCRIPT italic_t ∣ italic_y end_POSTSUBSCRIPT be the probability density of 𝒙t\bm{x}_{t}bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT conditioned on the value of yyitalic_y. Then

[y=k𝒙t]\displaystyle\operatorname{\mathbb{P}}[y=k\mid\bm{x}_{t}]blackboard_P [ italic_y = italic_k ∣ bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ] =pty(𝒙tk)πkpt(𝒙t)\displaystyle=\frac{p_{t\mid y}(\bm{x}_{t}\mid k)\pi_{k}}{p_{t}(\bm{x}_{t})}= divide start_ARG italic_p start_POSTSUBSCRIPT italic_t ∣ italic_y end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∣ italic_k ) italic_π start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_ARG start_ARG italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) end_ARG (3.2.14)
=πkφ(𝒙t;𝝁k,𝚺k+t2𝑰)i=1Kπiφ(𝒙t;𝝁i,𝚺i+t2𝑰).\displaystyle=\frac{\pi_{k}\varphi(\bm{x}_{t};\bm{\mu}_{k},\bm{\Sigma}_{k}+t^{2}\bm{I})}{\sum_{i=1}^{K}\pi_{i}\varphi(\bm{x}_{t};\bm{\mu}_{i},\bm{\Sigma}_{i}+t^{2}\bm{I})}.= divide start_ARG italic_π start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT italic_φ ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ; bold_italic_μ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT , bold_Σ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT + italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT bold_italic_I ) end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT italic_π start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_φ ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ; bold_italic_μ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , bold_Σ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT + italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT bold_italic_I ) end_ARG . (3.2.15)

On the other hand, the conditional expectation is as described before:

𝔼[𝒙𝒙t,y=k]=𝝁k+𝚺k(𝚺k+t2𝑰)1(𝒙t𝝁k).\operatorname{\mathbb{E}}[\bm{x}\mid\bm{x}_{t},y=k]=\bm{\mu}_{k}+\bm{\Sigma}_{k}(\bm{\Sigma}_{k}+t^{2}\bm{I})^{-1}(\bm{x}_{t}-\bm{\mu}_{k}).blackboard_E [ bold_italic_x ∣ bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , italic_y = italic_k ] = bold_italic_μ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT + bold_Σ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( bold_Σ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT + italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT bold_italic_I ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - bold_italic_μ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) . (3.2.16)

So putting this all together, the true Bayes optimal denoiser is

𝒙¯(t,𝒙t)=k=1Kπkφ(𝒙t;𝝁k,𝚺k+t2𝑰)i=1Kπiφ(𝒙t;𝝁i,𝚺i+t2𝑰)(𝝁k+𝚺k(𝚺k+t2𝑰)1(𝒙t𝝁k)).\bar{\bm{x}}^{\ast}(t,\bm{x}_{t})=\sum_{k=1}^{K}\frac{\pi_{k}\varphi(\bm{x}_{t};\bm{\mu}_{k},\bm{\Sigma}_{k}+t^{2}\bm{I})}{\sum_{i=1}^{K}\pi_{i}\varphi(\bm{x}_{t};\bm{\mu}_{i},\bm{\Sigma}_{i}+t^{2}\bm{I})}\cdot\left(\bm{\mu}_{k}+\bm{\Sigma}_{k}(\bm{\Sigma}_{k}+t^{2}\bm{I})^{-1}(\bm{x}_{t}-\bm{\mu}_{k})\right).over¯ start_ARG bold_italic_x end_ARG start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( italic_t , bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) = ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT divide start_ARG italic_π start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT italic_φ ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ; bold_italic_μ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT , bold_Σ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT + italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT bold_italic_I ) end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT italic_π start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_φ ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ; bold_italic_μ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , bold_Σ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT + italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT bold_italic_I ) end_ARG ⋅ ( bold_italic_μ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT + bold_Σ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( bold_Σ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT + italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT bold_italic_I ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - bold_italic_μ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) ) . (3.2.17)

This example is particularly important, and several special cases will give us great conceptual insight later. For now, let us attempt to extract some geometric intuition from the functional form of the optimal denoiser (3.2.17).

To try to understand (3.2.17) intuitively, let us first set K=1K=1italic_K = 1 (i.e., one Gaussian) such that 𝒙𝒩(𝝁,𝚺)\bm{x}\sim\operatorname{\mathcal{N}}(\bm{\mu},\bm{\Sigma})bold_italic_x ∼ caligraphic_N ( bold_italic_μ , bold_Σ ). Let us then diagonalize 𝚺=𝑽𝚲𝑽\bm{\Sigma}=\bm{V}\bm{\Lambda}\bm{V}^{\top}bold_Σ = bold_italic_V bold_Λ bold_italic_V start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT. Then the Bayes optimal denoiser is

𝒙¯(t,𝒙t)=𝝁+𝚺(𝚺+t2𝑰)1(𝒙t𝝁)=𝝁+𝑽[λ1/(λ1+t2)λD/(λD+t2)]𝑽(𝒙t𝝁),\bar{\bm{x}}^{\ast}(t,\bm{x}_{t})=\bm{\mu}+\bm{\Sigma}(\bm{\Sigma}+t^{2}\bm{I})^{-1}(\bm{x}_{t}-\bm{\mu})=\bm{\mu}+\bm{V}\begin{bmatrix}\lambda_{1}/(\lambda_{1}+t^{2})&&\\ &\ddots&\\ &&\lambda_{D}/(\lambda_{D}+t^{2})\end{bmatrix}\bm{V}^{\top}(\bm{x}_{t}-\bm{\mu}),over¯ start_ARG bold_italic_x end_ARG start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( italic_t , bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) = bold_italic_μ + bold_Σ ( bold_Σ + italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT bold_italic_I ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - bold_italic_μ ) = bold_italic_μ + bold_italic_V [ start_ARG start_ROW start_CELL italic_λ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT / ( italic_λ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) end_CELL start_CELL end_CELL start_CELL end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL ⋱ end_CELL start_CELL end_CELL end_ROW start_ROW start_CELL end_CELL start_CELL end_CELL start_CELL italic_λ start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT / ( italic_λ start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT + italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) end_CELL end_ROW end_ARG ] bold_italic_V start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - bold_italic_μ ) , (3.2.18)

where λ1,,λD\lambda_{1},\dots,\lambda_{D}italic_λ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_λ start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT are the eigenvalues of 𝚺\bm{\Sigma}bold_Σ. We can observe that this denoiser has three steps:

  • Translate the input 𝒙t\bm{x}_{t}bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT by 𝝁\bm{\mu}bold_italic_μ.

  • Contract the (translated) input 𝒙t𝝁\bm{x}_{t}-\bm{\mu}bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - bold_italic_μ in each eigenvector direction by a quantity λi/(λi+t2)\lambda_{i}/(\lambda_{i}+t^{2})italic_λ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT / ( italic_λ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT + italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ). If the translated input is low-rank and some eigenvalues of 𝚺\bm{\Sigma}bold_Σ are zero, these directions get immediately contracted to 0 by the denoiser, ensuring that the output of the contraction is similarly low-rank.

  • Translate the output back by 𝝁\bm{\mu}bold_italic_μ.

It is easy to show that it contracts the current 𝒙t\bm{x}_{t}bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT towards the mean 𝝁\bm{\mu}bold_italic_μ:

𝒙¯(t,𝒙t)𝝁2𝒙t𝝁2.\|\bar{\bm{x}}^{\ast}(t,\bm{x}_{t})-\bm{\mu}\|_{2}\leq\|\bm{x}_{t}-\bm{\mu}\|_{2}.∥ over¯ start_ARG bold_italic_x end_ARG start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( italic_t , bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) - bold_italic_μ ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ≤ ∥ bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - bold_italic_μ ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT . (3.2.19)

This is the geometric interpretation of the denoiser of a single Gaussian. The overall denoiser of the Gaussian mixture model (3.2.17) uses KKitalic_K such denoisers, weighting their output by the posterior probabilities [y=k𝒙t]\operatorname{\mathbb{P}}[y=k\mid\bm{x}_{t}]blackboard_P [ italic_y = italic_k ∣ bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ]. If the means of the Gaussians are well-separated, these posterior probabilities are very close to 0 or 111 near each mean or cluster. In this regime, the overall denoiser (3.2.17) has the same geometric interpretation as the above single Gaussian denoiser.

At first glance, such a contraction mapping (3.2.19) may appear similar to power iterations (see Section 2.1.2). However, the two are fundamentally different. Power iteration implements a contraction mapping towards a subspace—namely the subspace spanned by the first principal component. In contrast, the iterates in (3.2.19) converge to the mean 𝝁\bm{\mu}bold_italic_μ of the underlying distribution, which is a single point. \blacksquare

Figure 3.4 : Bayes optimal denoiser and score of a Gaussian mixture model. In the same setting as Figure 3.3 , we demonstrate the effect of the Bayes optimal denoiser 𝒙 ¯ ∗ \bar{\bm{x}}^{\ast} over¯ start_ARG bold_italic_x end_ARG start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT by plotting 𝒙 t \bm{x}_{t} bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT (red) and 𝒙 ¯ ∗ ​ ( t , 𝒙 t ) \bar{\bm{x}}^{\ast}(t,\bm{x}_{t}) over¯ start_ARG bold_italic_x end_ARG start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( italic_t , bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) (green) for some choice t t italic_t and 𝒙 t \bm{x}_{t} bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT . By Tweedie’s formula Theorem 3.3 , the residual between them is proportional to the so-called (Hyvärinen) score ∇ 𝒙 t log ⁡ p t ​ ( 𝒙 t ) \nabla_{\bm{x}_{t}}\log p_{t}(\bm{x}_{t}) ∇ start_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT roman_log italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) . We can see that the score points towards the modes of the distribution of 𝒙 t \bm{x}_{t} bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT .
Figure 3.4: Bayes optimal denoiser and score of a Gaussian mixture model. In the same setting as Figure 3.3, we demonstrate the effect of the Bayes optimal denoiser 𝒙¯\bar{\bm{x}}^{\ast}over¯ start_ARG bold_italic_x end_ARG start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT by plotting 𝒙t\bm{x}_{t}bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT (red) and 𝒙¯(t,𝒙t)\bar{\bm{x}}^{\ast}(t,\bm{x}_{t})over¯ start_ARG bold_italic_x end_ARG start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( italic_t , bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) (green) for some choice ttitalic_t and 𝒙t\bm{x}_{t}bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT. By Tweedie’s formula Theorem 3.3, the residual between them is proportional to the so-called (Hyvärinen) score 𝒙tlogpt(𝒙t)\nabla_{\bm{x}_{t}}\log p_{t}(\bm{x}_{t})∇ start_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT roman_log italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ). We can see that the score points towards the modes of the distribution of 𝒙t\bm{x}_{t}bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT.

Intuitively, and as we can see from Example 3.2, the Bayes optimal denoiser 𝒙¯(t,)\bar{\bm{x}}^{\ast}(t,\cdot)over¯ start_ARG bold_italic_x end_ARG start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( italic_t , ⋅ ) should move its input 𝒙t\bm{x}_{t}bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT towards the modes of the distribution of 𝒙\bm{x}bold_italic_x. It turns out that, actually, we can quantify this by showing that the Bayes optimal denoiser takes a gradient ascent step on the (log-)density of 𝒙t\bm{x}_{t}bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT, which (recall) we denoted ptp_{t}italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT. That is, following the denoiser means moving from the input iterate to a region of higher probability within this (perturbed) distribution. For small ttitalic_t, the perturbation is small so our initial intutition is therefore (almost) exactly right. The picture is visualized in Figure 3.4 and rigorously formulated as Tweedie’s formula [Rob56].

Theorem 3.3 (Tweedie’s Formula).

Suppose that (𝐱t)t[0,T](\bm{x}_{t})_{t\in[0,T]}( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT italic_t ∈ [ 0 , italic_T ] end_POSTSUBSCRIPT obeys (3.2.1). Let ptp_{t}italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT be the density of 𝐱t\bm{x}_{t}bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT (as previously declared). Then

𝔼[𝒙𝒙t]=𝒙t+t2𝒙tlogpt(𝒙t).\operatorname{\mathbb{E}}[\bm{x}\mid\bm{x}_{t}]=\bm{x}_{t}+t^{2}\nabla_{\bm{x}_{t}}\log p_{t}(\bm{x}_{t}).blackboard_E [ bold_italic_x ∣ bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ] = bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT + italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ∇ start_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT roman_log italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) . (3.2.20)
Proof.

For the proof let us suppose that 𝒙\bm{x}bold_italic_x has a density (even though the theorem is true without this assumption), and call this density ppitalic_p. Let p0tp_{0\mid t}italic_p start_POSTSUBSCRIPT 0 ∣ italic_t end_POSTSUBSCRIPT and pt0p_{t\mid 0}italic_p start_POSTSUBSCRIPT italic_t ∣ 0 end_POSTSUBSCRIPT be the conditional densities of 𝒙=𝒙0\bm{x}=\bm{x}_{0}bold_italic_x = bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT given 𝒙t\bm{x}_{t}bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT and 𝒙t\bm{x}_{t}bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT given 𝒙\bm{x}bold_italic_x respectively. Let φ(𝒙;𝝁,𝚺)\varphi(\bm{x};\bm{\mu},\bm{\Sigma})italic_φ ( bold_italic_x ; bold_italic_μ , bold_Σ ) be the density of 𝒩(𝝁,𝚺)\operatorname{\mathcal{N}}(\bm{\mu},\bm{\Sigma})caligraphic_N ( bold_italic_μ , bold_Σ ) evaluated at 𝒙\bm{x}bold_italic_x, so that pt0(𝒙t𝒙)=φ(𝒙t;𝒙,t2𝑰)p_{t\mid 0}(\bm{x}_{t}\mid\bm{x})=\varphi(\bm{x}_{t};\bm{x},t^{2}\bm{I})italic_p start_POSTSUBSCRIPT italic_t ∣ 0 end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∣ bold_italic_x ) = italic_φ ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ; bold_italic_x , italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT bold_italic_I ). Then a simple calculation gives

𝒙tlogpt(𝒙t)\displaystyle\nabla_{\bm{x}_{t}}\log p_{t}(\bm{x}_{t})∇ start_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT roman_log italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) =𝒙tpt(𝒙t)pt(𝒙t)\displaystyle=\frac{\nabla_{\bm{x}_{t}}p_{t}(\bm{x}_{t})}{p_{t}(\bm{x}_{t})}= divide start_ARG ∇ start_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) end_ARG start_ARG italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) end_ARG (3.2.21)
=1pt(𝒙t)𝒙tDp(𝒙)pt0(𝒙t𝒙)d𝒙\displaystyle=\frac{1}{p_{t}(\bm{x}_{t})}\nabla_{\bm{x}_{t}}\int_{\mathbb{R}^{D}}p(\bm{x})p_{t\mid 0}(\bm{x}_{t}\mid\bm{x})\mathrm{d}\bm{x}= divide start_ARG 1 end_ARG start_ARG italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) end_ARG ∇ start_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT ∫ start_POSTSUBSCRIPT blackboard_R start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT end_POSTSUBSCRIPT italic_p ( bold_italic_x ) italic_p start_POSTSUBSCRIPT italic_t ∣ 0 end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∣ bold_italic_x ) roman_d bold_italic_x (3.2.22)
=1pt(𝒙t)𝒙tDp(𝒙)φ(𝒙t;𝒙,t2𝑰)d𝒙\displaystyle=\frac{1}{p_{t}(\bm{x}_{t})}\nabla_{\bm{x}_{t}}\int_{\mathbb{R}^{D}}p(\bm{x})\varphi(\bm{x}_{t};\bm{x},t^{2}\bm{I})\mathrm{d}\bm{x}= divide start_ARG 1 end_ARG start_ARG italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) end_ARG ∇ start_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT ∫ start_POSTSUBSCRIPT blackboard_R start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT end_POSTSUBSCRIPT italic_p ( bold_italic_x ) italic_φ ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ; bold_italic_x , italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT bold_italic_I ) roman_d bold_italic_x (3.2.23)
=1pt(𝒙t)Dp(𝒙)[𝒙tφ(𝒙t;𝒙,t2𝑰)]d𝒙\displaystyle=\frac{1}{p_{t}(\bm{x}_{t})}\int_{\mathbb{R}^{D}}p(\bm{x})[\nabla_{\bm{x}_{t}}\varphi(\bm{x}_{t};\bm{x},t^{2}\bm{I})]\mathrm{d}\bm{x}= divide start_ARG 1 end_ARG start_ARG italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) end_ARG ∫ start_POSTSUBSCRIPT blackboard_R start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT end_POSTSUBSCRIPT italic_p ( bold_italic_x ) [ ∇ start_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_φ ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ; bold_italic_x , italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT bold_italic_I ) ] roman_d bold_italic_x (3.2.24)
=1pt(𝒙t)Dp(𝒙)φ(𝒙t;𝒙,t2𝑰)[𝒙t𝒙t2]d𝒙\displaystyle=\frac{1}{p_{t}(\bm{x}_{t})}\int_{\mathbb{R}^{D}}p(\bm{x})\varphi(\bm{x}_{t};\bm{x},t^{2}\bm{I})\left[-\frac{\bm{x}_{t}-\bm{x}}{t^{2}}\right]\mathrm{d}\bm{x}= divide start_ARG 1 end_ARG start_ARG italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) end_ARG ∫ start_POSTSUBSCRIPT blackboard_R start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT end_POSTSUBSCRIPT italic_p ( bold_italic_x ) italic_φ ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ; bold_italic_x , italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT bold_italic_I ) [ - divide start_ARG bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - bold_italic_x end_ARG start_ARG italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ] roman_d bold_italic_x (3.2.25)
=1t2pt(𝒙t)Dp(𝒙)φ(𝒙t;𝒙,t2𝑰)[𝒙𝒙t]d𝒙\displaystyle=\frac{1}{t^{2}p_{t}(\bm{x}_{t})}\int_{\mathbb{R}^{D}}p(\bm{x})\varphi(\bm{x}_{t};\bm{x},t^{2}\bm{I})[\bm{x}-\bm{x}_{t}]\mathrm{d}\bm{x}= divide start_ARG 1 end_ARG start_ARG italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) end_ARG ∫ start_POSTSUBSCRIPT blackboard_R start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT end_POSTSUBSCRIPT italic_p ( bold_italic_x ) italic_φ ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ; bold_italic_x , italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT bold_italic_I ) [ bold_italic_x - bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ] roman_d bold_italic_x (3.2.26)
=1t2pt(𝒙t)Dp(𝒙)φ(𝒙t;𝒙,t2𝑰)𝒙d𝒙𝒙tt2pt(𝒙t)Dp(𝒙)φ(𝒙t;𝒙,t2𝑰)d𝒙\displaystyle=\frac{1}{t^{2}p_{t}(\bm{x}_{t})}\int_{\mathbb{R}^{D}}p(\bm{x})\varphi(\bm{x}_{t};\bm{x},t^{2}\bm{I})\bm{x}\mathrm{d}\bm{x}-\frac{\bm{x}_{t}}{t^{2}p_{t}(\bm{x}_{t})}\int_{\mathbb{R}^{D}}p(\bm{x})\varphi(\bm{x}_{t};\bm{x},t^{2}\bm{I})\mathrm{d}\bm{x}= divide start_ARG 1 end_ARG start_ARG italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) end_ARG ∫ start_POSTSUBSCRIPT blackboard_R start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT end_POSTSUBSCRIPT italic_p ( bold_italic_x ) italic_φ ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ; bold_italic_x , italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT bold_italic_I ) bold_italic_x roman_d bold_italic_x - divide start_ARG bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG start_ARG italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) end_ARG ∫ start_POSTSUBSCRIPT blackboard_R start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT end_POSTSUBSCRIPT italic_p ( bold_italic_x ) italic_φ ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ; bold_italic_x , italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT bold_italic_I ) roman_d bold_italic_x (3.2.27)
=1t2pt(𝒙t)Dp(𝒙)pt0(𝒙t𝒙)𝒙d𝒙𝒙tt2pt(𝒙t)pt(𝒙t)\displaystyle=\frac{1}{t^{2}p_{t}(\bm{x}_{t})}\int_{\mathbb{R}^{D}}p(\bm{x})p_{t\mid 0}(\bm{x}_{t}\mid\bm{x})\bm{x}\mathrm{d}\bm{x}-\frac{\bm{x}_{t}}{t^{2}p_{t}(\bm{x}_{t})}p_{t}(\bm{x}_{t})= divide start_ARG 1 end_ARG start_ARG italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) end_ARG ∫ start_POSTSUBSCRIPT blackboard_R start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT end_POSTSUBSCRIPT italic_p ( bold_italic_x ) italic_p start_POSTSUBSCRIPT italic_t ∣ 0 end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∣ bold_italic_x ) bold_italic_x roman_d bold_italic_x - divide start_ARG bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG start_ARG italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) end_ARG italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) (3.2.28)
=1t2pt(𝒙t)Dpt(𝒙t)p0t(𝒙𝒙t)𝒙d𝒙𝒙tt2pt(𝒙t)pt(𝒙t)\displaystyle=\frac{1}{t^{2}p_{t}(\bm{x}_{t})}\int_{\mathbb{R}^{D}}p_{t}(\bm{x}_{t})p_{0\mid t}(\bm{x}\mid\bm{x}_{t})\bm{x}\mathrm{d}\bm{x}-\frac{\bm{x}_{t}}{t^{2}p_{t}(\bm{x}_{t})}p_{t}(\bm{x}_{t})= divide start_ARG 1 end_ARG start_ARG italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) end_ARG ∫ start_POSTSUBSCRIPT blackboard_R start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) italic_p start_POSTSUBSCRIPT 0 ∣ italic_t end_POSTSUBSCRIPT ( bold_italic_x ∣ bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) bold_italic_x roman_d bold_italic_x - divide start_ARG bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG start_ARG italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) end_ARG italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) (3.2.29)
=1t2Dp0t(𝒙𝒙t)𝒙d𝒙𝒙tt2\displaystyle=\frac{1}{t^{2}}\int_{\mathbb{R}^{D}}p_{0\mid t}(\bm{x}\mid\bm{x}_{t})\bm{x}\mathrm{d}\bm{x}-\frac{\bm{x}_{t}}{t^{2}}= divide start_ARG 1 end_ARG start_ARG italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ∫ start_POSTSUBSCRIPT blackboard_R start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT end_POSTSUBSCRIPT italic_p start_POSTSUBSCRIPT 0 ∣ italic_t end_POSTSUBSCRIPT ( bold_italic_x ∣ bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) bold_italic_x roman_d bold_italic_x - divide start_ARG bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG start_ARG italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG (3.2.30)
=1t2𝔼[𝒙𝒙t]𝒙tt2\displaystyle=\frac{1}{t^{2}}\operatorname{\mathbb{E}}[\bm{x}\mid\bm{x}_{t}]-\frac{\bm{x}_{t}}{t^{2}}= divide start_ARG 1 end_ARG start_ARG italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG blackboard_E [ bold_italic_x ∣ bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ] - divide start_ARG bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG start_ARG italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG (3.2.31)
=𝔼[𝒙𝒙t]𝒙tt2.\displaystyle=\frac{\operatorname{\mathbb{E}}[\bm{x}\mid\bm{x}_{t}]-\bm{x}_{t}}{t^{2}}.= divide start_ARG blackboard_E [ bold_italic_x ∣ bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ] - bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG start_ARG italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG . (3.2.32)

Simple rearranging of the above equality proves the theorem. ∎

This result develops a connection between denoising and optimization: the Bayes-optimal denoiser takes a single step of gradient ascent on the perturbed data density ptp_{t}italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT, and the step size adaptively becomes smaller (i.e., taking more precise steps) as the perturbation to the data distribution grows smaller. The quantity 𝒙tlogpt(𝒙t)\nabla_{\bm{x}_{t}}\log p_{t}(\bm{x}_{t})∇ start_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT roman_log italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) is called the (Hyvärinen) score and frequently appears in discussions about denoising, etc.; it first appeared in a paper of Aapo Hyvärinen in the context of ICA [Hyv05].

Similar to how one step of gradient descent is almost never sufficient to minimize an objective in practice when initializing far from the optimum, the output of the Bayes-optimal denoiser 𝒙¯(t,)\bar{\bm{x}}^{\ast}(t,\cdot)over¯ start_ARG bold_italic_x end_ARG start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( italic_t , ⋅ ) is almost never contained in a high-probability region of the data distribution when ttitalic_t is large, especially when the data have low-dimensional structures. We illustrate this point explicitly in the following example.

Example 3.3 (Denoising a Two-Point Mixture).

Let xxitalic_x be uniform on the two-point set {1,+1}\{-1,+1\}{ - 1 , + 1 } and let (𝒙t)t[0,T](\bm{x}_{t})_{t\in[0,T]}( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT italic_t ∈ [ 0 , italic_T ] end_POSTSUBSCRIPT follow (3.2.1). This is precisely a degenerate Gaussian mixture model with priors equal to 12\frac{1}{2}divide start_ARG 1 end_ARG start_ARG 2 end_ARG, means {1,+1}\{-1,+1\}{ - 1 , + 1 }, and covariances both equal to 0. For a fixed t>0t>0italic_t > 0 we can use the calculation of the Bayes-optimal denoiser in (3.2.17) to obtain (proof as exercise)

x¯(t,xt)=φ(xt;+1,t2)φ(xt;1,t2)φ(xt;1,t2)+φ(xt;1,t2)=tanh(xtt2).\bar{x}^{\ast}(t,x_{t})=\frac{\varphi(x_{t};+1,t^{2})-\varphi(x_{t};-1,t^{2})}{\varphi(x_{t};1,t^{2})+\varphi(x_{t};-1,t^{2})}=\tanh\left(-\frac{x_{t}}{t^{2}}\right).over¯ start_ARG italic_x end_ARG start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( italic_t , italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) = divide start_ARG italic_φ ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ; + 1 , italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) - italic_φ ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ; - 1 , italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) end_ARG start_ARG italic_φ ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ; 1 , italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) + italic_φ ( italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ; - 1 , italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) end_ARG = roman_tanh ( - divide start_ARG italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG start_ARG italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ) . (3.2.33)

For ttitalic_t near 0, this quantity is near {1,+1}\{-1,+1\}{ - 1 , + 1 } for almost all inputs x¯(t,xt)\bar{x}^{\ast}(t,x_{t})over¯ start_ARG italic_x end_ARG start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( italic_t , italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ). However, for ttitalic_t large, this quantity is not necessarily even approximately in the original support of xxitalic_x, which, remember, is {1,+1}\{-1,+1\}{ - 1 , + 1 }. In particular, for xt0x_{t}\approx 0italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ≈ 0 it holds x¯(t,xt)0\bar{x}^{\ast}(t,x_{t})\approx 0over¯ start_ARG italic_x end_ARG start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( italic_t , italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ≈ 0 which lies completely in between the two possible points. Thus x¯\bar{x}^{\ast}over¯ start_ARG italic_x end_ARG start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT will not output “realistic” xxitalic_x. Or more mathematically, the distribution of x¯(t,xt)\bar{x}(t,x_{t})over¯ start_ARG italic_x end_ARG ( italic_t , italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) is very different from the distribution of xxitalic_x. \blacksquare

Therefore, if we want to denoise the very noisy sample 𝒙T\bm{x}_{T}bold_italic_x start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT (where—recall—TTitalic_T is the maximum time), we cannot just use the denoiser once. Instead, we must use the denoiser many times, analogously to gradient descent with decaying step sizes, to converge to a stationary point 𝒙^\hat{\bm{x}}over^ start_ARG bold_italic_x end_ARG. Namely, we shall use the denoiser to go from 𝒙T\bm{x}_{T}bold_italic_x start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT to 𝒙^Tδ\hat{\bm{x}}_{T-\delta}over^ start_ARG bold_italic_x end_ARG start_POSTSUBSCRIPT italic_T - italic_δ end_POSTSUBSCRIPT which approximates 𝒙Tδ\bm{x}_{T-\delta}bold_italic_x start_POSTSUBSCRIPT italic_T - italic_δ end_POSTSUBSCRIPT, then from 𝒙^Tδ\hat{\bm{x}}_{T-\delta}over^ start_ARG bold_italic_x end_ARG start_POSTSUBSCRIPT italic_T - italic_δ end_POSTSUBSCRIPT to 𝒙^T2δ\hat{\bm{x}}_{T-2\delta}over^ start_ARG bold_italic_x end_ARG start_POSTSUBSCRIPT italic_T - 2 italic_δ end_POSTSUBSCRIPT, etc., all the way from 𝒙^δ\hat{\bm{x}}_{\delta}over^ start_ARG bold_italic_x end_ARG start_POSTSUBSCRIPT italic_δ end_POSTSUBSCRIPT to 𝒙^=𝒙^0\hat{\bm{x}}=\hat{\bm{x}}_{0}over^ start_ARG bold_italic_x end_ARG = over^ start_ARG bold_italic_x end_ARG start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT. Each time we take a denoising step, the action of the denoiser becomes more like a gradient step on the original (log-)density.

More formally, we uniformly discretize [0,T][0,T][ 0 , italic_T ] into L+1L+1italic_L + 1 timesteps 0=t0<t1<<tL=T0=t_{0}<t_{1}<\cdots<t_{L}=T0 = italic_t start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT < italic_t start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT < ⋯ < italic_t start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT = italic_T, i.e.,

t=LT,{0,1,,L}.t_{\ell}=\frac{\ell}{L}T,\qquad\ell\in\{0,1,\dots,L\}.italic_t start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT = divide start_ARG roman_ℓ end_ARG start_ARG italic_L end_ARG italic_T , roman_ℓ ∈ { 0 , 1 , … , italic_L } . (3.2.34)

Then for each [L]={1,2,,L}\ell\in[L]=\{1,2,\dots,L\}roman_ℓ ∈ [ italic_L ] = { 1 , 2 , … , italic_L }, going from =L\ell=Lroman_ℓ = italic_L to =1\ell=1roman_ℓ = 1, we can run the iteration

𝒙^t1\displaystyle\hat{\bm{x}}_{t_{\ell-1}}over^ start_ARG bold_italic_x end_ARG start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT roman_ℓ - 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT =𝔼[𝒙t1𝒙t=𝒙^t]\displaystyle=\operatorname{\mathbb{E}}[\bm{x}_{t_{\ell-1}}\mid\bm{x}_{t_{\ell}}=\hat{\bm{x}}_{t_{\ell}}]= blackboard_E [ bold_italic_x start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT roman_ℓ - 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ∣ bold_italic_x start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT end_POSTSUBSCRIPT = over^ start_ARG bold_italic_x end_ARG start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT end_POSTSUBSCRIPT ] (3.2.35)
=𝔼[𝒙+t1𝒈𝒙t=𝒙^t]\displaystyle=\operatorname{\mathbb{E}}[\bm{x}+t_{\ell-1}\bm{g}\mid\bm{x}_{t_{\ell}}=\hat{\bm{x}}_{t_{\ell}}]= blackboard_E [ bold_italic_x + italic_t start_POSTSUBSCRIPT roman_ℓ - 1 end_POSTSUBSCRIPT bold_italic_g ∣ bold_italic_x start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT end_POSTSUBSCRIPT = over^ start_ARG bold_italic_x end_ARG start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT end_POSTSUBSCRIPT ] (3.2.36)
=𝔼[𝒙+t1𝒙t𝒙t𝒙t=𝒙^t]\displaystyle=\operatorname{\mathbb{E}}\left[\bm{x}+t_{\ell-1}\cdot\frac{\bm{x}_{t_{\ell}}-\bm{x}}{t_{\ell}}\mid\bm{x}_{t_{\ell}}=\hat{\bm{x}}_{t_{\ell}}\right]= blackboard_E [ bold_italic_x + italic_t start_POSTSUBSCRIPT roman_ℓ - 1 end_POSTSUBSCRIPT ⋅ divide start_ARG bold_italic_x start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT end_POSTSUBSCRIPT - bold_italic_x end_ARG start_ARG italic_t start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT end_ARG ∣ bold_italic_x start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT end_POSTSUBSCRIPT = over^ start_ARG bold_italic_x end_ARG start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT end_POSTSUBSCRIPT ] (3.2.37)
=t1t𝒙^t+(1t1t)𝔼[𝒙𝒙t=𝒙^t]\displaystyle=\frac{t_{\ell-1}}{t_{\ell}}\hat{\bm{x}}_{t_{\ell}}+\left(1-\frac{t_{\ell-1}}{t_{\ell}}\right)\operatorname{\mathbb{E}}[\bm{x}\mid\bm{x}_{t_{\ell}}=\hat{\bm{x}}_{t_{\ell}}]= divide start_ARG italic_t start_POSTSUBSCRIPT roman_ℓ - 1 end_POSTSUBSCRIPT end_ARG start_ARG italic_t start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT end_ARG over^ start_ARG bold_italic_x end_ARG start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT end_POSTSUBSCRIPT + ( 1 - divide start_ARG italic_t start_POSTSUBSCRIPT roman_ℓ - 1 end_POSTSUBSCRIPT end_ARG start_ARG italic_t start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT end_ARG ) blackboard_E [ bold_italic_x ∣ bold_italic_x start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT end_POSTSUBSCRIPT = over^ start_ARG bold_italic_x end_ARG start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT end_POSTSUBSCRIPT ] (3.2.38)
=(11)𝒙^t+1𝒙¯(t,𝒙^t).\displaystyle=\left(1-\frac{1}{\ell}\right)\cdot\hat{\bm{x}}_{t_{\ell}}+\frac{1}{\ell}\cdot\bar{\bm{x}}^{\ast}(t_{\ell},\hat{\bm{x}}_{t_{\ell}}).= ( 1 - divide start_ARG 1 end_ARG start_ARG roman_ℓ end_ARG ) ⋅ over^ start_ARG bold_italic_x end_ARG start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT end_POSTSUBSCRIPT + divide start_ARG 1 end_ARG start_ARG roman_ℓ end_ARG ⋅ over¯ start_ARG bold_italic_x end_ARG start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( italic_t start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT , over^ start_ARG bold_italic_x end_ARG start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) . (3.2.39)

The effect of this iteration is as follows. At the beginning of the iteration, where \ellroman_ℓ is large, we barely trust the output of the denoiser and mostly keep the current iterate. This makes sense, as the denoiser can have huge variance (cf Example 3.3). When \ellroman_ℓ is small, the denoiser will “lock on” to the modes of the data distribution, as a denoising step basically takes a gradient step on the true distribution’s log-density, and we can trust it not to produce unreasonable samples, so the denoising step mostly involves the output of the denoiser. At =1\ell=1roman_ℓ = 1 we even throw away the current iterate and just keep the output of the denoiser.

Figure 3.5 : Denoising a low-rank mixture of Gaussians. Each figure represents samples from the true data distribution (gray, orange, red) and samples undergoing the denoising process ( 3.2.66 ) (light blue). At top left, the process has just started, and the noise is very large. As the process continues, the noise is pushed further towards the support of the low-rank data distribution. Finally, in the bottom right, the generated samples are perfectly aligned with the support of the data and look very much like samples drawn from the low-rank Gaussian mixture model.
Figure 3.5: Denoising a low-rank mixture of Gaussians. Each figure represents samples from the true data distribution (gray, orange, red) and samples undergoing the denoising process (3.2.66) (light blue). At top left, the process has just started, and the noise is very large. As the process continues, the noise is pushed further towards the support of the low-rank data distribution. Finally, in the bottom right, the generated samples are perfectly aligned with the support of the data and look very much like samples drawn from the low-rank Gaussian mixture model.

The above is intuition for why we expect the denoising process to converge. We visualize the convergence process in 3\mathbb{R}^{3}blackboard_R start_POSTSUPERSCRIPT 3 end_POSTSUPERSCRIPT in Figure 3.5. We will develop some rigorous results about convergence later. For now, recall that we wanted to build a process to reduce the entropy. While we did do this in a roundabout way by inverting a process which adds entropy, it is now time to pay the piper and confirm that our iterative denoising process reduces the entropy.

Theorem 3.4 (Simplified Version of Theorem B.3).

Suppose that (𝐱t)t[0,T](\bm{x}_{t})_{t\in[0,T]}( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT italic_t ∈ [ 0 , italic_T ] end_POSTSUBSCRIPT obeys (3.2.1). Then, under certain technical conditions on 𝐱\bm{x}bold_italic_x, for every s<ts<titalic_s < italic_t with s,t(0,T]s,t\in(0,T]italic_s , italic_t ∈ ( 0 , italic_T ],

h(𝔼[𝒙s𝒙t])<h(𝒙t).h(\operatorname{\mathbb{E}}[\bm{x}_{s}\mid\bm{x}_{t}])<h(\bm{x}_{t}).italic_h ( blackboard_E [ bold_italic_x start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT ∣ bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ] ) < italic_h ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) . (3.2.40)

The full statement of the theorem, and the proof itself, requires some technicality, so it is postponed to Section B.2.2.

The last thing we discuss here is that many times, we will not be able to compute 𝒙¯(t,)\bar{\bm{x}}^{\ast}(t,\cdot)over¯ start_ARG bold_italic_x end_ARG start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( italic_t , ⋅ ) for any ttitalic_t, since we do not have the distribution ptp_{t}italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT. But we can try to learn one from data. Recall that the denoiser 𝒙¯\bar{\bm{x}}^{\ast}over¯ start_ARG bold_italic_x end_ARG start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT is defined in (3.2.3) as minimizing the mean-squared error 𝔼𝒙¯(t,𝒙t)𝒙22\operatorname{\mathbb{E}}\|\bar{\bm{x}}(t,\bm{x}_{t})-\bm{x}\|_{2}^{2}blackboard_E ∥ over¯ start_ARG bold_italic_x end_ARG ( italic_t , bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) - bold_italic_x ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT. We can use this mean-squared error as a loss or objective function to learn the denoiser. For example, we can parameterize 𝒙¯(t,)\bar{\bm{x}}(t,\cdot)over¯ start_ARG bold_italic_x end_ARG ( italic_t , ⋅ ) by a neural network, writing it as 𝒙¯θ(t,)\bar{\bm{x}}_{\theta}(t,\cdot)over¯ start_ARG bold_italic_x end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_t , ⋅ ), and optimize the loss over the parameter space Θ\Thetaroman_Θ:

minθΘ𝔼𝒙,𝒙t𝒙¯θ(t,𝒙t)𝒙22.\min_{\theta\in\Theta}\operatorname{\mathbb{E}}_{\bm{x},\bm{x}_{t}}\|\bar{\bm{x}}_{\theta}(t,\bm{x}_{t})-\bm{x}\|_{2}^{2}.roman_min start_POSTSUBSCRIPT italic_θ ∈ roman_Θ end_POSTSUBSCRIPT blackboard_E start_POSTSUBSCRIPT bold_italic_x , bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT ∥ over¯ start_ARG bold_italic_x end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_t , bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) - bold_italic_x ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT . (3.2.41)

The solution to this optimization problem, implemented via gradient descent or a similar algorithm, will give us a 𝒙¯θ(t,)\bar{\bm{x}}_{\theta^{\ast}}(t,\cdot)over¯ start_ARG bold_italic_x end_ARG start_POSTSUBSCRIPT italic_θ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_t , ⋅ ) which is a good approximation to 𝒙¯(t,)\bar{\bm{x}}^{\ast}(t,\cdot)over¯ start_ARG bold_italic_x end_ARG start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( italic_t , ⋅ ) (at least if the training works) and which we will use as our denoiser.

What is a good architecture for this neural network 𝒙¯θ(t,)\bar{\bm{x}}_{\theta^{\ast}}(t,\cdot)over¯ start_ARG bold_italic_x end_ARG start_POSTSUBSCRIPT italic_θ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_t , ⋅ )? To answer this question, we will examine the ubiquitous case of a Gaussian mixture model, whose denoiser we computed in Example 3.2. This model is relevant because it can approximate many types of distributions: in particular, given a distribution for 𝒙\bm{x}bold_italic_x, there is a Gaussian mixture model that can approximate it arbitrarily well. So optimizing among the class of denoisers for Gaussian mixture models can give us something close to the optimal denoiser for the real data distribution.

In our case, we assume that 𝒙\bm{x}bold_italic_x is low-dimensional, which loosely translates into the requirement that 𝒙\bm{x}bold_italic_x is approximately distributed according to a mixture of low-rank Gaussians. Formally, we write

𝒙1Kk=1K𝒩(𝟎,𝑼k𝑼k)\bm{x}\sim\frac{1}{K}\sum_{k=1}^{K}\operatorname{\mathcal{N}}(\bm{0},\bm{U}_{k}\bm{U}_{k}^{\top})bold_italic_x ∼ divide start_ARG 1 end_ARG start_ARG italic_K end_ARG ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT caligraphic_N ( bold_0 , bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) (3.2.42)

where 𝑼k𝖮(D,P)D×P\bm{U}_{k}\in\mathsf{O}(D,P)\subseteq\mathbb{R}^{D\times P}bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ∈ sansserif_O ( italic_D , italic_P ) ⊆ blackboard_R start_POSTSUPERSCRIPT italic_D × italic_P end_POSTSUPERSCRIPT is an orthogonal matrix. Then the optimal denoiser under (3.2.1) is (from Example 3.2)

𝒙¯(t,𝒙t)=k=1Kφ(𝒙t;𝟎,𝑼k𝑼k+t2𝑰)i=1Kφ(𝒙t;𝟎,𝑼i𝑼i+t2𝑰)(𝑼k𝑼k(𝑼k𝑼k+t2𝑰)1𝒙t).\bar{\bm{x}}^{\ast}(t,\bm{x}_{t})=\sum_{k=1}^{K}\frac{\varphi(\bm{x}_{t};\bm{0},\bm{U}_{k}\bm{U}_{k}^{\top}+t^{2}\bm{I})}{\sum_{i=1}^{K}\varphi(\bm{x}_{t};\bm{0},\bm{U}_{i}\bm{U}_{i}^{\top}+t^{2}\bm{I})}\cdot\left(\bm{U}_{k}\bm{U}_{k}^{\top}(\bm{U}_{k}\bm{U}_{k}^{\top}+t^{2}\bm{I})^{-1}\bm{x}_{t}\right).over¯ start_ARG bold_italic_x end_ARG start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( italic_t , bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) = ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT divide start_ARG italic_φ ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ; bold_0 , bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT + italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT bold_italic_I ) end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT italic_φ ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ; bold_0 , bold_italic_U start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT bold_italic_U start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT + italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT bold_italic_I ) end_ARG ⋅ ( bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ( bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT + italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT bold_italic_I ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) . (3.2.43)

Notice that within the computation φ\varphiitalic_φ and outside of it, we compute the inverse (𝑼k𝑼k+t2𝑰)1(\bm{U}_{k}\bm{U}_{k}^{\top}+t^{2}\bm{I})^{-1}( bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT + italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT bold_italic_I ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT. This is a low-rank perturbation of the full-rank matrix t2𝑰t^{2}\bm{I}italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT bold_italic_I, and thus ripe for simplification via the Sherman-Morrison-Woodbury identity, i.e., for matrices 𝑨,𝑪,𝑼,𝑽\bm{A},\bm{C},\bm{U},\bm{V}bold_italic_A , bold_italic_C , bold_italic_U , bold_italic_V such that 𝑨\bm{A}bold_italic_A and 𝑪\bm{C}bold_italic_C are invertible,

(𝑨+𝑼𝑪𝑽)1=𝑨1𝑨1𝑼(𝑪1+𝑽𝑨1𝑼)1𝑽𝑨1.(\bm{A}+\bm{U}\bm{C}\bm{V})^{-1}=\bm{A}^{-1}-\bm{A}^{-1}\bm{U}(\bm{C}^{-1}+\bm{V}\bm{A}^{-1}\bm{U})^{-1}\bm{V}\bm{A}^{-1}.( bold_italic_A + bold_italic_U bold_italic_C bold_italic_V ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT = bold_italic_A start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT - bold_italic_A start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT bold_italic_U ( bold_italic_C start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT + bold_italic_V bold_italic_A start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT bold_italic_U ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT bold_italic_V bold_italic_A start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT . (3.2.44)

We prove this identity in Exercise 3.3. For now we apply this identity with 𝑨=t2𝑰\bm{A}=t^{2}\bm{I}bold_italic_A = italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT bold_italic_I, 𝑼=𝑼k\bm{U}=\bm{U}_{k}bold_italic_U = bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT, 𝑽=𝑼k\bm{V}=\bm{U}_{k}^{\top}bold_italic_V = bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT, and 𝑪=𝑰\bm{C}=\bm{I}bold_italic_C = bold_italic_I, obtaining

(𝑼k𝑼k+t2𝑰)1\displaystyle(\bm{U}_{k}\bm{U}_{k}^{\top}+t^{2}\bm{I})^{-1}( bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT + italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT bold_italic_I ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT =1t2𝑰1t4𝑼k(𝑰+1t2𝑼k𝑼k)1𝑼k\displaystyle=\frac{1}{t^{2}}\bm{I}-\frac{1}{t^{4}}\bm{U}_{k}\left(\bm{I}+\frac{1}{t^{2}}\bm{U}_{k}^{\top}\bm{U}_{k}\right)^{-1}\bm{U}_{k}^{\top}= divide start_ARG 1 end_ARG start_ARG italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG bold_italic_I - divide start_ARG 1 end_ARG start_ARG italic_t start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT end_ARG bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( bold_italic_I + divide start_ARG 1 end_ARG start_ARG italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT (3.2.45)
=1t2𝑰1t4(1+1t2)𝑼k𝑼k\displaystyle=\frac{1}{t^{2}}\bm{I}-\frac{1}{t^{4}\left(1+\frac{1}{t^{2}}\right)}\bm{U}_{k}\bm{U}_{k}^{\top}= divide start_ARG 1 end_ARG start_ARG italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG bold_italic_I - divide start_ARG 1 end_ARG start_ARG italic_t start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT ( 1 + divide start_ARG 1 end_ARG start_ARG italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ) end_ARG bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT (3.2.46)
=1t2(𝑰11+t2𝑼k𝑼k).\displaystyle=\frac{1}{t^{2}}\left(\bm{I}-\frac{1}{1+t^{2}}\bm{U}_{k}\bm{U}_{k}^{\top}\right).= divide start_ARG 1 end_ARG start_ARG italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ( bold_italic_I - divide start_ARG 1 end_ARG start_ARG 1 + italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) . (3.2.47)

Then we can compute the posterior probabilities as follows. Note that since 𝑼k\bm{U}_{k}bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT’s are all orthogonal, det(𝑼k𝑼k+t2𝑰)\det(\bm{U}_{k}\bm{U}_{k}^{\top}+t^{2}\bm{I})roman_det ( bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT + italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT bold_italic_I ) are all the same for each kkitalic_k. So

φ(𝒙t;𝟎,𝑼k𝑼k+t2𝑰)i=1Kφ(𝒙t;𝟎,𝑼i𝑼i+t2𝑰)\displaystyle\frac{\varphi(\bm{x}_{t};\bm{0},\bm{U}_{k}\bm{U}_{k}^{\top}+t^{2}\bm{I})}{\sum_{i=1}^{K}\varphi(\bm{x}_{t};\bm{0},\bm{U}_{i}\bm{U}_{i}^{\top}+t^{2}\bm{I})}divide start_ARG italic_φ ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ; bold_0 , bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT + italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT bold_italic_I ) end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT italic_φ ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ; bold_0 , bold_italic_U start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT bold_italic_U start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT + italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT bold_italic_I ) end_ARG =exp(12𝒙t(𝑼k𝑼k+t2𝑰)1𝒙t)i=1Kexp(12𝒙t(𝑼i𝑼i+t2𝑰)1𝒙t)\displaystyle=\frac{\exp\left(-\frac{1}{2}\bm{x}_{t}^{\top}(\bm{U}_{k}\bm{U}_{k}^{\top}+t^{2}\bm{I})^{-1}\bm{x}_{t}\right)}{\sum_{i=1}^{K}\exp\left(-\frac{1}{2}\bm{x}_{t}^{\top}(\bm{U}_{i}\bm{U}_{i}^{\top}+t^{2}\bm{I})^{-1}\bm{x}_{t}\right)}= divide start_ARG roman_exp ( - divide start_ARG 1 end_ARG start_ARG 2 end_ARG bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ( bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT + italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT bold_italic_I ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT roman_exp ( - divide start_ARG 1 end_ARG start_ARG 2 end_ARG bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ( bold_italic_U start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT bold_italic_U start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT + italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT bold_italic_I ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) end_ARG (3.2.48)
=exp(12t2𝒙t(𝑰11+t2𝑼k𝑼k)𝒙t)i=1Kexp(12t2𝒙t(𝑰11+t2𝑼i𝑼i)𝒙t)\displaystyle=\frac{\exp\left(-\frac{1}{2t^{2}}\bm{x}_{t}^{\top}\left(\bm{I}-\frac{1}{1+t^{2}}\bm{U}_{k}\bm{U}_{k}^{\top}\right)\bm{x}_{t}\right)}{\sum_{i=1}^{K}\exp\left(-\frac{1}{2t^{2}}\bm{x}_{t}^{\top}\left(\bm{I}-\frac{1}{1+t^{2}}\bm{U}_{i}\bm{U}_{i}^{\top}\right)\bm{x}_{t}\right)}= divide start_ARG roman_exp ( - divide start_ARG 1 end_ARG start_ARG 2 italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ( bold_italic_I - divide start_ARG 1 end_ARG start_ARG 1 + italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT roman_exp ( - divide start_ARG 1 end_ARG start_ARG 2 italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ( bold_italic_I - divide start_ARG 1 end_ARG start_ARG 1 + italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG bold_italic_U start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT bold_italic_U start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) end_ARG (3.2.49)
=exp(12t2𝒙t22+12t2(1+t2)𝑼k𝒙t22)i=1Kexp(12t2𝒙t22+12t2(1+t2)𝑼i𝒙t22)\displaystyle=\frac{\exp\left(-\frac{1}{2t^{2}}\|\bm{x}_{t}\|_{2}^{2}+\frac{1}{2t^{2}(1+t^{2})}\|\bm{U}_{k}^{\top}\bm{x}_{t}\|_{2}^{2}\right)}{\sum_{i=1}^{K}\exp\left(-\frac{1}{2t^{2}}\|\bm{x}_{t}\|_{2}^{2}+\frac{1}{2t^{2}(1+t^{2})}\|\bm{U}_{i}^{\top}\bm{x}_{t}\|_{2}^{2}\right)}= divide start_ARG roman_exp ( - divide start_ARG 1 end_ARG start_ARG 2 italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ∥ bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + divide start_ARG 1 end_ARG start_ARG 2 italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( 1 + italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) end_ARG ∥ bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT roman_exp ( - divide start_ARG 1 end_ARG start_ARG 2 italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ∥ bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + divide start_ARG 1 end_ARG start_ARG 2 italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( 1 + italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) end_ARG ∥ bold_italic_U start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) end_ARG (3.2.50)
=exp(12t2𝒙t22)exp(12t2(1+t2)𝑼k𝒙t22)exp(12t2𝒙t22)i=1Kexp(12t2(1+t2)𝑼i𝒙t22)\displaystyle=\frac{\exp\left(-\frac{1}{2t^{2}}\|\bm{x}_{t}\|_{2}^{2}\right)\exp\left(\frac{1}{2t^{2}(1+t^{2})}\|\bm{U}_{k}^{\top}\bm{x}_{t}\|_{2}^{2}\right)}{\exp\left(-\frac{1}{2t^{2}}\|\bm{x}_{t}\|_{2}^{2}\right)\sum_{i=1}^{K}\exp\left(\frac{1}{2t^{2}(1+t^{2})}\|\bm{U}_{i}^{\top}\bm{x}_{t}\|_{2}^{2}\right)}= divide start_ARG roman_exp ( - divide start_ARG 1 end_ARG start_ARG 2 italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ∥ bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) roman_exp ( divide start_ARG 1 end_ARG start_ARG 2 italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( 1 + italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) end_ARG ∥ bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) end_ARG start_ARG roman_exp ( - divide start_ARG 1 end_ARG start_ARG 2 italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ∥ bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT roman_exp ( divide start_ARG 1 end_ARG start_ARG 2 italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( 1 + italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) end_ARG ∥ bold_italic_U start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) end_ARG (3.2.51)
=exp(12t2(1+t2)𝑼k𝒙t22)i=1Kexp(12t2(1+t2)𝑼i𝒙t22).\displaystyle=\frac{\exp\left(\frac{1}{2t^{2}(1+t^{2})}\|\bm{U}_{k}^{\top}\bm{x}_{t}\|_{2}^{2}\right)}{\sum_{i=1}^{K}\exp\left(\frac{1}{2t^{2}(1+t^{2})}\|\bm{U}_{i}^{\top}\bm{x}_{t}\|_{2}^{2}\right)}.= divide start_ARG roman_exp ( divide start_ARG 1 end_ARG start_ARG 2 italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( 1 + italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) end_ARG ∥ bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT roman_exp ( divide start_ARG 1 end_ARG start_ARG 2 italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( 1 + italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) end_ARG ∥ bold_italic_U start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) end_ARG . (3.2.52)

This is a softmax operation weighted by the projection of 𝒙t\bm{x}_{t}bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT onto each subspace measured by 𝑼i𝒙t2\|\bm{U}_{i}^{\top}\bm{x}_{t}\|_{2}∥ bold_italic_U start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT (tempered by a temperature 2t2(1+t2)2t^{2}(1+t^{2})2 italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( 1 + italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT )). Meanwhile, the component denoisers can be written as

𝑼k𝑼k(𝑼k𝑼k+t2𝑰)1𝒙t\displaystyle\bm{U}_{k}\bm{U}_{k}^{\top}(\bm{U}_{k}\bm{U}_{k}^{\top}+t^{2}\bm{I})^{-1}\bm{x}_{t}bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ( bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT + italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT bold_italic_I ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT =1t2𝑼k𝑼k(𝑰11+t2𝑼k𝑼k)𝒙t\displaystyle=\frac{1}{t^{2}}\bm{U}_{k}\bm{U}_{k}^{\top}\left(\bm{I}-\frac{1}{1+t^{2}}\bm{U}_{k}\bm{U}_{k}^{\top}\right)\bm{x}_{t}= divide start_ARG 1 end_ARG start_ARG italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ( bold_italic_I - divide start_ARG 1 end_ARG start_ARG 1 + italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT (3.2.53)
=1t2(111+t2)𝑼k𝑼k𝒙t\displaystyle=\frac{1}{t^{2}}\left(1-\frac{1}{1+t^{2}}\right)\bm{U}_{k}\bm{U}_{k}^{\top}\bm{x}_{t}= divide start_ARG 1 end_ARG start_ARG italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ( 1 - divide start_ARG 1 end_ARG start_ARG 1 + italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ) bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT (3.2.54)
=11+t2𝑼k𝑼k𝒙t.\displaystyle=\frac{1}{1+t^{2}}\bm{U}_{k}\bm{U}_{k}^{\top}\bm{x}_{t}.= divide start_ARG 1 end_ARG start_ARG 1 + italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT . (3.2.55)

Putting these together, we have

𝒙¯(t,𝒙t)=11+t2k=1Kexp(12t2(1+t2)𝑼k𝒙t22)i=1Kexp(12t2(1+t2)𝑼i𝒙t22)𝑼k𝑼k𝒙t,\bar{\bm{x}}^{\ast}(t,\bm{x}_{t})=\frac{1}{1+t^{2}}\sum_{k=1}^{K}\frac{\exp\left(\frac{1}{2t^{2}(1+t^{2})}\|\bm{U}_{k}^{\top}\bm{x}_{t}\|_{2}^{2}\right)}{\sum_{i=1}^{K}\exp\left(\frac{1}{2t^{2}(1+t^{2})}\|\bm{U}_{i}^{\top}\bm{x}_{t}\|_{2}^{2}\right)}\bm{U}_{k}\bm{U}_{k}^{\top}\bm{x}_{t},over¯ start_ARG bold_italic_x end_ARG start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( italic_t , bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) = divide start_ARG 1 end_ARG start_ARG 1 + italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT divide start_ARG roman_exp ( divide start_ARG 1 end_ARG start_ARG 2 italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( 1 + italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) end_ARG ∥ bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT roman_exp ( divide start_ARG 1 end_ARG start_ARG 2 italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( 1 + italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) end_ARG ∥ bold_italic_U start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) end_ARG bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , (3.2.56)

i.e., a projection of 𝒙t\bm{x}_{t}bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT onto each of KKitalic_K subspaces, weighted by a soft-max operation of a quadratic function of 𝒙t\bm{x}_{t}bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT. This functional form is similar to an attention mechanism in a transformer architecture! As we will see in Chapter 4, this is no coincidence at all; the deep link between denoising and lossy compression (to be covered in Section 3.3) makes transformer denoisers so effective in practice. And so overall, our Gaussian mixture model theory motivates the use of transformer-like neural networks for denoising.

Remark 3.1.

Connections between denoising a distribution and probabilistic PCA. Here, we would like to connect denoising a low-dimensional distribution to probabilistic PCA (see Section 2.1.3 for more details about probabilistic PCA). Suppose that we consider K=1K=1italic_K = 1 in (3.2.42), i.e., 𝒙𝒩(𝟎,𝑼𝑼)\bm{x}\sim\operatorname{\mathcal{N}}(\bm{0},\bm{U}\bm{U}^{\top})bold_italic_x ∼ caligraphic_N ( bold_0 , bold_italic_U bold_italic_U start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ), where 𝑼𝖮(D,P)D×P\bm{U}\in\mathsf{O}(D,P)\subseteq\mathbb{R}^{D\times P}bold_italic_U ∈ sansserif_O ( italic_D , italic_P ) ⊆ blackboard_R start_POSTSUPERSCRIPT italic_D × italic_P end_POSTSUPERSCRIPT is an orthogonal matrix. According to (3.2.56), the Bayes optimal denoiser is

𝒙¯(t,𝒙t)=11+t2𝑼𝑼𝒙t.\displaystyle\bar{\bm{x}}^{\ast}(t,\bm{x}_{t})=\frac{1}{1+t^{2}}\bm{U}\bm{U}^{\top}\bm{x}_{t}.over¯ start_ARG bold_italic_x end_ARG start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( italic_t , bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) = divide start_ARG 1 end_ARG start_ARG 1 + italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG bold_italic_U bold_italic_U start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT . (3.2.57)

To learn this Bayes optimal denoiser, we can accordingly parameterize the denoising operator 𝒙¯(t,𝒙t)\bar{\bm{x}}(t,\bm{x}_{t})over¯ start_ARG bold_italic_x end_ARG ( italic_t , bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) as follows:

𝒙¯(t,𝒙t)=11+t2𝑽𝑽𝒙t,\displaystyle\bar{\bm{x}}(t,\bm{x}_{t})=\frac{1}{1+t^{2}}\bm{V}\bm{V}^{\top}\bm{x}_{t},over¯ start_ARG bold_italic_x end_ARG ( italic_t , bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) = divide start_ARG 1 end_ARG start_ARG 1 + italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG bold_italic_V bold_italic_V start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT , (3.2.58)

where 𝑽𝖮(D,P)\bm{V}\in\mathsf{O}(D,P)bold_italic_V ∈ sansserif_O ( italic_D , italic_P ) are learnable parameters. Substituting this into the training loss (3.2.3) yields

min𝑽𝖮(D,P)𝔼𝒙,𝒙t𝒙11+t2𝑽𝑽𝒙t22=𝔼𝒙,𝒈𝒙11+t2𝑽𝑽(𝒙+t𝒈)22,\min_{\bm{V}\in\mathsf{O}(D,P)}\operatorname{\mathbb{E}}_{\bm{x},\bm{x}_{t}}\left\|\bm{x}-\frac{1}{1+t^{2}}\bm{V}\bm{V}^{\top}\bm{x}_{t}\right\|_{2}^{2}=\operatorname{\mathbb{E}}_{\bm{x},\bm{g}}\left\|\bm{x}-\frac{1}{1+t^{2}}\bm{V}\bm{V}^{\top}(\bm{x}+t\bm{g})\right\|_{2}^{2},roman_min start_POSTSUBSCRIPT bold_italic_V ∈ sansserif_O ( italic_D , italic_P ) end_POSTSUBSCRIPT blackboard_E start_POSTSUBSCRIPT bold_italic_x , bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT ∥ bold_italic_x - divide start_ARG 1 end_ARG start_ARG 1 + italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG bold_italic_V bold_italic_V start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT = blackboard_E start_POSTSUBSCRIPT bold_italic_x , bold_italic_g end_POSTSUBSCRIPT ∥ bold_italic_x - divide start_ARG 1 end_ARG start_ARG 1 + italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG bold_italic_V bold_italic_V start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ( bold_italic_x + italic_t bold_italic_g ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT , (3.2.59)

where the equality is due to (3.2.1). Conditioned on 𝒙\bm{x}bold_italic_x, we compute

𝔼𝒈𝒙11+t2𝑽𝑽(𝒙+t𝒈)22\displaystyle\operatorname{\mathbb{E}}_{\bm{g}}\left\|\bm{x}-\frac{1}{1+t^{2}}\bm{V}\bm{V}^{\top}(\bm{x}+t\bm{g})\right\|_{2}^{2}blackboard_E start_POSTSUBSCRIPT bold_italic_g end_POSTSUBSCRIPT ∥ bold_italic_x - divide start_ARG 1 end_ARG start_ARG 1 + italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG bold_italic_V bold_italic_V start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ( bold_italic_x + italic_t bold_italic_g ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT (3.2.60)
=\displaystyle== 𝒙11+t2𝑽𝑽𝒙22t1+t2𝔼𝒈𝒙11+t2𝑽𝑽𝒙,𝑽𝑽𝒈+t2(1+t2)2𝔼𝒈𝑽𝑽𝒈22\displaystyle\left\|\bm{x}-\frac{1}{1+t^{2}}\bm{V}\bm{V}^{\top}\bm{x}\right\|_{2}^{2}-\frac{t}{1+t^{2}}\operatorname{\mathbb{E}}_{\bm{g}}\left\langle\bm{x}-\frac{1}{1+t^{2}}\bm{V}\bm{V}^{\top}\bm{x},\bm{V}\bm{V}^{\top}\bm{g}\right\rangle+\frac{t^{2}}{(1+t^{2})^{2}}\operatorname{\mathbb{E}}_{\bm{g}}\left\|\bm{V}\bm{V}^{\top}\bm{g}\right\|_{2}^{2}∥ bold_italic_x - divide start_ARG 1 end_ARG start_ARG 1 + italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG bold_italic_V bold_italic_V start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_x ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT - divide start_ARG italic_t end_ARG start_ARG 1 + italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG blackboard_E start_POSTSUBSCRIPT bold_italic_g end_POSTSUBSCRIPT ⟨ bold_italic_x - divide start_ARG 1 end_ARG start_ARG 1 + italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG bold_italic_V bold_italic_V start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_x , bold_italic_V bold_italic_V start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_g ⟩ + divide start_ARG italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG ( 1 + italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG blackboard_E start_POSTSUBSCRIPT bold_italic_g end_POSTSUBSCRIPT ∥ bold_italic_V bold_italic_V start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_g ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT (3.2.61)
=\displaystyle== 𝒙11+t2𝑽𝑽𝒙22+t2P(1+t2)2\displaystyle\left\|\bm{x}-\frac{1}{1+t^{2}}\bm{V}\bm{V}^{\top}\bm{x}\right\|_{2}^{2}+\frac{t^{2}P}{(1+t^{2})^{2}}∥ bold_italic_x - divide start_ARG 1 end_ARG start_ARG 1 + italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG bold_italic_V bold_italic_V start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_x ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + divide start_ARG italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_P end_ARG start_ARG ( 1 + italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG (3.2.62)

where the second equality follows from 𝒈𝒩(𝟎,𝑰)\bm{g}\sim\operatorname{\mathcal{N}}(\bm{0},\bm{I})bold_italic_g ∼ caligraphic_N ( bold_0 , bold_italic_I ) and 𝔼𝒈𝑽𝑽𝒈22=𝔼𝒈𝑽𝒈22=P\operatorname{\mathbb{E}}_{\bm{g}}\left\|\bm{V}\bm{V}^{\top}\bm{g}\right\|_{2}^{2}=\operatorname{\mathbb{E}}_{\bm{g}}\left\|\bm{V}^{\top}\bm{g}\right\|_{2}^{2}=Pblackboard_E start_POSTSUBSCRIPT bold_italic_g end_POSTSUBSCRIPT ∥ bold_italic_V bold_italic_V start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_g ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT = blackboard_E start_POSTSUBSCRIPT bold_italic_g end_POSTSUBSCRIPT ∥ bold_italic_V start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_g ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT = italic_P due to 𝑽𝖮(D,P)\bm{V}\in\mathsf{O}(D,P)bold_italic_V ∈ sansserif_O ( italic_D , italic_P ). Therefore, Problem (3.2.59) in equivalent to

min𝑽𝖮(D,P)𝔼𝒙𝒙11+t2𝑽𝑽𝒙22=𝔼𝒙𝒙22+(1(1+t2)221+t2)𝔼𝒙𝑽𝒙22.\displaystyle\min_{\bm{V}\in\mathsf{O}(D,P)}\operatorname{\mathbb{E}}_{\bm{x}}\left\|\bm{x}-\frac{1}{1+t^{2}}\bm{V}\bm{V}^{\top}\bm{x}\right\|_{2}^{2}=\operatorname{\mathbb{E}}_{\bm{x}}\|\bm{x}\|_{2}^{2}+\left(\frac{1}{(1+t^{2})^{2}}-\frac{2}{1+t^{2}}\right)\operatorname{\mathbb{E}}_{\bm{x}}\|\bm{V}^{\top}\bm{x}\|_{2}^{2}.roman_min start_POSTSUBSCRIPT bold_italic_V ∈ sansserif_O ( italic_D , italic_P ) end_POSTSUBSCRIPT blackboard_E start_POSTSUBSCRIPT bold_italic_x end_POSTSUBSCRIPT ∥ bold_italic_x - divide start_ARG 1 end_ARG start_ARG 1 + italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG bold_italic_V bold_italic_V start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_x ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT = blackboard_E start_POSTSUBSCRIPT bold_italic_x end_POSTSUBSCRIPT ∥ bold_italic_x ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + ( divide start_ARG 1 end_ARG start_ARG ( 1 + italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG - divide start_ARG 2 end_ARG start_ARG 1 + italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ) blackboard_E start_POSTSUBSCRIPT bold_italic_x end_POSTSUBSCRIPT ∥ bold_italic_V start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_x ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT . (3.2.63)

This is further equivalent to

max𝑽𝖮(D,P)𝔼𝒙𝑽𝒙22,\displaystyle\max_{\bm{V}\in\mathsf{O}(D,P)}\operatorname{\mathbb{E}}_{\bm{x}}\|\bm{V}^{\top}\bm{x}\|_{2}^{2},roman_max start_POSTSUBSCRIPT bold_italic_V ∈ sansserif_O ( italic_D , italic_P ) end_POSTSUBSCRIPT blackboard_E start_POSTSUBSCRIPT bold_italic_x end_POSTSUBSCRIPT ∥ bold_italic_V start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_x ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT , (3.2.64)

which is essentially Problem (2.1.27).

Overall, the learned denoiser forms an (implicit parametric) encoding scheme of the given data, since it can be used to denoise/project onto the data distribution. Training a denoiser is equivalent to finding a better coding scheme, and this partially fulfills one of the desiderata (the second) at the end of Section 3.1.3. In the sequel, we will discuss how to fulfill the other (the first).

3.2.2 Learning and Sampling a Distribution via Iterative Denoising

Remember that at the end of Section 3.1.3, we discussed a pair of desiderata for pursuing a distribution with low-dimensional structure. The first such desideratum is to start with a normal distribution, say with high entropy, and gradually reduce its entropy until it reaches the distribution of the data. We will call this procedure sampling since we are generating new samples. It is now time for us to discuss how to do this with the toolkit we have built up.

We know how to denoise very noisy samples 𝒙T\bm{x}_{T}bold_italic_x start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT to attain approximations 𝒙^\hat{\bm{x}}over^ start_ARG bold_italic_x end_ARG that have similar distributions to the target random variable 𝒙\bm{x}bold_italic_x. But the desideratum says that, to sample, we want to start with a template distribution with no influence from the distribution of 𝒙\bm{x}bold_italic_x and use the denoiser to guide the iterates towards the distribution of 𝒙\bm{x}bold_italic_x. How can we do this? One way is motivated as follows:

𝒙TT=𝒙+T𝒈T=𝒙T+𝒈𝒈𝒩(𝟎,𝑰).\frac{\bm{x}_{T}}{T}=\frac{\bm{x}+T\bm{g}}{T}=\frac{\bm{x}}{T}+\bm{g}\to\bm{g}\sim\operatorname{\mathcal{N}}(\bm{0},\bm{I}).divide start_ARG bold_italic_x start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT end_ARG start_ARG italic_T end_ARG = divide start_ARG bold_italic_x + italic_T bold_italic_g end_ARG start_ARG italic_T end_ARG = divide start_ARG bold_italic_x end_ARG start_ARG italic_T end_ARG + bold_italic_g → bold_italic_g ∼ caligraphic_N ( bold_0 , bold_italic_I ) . (3.2.65)

Thus, 𝒙T𝒩(𝟎,T2𝑰)\bm{x}_{T}\approx\operatorname{\mathcal{N}}(\bm{0},T^{2}\bm{I})bold_italic_x start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT ≈ caligraphic_N ( bold_0 , italic_T start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT bold_italic_I ). This approximation is quite good for almost all practical distributions, and visualized in Figure 3.6.

Figure 3.6 : Visualizing x T \bm{x}_{T} bold_italic_x start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT versus 𝒩 ⁡ ( 𝟎 , T 2 ​ I ) \operatorname{\mathcal{N}}(\bm{0},T^{2}\bm{I}) caligraphic_N ( bold_0 , italic_T start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT bold_italic_I ) . Left: A plot of Gaussian mixture model data 𝒙 \bm{x} bold_italic_x . Right: A plot of 𝒙 \bm{x} bold_italic_x as well as 𝒙 T \bm{x}_{T} bold_italic_x start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT and an independent sample of 𝒩 ⁡ ( 𝟎 , T 2 ​ 𝑰 ) \operatorname{\mathcal{N}}(\bm{0},T^{2}\bm{I}) caligraphic_N ( bold_0 , italic_T start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT bold_italic_I ) , for T = 10 T=10 italic_T = 10 . On the right plot, 𝒙 \bm{x} bold_italic_x is plotted in the same colors as the left: however, samples from 𝒙 T \bm{x}_{T} bold_italic_x start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT and 𝒩 ⁡ ( 𝟎 , T 2 ​ 𝑰 ) \operatorname{\mathcal{N}}(\bm{0},T^{2}\bm{I}) caligraphic_N ( bold_0 , italic_T start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT bold_italic_I ) are both much larger, on average, than samples from 𝒙 \bm{x} bold_italic_x , and so it appears much smaller because of the scaling. Despite this large blow-up, we clearly observe the similarities in the distributions of 𝒙 T \bm{x}_{T} bold_italic_x start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT and 𝒩 ⁡ ( 𝟎 , T 2 ​ 𝑰 ) \operatorname{\mathcal{N}}(\bm{0},T^{2}\bm{I}) caligraphic_N ( bold_0 , italic_T start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT bold_italic_I ) .
Figure 3.6: Visualizing xT\bm{x}_{T}bold_italic_x start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT versus 𝒩(𝟎,T2I)\operatorname{\mathcal{N}}(\bm{0},T^{2}\bm{I})caligraphic_N ( bold_0 , italic_T start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT bold_italic_I ). Left: A plot of Gaussian mixture model data 𝒙\bm{x}bold_italic_x. Right: A plot of 𝒙\bm{x}bold_italic_x as well as 𝒙T\bm{x}_{T}bold_italic_x start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT and an independent sample of 𝒩(𝟎,T2𝑰)\operatorname{\mathcal{N}}(\bm{0},T^{2}\bm{I})caligraphic_N ( bold_0 , italic_T start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT bold_italic_I ), for T=10T=10italic_T = 10. On the right plot, 𝒙\bm{x}bold_italic_x is plotted in the same colors as the left: however, samples from 𝒙T\bm{x}_{T}bold_italic_x start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT and 𝒩(𝟎,T2𝑰)\operatorname{\mathcal{N}}(\bm{0},T^{2}\bm{I})caligraphic_N ( bold_0 , italic_T start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT bold_italic_I ) are both much larger, on average, than samples from 𝒙\bm{x}bold_italic_x, and so it appears much smaller because of the scaling. Despite this large blow-up, we clearly observe the similarities in the distributions of 𝒙T\bm{x}_{T}bold_italic_x start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT and 𝒩(𝟎,T2𝑰)\operatorname{\mathcal{N}}(\bm{0},T^{2}\bm{I})caligraphic_N ( bold_0 , italic_T start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT bold_italic_I ).

So, discretizing [0,T][0,T][ 0 , italic_T ] into 0=t0<t1<<tL=T0=t_{0}<t_{1}<\cdots<t_{L}=T0 = italic_t start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT < italic_t start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT < ⋯ < italic_t start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT = italic_T uniformly using t=T/Lt_{\ell}=T\ell/Litalic_t start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT = italic_T roman_ℓ / italic_L (as in the previous section), one possible way to sample from pure noise is:

  • Sample 𝒙^T𝒩(𝟎,T2𝑰)\hat{\bm{x}}_{T}\sim\operatorname{\mathcal{N}}(\bm{0},T^{2}\bm{I})over^ start_ARG bold_italic_x end_ARG start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT ∼ caligraphic_N ( bold_0 , italic_T start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT bold_italic_I ) (i.i.d. of everything else)

  • Run the denoising iteration as in Section 3.2.1, i.e.,

    𝒙^t1=(11)𝒙^t+1𝒙¯(t,𝒙^t).\hat{\bm{x}}_{t_{\ell-1}}=\left(1-\frac{1}{\ell}\right)\cdot\hat{\bm{x}}_{t_{\ell}}+\frac{1}{\ell}\cdot\bar{\bm{x}}^{\ast}(t_{\ell},\hat{\bm{x}}_{t_{\ell}}).over^ start_ARG bold_italic_x end_ARG start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT roman_ℓ - 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT = ( 1 - divide start_ARG 1 end_ARG start_ARG roman_ℓ end_ARG ) ⋅ over^ start_ARG bold_italic_x end_ARG start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT end_POSTSUBSCRIPT + divide start_ARG 1 end_ARG start_ARG roman_ℓ end_ARG ⋅ over¯ start_ARG bold_italic_x end_ARG start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( italic_t start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT , over^ start_ARG bold_italic_x end_ARG start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) . (3.2.66)
  • Output 𝒙^=𝒙^0\hat{\bm{x}}=\hat{\bm{x}}_{0}over^ start_ARG bold_italic_x end_ARG = over^ start_ARG bold_italic_x end_ARG start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT.

This conceptually is all there is behind diffusion models, which transform noise into data samples in accordance with the first desideratum. However, there are a few steps left to take before we get models which can actually sample from real data distributions like images given practical resource constraints. In the sequel, we will introduce and motivate several such steps.

Step 1: different discretizations.

The first step we do is motivated by the following point: we do not need to spend so many denoising iterations at large ttitalic_t. If we look at Figure 3.5, we observe that the first 200200200 or 300300300 iterations out of the 500500500 iterations of the sampling process are just spent contracting the noise towards the data distribution as a whole, before the remaining iterations push the samples towards a subspace. Given a fixed iteration count LLitalic_L, this signals that we should spend more timesteps tt_{\ell}italic_t start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT near t=0t=0italic_t = 0 compared to t=Tt=Titalic_t = italic_T. During sampling (and training), we can therefore use another discretization of [0,T][0,T][ 0 , italic_T ] into 0t0<t1<<tLT0\leq t_{0}<t_{1}<\cdots<t_{L}\leq T0 ≤ italic_t start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT < italic_t start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT < ⋯ < italic_t start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT ≤ italic_T, such as an exponential discretization:

t=C1(eC21),{0,1,,L}t_{\ell}=C_{1}(e^{C_{2}\ell}-1),\qquad\forall\ell\in\{0,1,\dots,L\}italic_t start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT = italic_C start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ( italic_e start_POSTSUPERSCRIPT italic_C start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT roman_ℓ end_POSTSUPERSCRIPT - 1 ) , ∀ roman_ℓ ∈ { 0 , 1 , … , italic_L } (3.2.67)

where C1,C2>0C_{1},C_{2}>0italic_C start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_C start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT > 0 are constants which can be tuned for optimal performance in practice; theoretical analysis will often specify such optimal constants as well. Then the denoising/sampling iteration becomes

𝒙^t1t1t𝒙^t+(1t1t)𝒙¯(t,𝒙^t),\hat{\bm{x}}_{t_{\ell-1}}\doteq\frac{t_{\ell-1}}{t_{\ell}}\hat{\bm{x}}_{t_{\ell}}+\left(1-\frac{t_{\ell-1}}{t_{\ell}}\right)\bar{\bm{x}}^{\ast}(t_{\ell},\hat{\bm{x}}_{t_{\ell}}),over^ start_ARG bold_italic_x end_ARG start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT roman_ℓ - 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ≐ divide start_ARG italic_t start_POSTSUBSCRIPT roman_ℓ - 1 end_POSTSUBSCRIPT end_ARG start_ARG italic_t start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT end_ARG over^ start_ARG bold_italic_x end_ARG start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT end_POSTSUBSCRIPT + ( 1 - divide start_ARG italic_t start_POSTSUBSCRIPT roman_ℓ - 1 end_POSTSUBSCRIPT end_ARG start_ARG italic_t start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT end_ARG ) over¯ start_ARG bold_italic_x end_ARG start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( italic_t start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT , over^ start_ARG bold_italic_x end_ARG start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) , (3.2.68)

with, again, 𝒙^tL𝒩(𝟎,tL2𝑰)\hat{\bm{x}}_{t_{L}}\sim\operatorname{\mathcal{N}}(\bm{0},t_{L}^{2}\bm{I})over^ start_ARG bold_italic_x end_ARG start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT end_POSTSUBSCRIPT ∼ caligraphic_N ( bold_0 , italic_t start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT bold_italic_I ).

Step 2: different noise models.

The second step is to consider slightly different models compared to (3.2.1). The basic motivation for this is as follows. In practice, the noise distribution 𝒩(𝟎,tL2𝑰)\operatorname{\mathcal{N}}(\bm{0},t_{L}^{2}\bm{I})caligraphic_N ( bold_0 , italic_t start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT bold_italic_I ) becomes an increasingly poor estimate of the true covariance in high dimensions, i.e., (3.2.65) becomes an increasingly worse approximation, especially with anisotropic high-dimensional data. The increased distance between 𝒩(𝟎,tL2𝑰)\operatorname{\mathcal{N}}(\bm{0},t_{L}^{2}\bm{I})caligraphic_N ( bold_0 , italic_t start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT bold_italic_I ) and the true distribution of 𝒙tL\bm{x}_{t_{L}}bold_italic_x start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT end_POSTSUBSCRIPT may cause the denoiser to perform worse in such circumstances. Theoretically, 𝒙tL\bm{x}_{t_{L}}bold_italic_x start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT end_POSTSUBSCRIPT never converges to any distribution as tLt_{L}italic_t start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT increases, so this setup is difficult to analyze end-to-end. In this case, our remedy is to simultaneously add noise and shrink the contribution of 𝐱\bm{x}bold_italic_x, such that 𝐱T\bm{x}_{T}bold_italic_x start_POSTSUBSCRIPT italic_T end_POSTSUBSCRIPT converges as TT\to\inftyitalic_T → ∞. The rate of added noise is denoted σ:[0,T]0\sigma\colon[0,T]\to\mathbb{R}_{\geq 0}italic_σ : [ 0 , italic_T ] → blackboard_R start_POSTSUBSCRIPT ≥ 0 end_POSTSUBSCRIPT, and the rate of shrinkage is denoted α:[0,T]0\alpha\colon[0,T]\to\mathbb{R}_{\geq 0}italic_α : [ 0 , italic_T ] → blackboard_R start_POSTSUBSCRIPT ≥ 0 end_POSTSUBSCRIPT, such that σ\sigmaitalic_σ is increasing and α\alphaitalic_α is (not strictly) decreasing, and

𝒙tαt𝒙+σt𝒈,t[0,T].\bm{x}_{t}\doteq\alpha_{t}\bm{x}+\sigma_{t}\bm{g},\qquad\forall t\in[0,T].bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ≐ italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT bold_italic_x + italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT bold_italic_g , ∀ italic_t ∈ [ 0 , italic_T ] . (3.2.69)

The previous setup has αt=1\alpha_{t}=1italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = 1 and σt=t\sigma_{t}=titalic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = italic_t, and this is called the variance-exploding (VE) process. A popular choice which decreases the contribution of 𝒙\bm{x}bold_italic_x, as we described originally, has T=1T=1italic_T = 1 (so that t[0,1]t\in[0,1]italic_t ∈ [ 0 , 1 ]), αt=1t2\alpha_{t}=\sqrt{1-t^{2}}italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = square-root start_ARG 1 - italic_t start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG and σt=t\sigma_{t}=titalic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = italic_t; this is the variance-preserving (VP) process. Note that under the VP process, 𝒙1𝒩(𝟎,𝑰)\bm{x}_{1}\sim\operatorname{\mathcal{N}}(\bm{0},\bm{I})bold_italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ∼ caligraphic_N ( bold_0 , bold_italic_I ) exactly, so we can just sample from this standard distribution and iteratively denoise. As a result, the VP process is much easier to analyze theoretically and more stable empirically.666Why use the whole α,σ\alpha,\sigmaitalic_α , italic_σ setup? As we will see in Exercise 3.5, it encapsulates and unifies many proposed processes, including the recently popular so-called flow matching process. Despite this, the VE and VP processes are still the most popular empirically and theoretically (so far), and so we will consider them in this Section.

With this more general setup, Tweedie’s formula (3.2.20) becomes

𝔼[𝒙𝒙t]=1αt(𝒙t+σt2logpt(𝒙)).\operatorname{\mathbb{E}}[\bm{x}\mid\bm{x}_{t}]=\frac{1}{\alpha_{t}}\left(\bm{x}_{t}+\sigma_{t}^{2}\nabla\log p_{t}(\bm{x})\right).blackboard_E [ bold_italic_x ∣ bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ] = divide start_ARG 1 end_ARG start_ARG italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT + italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ∇ roman_log italic_p start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ( bold_italic_x ) ) . (3.2.70)

The denoising iteration (3.2.68) becomes

𝒙^t1=σt1σt𝒙^t+(αt1σt1σtαt)𝒙¯(t,𝒙^t).\hat{\bm{x}}_{t_{\ell-1}}=\frac{\sigma_{t_{\ell-1}}}{\sigma_{t_{\ell}}}\hat{\bm{x}}_{t_{\ell}}+\left(\alpha_{t_{\ell-1}}-\frac{\sigma_{t_{\ell-1}}}{\sigma_{t_{\ell}}}\alpha_{t_{\ell}}\right)\bar{\bm{x}}^{\ast}(t_{\ell},\hat{\bm{x}}_{t_{\ell}}).over^ start_ARG bold_italic_x end_ARG start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT roman_ℓ - 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT = divide start_ARG italic_σ start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT roman_ℓ - 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_ARG start_ARG italic_σ start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_ARG over^ start_ARG bold_italic_x end_ARG start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT end_POSTSUBSCRIPT + ( italic_α start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT roman_ℓ - 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT - divide start_ARG italic_σ start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT roman_ℓ - 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_ARG start_ARG italic_σ start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_ARG italic_α start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) over¯ start_ARG bold_italic_x end_ARG start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( italic_t start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT , over^ start_ARG bold_italic_x end_ARG start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) . (3.2.71)

Finally, the Gaussian mixture model denoiser (3.2.17) becomes

𝒙¯(t,𝒙t)=k=1Kπkφ(𝒙t;αt𝝁k,αt2𝚺k+σt2𝑰)i=1Kπiφ(𝒙t;αt𝝁i,αt2𝚺i+σt2𝑰)(𝝁k+αt𝚺k(αt2𝚺k+σt2𝑰)1(𝒙tαt𝝁k)).\bar{\bm{x}}^{\ast}(t,\bm{x}_{t})=\sum_{k=1}^{K}\frac{\pi_{k}\varphi(\bm{x}_{t};\alpha_{t}\bm{\mu}_{k},\alpha_{t}^{2}\bm{\Sigma}_{k}+\sigma_{t}^{2}\bm{I})}{\sum_{i=1}^{K}\pi_{i}\varphi(\bm{x}_{t};\alpha_{t}\bm{\mu}_{i},\alpha_{t}^{2}\bm{\Sigma}_{i}+\sigma_{t}^{2}\bm{I})}\cdot\left(\bm{\mu}_{k}+\alpha_{t}\bm{\Sigma}_{k}(\alpha_{t}^{2}\bm{\Sigma}_{k}+\sigma_{t}^{2}\bm{I})^{-1}(\bm{x}_{t}-\alpha_{t}\bm{\mu}_{k})\right).over¯ start_ARG bold_italic_x end_ARG start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( italic_t , bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) = ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT divide start_ARG italic_π start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT italic_φ ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ; italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT bold_italic_μ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT , italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT bold_Σ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT + italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT bold_italic_I ) end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT italic_π start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT italic_φ ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ; italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT bold_italic_μ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT bold_Σ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT + italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT bold_italic_I ) end_ARG ⋅ ( bold_italic_μ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT + italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT bold_Σ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT bold_Σ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT + italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT bold_italic_I ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT bold_italic_μ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) ) . (3.2.72)

Figure 3.7 demonstrates iterations of the sampling procedure. Note that the denoising iteration (3.2.71) gives a sampling algorithm called the DDIM (“Denoising Diffusion Implicit Model”) sampler [SME20], and is one of the most popular sampling algorithms used today in diffusion models. We summarize it here in Algorithm 3.1.

Figure 3.7 : Denoising a mixture of Gaussians using the VP diffusion process. We use the same figure setup and data distribution as Figure 3.5 . Note that compared to Figure 3.5 , the noise distribution is much more concentrated around the origin.
Figure 3.7: Denoising a mixture of Gaussians using the VP diffusion process. We use the same figure setup and data distribution as Figure 3.5. Note that compared to Figure 3.5, the noise distribution is much more concentrated around the origin.
Algorithm 3.1 Sampling using a denoiser.
1:An ordered list of timesteps 0t0<<tLT0\leq t_{0}<\cdots<t_{L}\leq T0 ≤ italic_t start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT < ⋯ < italic_t start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT ≤ italic_T to use for sampling.
2:A denoiser 𝒙¯:{t}=1L×DD\bar{\bm{x}}\colon\{t_{\ell}\}_{\ell=1}^{L}\times\mathbb{R}^{D}\to\mathbb{R}^{D}over¯ start_ARG bold_italic_x end_ARG : { italic_t start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT } start_POSTSUBSCRIPT roman_ℓ = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT × blackboard_R start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT → blackboard_R start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT.
3:Scale and noise level functions α,σ:{t}=0L0\alpha,\sigma\colon\{t_{\ell}\}_{\ell=0}^{L}\to\mathbb{R}_{\geq 0}italic_α , italic_σ : { italic_t start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT } start_POSTSUBSCRIPT roman_ℓ = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT → blackboard_R start_POSTSUBSCRIPT ≥ 0 end_POSTSUBSCRIPT.
4:A sample 𝒙^\hat{\bm{x}}over^ start_ARG bold_italic_x end_ARG, approximately from the distribution of 𝒙\bm{x}bold_italic_x.
5:function DDIMSampler(𝒙¯,(t)=0L\bar{\bm{x}},(t_{\ell})_{\ell=0}^{L}over¯ start_ARG bold_italic_x end_ARG , ( italic_t start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT ) start_POSTSUBSCRIPT roman_ℓ = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT)
6:     Initialize 𝒙~tL\tilde{\bm{x}}_{t_{L}}\simover~ start_ARG bold_italic_x end_ARG start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT end_POSTSUBSCRIPT ∼ approximate distribution of 𝒙tL\bm{x}_{t_{L}}bold_italic_x start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT end_POSTSUBSCRIPT \triangleright VP 𝒩(𝟎,𝑰)\implies\operatorname{\mathcal{N}}(\bm{0},\bm{I})⟹ caligraphic_N ( bold_0 , bold_italic_I ), VE 𝒩(𝟎,tL2𝑰)\implies\operatorname{\mathcal{N}}(\bm{0},t_{L}^{2}\bm{I})⟹ caligraphic_N ( bold_0 , italic_t start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT bold_italic_I ).
7:     for =L,L1,,1\ell=L,L-1,\dots,1roman_ℓ = italic_L , italic_L - 1 , … , 1 do
8:         Compute
𝒙^t1σt1σt𝒙^t+(αt1σt1σtαt)𝒙¯(t,𝒙^t)\hat{\bm{x}}_{t_{\ell-1}}\doteq\frac{\sigma_{t_{\ell-1}}}{\sigma_{t_{\ell}}}\hat{\bm{x}}_{t_{\ell}}+\left(\alpha_{t_{\ell-1}}-\frac{\sigma_{t_{\ell-1}}}{\sigma_{t_{\ell}}}\alpha_{t_{\ell}}\right)\bar{\bm{x}}(t_{\ell},\hat{\bm{x}}_{t_{\ell}})over^ start_ARG bold_italic_x end_ARG start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT roman_ℓ - 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ≐ divide start_ARG italic_σ start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT roman_ℓ - 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_ARG start_ARG italic_σ start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_ARG over^ start_ARG bold_italic_x end_ARG start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT end_POSTSUBSCRIPT + ( italic_α start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT roman_ℓ - 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT - divide start_ARG italic_σ start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT roman_ℓ - 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_ARG start_ARG italic_σ start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_ARG italic_α start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) over¯ start_ARG bold_italic_x end_ARG ( italic_t start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT , over^ start_ARG bold_italic_x end_ARG start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT end_POSTSUBSCRIPT )
9:     end for
10:     return 𝒙^t0\hat{\bm{x}}_{t_{0}}over^ start_ARG bold_italic_x end_ARG start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT end_POSTSUBSCRIPT
11:end function

Step 3: optimizing training pipelines.

If we use the procedure dictated by Section 3.2.1 to learn a separate denoiser 𝒙¯(t,)\bar{\bm{x}}(t,\cdot)over¯ start_ARG bold_italic_x end_ARG ( italic_t , ⋅ ) for each time ttitalic_t to be used in the sampling algorithm, we would have to learn LLitalic_L separate denoisers! This is highly inefficient—the usual case is that we have to train LLitalic_L separate neural networks, taking up LLitalic_L times the training time and storage memory, and then be locked into using these timesteps for sampling forever. Instead, we can train a single neural network to denoise across all times ttitalic_t, taking as input the continuous variables 𝒙t\bm{x}_{t}bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT and ttitalic_t (instead of just 𝒙t\bm{x}_{t}bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT before). Mechanically, our training loss averages over ttitalic_t, i.e., solves the following problem:

minθ𝔼t,𝒙,𝒙t𝒙¯θ(t,𝒙t)𝒙22.\min_{\theta}\operatorname{\mathbb{E}}_{t,\bm{x},\bm{x}_{t}}\|\bar{\bm{x}}_{\theta}(t,\bm{x}_{t})-\bm{x}\|_{2}^{2}.roman_min start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT blackboard_E start_POSTSUBSCRIPT italic_t , bold_italic_x , bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT ∥ over¯ start_ARG bold_italic_x end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_t , bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) - bold_italic_x ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT . (3.2.73)

Similar to Step 1, where we used more timesteps closer to t=0t=0italic_t = 0 to ensure a better sampling process, we may want to ensure that the denoiser is higher quality closer to t=0t=0italic_t = 0, and thereby weight the loss so that ttitalic_t near 0 has higher weight. Letting wtw_{t}italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT be the weight at time ttitalic_t, the weighted loss would look like

minθ𝔼twt𝔼𝒙,𝒙t𝒙¯θ(t,𝒙t)𝒙22.\min_{\theta}\operatorname{\mathbb{E}}_{t}w_{t}\operatorname{\mathbb{E}}_{\bm{x},\bm{x}_{t}}\|\bar{\bm{x}}_{\theta}(t,\bm{x}_{t})-\bm{x}\|_{2}^{2}.roman_min start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT blackboard_E start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT blackboard_E start_POSTSUBSCRIPT bold_italic_x , bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT ∥ over¯ start_ARG bold_italic_x end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( italic_t , bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) - bold_italic_x ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT . (3.2.74)

One reasonable choice of weight in practice is wt=αt/σtw_{t}=\alpha_{t}/\sigma_{t}italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT / italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT. The precise reason will be covered in the next paragraph, but generally it serves to up-weight the losses corresponding to ttitalic_t near 0 while still remaining reasonably numerically stable. Also, of course, we cannot compute the expectation in practice, so we use the most straightforward Monte-Carlo average to estimate it. The series of changes made here have several conceptual and computational benefits: we do not need to train multiple denoisers, we can train on one set of timesteps and sample using a subset (or others entirely), etc. The full pipeline is discussed in Algorithm 3.2.

1:Dataset 𝒟D\mathcal{D}\subseteq\mathbb{R}^{D}caligraphic_D ⊆ blackboard_R start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT.
2:An ordered list of timesteps 0t0<<tLT0\leq t_{0}<\cdots<t_{L}\leq T0 ≤ italic_t start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT < ⋯ < italic_t start_POSTSUBSCRIPT italic_L end_POSTSUBSCRIPT ≤ italic_T to use for sampling.
3:A weighting function w:{t}=1L0w\colon\{t_{\ell}\}_{\ell=1}^{L}\to\mathbb{R}_{\geq 0}italic_w : { italic_t start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT } start_POSTSUBSCRIPT roman_ℓ = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT → blackboard_R start_POSTSUBSCRIPT ≥ 0 end_POSTSUBSCRIPT.
4:Scale and noise level functions α,σ:{t}=0L0\alpha,\sigma\colon\{t_{\ell}\}_{\ell=0}^{L}\to\mathbb{R}_{\geq 0}italic_α , italic_σ : { italic_t start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT } start_POSTSUBSCRIPT roman_ℓ = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT → blackboard_R start_POSTSUBSCRIPT ≥ 0 end_POSTSUBSCRIPT.
5:A parameter space Θ\Thetaroman_Θ and a denoiser architecture 𝒙¯θ\bar{\bm{x}}_{\theta}over¯ start_ARG bold_italic_x end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT.
6:An optimization algorithm for the parameters.
7:The number of optimization iterations MMitalic_M.
8:The number of Monte-Carlo draws NNitalic_N per iteration (to approximate the expectation in (3.2.74))
9:A trained denoiser 𝒙¯θ\bar{\bm{x}}_{\theta^{\ast}}over¯ start_ARG bold_italic_x end_ARG start_POSTSUBSCRIPT italic_θ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT.
10:function TrainDenoiser(𝒟,Θ\mathcal{D},\Thetacaligraphic_D , roman_Θ)
11:     Initialize θ(1)Θ\theta^{(1)}\in\Thetaitalic_θ start_POSTSUPERSCRIPT ( 1 ) end_POSTSUPERSCRIPT ∈ roman_Θ
12:     for i[M]i\in[M]italic_i ∈ [ italic_M ] do
13:         for n[N]n\in[N]italic_n ∈ [ italic_N ] do
14:              𝒙n(i)𝒟\bm{x}_{n}^{(i)}\sim\mathcal{D}bold_italic_x start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT ∼ caligraphic_D \triangleright Draw a sample from the dataset.
15:              tn(i)i.i.d.𝒰({t}=1L)t_{n}^{(i)}\stackrel{{\scriptstyle\mathrm{i.i.d.}}}{{\sim}}\operatorname{\mathcal{U}}(\{t_{\ell}\}_{\ell=1}^{L})italic_t start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT start_RELOP SUPERSCRIPTOP start_ARG ∼ end_ARG start_ARG roman_i . roman_i . roman_d . end_ARG end_RELOP caligraphic_U ( { italic_t start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT } start_POSTSUBSCRIPT roman_ℓ = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT ) \triangleright Sample a timestep.
16:              𝒈n(i)i.i.d.𝒩(𝟎,𝑰)\bm{g}_{n}^{(i)}\stackrel{{\scriptstyle\mathrm{i.i.d.}}}{{\sim}}\operatorname{\mathcal{N}}(\bm{0},\bm{I})bold_italic_g start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT start_RELOP SUPERSCRIPTOP start_ARG ∼ end_ARG start_ARG roman_i . roman_i . roman_d . end_ARG end_RELOP caligraphic_N ( bold_0 , bold_italic_I ) \triangleright Sample a noise vector.
17:              𝒙t,n(i)αtn(i)𝒙n(i)+σtn(i)𝒈n(i)\bm{x}_{t,n}^{(i)}\doteq\alpha_{t_{n}^{(i)}}\bm{x}_{n}^{(i)}+\sigma_{t_{n}^{(i)}}\bm{g}_{n}^{(i)}bold_italic_x start_POSTSUBSCRIPT italic_t , italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT ≐ italic_α start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT end_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT + italic_σ start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT end_POSTSUBSCRIPT bold_italic_g start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT \triangleright Compute the noised sample.
18:              wn(i)wtn(i)w_{n}^{(i)}\doteq w_{t_{n}^{(i)}}italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT ≐ italic_w start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT end_POSTSUBSCRIPT \triangleright Compute the loss weight.
19:         end for
20:         ^(i)1Nn=1Nwn(i)𝒙n(i)𝒙¯θ(i)(tn(i),𝒙t,n(i))22\hat{\mathcal{L}}^{(i)}\doteq\displaystyle\frac{1}{N}\sum_{n=1}^{N}w_{n}^{(i)}\|\bm{x}_{n}^{(i)}-\bar{\bm{x}}_{\theta^{(i)}}(t_{n}^{(i)},\bm{x}_{t,n}^{(i)})\|_{2}^{2}over^ start_ARG caligraphic_L end_ARG start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT ≐ divide start_ARG 1 end_ARG start_ARG italic_N end_ARG ∑ start_POSTSUBSCRIPT italic_n = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT italic_w start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT ∥ bold_italic_x start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT - over¯ start_ARG bold_italic_x end_ARG start_POSTSUBSCRIPT italic_θ start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( italic_t start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT , bold_italic_x start_POSTSUBSCRIPT italic_t , italic_n end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT \triangleright Compute the loss estimate.
21:         θ(i+1)OptimizationUpdate(i)(θ(i),θ(i)^(i))\theta^{(i+1)}\doteq\texttt{OptimizationUpdate}^{(i)}(\theta^{(i)},\nabla_{\theta^{(i)}}\hat{\mathcal{L}}^{(i)})italic_θ start_POSTSUPERSCRIPT ( italic_i + 1 ) end_POSTSUPERSCRIPT ≐ OptimizationUpdate start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT ( italic_θ start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT , ∇ start_POSTSUBSCRIPT italic_θ start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT end_POSTSUBSCRIPT over^ start_ARG caligraphic_L end_ARG start_POSTSUPERSCRIPT ( italic_i ) end_POSTSUPERSCRIPT ) \triangleright Update parameters.
22:     end for
23:     return 𝒙¯θ(K+1)\bar{\bm{x}}_{\theta^{(K+1)}}over¯ start_ARG bold_italic_x end_ARG start_POSTSUBSCRIPT italic_θ start_POSTSUPERSCRIPT ( italic_K + 1 ) end_POSTSUPERSCRIPT end_POSTSUBSCRIPT
24:end function
Algorithm 3.2 Learning a denoiser from data.

(Optional) Step 4: changing the estimation target.

Note that it is common to instead reorient the whole denoising pipeline around noise predictors, i.e., estimates of 𝔼[𝒈𝒙t]\operatorname{\mathbb{E}}[\bm{g}\mid\bm{x}_{t}]blackboard_E [ bold_italic_g ∣ bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ]. In practice, noise predictors are slightly easier to train because their output is (almost) always of comparable size to a Gaussian random variable, so training is more numerically stable. Note that by (3.2.69) we have

𝒙t=αt𝔼[𝒙𝒙t]+σt𝔼[𝒈𝒙t]𝔼[𝒈𝒙t]=1σt(𝒙tαt𝔼[𝒙𝒙t]),\bm{x}_{t}=\alpha_{t}\operatorname{\mathbb{E}}[\bm{x}\mid\bm{x}_{t}]+\sigma_{t}\operatorname{\mathbb{E}}[\bm{g}\mid\bm{x}_{t}]\implies\operatorname{\mathbb{E}}[\bm{g}\mid\bm{x}_{t}]=\frac{1}{\sigma_{t}}\left(\bm{x}_{t}-\alpha_{t}\operatorname{\mathbb{E}}[\bm{x}\mid\bm{x}_{t}]\right),bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT blackboard_E [ bold_italic_x ∣ bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ] + italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT blackboard_E [ bold_italic_g ∣ bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ] ⟹ blackboard_E [ bold_italic_g ∣ bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ] = divide start_ARG 1 end_ARG start_ARG italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG ( bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT blackboard_E [ bold_italic_x ∣ bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ] ) , (3.2.75)

Therefore any predictor for 𝒙\bm{x}bold_italic_x can be turned into a predictor for 𝒈\bm{g}bold_italic_g using the above relation, i.e.,

𝒈¯(t,𝒙t)=1σt𝒙tαtσt𝒙¯(t,𝒙t),\bar{\bm{g}}(t,\bm{x}_{t})=\frac{1}{\sigma_{t}}\bm{x}_{t}-\frac{\alpha_{t}}{\sigma_{t}}\bar{\bm{x}}(t,\bm{x}_{t}),over¯ start_ARG bold_italic_g end_ARG ( italic_t , bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) = divide start_ARG 1 end_ARG start_ARG italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - divide start_ARG italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG start_ARG italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_ARG over¯ start_ARG bold_italic_x end_ARG ( italic_t , bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) , (3.2.76)

and vice-versa. Thus a good network for estimating 𝒈¯\bar{\bm{g}}over¯ start_ARG bold_italic_g end_ARG is the same as a good network for estimating 𝒙¯\bar{\bm{x}}over¯ start_ARG bold_italic_x end_ARG plus a residual connection (as seen in, e.g., transformers). Their losses are also the same as the denoiser, up to the factor of αt/σt\alpha_{t}/\sigma_{t}italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT / italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT, i.e.,

𝔼twt𝔼𝒈,𝒙t𝒈𝒈¯(t,𝒙t)22=𝔼twtαt2σt2𝔼𝒙,𝒙t𝒙𝒙¯(t,𝒙t)22.\operatorname{\mathbb{E}}_{t}w_{t}\operatorname{\mathbb{E}}_{\bm{g},\bm{x}_{t}}\|\bm{g}-\bar{\bm{g}}(t,\bm{x}_{t})\|_{2}^{2}=\operatorname{\mathbb{E}}_{t}w_{t}\frac{\alpha_{t}^{2}}{\sigma_{t}^{2}}\operatorname{\mathbb{E}}_{\bm{x},\bm{x}_{t}}\|\bm{x}-\bar{\bm{x}}(t,\bm{x}_{t})\|_{2}^{2}.blackboard_E start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT blackboard_E start_POSTSUBSCRIPT bold_italic_g , bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT ∥ bold_italic_g - over¯ start_ARG bold_italic_g end_ARG ( italic_t , bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT = blackboard_E start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT italic_w start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT divide start_ARG italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG blackboard_E start_POSTSUBSCRIPT bold_italic_x , bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT end_POSTSUBSCRIPT ∥ bold_italic_x - over¯ start_ARG bold_italic_x end_ARG ( italic_t , bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT . (3.2.77)

For the sake of completeness we will mention that other targets have been proposed for different tasks, e.g., 𝔼[ddt𝒙t𝒙t]\operatorname{\mathbb{E}}[\frac{\mathrm{d}}{\mathrm{d}t}\bm{x}_{t}\mid\bm{x}_{t}]blackboard_E [ divide start_ARG roman_d end_ARG start_ARG roman_d italic_t end_ARG bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ∣ bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ] (called vvitalic_v-prediction or velocity prediction), etc., but denoising and noise prediction remain commonly used. Throughout the rest of this book we will only consider denoising.

We have made lots of changes to our original platonic noising/denoising process. To assure ourselves that the new process still works in practice, we can compute numerical examples (such as Figure 3.7). To assure ourselves that it is theoretically sound, we can prove a bound on the error rate for the sampling algorithm, which shows that the error rate is small. We will now furnish such a rate from the literature, which shows that the output distribution of the sampler converges in the so-called total variation (TV) distance to the true distribution. The TV distance is defined between two random variables 𝒙\bm{x}bold_italic_x and 𝒚\bm{y}bold_italic_y as:

𝖳𝖵(𝒙,𝒚)supAd|[𝒙A][𝒚A]|.\operatorname{\mathsf{TV}}(\bm{x},\bm{y})\doteq\sup_{A\subseteq\mathbb{R}^{d}}\left\lvert\operatorname{\mathbb{P}}[\bm{x}\in A]-\operatorname{\mathbb{P}}[\bm{y}\in A]\right\rvert.sansserif_TV ( bold_italic_x , bold_italic_y ) ≐ roman_sup start_POSTSUBSCRIPT italic_A ⊆ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT end_POSTSUBSCRIPT | blackboard_P [ bold_italic_x ∈ italic_A ] - blackboard_P [ bold_italic_y ∈ italic_A ] | . (3.2.78)

If 𝒙\bm{x}bold_italic_x and 𝒚\bm{y}bold_italic_y are very close (uniformly), then the supremum will be small. So the TV distance measures the closeness of random variables. (It is indeed a metric, as the name suggests; the proof is an exercise.)

Theorem 3.5 ([LY24] Theorem 1, Simplified).

Suppose that 𝔼𝐱2<\operatorname{\mathbb{E}}\|\bm{x}\|_{2}<\inftyblackboard_E ∥ bold_italic_x ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT < ∞. If 𝐱\bm{x}bold_italic_x is denoised according to the VP process with an exponential discretization777The precise definition is rather lengthy in our notation and only defined up to various absolute constants, so we omit it here for brevity. Of course it is in the original paper [LY24]. as in (3.2.67), the output 𝐱^\hat{\bm{x}}over^ start_ARG bold_italic_x end_ARG of Algorithm 3.1 satisfies the total variation bound

𝖳𝖵(𝒙,𝒙^)=𝒪~(DLdiscretization error+1L=1Lαtσt2𝔼𝒙,𝒙t𝒙¯(t,𝒙t)𝒙¯(t,𝒙t)22average excess error of the denoiser)\operatorname{\mathsf{TV}}(\bm{x},\hat{\bm{x}})=\tilde{\mathcal{O}}\left(\underbrace{\frac{D}{L}}_{\text{discretization error}}+\underbrace{\sqrt{\frac{1}{L}\sum_{\ell=1}^{L}\frac{\alpha_{t_{\ell}}}{\sigma_{t_{\ell}}^{2}}\operatorname{\mathbb{E}}_{\bm{x},\bm{x}_{t_{\ell}}}\|\bar{\bm{x}}^{\ast}(t_{\ell},\bm{x}_{t_{\ell}})-\bar{\bm{x}}(t_{\ell},\bm{x}_{t_{\ell}})\|_{2}^{2}}}_{\text{average excess error of the denoiser}}\right)sansserif_TV ( bold_italic_x , over^ start_ARG bold_italic_x end_ARG ) = over~ start_ARG caligraphic_O end_ARG ( under⏟ start_ARG divide start_ARG italic_D end_ARG start_ARG italic_L end_ARG end_ARG start_POSTSUBSCRIPT discretization error end_POSTSUBSCRIPT + under⏟ start_ARG square-root start_ARG divide start_ARG 1 end_ARG start_ARG italic_L end_ARG ∑ start_POSTSUBSCRIPT roman_ℓ = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT divide start_ARG italic_α start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_ARG start_ARG italic_σ start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG blackboard_E start_POSTSUBSCRIPT bold_italic_x , bold_italic_x start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT end_POSTSUBSCRIPT end_POSTSUBSCRIPT ∥ over¯ start_ARG bold_italic_x end_ARG start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ( italic_t start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT , bold_italic_x start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) - over¯ start_ARG bold_italic_x end_ARG ( italic_t start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT , bold_italic_x start_POSTSUBSCRIPT italic_t start_POSTSUBSCRIPT roman_ℓ end_POSTSUBSCRIPT end_POSTSUBSCRIPT ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG end_ARG start_POSTSUBSCRIPT average excess error of the denoiser end_POSTSUBSCRIPT ) (3.2.79)

where 𝐱¯\bar{\bm{x}}^{\ast}over¯ start_ARG bold_italic_x end_ARG start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT is the Bayes optimal denoiser for 𝐱\bm{x}bold_italic_x, and 𝒪~\tilde{\mathcal{O}}over~ start_ARG caligraphic_O end_ARG is a version of the big-𝒪\mathcal{O}caligraphic_O notation which ignores logarithmic factors in LLitalic_L.

The very high-level proof technique is, as discussed earlier, to bound the error at each step, distinguish the error sources (between discretization and denoiser error), and carefully ensure that the errors do not accumulate too much (or even cancel out).

Note that if LL\to\inftyitalic_L → ∞ and we correctly learn the Bayes optimal denoiser 𝒙¯=𝒙¯\bar{\bm{x}}=\bar{\bm{x}}^{\ast}over¯ start_ARG bold_italic_x end_ARG = over¯ start_ARG bold_italic_x end_ARG start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT (such that the excess error is 0), then the sampling process in Algorithm 3.1 yields a perfect (in distribution) inverse of the noising process, since the error rate in Theorem 3.5 goes to 0,888There are similar results for VE processes, though none are as sharp as this to our knowledge. as heuristically argued previously.

Remark 3.2.

What if the data is low-dimensional, say supported on a low-rank subspace of the high dimensional space D\mathbb{R}^{D}blackboard_R start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT? If the data distribution is compactly supported—say if the data is normalized to the unit hypercube, which is often ensured as a pre-processing step for real data such as images—it is possible to do better. Namely, the authors of [LY24] also define a measure of approximate intrinsic dimension using the asymptotics of the so-called covering number, which is extremely similar in intuition (if not in implementation) to the rate distortion function presented in the next Section. Then they show that using a particular small modification of the DDIM sampler in Algorithm 3.1 (i.e., slightly perturbing the update coefficients), the discretization error becomes

𝒪~(approximate intrinsic dimensionL)\tilde{\mathcal{O}}\left(\frac{\text{approximate intrinsic dimension}}{L}\right)over~ start_ARG caligraphic_O end_ARG ( divide start_ARG approximate intrinsic dimension end_ARG start_ARG italic_L end_ARG ) (3.2.80)

instead of DL\frac{D}{L}divide start_ARG italic_D end_ARG start_ARG italic_L end_ARG like it was in Theorem 3.5. Therefore, using this modified algorithm, LLitalic_L does not have to be too large even as DDitalic_D reaches the thousands or millions, since real data have low-dimensional structure. However in practice we use the DDIM sampler instead, so LLitalic_L should have a mild dependence on DDitalic_D to achieve consistent error rates. The exact choice of LLitalic_L trades off between the computational complexity (e.g., runtime or memory consumption) of sampling and the statistical complexity of learning a denoiser for low-dimensional structures. The value of LLitalic_L is often different at training time (where a larger LLitalic_L allows better coverage of the interval [0,T][0,T][ 0 , italic_T ], which helps the network learn a relationship which generalizes over ttitalic_t) and sampling time (where LLitalic_L being smaller means more efficient sampling). One can even pick the timesteps adaptively at sampling time in order to optimize this tradeoffs [BLZ+22].

Remark 3.3.

Various other works define the reverse process as moving backward in the time index ttitalic_t using an explicit difference equation, or differential equation in the limit LL\to\inftyitalic_L → ∞, or forward in time using the transformation 𝒚t=𝒙Tt\bm{y}_{t}=\bm{x}_{T-t}bold_italic_y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = bold_italic_x start_POSTSUBSCRIPT italic_T - italic_t end_POSTSUBSCRIPT, such that if ttitalic_t increases then 𝒚t\bm{y}_{t}bold_italic_y start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT becomes closer to 𝒙0\bm{x}_{0}bold_italic_x start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT. In this work we strive to keep consistency: we move forward in time to noise, and backward in time to denoise. If you are reading another work which is not clear on the time index, or trying to implement an algorithm which is similarly unclear, there is one way to do it right every time: the sampling process should always have a positive coefficient on both the denoiser term and the current iterate when moving from step to step. But in general many papers define their own notation and it is not user-friendly.

Remark 3.4.

The theory presented at the end of the last Section 3.2.1 seems to suggest (loosely speaking) that in practice, using a transformer-like network is a good choice for learning or approximating a denoiser. This is reasonable, but what is the problem with using any old neural network (such as a multi-layer perceptron (MLP)) and just trying to scale it up to infinity? To observe the problem with this, let us look at another special case of the Gaussian mixture model studied in Example 3.2. Namely, the empirical distribution is an instance of a degenerate Gaussian mixture model, with K=NK=Nitalic_K = italic_N components 𝒩(𝒙i,𝟎)\operatorname{\mathcal{N}}(\bm{x}_{i},\bm{0})caligraphic_N ( bold_italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , bold_0 ) sampled with equal probability πi=1N\pi_{i}=\frac{1}{N}italic_π start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = divide start_ARG 1 end_ARG start_ARG italic_N end_ARG. In this case the Bayes optimal denoiser is

𝒙¯(t,𝒙t)=i=1Ne𝒙tαt𝒙i22/(2σt2)j=1Ne𝒙tαt𝒙j22/(2σt2)𝒙i.\bar{\bm{x}}^{\star}(t,\bm{x}_{t})=\sum_{i=1}^{N}\frac{e^{-\|\bm{x}_{t}-\alpha_{t}\bm{x}_{i}\|_{2}^{2}/(2\sigma_{t}^{2})}}{\sum_{j=1}^{N}e^{-\|\bm{x}_{t}-\alpha_{t}\bm{x}_{j}\|_{2}^{2}/(2\sigma_{t}^{2})}}\bm{x}_{i}.over¯ start_ARG bold_italic_x end_ARG start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ( italic_t , bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ) = ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT divide start_ARG italic_e start_POSTSUPERSCRIPT - ∥ bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT / ( 2 italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) end_POSTSUPERSCRIPT end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_j = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT italic_e start_POSTSUPERSCRIPT - ∥ bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT - italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT bold_italic_x start_POSTSUBSCRIPT italic_j end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT / ( 2 italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) end_POSTSUPERSCRIPT end_ARG bold_italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT . (3.2.81)

This is a convex combination of the data 𝒙i\bm{x}_{i}bold_italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT, and the coefficients get “sharper” (i.e., closer to 0 or 111) as t0t\to 0italic_t → 0. Notice that this denoiser optimally solves the denoising optimization problem (3.2.74) when we compute the loss based on drawing 𝒙\bm{x}bold_italic_x uniformly at random from a fixed finite dataset 𝑿={𝒙i}i=1N\bm{X}=\{\bm{x}_{i}\}_{i=1}^{N}bold_italic_X = { bold_italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT, which is a very realistic setting. Thus, if our network architecture 𝒙¯θ\bar{\bm{x}}_{\theta}over¯ start_ARG bold_italic_x end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT is expressive enough such that optimal denoisers of the above form (3.2.81) may be well-approximated, then the learned denoiser may do just that. Then, our iterative denoising Algorithm 3.1 will sample exactly from the empirical distribution, re-generating samples in the training data, as certified by Theorem 3.5. This is a bad sampler, not really more interesting than a database of all samples, and so it is important to understand how to avoid this in practice. The key is to come up with a network architecture which can well-approximate the true denoiser (say corresponding to a low-rank distribution as in (3.2.56)) but not the empirical Bayesian denoiser as in (3.2.81). Some work has explored this fine line and why modern diffusion models, which use transformer- and convolutional-based network architectures, can memorize and generalize in different regimes [KG24, NZM+24].

At a high level, a denoiser which memorizes all the training points, as in (3.2.81), corresponds to a parametric model of the distribution which has minimal coding rate, and achieves this by just coding every sample separately. We will discuss this problem (and seeming paradox with our initial desiderata at the end of Section 3.1.3) from the perspective of information theory in the next section.

3.3 Compression via Lossy Coding

Let us recap what we have covered so far. We have discussed how to fit a denoiser 𝒙¯θ\bar{\bm{x}}_{\theta}over¯ start_ARG bold_italic_x end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT using finite samples. We showed that this denoiser encodes a distribution in that it is directly connected to its log-density via Tweedie’s formula (3.2.20). Then, we used it to gradually transform a pure noise (high-entropy) distribution towards the learned distribution via iterative denoising. Thus, we have developed the first way of learning or pursuing a distribution laid out at the end of Section 3.1.3.

Nevertheless, in this methodology, the encoding of the distribution is implicit in the denoiser’s functional form and parameters, if any. In fact, acute readers might have noticed that for a general distribution, we have never explicitly specified what the functional form for the denoiser is. In practice, people typically model it by some deep neural network with an empirically designed architecture. In addition, although we know the above denoising process reduces the entropy, we do not know by how much, nor do we know the entropy of the intermediate and resulting distributions.

Recall that our general goal is to model data from a (continuous) distribution with a low-dimensional support. If our goal is to identify the “simplest” model that generates the data, one could consider three typical measures of parsimony: the dimension, the volume, or the (differential) entropy. Well, if one uses the dimension, then obviously the best model for a given dataset is the empirical distribution itself which is zero-dimensional. For all distributions with low-dimensional supports, the differential entropy is always negative infinity; the volume of their supports are always zero. So, among all distributions of low-dimensional supports that could have generated the same data samples, how can we decide which one is better based on these measures of parsimony that cannot distinguish among low-dimensional models at all? This section aims to address this seemingly baffling situation.

In the remainder of this chapter, we discuss a framework that allows us to alleviate the above technical difficulty by associating the learned distribution with an explicit computable encoding and decoding scheme, following the second approach suggested at the end of Section 3.1.3. As we will see, such an approach essentially allows us to accurately approximate the entropy of the learned distributions in terms of a (lossy) coding length or coding rate associated with the coding scheme. With such a measure, not only can we accurately measure how much the entropy is reduced, hence information gained, by any processing (including denoising) of the distribution, but we can also derive an explicit form of the optimal operator that can conduct such operations in the most efficient way. As we will see in the next Chapter 4, this will lead to a principled explanation for the architecture of deep networks, as well as to more efficient deep-architecture designs.

3.3.1 Necessity of Lossy Coding

We have previously, multiple times, discussed a difficulty: if we learn the distribution from finite samples in the end, and our function class of denoisers contains enough functions, how do we ensure that we sample from the true distribution (with low-dimensional supports), instead of any other distribution that may produce those finite samples with high probability? Let us reveal some of the conceptual and technical difficulties with some concrete examples.

Example 3.4 (Volume, Dimension, and Entropy).

For the example shown on the top of Figure 3.8, suppose we have taken some samples from a uniform distribution on a line (say in a 2D plane). The volume of the line or the sample sets is zero. Geometrically, the empirical distribution on the produced finite sample set is the minimum-dimension one which can produce the finite sample set.999A set of discrete samples are all of zero dimension whereas the supporting line is one dimension. But this is in seemingly contrast with yet another measure of complex: entropy. The (differential) entropy of the line is negative infinity but the (discrete) entropy of this sample set is finite and positive. So we seem to have a paradoxical situation according to these common measures of parsimony or complexity: they cannot properly differentiate among (models for) distributions of low-dimensional supports at all, and some seem to differentiate them even in exactly opposite manners.101010Of course, strictly speaking, differential entropy and discrete entropy are not directly comparable. \blacksquare

Example 3.5 (Density).

Consider the two sets of sampled data points shown in Figure 3.8. Geometrically, they are essentially the same: each set consists of eight points and each point has occurred with equal frequency 1/81/81 / 8th. The only difference is that for the second data set, some points are “close” enough to be viewed as having a higher density around their respective “cluster.” Which one is more relevant to the true distribution that may have generated the samples? How can we reconcile such ambiguity in interpreting this kind of (empirical) distributions?

Figure 3.8 : Eight points observed on a line.
Figure 3.8: Eight points observed on a line.

\blacksquare

There is yet another technical difficulty associated with constructing an explicit encoding and decoding scheme for a data set. Given a sampled data set in 𝑿=[𝒙1,,𝒙N]\bm{X}=[\bm{x}_{1},\ldots,\bm{x}_{N}]bold_italic_X = [ bold_italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , bold_italic_x start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT ], how to design a coding scheme that is implementable on machines with finite memory and computing resources? Note that even representing a general real number requires an infinite number of digits or bits. Therefore, one may wonder whether the entropy of a distribution is a direct measure for the complexity of its (optimal) coding scheme. We examine this matter with another simple example.

Example 3.6 (Precision).

Consider a discrete distribution 𝑿=[e,π]\bm{X}=[e,\pi]bold_italic_X = [ italic_e , italic_π ] with equal probability 1/21/21 / 2 taking the values of the Euler number e2.71828e\approx 2.71828italic_e ≈ 2.71828 or the number π3.14159\pi\approx 3.14159italic_π ≈ 3.14159. The entropy of this distribution is H=1H=1italic_H = 1, which suggests that one may encode the two numbers by a one-bit digit 0 or 111, respectively. But can you realize a decoding scheme for this code on a finite-state machine? The answer is actually no, as it takes infinitely many bits to describe either number precisely. \blacksquare

Hence, it is generally impossible to have an encoding and decoding scheme that can precisely reproduce samples from an arbitrary real-valued distribution.111111That is, if one wants to encode such samples precisely, the only way is to memorize every single sample. But there would be little practical value to encode a distribution without being able to decode for samples drawn from the same distribution.

So to ensure that any encoding/decoding scheme is computable and implementable with finite memory and computational resources, we need to quantify the sample 𝒙\bm{x}bold_italic_x and encode it only up to a certain precision, say ϵ>0\epsilon>0italic_ϵ > 0. By doing so, in essence, we treat any two data points equivalent if their distance is less than ϵ\epsilonitalic_ϵ. More precisely, we would like to consider coding schemes

𝒙𝒙^\bm{x}\mapsto\hat{\bm{x}}bold_italic_x ↦ over^ start_ARG bold_italic_x end_ARG (3.3.1)

such that the expected error caused by the quantization is bounded by ϵ\epsilonitalic_ϵ. It is mathematically more convenient, and conceptually almost identical, to bound the expected squared error by ϵ2\epsilon^{2}italic_ϵ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT, i.e.,

𝔼[d(𝒙,𝒙^)2]ϵ2.\operatorname{\mathbb{E}}[d(\bm{x},\hat{\bm{x}})^{2}]\leq\epsilon^{2}.blackboard_E [ italic_d ( bold_italic_x , over^ start_ARG bold_italic_x end_ARG ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ] ≤ italic_ϵ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT . (3.3.2)

Typically, the distance dditalic_d is chosen to be the Euclidean distance, or the 2-norm.121212More generally, we can replace d2d^{2}italic_d start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT with any so-called divergence. We will adopt this choice in the sequel.

3.3.2 Rate Distortion and Data Geometry

Of course, among all encoding schemes that satisfy the above constraint, we would like to choose the one that minimizes the resulting coding rate. For a given random variable 𝒙\bm{x}bold_italic_x and a precision ϵ\epsilonitalic_ϵ, this rate is known as the rate distortion, denoted as ϵ(𝒙)\mathcal{R}_{\epsilon}(\bm{x})caligraphic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( bold_italic_x ). A deep theorem in information theory, originally proved by [Sha59], establishes that this rate can be expressed equivalently in purely probabilistic terms as

ϵ(𝒙)=minp(𝒙^𝒙):𝔼[𝒙𝒙^22]ϵ2I(𝒙;𝒙^),\mathcal{R}_{\epsilon}(\bm{x})=\min_{p(\hat{\bm{x}}\mid\bm{x}):\operatorname{\mathbb{E}}[\|\bm{x}-\hat{\bm{x}}\|_{2}^{2}]\leq\epsilon^{2}}I(\bm{x};\hat{\bm{x}}),caligraphic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( bold_italic_x ) = roman_min start_POSTSUBSCRIPT italic_p ( over^ start_ARG bold_italic_x end_ARG ∣ bold_italic_x ) : blackboard_E [ ∥ bold_italic_x - over^ start_ARG bold_italic_x end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ] ≤ italic_ϵ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_POSTSUBSCRIPT italic_I ( bold_italic_x ; over^ start_ARG bold_italic_x end_ARG ) , (3.3.3)

where the quantity I(𝒙;𝒙^)I(\bm{x};\hat{\bm{x}})italic_I ( bold_italic_x ; over^ start_ARG bold_italic_x end_ARG ) is known as the mutual information, defined by

I(𝒙;𝒙^)=𝖪𝖫(p(𝒙,𝒙^)p(𝒙)p(𝒙^)).I(\bm{x};\hat{\bm{x}})=\operatorname{\mathsf{KL}}(p(\bm{x},\hat{\bm{x}})\;\|\;p(\bm{x})p(\hat{\bm{x}})).italic_I ( bold_italic_x ; over^ start_ARG bold_italic_x end_ARG ) = sansserif_KL ( italic_p ( bold_italic_x , over^ start_ARG bold_italic_x end_ARG ) ∥ italic_p ( bold_italic_x ) italic_p ( over^ start_ARG bold_italic_x end_ARG ) ) . (3.3.4)

Note that the minimization in (3.3.3) is over all conditional distributions p(𝒙^𝒙)p(\hat{\bm{x}}\mid\bm{x})italic_p ( over^ start_ARG bold_italic_x end_ARG ∣ bold_italic_x ) that satisfy the distortion constraint 𝔼𝒙,𝒙^[𝒙𝒙^22]ϵ2\operatorname{\mathbb{E}}_{\bm{x},\hat{\bm{x}}}[\|\bm{x}-\hat{\bm{x}}\|_{2}^{2}]\leq\epsilon^{2}blackboard_E start_POSTSUBSCRIPT bold_italic_x , over^ start_ARG bold_italic_x end_ARG end_POSTSUBSCRIPT [ ∥ bold_italic_x - over^ start_ARG bold_italic_x end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ] ≤ italic_ϵ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT. Each such conditional distribution induces a joint distribution p(𝒙,𝒙^)=p(𝒙^𝒙)p(𝒙)p(\bm{x},\hat{\bm{x}})=p(\hat{\bm{x}}\mid\bm{x})p(\bm{x})italic_p ( bold_italic_x , over^ start_ARG bold_italic_x end_ARG ) = italic_p ( over^ start_ARG bold_italic_x end_ARG ∣ bold_italic_x ) italic_p ( bold_italic_x ), which determines the mutual information (3.3.4). Many convenient properties of the mutual information (and hence the rate distortion) are implied by corresponding properties of the KL divergence (recall Theorem 3.1). From the definition, we know that ϵ(𝒙)\mathcal{R}_{\epsilon}(\bm{x})caligraphic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( bold_italic_x ) is a non-increasing function in ϵ\epsilonitalic_ϵ.

Remark 3.5.

As it turns out, the rate distortion is an implementable approximation to the entropy of 𝒙\bm{x}bold_italic_x in the following sense. Assume that 𝒙\bm{x}bold_italic_x and 𝒙^\hat{\bm{x}}over^ start_ARG bold_italic_x end_ARG are continuous random vectors. Then the mutual information can be written as

I(𝒙;𝒙^)=h(𝒙)h(𝒙𝒙^),I(\bm{x};\hat{\bm{x}})=h(\bm{x})-h(\bm{x}\mid\hat{\bm{x}}),italic_I ( bold_italic_x ; over^ start_ARG bold_italic_x end_ARG ) = italic_h ( bold_italic_x ) - italic_h ( bold_italic_x ∣ over^ start_ARG bold_italic_x end_ARG ) , (3.3.5)

where h(𝒙𝒙^)=𝔼[log2p(𝒙𝒙^)]h(\bm{x}\mid\hat{\bm{x}})=\mathbb{E}[\log_{2}p(\bm{x}\mid\hat{\bm{x}})]italic_h ( bold_italic_x ∣ over^ start_ARG bold_italic_x end_ARG ) = blackboard_E [ roman_log start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT italic_p ( bold_italic_x ∣ over^ start_ARG bold_italic_x end_ARG ) ] is the conditional entropy of 𝒙\bm{x}bold_italic_x given 𝒙^\hat{\bm{x}}over^ start_ARG bold_italic_x end_ARG. Hence, the minimal coding rate is achieved when the difference between the entropy of 𝒙\bm{x}bold_italic_x and the conditional entropy of 𝒙\bm{x}bold_italic_x given 𝒙^\hat{\bm{x}}over^ start_ARG bold_italic_x end_ARG is minimized among all distributions that satisfy the constraint: 𝔼[𝒙𝒙^22]ϵ2\mathbb{E}[\|\bm{x}-\hat{\bm{x}}\|_{2}^{2}]\leq\epsilon^{2}blackboard_E [ ∥ bold_italic_x - over^ start_ARG bold_italic_x end_ARG ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ] ≤ italic_ϵ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT.

In fact, it is not necessary to assume that 𝒙\bm{x}bold_italic_x and 𝒙^\hat{\bm{x}}over^ start_ARG bold_italic_x end_ARG are continuous to obtain the above type of conclusion. For example, if both random vectors are instead discrete, we have after a suitable interpretation of the KL divergence for discrete-valued random vectors that

I(𝒙;𝒙^)=H(𝒙)H(𝒙𝒙^).I(\bm{x};\hat{\bm{x}})=H(\bm{x})-H(\bm{x}\mid\hat{\bm{x}}).italic_I ( bold_italic_x ; over^ start_ARG bold_italic_x end_ARG ) = italic_H ( bold_italic_x ) - italic_H ( bold_italic_x ∣ over^ start_ARG bold_italic_x end_ARG ) . (3.3.6)

More generally, advanced mathematical notions from measure theory can be used to define the mutual information (and hence the rate distortion) for arbitrary random variables 𝒙\bm{x}bold_italic_x and 𝒙^\hat{\bm{x}}over^ start_ARG bold_italic_x end_ARG, including those with rather exotic low-dimensional distributions; see [CT91, §8.5].

Remark 3.6.

Given a set of data points in 𝑿=[𝒙1,,𝒙N]\bm{X}=[\bm{x}_{1},\ldots,\bm{x}_{N}]bold_italic_X = [ bold_italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , bold_italic_x start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT ], one can always interpret them as samples from a uniform discrete distribution with equal probability 1/N1/N1 / italic_N on these NNitalic_N vectors. The entropy for such a distribution is H(𝑿)=1Nlog2NH(\bm{X})=\frac{1}{N}\log_{2}Nitalic_H ( bold_italic_X ) = divide start_ARG 1 end_ARG start_ARG italic_N end_ARG roman_log start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT italic_N.131313Note again, even if we can encode these vectors with this coding rate, we cannot decode them with an arbitrary precision. However, even if 𝑿\bm{X}bold_italic_X is a uniform distribution on its samples, the coding rate ϵ(𝑿)\mathcal{R}_{\epsilon}(\bm{X})caligraphic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( bold_italic_X ) achievable with a lossy coding scheme could be significantly lower than H(𝑿)H(\bm{X})italic_H ( bold_italic_X ) if these samples are not so evenly distributed and many are clustered closely to each other. Therefore, for the second distribution shown in Figure 3.8, for a properly chosen quantization error ϵ\epsilonitalic_ϵ, the achievable lossy coding rate can be significantly lower than coding it as a uniform distribution.141414Nevertheless, for this discrete uniform distribution, when ϵ\epsilonitalic_ϵ is small enough, we always have H(𝑿)ϵ(𝑿)H(\bm{X})\approx\mathcal{R}_{\epsilon}(\bm{X})italic_H ( bold_italic_X ) ≈ caligraphic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( bold_italic_X ). Also notice that, with the notion of rate distortion, the difficulty discussed in Example 3.6 also disappears: We can choose two rational numbers that are close enough to each of the two irrational numbers. The resulting coding scheme will have a finite complexity.

Example 3.7.

Sometimes, one may face an opposite situation when we want to fix the coding rate first and try to find a coding scheme that minimizes the distortion. For example, suppose that we only want to use a fixed number of codes for points sampled from a distribution, and we want to know how to design the codes such that the average or maximum distortion is minimized during the encoding/decoding scheme. For example, given a uniform distribution on a unit square, we wonder how precisely we can encode points drawn from this distribution, with say nnitalic_n bits. This problem is equivalent to asking what is the minimum radius (i.e., distortion) such that we can cover the unit square with 2n2^{n}2 start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT discs of this radius. Figure 3.9 shows approximately optimal coverings of a square with n=4,6,8n=4,6,8italic_n = 4 , 6 , 8, so that 2n=16,64,2562^{n}=16,64,2562 start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT = 16 , 64 , 256 discs, respectively. Notice that the optimal radii of the discs decreases as the number of discs 2n2^{n}2 start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT increases.

Figure 3.9 : Approximations to the optimal solutions for 2 4 2^{4} 2 start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT , 2 6 2^{6} 2 start_POSTSUPERSCRIPT 6 end_POSTSUPERSCRIPT , and 2 8 2^{8} 2 start_POSTSUPERSCRIPT 8 end_POSTSUPERSCRIPT discs covering a square, along with the corresponding radii, calculated using a heuristic optimization algorithm.
Figure 3.9 : Approximations to the optimal solutions for 2 4 2^{4} 2 start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT , 2 6 2^{6} 2 start_POSTSUPERSCRIPT 6 end_POSTSUPERSCRIPT , and 2 8 2^{8} 2 start_POSTSUPERSCRIPT 8 end_POSTSUPERSCRIPT discs covering a square, along with the corresponding radii, calculated using a heuristic optimization algorithm.
Figure 3.9 : Approximations to the optimal solutions for 2 4 2^{4} 2 start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT , 2 6 2^{6} 2 start_POSTSUPERSCRIPT 6 end_POSTSUPERSCRIPT , and 2 8 2^{8} 2 start_POSTSUPERSCRIPT 8 end_POSTSUPERSCRIPT discs covering a square, along with the corresponding radii, calculated using a heuristic optimization algorithm.
Figure 3.9: Approximations to the optimal solutions for 242^{4}2 start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT, 262^{6}2 start_POSTSUPERSCRIPT 6 end_POSTSUPERSCRIPT, and 282^{8}2 start_POSTSUPERSCRIPT 8 end_POSTSUPERSCRIPT discs covering a square, along with the corresponding radii, calculated using a heuristic optimization algorithm.

\blacksquare

Figure 3.10 : The approximation of a low-dimensional distribution by ϵ \epsilon italic_ϵ balls. We can see that as the ϵ \epsilon italic_ϵ parameter shrinks, the union of ϵ \epsilon italic_ϵ -balls approximates the support of the true distribution (black) increasingly well. Furthermore, the associated denoisers (whose input-output mapping is given by the provided arrows) obtained by approximating the true distribution by a mixture of Gaussians, each with covariance ( ϵ 2 / D ) ​ 𝑰 (\epsilon^{2}/D)\bm{I} ( italic_ϵ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT / italic_D ) bold_italic_I , increasingly well-approximate the true denoisers. At large ϵ \epsilon italic_ϵ , such denoisers do not point near the true distribution at all, whereas at small ϵ \epsilon italic_ϵ they closely approximate the true denoisers. Theorem 3.6 establishes that this approximation characterizes the rate distortion function at small distortions ϵ \epsilon italic_ϵ , unifying the parallel approaches of coding rate minimization and denoising for learning low-dimensional distributions without pathologies.
Figure 3.10: The approximation of a low-dimensional distribution by ϵ\epsilonitalic_ϵ balls. We can see that as the ϵ\epsilonitalic_ϵ parameter shrinks, the union of ϵ\epsilonitalic_ϵ-balls approximates the support of the true distribution (black) increasingly well. Furthermore, the associated denoisers (whose input-output mapping is given by the provided arrows) obtained by approximating the true distribution by a mixture of Gaussians, each with covariance (ϵ2/D)𝑰(\epsilon^{2}/D)\bm{I}( italic_ϵ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT / italic_D ) bold_italic_I, increasingly well-approximate the true denoisers. At large ϵ\epsilonitalic_ϵ, such denoisers do not point near the true distribution at all, whereas at small ϵ\epsilonitalic_ϵ they closely approximate the true denoisers. Theorem 3.6 establishes that this approximation characterizes the rate distortion function at small distortions ϵ\epsilonitalic_ϵ, unifying the parallel approaches of coding rate minimization and denoising for learning low-dimensional distributions without pathologies.

It turns out to be a notoriously hard problem to obtain closed-form expressions for the rate distortion function (3.3.3) for general distributions p(𝒙)p(\bm{x})italic_p ( bold_italic_x ). However, as Example 3.7 suggests, there are important special cases where the geometry of the support of the distribution p(𝒙)p(\bm{x})italic_p ( bold_italic_x ) can be linked to the rate distortion function and hence to the optimal coding rate at distortion level ϵ\epsilonitalic_ϵ. In fact, this example can be generalized to any setting where the support of p(𝒙)p(\bm{x})italic_p ( bold_italic_x ) is a sufficiently regular compact set—including low-dimensional distributions—and p(𝒙)p(\bm{x})italic_p ( bold_italic_x ) is uniformly distributed on its support. This covers a vast number of cases of practical interest. We formalize this notion in the following result, which establishes this property for a special case.

Theorem 3.6.

Suppose that 𝐱\bm{x}bold_italic_x is a random variable such that its support KSupp(𝐱)K\doteq\operatorname{Supp}(\bm{x})italic_K ≐ roman_Supp ( bold_italic_x ) is a compact set. Define the covering number 𝒩ϵ(K)\mathcal{N}_{\epsilon}(K)caligraphic_N start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( italic_K ) as the minimum number of balls of radius ϵ\epsilonitalic_ϵ that can cover KKitalic_K, i.e.,

𝒩ϵ(K)min{n:𝒑1,,𝒑nKs.t.Ki=1nBϵ(𝒑i)},\mathcal{N}_{\epsilon}(K)\doteq\min\left\{n\in\mathbb{N}\colon\exists\bm{p}_{1},\dots,\bm{p}_{n}\in K\ \text{s.t.}\ K\subseteq\bigcup_{i=1}^{n}B_{\epsilon}(\bm{p}_{i})\right\},caligraphic_N start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( italic_K ) ≐ roman_min { italic_n ∈ blackboard_N : ∃ bold_italic_p start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , bold_italic_p start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT ∈ italic_K s.t. italic_K ⊆ ⋃ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT italic_B start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( bold_italic_p start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) } , (3.3.7)

where Bϵ(𝐩)={𝛏D𝛏𝐩2ϵ}B_{\epsilon}(\bm{p})=\{\bm{\xi}\in\mathbb{R}^{D}\mid\|\bm{\xi}-\bm{p}\|_{2}\leq\epsilon\}italic_B start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( bold_italic_p ) = { bold_italic_ξ ∈ blackboard_R start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT ∣ ∥ bold_italic_ξ - bold_italic_p ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ≤ italic_ϵ } is the Euclidean ball of radius ϵ\epsilonitalic_ϵ centered at 𝐩\bm{p}bold_italic_p. Then it holds

ϵ(𝒙)log2𝒩ϵ(K).\mathcal{R}_{\epsilon}(\bm{x})\leq\log_{2}\mathcal{N}_{\epsilon}(K).caligraphic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( bold_italic_x ) ≤ roman_log start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT caligraphic_N start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( italic_K ) . (3.3.8)

If, in addition, 𝐱\bm{x}bold_italic_x is uniformly distributed on KKitalic_K and KKitalic_K is a mixture of mutually orthogonal low-rank subspaces,151515In fact, it is possible to treat highly irregular KKitalic_K, such as fractals, with a parallel result, but its statement becomes far more technical: c.f. Riegler et al. [RBK18, RKB23]. We give a simple proof in Section B.3 which shows the result for mixtures of subspaces. then a matching lower bound holds:

ϵ(𝒙)log2𝒩ϵ(K)O(D).\mathcal{R}_{\epsilon}(\bm{x})\geq\log_{2}\mathcal{N}_{\epsilon}(K)-O(D).caligraphic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( bold_italic_x ) ≥ roman_log start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT caligraphic_N start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( italic_K ) - italic_O ( italic_D ) . (3.3.9)
Proof.

A proof of this theorem is beyond the scope of this book and we defer it to Section B.3. ∎

The implication of Theorem 3.6 can be summarized as follows: for sufficiently accurate coding of the distribution of 𝒙\bm{x}bold_italic_x, the minimum rate distortion coding framework is completely characterized by the sphere packing problem on the support of 𝒙\bm{x}bold_italic_x. The core of the proof of Theorem 3.6 can indeed be generalized to more complex distributions such as sufficiently incoherent mixtures of manifolds, but we leave this for a future study. So the rate distortion can be thought of as a “probability-aware” way to approximate the support of the distribution of 𝒙\bm{x}bold_italic_x by a mixture of many small balls.

We now discuss another connection between this and the denoising-diffusion-entropy complexity hierarchy we discussed earlier in this chapter.

Remark 3.7.

The key ingredient in the proof of the lower bound in Theorem 3.6 is an important result from information theory known as the Shannon lower bound for the rate distortion, named after Claude Shannon, who first derived it in a special case [Sha59]. It asserts the following estimate for the rate distortion function, for any random variable 𝒙\bm{x}bold_italic_x with a density p(𝒙)p(\bm{x})italic_p ( bold_italic_x ) and finite expected squared norm [LZ94]:

ϵ(𝒙)h(𝒙)log2vol(Bϵ)CD,\mathcal{R}_{\epsilon}(\bm{x})\geq h(\bm{x})-\log_{2}\operatorname{vol}(B_{\epsilon})-C_{D},caligraphic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( bold_italic_x ) ≥ italic_h ( bold_italic_x ) - roman_log start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT roman_vol ( italic_B start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ) - italic_C start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT , (3.3.10)

where CD>0C_{D}>0italic_C start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT > 0 is a constant depending only on DDitalic_D. Moreover, this lower bound is actually sharp as ϵ0\epsilon\to 0italic_ϵ → 0: that is,

limϵ0ϵ(𝒙)[h(𝒙)log2vol(Bϵ)CD]=0.\lim_{\epsilon\to 0}\mathcal{R}_{\epsilon}(\bm{x})-\left[h(\bm{x})-\log_{2}\operatorname{vol}(B_{\epsilon})-C_{D}\right]=0.roman_lim start_POSTSUBSCRIPT italic_ϵ → 0 end_POSTSUBSCRIPT caligraphic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( bold_italic_x ) - [ italic_h ( bold_italic_x ) - roman_log start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT roman_vol ( italic_B start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ) - italic_C start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT ] = 0 . (3.3.11)

So when the distortion ϵ\epsilonitalic_ϵ is small, we can think solely in terms of the Shannon lower bound, rather than the (generally intractable) optimization problem defining the rate distortion (3.3.3).

The Shannon lower bound is the bridge between the coding rate, entropy minimization/denoising, and geometric sphere packing approaches for learning low-dimensional distributions. Notice that in the special case of a uniform density p(𝒙)p(\bm{x})italic_p ( bold_italic_x ), (3.3.10) becomes

ϵ(𝒙)\displaystyle\mathcal{R}_{\epsilon}(\bm{x})caligraphic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( bold_italic_x ) K1vol(K)log21vol(K)d𝝃log2vol(Bϵ)Cd\displaystyle\geq-\int_{K}\frac{1}{\operatorname{vol}(K)}\log_{2}\frac{1}{\operatorname{vol}(K)}\mathrm{d}\bm{\xi}-\log_{2}\operatorname{vol}(B_{\epsilon})-C_{d}≥ - ∫ start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT divide start_ARG 1 end_ARG start_ARG roman_vol ( italic_K ) end_ARG roman_log start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT divide start_ARG 1 end_ARG start_ARG roman_vol ( italic_K ) end_ARG roman_d bold_italic_ξ - roman_log start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT roman_vol ( italic_B start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ) - italic_C start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT (3.3.12)
=log2vol(K)/vol(Bϵ)Cd.\displaystyle=\log_{2}\operatorname{vol}(K)/\operatorname{vol}(B_{\epsilon})-C_{d}.= roman_log start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT roman_vol ( italic_K ) / roman_vol ( italic_B start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ) - italic_C start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT . (3.3.13)

The ratio vol(K)/vol(Bϵ)\operatorname{vol}(K)/\operatorname{vol}(B_{\epsilon})roman_vol ( italic_K ) / roman_vol ( italic_B start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ) approximates the number of ϵ\epsilonitalic_ϵ-balls needed to cover KKitalic_K by a worst-case argument, which is accurate for sufficiently regular sets KKitalic_K when ϵ\epsilonitalic_ϵ is small (see Section B.3 for details). Meanwhile, recall the Gaussian denoising model 𝒙ϵ=𝒙+ϵ𝒈\bm{x}_{\epsilon}=\bm{x}+\epsilon\bm{g}bold_italic_x start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT = bold_italic_x + italic_ϵ bold_italic_g from earlier in the Chapter, where 𝒈𝒩(𝟎,𝑰)\bm{g}\sim\mathcal{N}(\mathbf{0},\bm{I})bold_italic_g ∼ caligraphic_N ( bold_0 , bold_italic_I ) is independent of 𝒙\bm{x}bold_italic_x. Interestingly, the differential entropy of the joint distribution (𝒙,𝒈)(\bm{x},\bm{g})( bold_italic_x , bold_italic_g ) can be calculated as

h(𝒙,𝒈)\displaystyle h(\bm{x},\bm{g})italic_h ( bold_italic_x , bold_italic_g ) =p(𝝃)p(𝜸)log2p(𝝃)p(𝜸)d𝝃d𝜸\displaystyle=-\int p(\bm{\xi})p(\bm{\gamma})\log_{2}p(\bm{\xi})p(\bm{\gamma})\mathrm{d}\bm{\xi}\mathrm{d}\bm{\gamma}= - ∫ italic_p ( bold_italic_ξ ) italic_p ( bold_italic_γ ) roman_log start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT italic_p ( bold_italic_ξ ) italic_p ( bold_italic_γ ) roman_d bold_italic_ξ roman_d bold_italic_γ (3.3.14)
=h(𝒙)+h(ϵ𝒈).\displaystyle=h(\bm{x})+h(\epsilon\bm{g}).= italic_h ( bold_italic_x ) + italic_h ( italic_ϵ bold_italic_g ) . (3.3.15)

We have seen the Gaussian entropy calculated in Equation 3.1.4: when ϵ\epsilonitalic_ϵ is small, it is equal, up to additive constants, to the volumetric quantity log2vol(Bϵ)-\log_{2}\operatorname{vol}(B_{\epsilon})- roman_log start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT roman_vol ( italic_B start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ) we have seen in the Shannon lower bound. In certain special cases (e.g., data supported on incoherent low-rank subspaces), when ϵ\epsilonitalic_ϵ is small and the support of 𝒙\bm{x}bold_italic_x is sufficiently regular, the distribution of 𝒙ϵ\bm{x}_{\epsilon}bold_italic_x start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT can even be well-approximated locally by the product of the distributions p(𝒙)p(\bm{x})italic_p ( bold_italic_x ) and p(𝒈)p(\bm{g})italic_p ( bold_italic_g ), justifying the above computation. Hence the Gaussian denoising process yields yet another interpretation of the Shannon lower bound, as arising from the entropy of a noisy version of 𝒙\bm{x}bold_italic_x, with noise level proportional to the distortion level ϵ\epsilonitalic_ϵ.

Thus, this finite rate distortion approach via sphere covering re-enables or generalizes all previous measures of complexity of the distribution, allowing us to differentiate between and rank different distributions in a unified way. These interrelated viewpoints are visualized in Figure 3.10.

For a general distribution at finite distortion levels, it is typically impossible to find its rate distortion function in an analytical form. One must often resort to numerical computation161616Interested readers may refer to [Bla72] for a classic algorithm that computes the rate distortion function numerically for a discrete distribution.. Nevertheless, as we will see, in our context we often need to know the rate distortion as an explicit function of a set of data points or their representations. This is because we want to use the coding rate as a measure of the goodness of the representations. An explicit analytical form makes it easy to determine how to transform the data distribution to improve the representation. So, we should work with distributions whose rate distortion functions take explicit analytical forms. To this end, we start with the simplest, and also the most important, family of distributions.

3.3.3 Lossy Coding Rate for a Low-Dimensional Gaussian

Now suppose we are given a set of data samples in 𝑿=[𝒙1,,𝒙N]\bm{X}=[\bm{x}_{1},\ldots,\bm{x}_{N}]bold_italic_X = [ bold_italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , bold_italic_x start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT ] from any distribution.171717Or these data points could be viewed as an (empirical) distribution themselves. We would like to come up with a constructive scheme that can encode the data up to certain precision, say

𝒙i𝒙^i,subject to𝒙i𝒙^i2ϵ.\bm{x}_{i}\mapsto\hat{\bm{x}}_{i},\quad\mbox{subject to}\quad\|\bm{x}_{i}-\hat{\bm{x}}_{i}\|_{2}\leq\epsilon.bold_italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ↦ over^ start_ARG bold_italic_x end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , subject to ∥ bold_italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT - over^ start_ARG bold_italic_x end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ≤ italic_ϵ . (3.3.16)

Notice that this is a sufficient, explicit, and interpretable condition which ensures that the data are encoded such that 1Ni=1N𝒙i𝒙^i22ϵ2\frac{1}{N}\sum_{i=1}^{N}\|\bm{x}_{i}-\hat{\bm{x}}_{i}\|_{2}^{2}\leq\epsilon^{2}divide start_ARG 1 end_ARG start_ARG italic_N end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT ∥ bold_italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT - over^ start_ARG bold_italic_x end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ≤ italic_ϵ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT. This latter inequality is exactly the rate distortion constraint for the provided empirical distribution and its encoding. For example, in Example 3.7, we used this simplified criterion to explicitly find the minimum distortion and explicit coding scheme for a given coding rate.

Without loss of generality, let us assume the mean of 𝑿\bm{X}bold_italic_X is zero, i.e., 1Ni=1N𝒙i=𝟎\frac{1}{N}\sum_{i=1}^{N}\bm{x}_{i}=\bm{0}divide start_ARG 1 end_ARG start_ARG italic_N end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT bold_italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = bold_0. Without any prior knowledge about the nature of the distribution behind 𝑿\bm{X}bold_italic_X, we may view 𝑿\bm{X}bold_italic_X as sampled from a Gaussian distribution 𝒩(𝟎,𝚺)\mathcal{N}(\bm{0},{\bm{\Sigma}})caligraphic_N ( bold_0 , bold_Σ ) with the covariance181818It is known that given a fixed variance, the Gaussian achieves the maximal entropy. That is, it gives an upper bound for what the worst case could be in terms of possible coding rate.

𝚺=1N𝑿𝑿.{\bm{\Sigma}}=\frac{1}{N}\bm{X}\bm{X}^{\top}.bold_Σ = divide start_ARG 1 end_ARG start_ARG italic_N end_ARG bold_italic_X bold_italic_X start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT . (3.3.17)

Notice that geometrically 𝚺{\bm{\Sigma}}bold_Σ characterizes an ellipsoidal region where most of the samples 𝒙i\bm{x}_{i}bold_italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT reside.

We may view 𝑿^=[𝒙^1,,𝒙^N]\hat{\bm{X}}=[\hat{\bm{x}}_{1},\ldots,\hat{\bm{x}}_{N}]over^ start_ARG bold_italic_X end_ARG = [ over^ start_ARG bold_italic_x end_ARG start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , over^ start_ARG bold_italic_x end_ARG start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT ] as a noisy version of 𝑿=[𝒙1,,𝒙N]\bm{X}=[\bm{x}_{1},\ldots,\bm{x}_{N}]bold_italic_X = [ bold_italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , bold_italic_x start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT ]:

𝒙^i=𝒙i+𝒘i,\hat{\bm{x}}_{i}=\bm{x}_{i}+\bm{w}_{i},over^ start_ARG bold_italic_x end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = bold_italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT + bold_italic_w start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , (3.3.18)

where 𝒘i\bm{w}_{i}bold_italic_w start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT is a Gaussian noise 𝒘i𝒩(𝟎,ϵ2𝑰/D)\bm{w}_{i}\sim\mathcal{N}(\bm{0},{\epsilon^{2}}\bm{I}/{D})bold_italic_w start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∼ caligraphic_N ( bold_0 , italic_ϵ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT bold_italic_I / italic_D ) independent of 𝒙i\bm{x}_{i}bold_italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT. Then the covariance of 𝒙^i\hat{\bm{x}}_{i}over^ start_ARG bold_italic_x end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT is given by

𝚺^=𝔼[𝒙^i𝒙^i]=ϵ2D𝑰+1N𝑿𝑿.\hat{\bm{\Sigma}}=\mathbb{E}\left[\hat{\bm{x}}_{i}\hat{\bm{x}}_{i}^{\top}\right]=\frac{\epsilon^{2}}{D}\bm{I}+\frac{1}{N}\bm{X}\bm{X}^{\top}.over^ start_ARG bold_Σ end_ARG = blackboard_E [ over^ start_ARG bold_italic_x end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT over^ start_ARG bold_italic_x end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ] = divide start_ARG italic_ϵ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG italic_D end_ARG bold_italic_I + divide start_ARG 1 end_ARG start_ARG italic_N end_ARG bold_italic_X bold_italic_X start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT . (3.3.19)

Note that the volume of the region spanned by the vectors 𝒙i\bm{x}_{i}bold_italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT is proportional to the square root of the determinant of the covariance matrix

volume(𝒙^i)det(𝚺^)=det(ϵ2D𝑰+1N𝑿𝑿).\mbox{volume}(\hat{\bm{x}}_{i})\propto\sqrt{\det\big{(}\hat{\bm{\Sigma}}\big{)}}=\sqrt{\det\left(\frac{\epsilon^{2}}{D}\bm{I}+\frac{1}{N}\bm{X}\bm{X}^{\top}\right)}.volume ( over^ start_ARG bold_italic_x end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ∝ square-root start_ARG roman_det ( over^ start_ARG bold_Σ end_ARG ) end_ARG = square-root start_ARG roman_det ( divide start_ARG italic_ϵ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG italic_D end_ARG bold_italic_I + divide start_ARG 1 end_ARG start_ARG italic_N end_ARG bold_italic_X bold_italic_X start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) end_ARG . (3.3.20)

The volume spanned by each random vector 𝒘i\bm{w}_{i}bold_italic_w start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT is proportional to

volume(𝒘i)det(ϵ2D𝑰).\mbox{volume}(\bm{w}_{i})\propto\sqrt{\det\left(\frac{\epsilon^{2}}{D}\bm{I}\right)}.volume ( bold_italic_w start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ∝ square-root start_ARG roman_det ( divide start_ARG italic_ϵ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG italic_D end_ARG bold_italic_I ) end_ARG . (3.3.21)
Figure 3.11 : Covering the region spanned by the data vectors using ϵ \epsilon italic_ϵ -balls. The larger the volume of the space, the more balls are needed, hence the more bits are needed to encode and enumerate the balls. Each real-valued vector 𝒙 \bm{x} bold_italic_x can be encoded as the number of the ball which it falls into.
Figure 3.11: Covering the region spanned by the data vectors using ϵ\epsilonitalic_ϵ-balls. The larger the volume of the space, the more balls are needed, hence the more bits are needed to encode and enumerate the balls. Each real-valued vector 𝒙\bm{x}bold_italic_x can be encoded as the number of the ball which it falls into.

To encode vectors that fall into the region spanned by 𝒙^i\hat{\bm{x}}_{i}over^ start_ARG bold_italic_x end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT, we can cover the region with non-overlapping balls of radius ϵ\epsilonitalic_ϵ, as illustrated in Figure 3.11. When the volume of the region spanned by 𝒙^i\hat{\bm{x}}_{i}over^ start_ARG bold_italic_x end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT is significantly larger than the volume of the ϵ\epsilonitalic_ϵ-ball, the total number of balls that we need to cover the region is approximately equal to the ratio of the two volumes:

#ϵ-ballsvolume(𝒙^i)volume(𝒘i)=det(𝑰+DNϵ2𝑿𝑿).\#\,\epsilon\mbox{-balls}\approx\frac{\mbox{volume}(\hat{\bm{x}}_{i})}{\mbox{volume}(\bm{w}_{i})}=\sqrt{\det\left(\bm{I}+\frac{D}{N\epsilon^{2}}\bm{X}\bm{X}^{\top}\right)}.# italic_ϵ -balls ≈ divide start_ARG volume ( over^ start_ARG bold_italic_x end_ARG start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) end_ARG start_ARG volume ( bold_italic_w start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) end_ARG = square-root start_ARG roman_det ( bold_italic_I + divide start_ARG italic_D end_ARG start_ARG italic_N italic_ϵ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG bold_italic_X bold_italic_X start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) end_ARG . (3.3.22)

If we use binary numbers to label all the ϵ\epsilonitalic_ϵ-balls in the region of interest, the total number of binary bits needed is thus

ϵ(𝑿)log2(#ϵ-balls)Rϵ(𝑿)12logdet(𝑰+DNϵ2𝑿𝑿).\mathcal{R}_{\epsilon}(\bm{X})\approx\log_{2}(\#\,\epsilon\mbox{-balls})\approx R_{\epsilon}(\bm{X})\doteq\frac{1}{2}\log\det\left(\bm{I}+\frac{D}{N\epsilon^{2}}\bm{X}\bm{X}^{\top}\right).caligraphic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( bold_italic_X ) ≈ roman_log start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ( # italic_ϵ -balls ) ≈ italic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( bold_italic_X ) ≐ divide start_ARG 1 end_ARG start_ARG 2 end_ARG roman_log roman_det ( bold_italic_I + divide start_ARG italic_D end_ARG start_ARG italic_N italic_ϵ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG bold_italic_X bold_italic_X start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) . (3.3.23)
Example 3.8.

Figure 3.11 shows an example of a 2D distribution with an ellipsoidal support – approximating the support of a 2D Gaussian distribution. The region is covered by small balls of size ϵ\epsilonitalic_ϵ. All the balls are numbered from 111 to say nnitalic_n. Then given any vector 𝒙\bm{x}bold_italic_x in this region, we only need to determine to which ϵ\epsilonitalic_ϵ-ball center it is the closest, denoted as ballϵ(𝒙)\operatorname{ball}_{\epsilon}(\bm{x})roman_ball start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( bold_italic_x ). To remember 𝒙\bm{x}bold_italic_x, we only need to remember the number of this ball, which takes log(n)\log(n)roman_log ( italic_n ) bits to store. If we need to decode 𝒙\bm{x}bold_italic_x from this number, we simply take 𝒙^\hat{\bm{x}}over^ start_ARG bold_italic_x end_ARG as the center of the ball. This leads to an explicit encoding and decoding scheme:

𝒙ballϵ(𝒙)𝒙^=center ofballϵ(𝒙).\bm{x}\longrightarrow\operatorname{ball}_{\epsilon}(\bm{x})\longrightarrow\hat{\bm{x}}=\mbox{center of}\operatorname{ball}_{\epsilon}(\bm{x}).bold_italic_x ⟶ roman_ball start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( bold_italic_x ) ⟶ over^ start_ARG bold_italic_x end_ARG = center of roman_ball start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( bold_italic_x ) . (3.3.24)

One may refer to these ball centers as “codes” of a code book or a dictionary for the encoding scheme. It is easy to see that the accuracy of this (lossy) encoding-decoding scheme is about the radius of the ball ϵ\epsilonitalic_ϵ. Clearly ϵ(𝒁)\mathcal{R}_{\epsilon}(\bm{Z})caligraphic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( bold_italic_Z ) is the average number of bits required to encode the ball number of each vector 𝒛\bm{z}bold_italic_z with this coding scheme, and hence can be called the coding rate associated with this scheme. \blacksquare

From the above derivation, we know that the coding rate ϵ(𝑿)\mathcal{R}_{\epsilon}(\bm{X})caligraphic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( bold_italic_X ) is (approximately) achievable with an explicit encoding (and decoding) scheme. It has two interesting properties:

  • First, one may notice that Rϵ(𝑿)R_{\epsilon}(\bm{X})italic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( bold_italic_X ) closely resembles the rate distortion function of a Gaussian source [CT91]. Indeed, when ϵ\epsilonitalic_ϵ is small, the above expression is a close approximation to the rate distortion of a Gaussian source, as pointed out by [MDH+07].

  • Second, the same closed-form coding rate Rϵ(𝑿)R_{\epsilon}(\bm{X})italic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( bold_italic_X ) can be derived as an approximation of ϵ(𝑿)\mathcal{R}_{\epsilon}(\bm{X})caligraphic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( bold_italic_X ) if the data 𝑿\bm{X}bold_italic_X are assumed to be from a linear subspace. This can be shown by properly quantifying the singular value decomposition (SVD) of 𝑿=𝑼𝚺𝑽\bm{X}=\bm{U}\bm{\Sigma}\bm{V}^{\top}bold_italic_X = bold_italic_U bold_Σ bold_italic_V start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT and constructing a lossy coding scheme for vectors in the subspace spanned by 𝑼\bm{U}bold_italic_U [MDH+07].

In our context, the closed-form expression Rϵ(𝑿)R_{\epsilon}(\bm{X})italic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( bold_italic_X ) is rather fundamental: it is the coding rate associated with an explicit and natural lossy coding scheme for data drawn from either a Gaussian distribution or a linear subspace. As we will see in the next chapter, this formula plays an important role in understanding the architecture of deep neural networks.

3.3.4 Clustering a Mixture of Low-Dimensional Gaussians

As we have discussed before, the given dataset 𝑿\bm{X}bold_italic_X often has low-dimensional intrinsic structures. Hence, encoding it as a general Gaussian would be very redundant. If we can identify those intrinsic structures in 𝑿\bm{X}bold_italic_X, we could design much better coding schemes that give much lower coding rates. Or equivalently, the codes used to encode such 𝑿\bm{X}bold_italic_X can be compressed. We will see that compression gives a unifying computable way to identify such structures. In this section, we demonstrate this important idea with the most basic family of low-dimensional structures: a mixture of (low-dimensional) Gaussians or subspaces.

Example 3.9.

Figure 3.12 shows an example in which the data 𝑿\bm{X}bold_italic_X are distributed around two subspaces (or low-dimensional Gaussians). If they are viewed and coded together as one single Gaussian, the associated discrete (lossy) code book, represented by all the blue balls, is obviously very redundant. We can try to identify the locations of the two subspaces, denoted by S1S_{1}italic_S start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT and S2S_{2}italic_S start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT, and design a code book that only covers the two subspaces, i.e., the green balls. If we can correctly partition samples in the data 𝑿\bm{X}bold_italic_X into the two subspaces: 𝑿=[𝑿1,𝑿2]𝚷\bm{X}=[\bm{X}_{1},\bm{X}_{2}]\bm{\Pi}bold_italic_X = [ bold_italic_X start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , bold_italic_X start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ] bold_Π with 𝑿1S1\bm{X}_{1}\in S_{1}bold_italic_X start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ∈ italic_S start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT and 𝑿2S2\bm{X}_{2}\in S_{2}bold_italic_X start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ∈ italic_S start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT, where 𝚷\bm{\Pi}bold_Π denotes a permutation matrix, then the resulting coding rate for the data will be much lower. This gives a more parsimonious, hence more desirable, representation of the data. \blacksquare

Figure 3.12 : Comparison of two lossy coding schemes for data that are distributed around two subspaces. One is to pack (blue) ϵ \epsilon italic_ϵ -balls for the entire space spanned by the two subspaces; the other is to pack balls only in a tabular neighborhood around the two subspaces. The latter obviously has a much smaller code book and results in a much lower coding rate for samples on the subspaces.
Figure 3.12: Comparison of two lossy coding schemes for data that are distributed around two subspaces. One is to pack (blue) ϵ\epsilonitalic_ϵ-balls for the entire space spanned by the two subspaces; the other is to pack balls only in a tabular neighborhood around the two subspaces. The latter obviously has a much smaller code book and results in a much lower coding rate for samples on the subspaces.

So, more generally speaking, if the data are drawn from any mixture of subspaces or low-dimensional Gaussians, it would be desirable to identify those components and encode the data based on the intrinsic dimensions of those components. It turns out that we do not lose much generality by assuming that the data are drawn from a mixture of low-dimensional Gaussians. This is because a mixture of Gaussians can closely approximate most general distributions [BDS16].

The clustering problem.

Now for this specific family of distributions, how can we effectively and efficiently identify those low-dimensional components from a set of samples

𝑿=[𝒙1,𝒙2,,𝒙N],\bm{X}=\left[\bm{x}_{1},\bm{x}_{2},\ldots,\bm{x}_{N}\right],bold_italic_X = [ bold_italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , bold_italic_x start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , … , bold_italic_x start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT ] , (3.3.25)

drawn from them? In other words, given the whole data set 𝑿\bm{X}bold_italic_X, we want to partition, or cluster, it into multiple, say KKitalic_K, subsets:

𝑿𝚷=[𝑿1,𝑿2,,𝑿K],\bm{X}\bm{\Pi}=[\bm{X}_{1},\bm{X}_{2},\dots,\bm{X}_{K}],bold_italic_X bold_Π = [ bold_italic_X start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , bold_italic_X start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , … , bold_italic_X start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ] , (3.3.26)

where each subset consists of samples drawn from only one low-dimensional Gaussian or subspace and 𝚷\bm{\Pi}bold_Π is a permutation matrix to indicate membership of the partition. Note that, depending the situation, the partition could be either deterministic or probabilistic. As shown in [MDH+07a], for mixture of Gaussians, probabilistic partition does not lead to a lower coding rate. So for simplicity, we here consider a deterministic partition only.

Clustering via lossy compression.

The main difficulty in solving the above clustering problem is that we normally do not know the number of clusters KKitalic_K, nor do we know the dimension of each component. There has been a long history for the study of this clustering problem. The textbook [VMS16] gives a systematic and comprehensive coverage of different approaches to this problem. To find an effective approach to this problem, we first need to understand and clarify why we want to cluster. In other words, what exactly do we gain from clustering the data, compared with not to? How do we measure the gain? From the perspective of data compression, a correct clustering should lead to a more efficient encoding (and decoding) scheme.

For any given data set 𝑿\bm{X}bold_italic_X, there are already two obvious encoding schemes as the baseline. They represent two extreme ways to encode the data:

  • Simply view all the samples together drawn as from one single Gaussian. The associated coding rate is, as derived before, given by:

    ϵ(𝑿)Rϵ(𝑿)=12logdet(𝑰+DNϵ2𝑿𝑿).\mathcal{R}_{\epsilon}(\bm{X})\approx R_{\epsilon}(\bm{X})=\frac{1}{2}\log\det\left(\bm{I}+\frac{D}{N\epsilon^{2}}\bm{X}\bm{X}^{\top}\right).caligraphic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( bold_italic_X ) ≈ italic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( bold_italic_X ) = divide start_ARG 1 end_ARG start_ARG 2 end_ARG roman_log roman_det ( bold_italic_I + divide start_ARG italic_D end_ARG start_ARG italic_N italic_ϵ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG bold_italic_X bold_italic_X start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) . (3.3.27)
  • Simply memorize all the samples separately by assigning a different number to each sample. The coding rate would be:

    0(𝑿)=log(N).\mathcal{R}_{0}(\bm{X})=\log(N).caligraphic_R start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( bold_italic_X ) = roman_log ( italic_N ) . (3.3.28)

Note that either coding scheme can become the “optimal” solution for certain (extreme) choice of the quantization error ϵ\epsilonitalic_ϵ:

  1. 1.

    Lazy Regime: If we choose ϵ\epsilonitalic_ϵ to be extremely large, all samples in 𝑿\bm{X}bold_italic_X can be covered by a single ball. The rate is limϵϵ12logdet(𝑰)=0\lim_{\epsilon\rightarrow\infty}\mathcal{R}_{\epsilon}\rightarrow\frac{1}{2}\log\det(\bm{I})=0roman_lim start_POSTSUBSCRIPT italic_ϵ → ∞ end_POSTSUBSCRIPT caligraphic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT → divide start_ARG 1 end_ARG start_ARG 2 end_ARG roman_log roman_det ( bold_italic_I ) = 0.

  2. 2.

    Memorization Regime: If ϵ\epsilonitalic_ϵ is extremely small, every sample in 𝑿\bm{X}bold_italic_X is covered by a different ϵ\epsilonitalic_ϵ-ball, hence the total is NNitalic_N. The rate is limϵ0ϵlog(N)\lim_{\epsilon\rightarrow 0}\mathcal{R}_{\epsilon}\rightarrow\log(N)roman_lim start_POSTSUBSCRIPT italic_ϵ → 0 end_POSTSUBSCRIPT caligraphic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT → roman_log ( italic_N ).

Note that the first scheme corresponds to the scenario when one does not care about anything interesting about the distribution at all. One does not want to spare any bit for anything informative. We call this the “lazy regime.” The second scheme corresponds to the scenario when one wants to decode every sample with an extremely high precision. So one would better “memorize” every sample. We call this the “memorization regime.”

Figure 3.13 : A number of random samples on a 2D plane. Consider an ϵ \epsilon italic_ϵ -disc assigned to each sample with the sample as its center. The density of the samples increases from left to right.
Figure 3.13: A number of random samples on a 2D plane. Consider an ϵ\epsilonitalic_ϵ-disc assigned to each sample with the sample as its center. The density of the samples increases from left to right.
Example 3.10.

To see when the memorization regime is preferred or not, let us consider a number, say NNitalic_N, of samples randomly distributed in a unit area on a 2D plane.191919Say the points are drawn by a Poisson process with density NNitalic_N points per unit area. Imagine we try to design a lossy coding scheme with a fixed quantization error ϵ\epsilonitalic_ϵ. This is equivalent to putting an ϵ\epsilonitalic_ϵ-disc around each sample, as shown in Figure 3.13. When NNitalic_N is small, the chance that all the discs overlap with each other is zero. A codebook of size NNitalic_N is necessary and optimal in this case. When NNitalic_N or the density reaches a certain critical value NcN_{c}italic_N start_POSTSUBSCRIPT italic_c end_POSTSUBSCRIPT, with high probability all the discs start to overlap and connect into one cluster that covers the whole plane—this phenomenon is known as continuum “percolation” [Gil61, MM12]. When NNitalic_N becomes larger than this value, the discs overlap heavily. The number NNitalic_N of discs becomes very redundant because we only want to encode points on the plane up to the given precision ϵ\epsilonitalic_ϵ. The number of discs needed to cover all the samples is much less than NNitalic_N.202020In fact, there are efficient algorithms to find such a covering [BBF+01]. \blacksquare

Both the lazy and memorization regimes are somewhat trivial and perhaps are of little theoretical or practical interest. Either scheme would be far from optimal when used to encode a large number of samples drawn from a distribution that has a compact and low-dimensional support. The interesting regime exists in between these two.

Example 3.11.
Figure 3.14 : Top: 358 noisy samples drawn from two lines and one plane in ℝ 3 \mathbb{R}^{3} blackboard_R start_POSTSUPERSCRIPT 3 end_POSTSUPERSCRIPT . Bottom: the effect of varying ϵ \epsilon italic_ϵ on the clustering result and the coding rate. The red line marks the variance ϵ 0 \epsilon_{0} italic_ϵ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT of the Gaussian noise added to the samples.
Figure 3.14 : Top: 358 noisy samples drawn from two lines and one plane in ℝ 3 \mathbb{R}^{3} blackboard_R start_POSTSUPERSCRIPT 3 end_POSTSUPERSCRIPT . Bottom: the effect of varying ϵ \epsilon italic_ϵ on the clustering result and the coding rate. The red line marks the variance ϵ 0 \epsilon_{0} italic_ϵ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT of the Gaussian noise added to the samples.
Figure 3.14: Top: 358 noisy samples drawn from two lines and one plane in 3\mathbb{R}^{3}blackboard_R start_POSTSUPERSCRIPT 3 end_POSTSUPERSCRIPT. Bottom: the effect of varying ϵ\epsilonitalic_ϵ on the clustering result and the coding rate. The red line marks the variance ϵ0\epsilon_{0}italic_ϵ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT of the Gaussian noise added to the samples.

Figure 3.14 shows an example with noisy samples drawn from two lines and one plane in 3\mathbb{R}^{3}blackboard_R start_POSTSUPERSCRIPT 3 end_POSTSUPERSCRIPT. As we notice from the plot (c) on the right, the optimal coding rate decreases monotonically as we increase ϵ\epsilonitalic_ϵ, as anticipated from the property of the rate distortion function. The plots (a) and (b) show, when varying ϵ\epsilonitalic_ϵ from very small (near zero) to very large (towards infinite), the optimal number of clusters when the coding rate is minimal. We can clearly see the lazy regime and the memorization regime on the two ends of the plots. But one can also notice in plot (b), when the quantization error ϵ\epsilonitalic_ϵ is chosen to be around the level of the true noise variance ϵ0\epsilon_{0}italic_ϵ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT, the optimal number of clusters is the “correct” number three that represents two planes and one subspace. We informally refer to this middle regime as the “generalization regime”. Notice that a sharp phase transition takes place between these regimes.212121So far, to our best knowledge, there is no rigorous theoretical justification for these phase transition behaviors. \blacksquare

From the above discussion and examples, we see that, when the quantization error relative to the sample density222222or the sample density relative to the quantization error is in a proper range, minimizing the lossy coding rate would allow us to uncover the underlying (low-dimensional) distribution of the sampled data. Hence, quantization, started as a choice of practicality, seems to be becoming necessary for learning a continuous distribution from its empirical distribution with finite samples. Although a rigorous theory for explaining this phenomenon remains elusive, here, for learning purposes, we care about how to exploit the phenomenon to design algorithms that can find the correct distribution.

Let us use the simple example shown in Figure 3.12 to illustrate the basic ideas. If one can partition all samples in 𝑿\bm{X}bold_italic_X into two clusters in 𝑿1\bm{X}_{1}bold_italic_X start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT and 𝑿2\bm{X}_{2}bold_italic_X start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT, with N1N_{1}italic_N start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT and N2N_{2}italic_N start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT samples respectively, then the associated coding rate would be232323We here ignore some overhead bits needed to encode the membership for each sample, say via the Huffman coding.

Rϵc(𝑿𝚷)=N1NRϵ(𝑿1)+N2NRϵ(𝑿2),R_{\epsilon}^{c}(\bm{X}\mid\bm{\Pi})=\frac{N_{1}}{N}R_{\epsilon}(\bm{X}_{1})+\frac{N_{2}}{N}R_{\epsilon}(\bm{X}_{2}),italic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT ( bold_italic_X ∣ bold_Π ) = divide start_ARG italic_N start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_ARG start_ARG italic_N end_ARG italic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( bold_italic_X start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) + divide start_ARG italic_N start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_ARG start_ARG italic_N end_ARG italic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( bold_italic_X start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) , (3.3.29)

where we use 𝚷\bm{\Pi}bold_Π to indicate membership of the partition. If the partition respects the low-dimensional structures of the distribution, in this case 𝑿1\bm{X}_{1}bold_italic_X start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT and 𝑿2\bm{X}_{2}bold_italic_X start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT belonging to the two subspaces respectively, then the resulting coding rate should be significantly smaller than the above two basic schemes:

Rϵc(𝑿𝚷)Rϵ(𝑿),Rϵc(𝑿𝚷)R0(𝑿).R_{\epsilon}^{c}(\bm{X}\mid\bm{\Pi})\ll R_{\epsilon}(\bm{X}),\quad R_{\epsilon}^{c}(\bm{X}\mid\bm{\Pi})\ll R_{0}(\bm{X}).italic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT ( bold_italic_X ∣ bold_Π ) ≪ italic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( bold_italic_X ) , italic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT ( bold_italic_X ∣ bold_Π ) ≪ italic_R start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( bold_italic_X ) . (3.3.30)

In general, we can cast the clustering problem into an optimization problem that minimizes the coding rate:

min𝚷{Rϵc(𝑿𝚷)k=1KNkNRϵ(𝑿k)}.\min_{\bm{\Pi}}\left\{R_{\epsilon}^{c}(\bm{X}\mid\bm{\Pi})\doteq\sum_{k=1}^{K}\frac{N_{k}}{N}R_{\epsilon}(\bm{X}_{k})\right\}.roman_min start_POSTSUBSCRIPT bold_Π end_POSTSUBSCRIPT { italic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT ( bold_italic_X ∣ bold_Π ) ≐ ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT divide start_ARG italic_N start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_ARG start_ARG italic_N end_ARG italic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( bold_italic_X start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) } . (3.3.31)

Optimization strategies to cluster.

The remaining question is how we optimize the above coding rate objective to find the optimal clusters. There are three natural approaches to this objective:

  1. 1.

    We may start with the whole set 𝑿\bm{X}bold_italic_X as a single cluster (i.e. the lazy regime) and then search (say randomly) to partition it so that it would lead to a smaller coding rate.

  2. 2.

    Inversely, we may start with each sample 𝒙i\bm{x}_{i}bold_italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT as its own cluster (i.e. the memorization regime) and search to merge clusters that would result in a smaller coding rate.

  3. 3.

    Alternatively, if we could represent (or approximate) the membership 𝚷\bm{\Pi}bold_Π as some continuous parameters, we may use optimization methods such as gradient descent (GD).

The first approach is not so appealing computationally as the number of possible partitions that one needs to try is exponential in the number of samples. For example, the number of partitions of 𝑿\bm{X}bold_italic_X into two subsets of equal size is (NN/2)N\choose N/2( binomial start_ARG italic_N end_ARG start_ARG italic_N / 2 end_ARG ) which explodes as NNitalic_N becomes large. We will explore the third approach in the next Chapter 4. There, we will see how the role of deep neural networks, transformers in particular, is connected with the coding rate objective.

The second approach was originally suggested in the work of [MDH+07a]. It demonstrates the benefit of being able to evaluate the coding rate efficiently (say with an analytical form). With it, the (low-dimensional) clusters of the data can be found rather efficiently and effectively via the principle of minimizing coding length (MCL). Note that for a cluster 𝑿k\bm{X}_{k}bold_italic_X start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT with NkN_{k}italic_N start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT samples, the length of binary bits needed to encode all the samples in 𝑿k\bm{X}_{k}bold_italic_X start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT is given by:242424In fact, a more accurate estimate of the coding length is L(𝑿k)=(Nk+D)Rϵ(𝑿k)L(\bm{X}_{k})=(N_{k}+D)R_{\epsilon}(\bm{X}_{k})italic_L ( bold_italic_X start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) = ( italic_N start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT + italic_D ) italic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( bold_italic_X start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) where the extra bits are used to encode the basis of the subspace [MDH+07a]. Here we omit this overhead for simplicity.

L(𝑿k)=NkRϵ(𝑿k).L(\bm{X}_{k})=N_{k}R_{\epsilon}(\bm{X}_{k}).italic_L ( bold_italic_X start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) = italic_N start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT italic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( bold_italic_X start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) . (3.3.32)

If we have two clusters 𝑿k\bm{X}_{k}bold_italic_X start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT and 𝑿l\bm{X}_{l}bold_italic_X start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT, if we want to code the samples as two separate clusters, the length of binary bits needed is

Lc(𝑿k,𝑿l)=NkRϵ(𝑿k)+NlRϵ(𝑿l)NklogNkNk+NlNllogNlNk+Nl.L^{c}(\bm{X}_{k},\bm{X}_{l})=N_{k}R_{\epsilon}(\bm{X}_{k})+N_{l}R_{\epsilon}(\bm{X}_{l})-N_{k}\log\frac{N_{k}}{N_{k}+N_{l}}-N_{l}\log\frac{N_{l}}{N_{k}+N_{l}}.italic_L start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT ( bold_italic_X start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT , bold_italic_X start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ) = italic_N start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT italic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( bold_italic_X start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) + italic_N start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT italic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( bold_italic_X start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ) - italic_N start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT roman_log divide start_ARG italic_N start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_ARG start_ARG italic_N start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT + italic_N start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_ARG - italic_N start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT roman_log divide start_ARG italic_N start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_ARG start_ARG italic_N start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT + italic_N start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT end_ARG .

The last two terms are the number of bits needed to encode the memberships of samples according to the Huffman code.

Then, given any two separate clusters 𝑿1\bm{X}_{1}bold_italic_X start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT and 𝑿2\bm{X}_{2}bold_italic_X start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT, we can decide whether to merge them or not based on the difference between the two coding lengths:

L(𝑿k𝑿l)Lc(𝑿k,𝑿l)L(\bm{X}_{k}\cup\bm{X}_{l})-L^{c}(\bm{X}_{k},\bm{X}_{l})italic_L ( bold_italic_X start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ∪ bold_italic_X start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ) - italic_L start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT ( bold_italic_X start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT , bold_italic_X start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ) (3.3.33)

is positive or negative and 𝑿k𝑿l\bm{X}_{k}\cup\bm{X}_{l}bold_italic_X start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ∪ bold_italic_X start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT denotes the union of the sets of samples in 𝑿k\bm{X}_{k}bold_italic_X start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT and 𝑿l\bm{X}_{l}bold_italic_X start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT. If it is negative, it means the coding length would become smaller if we merge the two clusters into one. This simple fact leads to the following clustering algorithm proposed by [MDH+07a]:

Algorithm 3.3 Pairwise Steepest Descent of Coding Length
1:NNitalic_N data points {𝒙i}i=1N\{\bm{x}_{i}\}_{i=1}^{N}{ bold_italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT
2:A set 𝒞\mathcal{C}caligraphic_C of clusters
3:procedure PairwiseSteepestDescentOfCodingLength({𝒙i}i=1N\{\bm{x}_{i}\}_{i=1}^{N}{ bold_italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT)
4:     𝒞{{𝒙i}}i=1N\mathcal{C}\leftarrow\{\{\bm{x}_{i}\}\}_{i=1}^{N}caligraphic_C ← { { bold_italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT \triangleright Initialize NNitalic_N clusters 𝑿k\bm{X}_{k}bold_italic_X start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT with one element each
5:     while |𝒞|>1\lvert\mathcal{C}\rvert>1| caligraphic_C | > 1 do
6:         if min𝑿k,𝑿l𝒞[L(𝑿k𝑿l)Lc(𝑿k,𝑿l)]0\displaystyle\min_{\bm{X}_{k},\bm{X}_{l}\in\mathcal{C}}[L(\bm{X}_{k}\cup\bm{X}_{l})-L^{c}(\bm{X}_{k},\bm{X}_{l})]\geq 0roman_min start_POSTSUBSCRIPT bold_italic_X start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT , bold_italic_X start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ∈ caligraphic_C end_POSTSUBSCRIPT [ italic_L ( bold_italic_X start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ∪ bold_italic_X start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ) - italic_L start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT ( bold_italic_X start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT , bold_italic_X start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ) ] ≥ 0 then \triangleright If no bits are saved by any merging
7:              return 𝒞\mathcal{C}caligraphic_C \triangleright Early return 𝒞\mathcal{C}caligraphic_C and exit
8:         else
9:              𝑿k,𝑿largmin𝑿k,𝑿l𝒞[L(𝑿k𝑿l)Lc(𝑿k,𝑿l)]\displaystyle\bm{X}_{k^{\ast}},\bm{X}_{l^{\ast}}\leftarrow\operatorname*{arg\ min}_{\bm{X}_{k},\bm{X}_{l}\in\mathcal{C}}[L(\bm{X}_{k}\cup\bm{X}_{l})-L^{c}(\bm{X}_{k},\bm{X}_{l})]bold_italic_X start_POSTSUBSCRIPT italic_k start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT , bold_italic_X start_POSTSUBSCRIPT italic_l start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ← start_OPERATOR roman_arg roman_min end_OPERATOR start_POSTSUBSCRIPT bold_italic_X start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT , bold_italic_X start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ∈ caligraphic_C end_POSTSUBSCRIPT [ italic_L ( bold_italic_X start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ∪ bold_italic_X start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ) - italic_L start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT ( bold_italic_X start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT , bold_italic_X start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT ) ] \triangleright Merge clusters which save the most bits
10:              𝒞[𝒞{𝑿k,𝑿l}]{𝑿k𝑿l}\displaystyle\mathcal{C}\leftarrow[\mathcal{C}\setminus\{\bm{X}_{k^{\ast}},\bm{X}_{l^{\ast}}\}]\cup\{\bm{X}_{k^{\ast}}\cup\bm{X}_{l^{\ast}}\}caligraphic_C ← [ caligraphic_C ∖ { bold_italic_X start_POSTSUBSCRIPT italic_k start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT , bold_italic_X start_POSTSUBSCRIPT italic_l start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT } ] ∪ { bold_italic_X start_POSTSUBSCRIPT italic_k start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ∪ bold_italic_X start_POSTSUBSCRIPT italic_l start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT } \triangleright Remove unmerged clusters and add back the merged one
11:         end if
12:     end while
13:     return 𝒞\mathcal{C}caligraphic_C \triangleright If all merges yield savings, return one cluster
14:end procedure

Note that this algorithm is tractable as the total number of (pairwise) comparisons and merges is about O(N2logN)O(N^{2}\log N)italic_O ( italic_N start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT roman_log italic_N ). However, due to its greedy nature, there is no theoretical guarantee that the process will converge to the globally optimal clustering solution. Nevertheless, as reported in [MDH+07a], in practice, this seemingly simple algorithm works extremely well. The clustering results plotted in Figure 3.14 were actually computed by this algorithm.

Example 3.12 (Image Segmentation).

The above measure of coding length and the associated clustering algorithm assume the data distribution is a mixture of (low-dimensional) Gaussians. Although this seems somewhat idealistic, the measure and algorithm can already be very useful and even powerful in scenarios when the model is (approximately) valid.

For example, a natural image typically consists of multiple regions with nearly homogeneous textures. If we take many small windows from each region, they should resemble samples drawn from a (low-dimensional) Gaussian, as illustrated in Figure 3.15. Figure 3.16 shows the results of image segmentation based on applying the above clustering algorithm to the image patches directly. More technical details regarding customizing the algorithm to the image segmentation problem can be found in [MRY+11]. \blacksquare

Figure 3.15 : Image patches with a size of w × w w\times w italic_w × italic_w pixels.
Figure 3.15: Image patches with a size of w×ww\times witalic_w × italic_w pixels.
Figure 3.16 : Segmentation results based on the clustering algorithm applied to the image patches.
Figure 3.16: Segmentation results based on the clustering algorithm applied to the image patches.

3.4 Maximizing Information Gain

So far in this chapter, we have discussed how to identify a distribution with low-dimensional structures through the principle of compression. As we have seen from the previous two sections, computational compression can be realized through either the denoising operation or clustering. Figure 3.17 illustrates this concept with our favorite example.

Figure 3.17 : Identify a low-dimensional distribution with two subspaces (left) via denoising or clustering, starting from a generic random Gaussian distribution (right).
Figure 3.17: Identify a low-dimensional distribution with two subspaces (left) via denoising or clustering, starting from a generic random Gaussian distribution (right).

Of course, the ultimate goal for identifying a data distribution is to use it to facilitate certain subsequent tasks such as segmentation, classification, or generation (of images). Hence, how the resulting distribution is “represented” matters tremendously with respect to how information related to these subsequent tasks can be efficiently and effectively retrieved and utilized. This naturally raises a fundamental question: what makes a representation truly “good” for downstream use? In the following, we will explore the essential properties that a meaningful and useful representation should possess, and how these properties can be explicitly characterized and pursued via maximizing information gain.

How to measure the goodness of representations.

One may view a given dataset as samples of a random vector 𝒙\bm{x}bold_italic_x with a certain distribution in a high-dimensional space, say D\mathbb{R}^{D}blackboard_R start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT. Typically, the distribution of 𝒙\bm{x}bold_italic_x has a much lower intrinsic dimension than the ambient space. Generally speaking, learning a representation refers to learning a continuous mapping, say f()f(\cdot)italic_f ( ⋅ ), that transforms 𝒙\bm{x}bold_italic_x to a so-called feature vector 𝒛\bm{z}bold_italic_z in another (typically lower-dimensional) space, say d\mathbb{R}^{d}blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT, where d<Dd<Ditalic_d < italic_D. It is hopeful that through such a mapping

𝒙Df(𝒙)𝒛d,\bm{x}\in\mathbb{R}^{D}\xrightarrow{\hskip 5.69054ptf(\bm{x})\hskip 5.69054pt}\bm{z}\in\mathbb{R}^{d},bold_italic_x ∈ blackboard_R start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT start_ARROW start_OVERACCENT italic_f ( bold_italic_x ) end_OVERACCENT → end_ARROW bold_italic_z ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT , (3.4.1)

the low-dimensional intrinsic structures of 𝒙\bm{x}bold_italic_x are identified and represented by 𝒛\bm{z}bold_italic_z in a more compact and structured way so as to facilitate subsequent tasks such as classification or generation. The feature 𝒛\bm{z}bold_italic_z can be viewed as a (learned) compact code for the original data 𝒙\bm{x}bold_italic_x, so the mapping ffitalic_f is also called an encoder. The fundamental question of representation learning is

What is a principled and effective measure for the goodness of representations?

Conceptually, the quality of a representation 𝒛\bm{z}bold_italic_z depends on how well it identifies the most relevant and sufficient information of 𝒙\bm{x}bold_italic_x for subsequent tasks and how efficiently it represents this information. For a long time, it was believed and argued that the “sufficiency” or “goodness” of a learned feature representation should be defined in terms of a specific task. For example, 𝒛\bm{z}bold_italic_z just needs to be sufficient for predicting the class label 𝒚\bm{y}bold_italic_y in a classification problem. Below, let us start with the classic problem of image classification and argue why such a notion of a task-specific “representation” is limited and needs to be generalized.

3.4.1 Linear Discriminative Representations

Suppose that 𝒙D\bm{x}\in\mathbb{R}^{D}bold_italic_x ∈ blackboard_R start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT is a random vector drawn from a mixture of KKitalic_K (component) distributions 𝒟={𝒟k}k=1K\mathcal{D}=\{\mathcal{D}_{k}\}_{k=1}^{K}caligraphic_D = { caligraphic_D start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT. Give a finite set of i.i.d. samples 𝑿=[𝒙1,𝒙2,,𝒙N]D×N\bm{X}=[\bm{x}_{1},\bm{x}_{2},\ldots,\bm{x}_{N}]\in\mathbb{R}^{D\times N}bold_italic_X = [ bold_italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , bold_italic_x start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , … , bold_italic_x start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT ] ∈ blackboard_R start_POSTSUPERSCRIPT italic_D × italic_N end_POSTSUPERSCRIPT of the random vector 𝒙\bm{x}bold_italic_x, we seek a good representation through a continuous mapping f(𝒙):Ddf(\bm{x}):\mathbb{R}^{D}\rightarrow\mathbb{R}^{d}italic_f ( bold_italic_x ) : blackboard_R start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT → blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT that captures intrinsic structures of 𝒙\bm{x}bold_italic_x and best facilitates the subsequent classification task.252525Classification is the domain where deep learning demonstrated the initial success, sparking the explosive interest in deep networks. Although our study focuses on classification, we believe the ideas and principles can be naturally generalized to other settings, such as regression. To ease the task of learning distribution 𝒟\mathcal{D}caligraphic_D, in the popular supervised classification setting, a true class label (or a code word for each class), usually represented by a one-hot vector 𝒚iK\bm{y}_{i}\in\mathbb{R}^{K}bold_italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT, is given for each sample 𝒙i\bm{x}_{i}bold_italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT.

Encoding class information via cross entropy.

Extensive studies have shown that for many practical datasets (e.g., images, audio, and natural languages), the (encoding) mapping from the data 𝒙\bm{x}bold_italic_x to its class label 𝒚\bm{y}bold_italic_y can be effectively modeled by training a deep network,262626Here let us not worry about yet which network we should use here and why. The purpose here is to consider any empirically tested deep network. We will leave the justification of the network architectures to the next chapter. here denoted as

f(𝒙,θ):𝒙𝒚f(\bm{x},\theta):\bm{x}\mapsto\bm{y}italic_f ( bold_italic_x , italic_θ ) : bold_italic_x ↦ bold_italic_y

with network parameters θΘ\theta\in\Thetaitalic_θ ∈ roman_Θ, where Θ\Thetaroman_Θ denotes the parameter space. For the output f(𝒙,θ)f(\bm{x},\theta)italic_f ( bold_italic_x , italic_θ ) to match well with the label 𝒚\bm{y}bold_italic_y, we like to minimize the cross-entropy loss over a training set {(𝒙i,𝒚i)}i=1N\{(\bm{x}_{i},\bm{y}_{i})\}_{i=1}^{N}{ ( bold_italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , bold_italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT:

minθΘ𝔼[𝒚,log(f(𝒙,θ))]1Ni=1N𝒚i,log(f(𝒙i,θ)).\min_{\theta\in\Theta}\;-\mathbb{E}[\langle\bm{y},\log(f(\bm{x},\theta))\rangle]\,\approx-\frac{1}{N}\sum_{i=1}^{N}\langle\bm{y}_{i},\log\left(f(\bm{x}_{i},\theta)\right)\rangle.roman_min start_POSTSUBSCRIPT italic_θ ∈ roman_Θ end_POSTSUBSCRIPT - blackboard_E [ ⟨ bold_italic_y , roman_log ( italic_f ( bold_italic_x , italic_θ ) ) ⟩ ] ≈ - divide start_ARG 1 end_ARG start_ARG italic_N end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT ⟨ bold_italic_y start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , roman_log ( italic_f ( bold_italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_θ ) ) ⟩ . (3.4.2)

The optimal network parameters θ\thetaitalic_θ are typically found by optimizing the above objective through an efficient gradient descent scheme, with gradients computed via back propagation (BP), as described in Section A.2.3 of Appendix A.

Despite its effectiveness and enormous popularity, there are two serious limitations with this approach: 1) It aims only to predict the labels 𝒚\bm{y}bold_italic_y even if they might be mislabeled. Empirical studies show that deep networks, used as a “black box,” can even fit random labels [ZBH+17]. 2) With such an end-to-end data fitting, despite plenty of empirical efforts in trying to interpret the so-learned features, it is not clear to what extent the intermediate features learned by the network capture the intrinsic structures of the data that make meaningful classification possible in the first place. The precise geometric and statistical properties of the learned features are also often obscured, which leads to the lack of interpretability and subsequent performance guarantees (e.g., generalizability, transferability, and robustness, etc.) in deep learning. Therefore, one of the goals of this section is to address such limitations by reformulating the objective towards learning explicitly meaningful and useful representations for the data 𝐱\bm{x}bold_italic_x, not limited to classification.

Figure 3.18 : Evolution of penultimate layer outputs of a VGG13 neural network when trained on the CIFAR10 dataset with 3 randomly selected classes. Figure from [ PHD20 ] .
Figure 3.18: Evolution of penultimate layer outputs of a VGG13 neural network when trained on the CIFAR10 dataset with 3 randomly selected classes. Figure from [PHD20].

Minimal discriminative features via information bottleneck.

One popular approach to interpret the role of deep networks is to view outputs of intermediate layers of the network as selecting certain latent features 𝒛=f(𝒙,θ)d\bm{z}=f(\bm{x},\theta)\in\mathbb{R}^{d}bold_italic_z = italic_f ( bold_italic_x , italic_θ ) ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT of the data that are discriminative among multiple classes. Learned representations 𝒛\bm{z}bold_italic_z then facilitate the subsequent classification task for predicting the class label 𝒚\bm{y}bold_italic_y by optimizing a classifier g(𝒛)g(\bm{z})italic_g ( bold_italic_z ):

𝒙f(𝒙,θ)𝒛g(𝒛)𝒚.\bm{x}\xrightarrow{\hskip 5.69054ptf(\bm{x},\theta)\hskip 5.69054pt}\bm{z}\xrightarrow{\hskip 5.69054ptg(\bm{z})\hskip 5.69054pt}\bm{y}.bold_italic_x start_ARROW start_OVERACCENT italic_f ( bold_italic_x , italic_θ ) end_OVERACCENT → end_ARROW bold_italic_z start_ARROW start_OVERACCENT italic_g ( bold_italic_z ) end_OVERACCENT → end_ARROW bold_italic_y . (3.4.3)

We know from information theory [CT91] that the mutual information between two random variables, say 𝒙,𝒛\bm{x},\bm{z}bold_italic_x , bold_italic_z, is defined to be

I(𝒙;𝒛)=H(𝒙)H(𝒙𝒛),I(\bm{x};\bm{z})=H(\bm{x})-H(\bm{x}\mid\bm{z}),italic_I ( bold_italic_x ; bold_italic_z ) = italic_H ( bold_italic_x ) - italic_H ( bold_italic_x ∣ bold_italic_z ) , (3.4.4)

where H(𝒙|𝒛)H(\bm{x}|\bm{z})italic_H ( bold_italic_x | bold_italic_z ) is the conditional entropy of 𝒙\bm{x}bold_italic_x given 𝒛\bm{z}bold_italic_z. The mutual information is also known as the information gain: It measures how much the entropy of the random variable 𝒙\bm{x}bold_italic_x can be reduced once 𝒛\bm{z}bold_italic_z is given. Or equivalently, it measures how much information 𝒛\bm{z}bold_italic_z contains about 𝒙\bm{x}bold_italic_x. The information bottleneck (IB) formulation [TZ15] further hypothesizes that the role of the network is to learn 𝒛\bm{z}bold_italic_z as the minimal sufficient statistics for predicting 𝒚\bm{y}bold_italic_y. Formally, it seeks to maximize the mutual information I(𝒛,𝒚)I(\bm{z},\bm{y})italic_I ( bold_italic_z , bold_italic_y ) between 𝒛\bm{z}bold_italic_z and 𝒚\bm{y}bold_italic_y while minimizing the mutual information I(𝒙,𝒛)I(\bm{x},\bm{z})italic_I ( bold_italic_x , bold_italic_z ) between 𝒙\bm{x}bold_italic_x and 𝒛\bm{z}bold_italic_z:

maxθΘIB(𝒙,𝒚,𝒛)I(𝒛;𝒚)βI(𝒙;𝒛)s.t.𝒛=f(𝒙,θ),\max_{\theta\in\Theta}\;\mbox{IB}(\bm{x},\bm{y},\bm{z})\doteq I(\bm{z};\bm{y})-\beta I(\bm{x};\bm{z})\quad\ \mathrm{s.t.}\ \bm{z}=f(\bm{x},\theta),roman_max start_POSTSUBSCRIPT italic_θ ∈ roman_Θ end_POSTSUBSCRIPT IB ( bold_italic_x , bold_italic_y , bold_italic_z ) ≐ italic_I ( bold_italic_z ; bold_italic_y ) - italic_β italic_I ( bold_italic_x ; bold_italic_z ) roman_s . roman_t . bold_italic_z = italic_f ( bold_italic_x , italic_θ ) , (3.4.5)

where β>0\beta>0italic_β > 0.

Given one can overcome some caveats associated with this framework [KTV18], such as how to accurately evaluate mutual information with finite samples of degenerate distributions, this framework can be helpful in explaining certain behaviors of deep networks. For example, recent work [PHD20] indeed shows that the representations learned via the cross-entropy loss (3.4.2) exhibit a neural collapse phenomenon. That is, features of each class are mapped to a one-dimensional vector whereas all other information of the class is suppressed, as illustrated in Figure 3.18.

Remark 3.8.

Neural collapse refers to a phenomenon observed in deep neural networks trained for classification, where the learned feature representations and classifier weights exhibit highly symmetric and structured behavior during the terminal phase of training [PHD20, ZDZ+21]. Specifically, within each class, features collapse to their class mean, and across classes, these means become maximally separated, forming a simplex equiangular configuration. The linear classifier aligns with the class mean up to rescaling. Additionally, the last-layer classifier converges to choosing whichever class has the nearest train class mean. Neural collapse reveals deep connections between optimization dynamics, generalization, and geometric structures arising in supervised learning.

From the above example of classification, we see that the so-learned representation gives a very simple encoder that essentially maps each class of data to only one code word: the one-hot vector representing each class. From the lossy compression perspective, such an encoder is too lossy to preserve information in the data distribution. Other information, such as that useful for tasks such as image generation, is severely lost in such a supervised learning process. To remedy this situation, we want to learn a different encoding scheme such that the resulting feature representation can capture much richer information about the data distribution, not limited to that useful for classification alone.

Figure 3.19 : After identifying the low-dimensional data distribution, we would like to further transform the data distribution to a more informative structure representation: R R italic_R is the number of ϵ \epsilon italic_ϵ -balls covering the whole space and R c R^{c} italic_R start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT is the sum of the numbers for all the subspaces (the green balls). Δ ​ R \Delta R roman_Δ italic_R is their difference (the number of blue balls).
Figure 3.19: After identifying the low-dimensional data distribution, we would like to further transform the data distribution to a more informative structure representation: RRitalic_R is the number of ϵ\epsilonitalic_ϵ-balls covering the whole space and RcR^{c}italic_R start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT is the sum of the numbers for all the subspaces (the green balls). ΔR\Delta Rroman_Δ italic_R is their difference (the number of blue balls).

Linear discriminative representations.

Whether the given data 𝑿\bm{X}bold_italic_X of a mixed distribution 𝒟\mathcal{D}caligraphic_D can be effectively classified or clustered depends on how separable (or discriminative) the component distributions 𝒟k\mathcal{D}_{k}caligraphic_D start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT are (or can be made). One popular working assumption is that the distribution of each class has relatively low-dimensional intrinsic structures. Hence we may assume that the distribution 𝒟k\mathcal{D}_{k}caligraphic_D start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT of each class has a support on a low-dimensional submanifold, say k\mathcal{M}_{k}caligraphic_M start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT with dimension dkDd_{k}\ll Ditalic_d start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ≪ italic_D, and the distribution 𝒟\mathcal{D}caligraphic_D of 𝒙\bm{x}bold_italic_x is supported on the mixture of those submanifolds, =k=1Kk\mathcal{M}=\cup_{k=1}^{K}\mathcal{M}_{k}caligraphic_M = ∪ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT caligraphic_M start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT, in the high-dimensional ambient space D\mathbb{R}^{D}blackboard_R start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT.

Not only do we need to identify the low-dimensional distribution, but we also want to represent the distribution in a form that best facilitates subsequent tasks such as classification, clustering, and conditioned generation (as we will see in the future). To do so, we require our learned feature representations to have the following properties:

  1. 1.

    Within-Class Compressible: Features of samples from the same class should be strongly correlated in the sense that they belong to a low-dimensional linear subspace.

  2. 2.

    Between-Class Discriminative: Features of samples from different classes should be highly uncorrelated and belong to different low-dimensional linear subspaces.

  3. 3.

    Maximally Diverse Representation: Dimension (or variance) of the features of each class should be as large as possible as long as they are incoherent to the other classes.

We refer to such a representation the linear discriminative representation (LDR). Notice that the first property aligns well with the objective of the classic principal component analysis (PCA) that we have discussed in Section 2.1.1. The second property resembles that of the classic linear discriminant analysis (LDA) [HTF09]. Figure 3.19 illustrates these properties with a simple example when the data distribution is actually a mixture of two subspaces. Through compression (denoising or clustering), we first identify that the true data distribution is a mixture of two low-dimensional subspaces (middle) instead of a generic Gaussian distribution (left). We then would like to transform the distribution so that the two subspaces eventually become mutually incoherent/independent (right).

Remark 3.9.

Linear discriminant analysis (LDA) [HTF09] is a supervised dimensionality reduction technique that aims to find a linear projection of data that maximizes class separability. Specifically, given labeled data, LDA seeks a linear transformation that projects high-dimensional inputs onto a lower-dimensional space where the classes are maximally separated. Note that PCA is an unsupervised method that projects data onto directions of maximum variance without considering class labels. While PCA focuses purely on preserving global variance structure, LDA explicitly exploits label information to enhance discriminative power; see the comparison in Figure 3.20.

Figure 3.20 : Comparison between PCA and LDA. Figures adpoted from https://sebastianraschka.com/Articles/2014_python_lda.html .
Figure 3.20: Comparison between PCA and LDA. Figures adpoted from https://sebastianraschka.com/Articles/2014_python_lda.html.

The third property is also important because we want the learned features to reveal all possible causes of why one class is different from all other classes. For example, to tell “apple” from “orange”, we care not only about color but also shape and the leaves. Ideally, the dimension of each subspace {𝒮k}\{\mathcal{S}_{k}\}{ caligraphic_S start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT } should be equal to that of the corresponding submanifold k\mathcal{M}_{k}caligraphic_M start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT. This property will be important if we would like the map f(𝒙,θ)f(\bm{x},\theta)italic_f ( bold_italic_x , italic_θ ) to be invertible for tasks such as image generation. For example, if we draw different sample points from the feature subspace for “apple”, we should be able to decode them to generate diverse images of apples. The feature learned from minimizing the cross entropy (3.4.2) clearly does not have this property.

In general, although the intrinsic structures of each class/cluster may be low-dimensional, they are by no means simply linear (or Gaussian) in their original representation 𝒙\bm{x}bold_italic_x and they need to be made linear first, through some nonlinear transformation.272727We will discuss how this can be done explicitly in Chapter 5. Therefore, overall, we use the nonlinear transformation f(𝒙,θ)f(\bm{x},\theta)italic_f ( bold_italic_x , italic_θ ) to seek a representation of the data such that the subspaces that represent all the classes are maximally incoherent linear subspaces. To be more precise, we want to learn a mapping 𝒛=f(𝒙,θ)\bm{z}=f(\bm{x},\theta)bold_italic_z = italic_f ( bold_italic_x , italic_θ ) that maps each of the submanifolds kD\mathcal{M}_{k}\subset\mathbb{R}^{D}caligraphic_M start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ⊂ blackboard_R start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT (Figure 3.21 left) to a linear subspace 𝒮kd\mathcal{S}_{k}\subset\mathbb{R}^{d}caligraphic_S start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ⊂ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT (Figure 3.21 right). To some extent, the resulting multiple subspaces {𝒮k}\{\mathcal{S}_{k}\}{ caligraphic_S start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT } can be viewed as discriminative generalized principal components [VMS16] or, if orthogonal, independent components [HO00] of the resulting features 𝒛\bm{z}bold_italic_z for the original data 𝒙\bm{x}bold_italic_x. As we will see in the next Chapter 4, deep networks precisely play the role of modeling and realizing this nonlinear transformation from the data distribution to linear discriminative representations.

3.4.2 The Principle of Maximal Coding Rate Reduction

Although the three properties—between-class discriminative, within-class compressible, and maximally diverse representation—for linear discriminative representations (LDRs) are all highly desired properties of the learned representation 𝒛\bm{z}bold_italic_z, they are by no means easy to obtain: Are these properties compatible so that we can expect to achieve them all at once? If so, is there a simple but principled objective that can measure the goodness of the resulting representations in terms of all these properties? The key to these questions is to find a principled “measure of compactness” or “information gain” for the distribution of a random variable 𝒛\bm{z}bold_italic_z or from its finite samples {𝒛i}i=1N\{\bm{z}_{i}\}_{i=1}^{N}{ bold_italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_N end_POSTSUPERSCRIPT. Such a measure should directly and accurately characterize intrinsic geometric or statistical properties of the distribution, in terms of its intrinsic dimension or volume. Unlike the cross entropy (3.4.2) or information bottleneck (3.4.5), such a measure should not depend exclusively on class labels so that it can work in more general settings such as supervised, self-supervised, semi-supervised, and unsupervised settings.

Figure 3.21 : The distribution 𝒟 \mathcal{D} caligraphic_D of high-dimensional data 𝒙 ∈ ℝ D \bm{x}\in\mathbb{R}^{D} bold_italic_x ∈ blackboard_R start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT is supported on a manifold ℳ \mathcal{M} caligraphic_M and its classes on low-dimensional submanifolds ℳ k \mathcal{M}_{k} caligraphic_M start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT . We aim to learn a mapping f ​ ( 𝒙 , θ ) f(\bm{x},\theta) italic_f ( bold_italic_x , italic_θ ) parameterized by θ \theta italic_θ such that 𝒛 i = f ​ ( 𝒙 i , θ ) \bm{z}_{i}=f(\bm{x}_{i},\theta) bold_italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = italic_f ( bold_italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_θ ) lie on a union of maximally uncorrelated subspaces { 𝒮 k } \{\mathcal{S}_{k}\} { caligraphic_S start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT } .
Figure 3.21: The distribution 𝒟\mathcal{D}caligraphic_D of high-dimensional data 𝒙D\bm{x}\in\mathbb{R}^{D}bold_italic_x ∈ blackboard_R start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT is supported on a manifold \mathcal{M}caligraphic_M and its classes on low-dimensional submanifolds k\mathcal{M}_{k}caligraphic_M start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT. We aim to learn a mapping f(𝒙,θ)f(\bm{x},\theta)italic_f ( bold_italic_x , italic_θ ) parameterized by θ\thetaitalic_θ such that 𝒛i=f(𝒙i,θ)\bm{z}_{i}=f(\bm{x}_{i},\theta)bold_italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = italic_f ( bold_italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT , italic_θ ) lie on a union of maximally uncorrelated subspaces {𝒮k}\{\mathcal{S}_{k}\}{ caligraphic_S start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT }.

Without loss of generality, assume that the distribution 𝒟\mathcal{D}caligraphic_D of the random vector 𝒙\bm{x}bold_italic_x is supported on a mixture of distributions, i.e., 𝒟=k=1K𝒟k\mathcal{D}=\cup_{k=1}^{K}\mathcal{D}_{k}caligraphic_D = ∪ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT caligraphic_D start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT, where each 𝒟kD\mathcal{D}_{k}\subset\mathbb{R}^{D}caligraphic_D start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ⊂ blackboard_R start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT has a low intrinsic dimension in the high-dimensional ambient space D\mathbb{R}^{D}blackboard_R start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT. Let 𝑿kD×Nk\bm{X}_{k}\in\mathbb{R}^{D\times N_{k}}bold_italic_X start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_D × italic_N start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUPERSCRIPT denote the data matrix whose columns are samples drawn from the distribution 𝒟k\mathcal{D}_{k}caligraphic_D start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT, where NkN_{k}italic_N start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT denotes the number of samples for each k=1,,Kk=1,\dots,Kitalic_k = 1 , … , italic_K. Then, we use 𝑿=[𝑿1,,𝑿K]D×N\bm{X}=[\bm{X}_{1},\dots,\bm{X}_{K}]\in\mathbb{R}^{D\times N}bold_italic_X = [ bold_italic_X start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , bold_italic_X start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ] ∈ blackboard_R start_POSTSUPERSCRIPT italic_D × italic_N end_POSTSUPERSCRIPT to denote all the samples, where N=k=1KNkN=\sum_{k=1}^{K}N_{k}italic_N = ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT. Recall that we also use 𝒙i\bm{x}_{i}bold_italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT to denote the iiitalic_i-th sample of 𝑿\bm{X}bold_italic_X, i.e., 𝑿=[𝒙1,,𝒙N]\bm{X}=[\bm{x}_{1},\dots,\bm{x}_{N}]bold_italic_X = [ bold_italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , bold_italic_x start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT ]. Under an encoding mapping:

𝒙f(𝒙)𝒛,\bm{x}\xrightarrow{\hskip 5.69054ptf(\bm{x})\hskip 5.69054pt}\bm{z},bold_italic_x start_ARROW start_OVERACCENT italic_f ( bold_italic_x ) end_OVERACCENT → end_ARROW bold_italic_z , (3.4.6)

the input samples are mapped to 𝒛i=f(𝒙i)\bm{z}_{i}=f(\bm{x}_{i})bold_italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = italic_f ( bold_italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) for each i=1,,Ni=1,\dots,Nitalic_i = 1 , … , italic_N. With an abuse of notation, we also write 𝒁k=f(𝑿k)\bm{Z}_{k}=f(\bm{X}_{k})bold_italic_Z start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT = italic_f ( bold_italic_X start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) and 𝒁=f(𝑿)\bm{Z}=f(\bm{X})bold_italic_Z = italic_f ( bold_italic_X ). Therefore, we have 𝒁=[𝒁1,,𝒁K]\bm{Z}=[\bm{Z}_{1},\dots,\bm{Z}_{K}]bold_italic_Z = [ bold_italic_Z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , bold_italic_Z start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ] and 𝒁=[𝒛1,𝒛N]\bm{Z}=[\bm{z}_{1},\dots\bm{z}_{N}]bold_italic_Z = [ bold_italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … bold_italic_z start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT ].

On one hand, for learned features to be discriminative, features of different classes/clusters are preferred to be maximally incoherent to each other. Hence, they together should span a space of the largest possible volume (or dimension) and the coding rate of the whole set 𝒁\bm{Z}bold_italic_Z should be as large as possible. On the other hand, learned features of the same class/cluster should be highly correlated and coherent. Hence, each class/cluster should only span a space (or subspace) of a very small volume and the coding rate should be as small as possible. Now, we will introduce how to measure the coding rate of the learned features.

Coding rate of features.

Notably, a practical challenge in evaluating the coding rate is that the underlying distribution of the feature representations 𝒁\bm{Z}bold_italic_Z is typically unknown. To address this, we may approximate the features 𝒁=[𝒛1,,𝒛N]\bm{Z}=[\bm{z}_{1},\ldots,\bm{z}_{N}]bold_italic_Z = [ bold_italic_z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , bold_italic_z start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT ] as samples drawn from a multivariate Gaussian distribution. Under this assumption, as discussed in Section 3.3.3, the compactness of the features 𝒁\bm{Z}bold_italic_Z as a whole can be measured in terms of the average coding length per sample, referred to as the coding rate, subject to a precision level ϵ>0\epsilon>0italic_ϵ > 0 (see (3.3.23)) defined as follows:

Rϵ(𝒁)=12logdet(𝑰+dNϵ2𝒁𝒁).R_{\epsilon}(\bm{Z})=\frac{1}{2}\log\det\left(\bm{I}+\frac{d}{N\epsilon^{2}}\bm{Z}\bm{Z}^{\top}\right).italic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( bold_italic_Z ) = divide start_ARG 1 end_ARG start_ARG 2 end_ARG roman_log roman_det ( bold_italic_I + divide start_ARG italic_d end_ARG start_ARG italic_N italic_ϵ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG bold_italic_Z bold_italic_Z start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) . (3.4.7)

On the other hand, we hope that a nonlinear transformation f(𝒙)f(\bm{x})italic_f ( bold_italic_x ) maps each class-specific submanifold kD\mathcal{M}_{k}\subset\mathbb{R}^{D}caligraphic_M start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ⊂ blackboard_R start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT to a maximally incoherent linear subspace 𝒮kd\mathcal{S}_{k}\subset\mathbb{R}^{d}caligraphic_S start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ⊂ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT such that the learned features 𝒁\bm{Z}bold_italic_Z lie in a union of low-dimensional subspaces. This structure allows for a more accurate evaluation of the coding rate by analyzing each subspace separately. Recall that the columns of 𝒁k\bm{Z}_{k}bold_italic_Z start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT denotes the features of the samples in 𝑿k\bm{X}_{k}bold_italic_X start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT for each k=1,,Kk=1,\dots,Kitalic_k = 1 , … , italic_K. The coding rate for the features in 𝒁k\bm{Z}_{k}bold_italic_Z start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT can be computed as follows:

Rϵ(𝒁k)=Nk2Nlogdet(𝑰+dNkϵ2𝒁k𝒁k)\displaystyle R_{\epsilon}(\bm{Z}_{k})=\frac{N_{k}}{2N}\log\det\left(\bm{I}+\frac{d}{N_{k}\epsilon^{2}}\bm{Z}_{k}\bm{Z}_{k}^{\top}\right)italic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( bold_italic_Z start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) = divide start_ARG italic_N start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_ARG start_ARG 2 italic_N end_ARG roman_log roman_det ( bold_italic_I + divide start_ARG italic_d end_ARG start_ARG italic_N start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT italic_ϵ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG bold_italic_Z start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT bold_italic_Z start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) (3.4.8)

Then, the sum of the average coding rates of features in each class is

Rϵc(𝒁)k=1KRϵ(𝒁k),R_{\epsilon}^{c}(\bm{Z})\doteq\sum_{k=1}^{K}R_{\epsilon}(\bm{Z}_{k}),italic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT ( bold_italic_Z ) ≐ ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT italic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( bold_italic_Z start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) , (3.4.9)

Therefore, a good representation 𝒁\bm{Z}bold_italic_Z of 𝑿\bm{X}bold_italic_X is the one that achieves a large difference between the coding rate for the whole and that for all the classes:

ΔRϵ(𝒁)Rϵ(𝒁)Rϵc(𝒁).\Delta R_{\epsilon}(\bm{Z})\doteq R_{\epsilon}(\bm{Z})-R_{\epsilon}^{c}(\bm{Z}).roman_Δ italic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( bold_italic_Z ) ≐ italic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( bold_italic_Z ) - italic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT ( bold_italic_Z ) . (3.4.10)

Notice that, as per our discussions earlier in this chapter, this difference can be interpreted as the amount of “information gained” by identifying the correct low-dimensional clusters 𝒁k\bm{Z}_{k}bold_italic_Z start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT within the overall set 𝒁\bm{Z}bold_italic_Z.

If we choose our feature mapping f()f(\cdot)italic_f ( ⋅ ) to be a deep neural network f(,θ)f(\cdot,\theta)italic_f ( ⋅ , italic_θ ) with network parameters θ\thetaitalic_θ, the overall process of the feature representation and the resulting rate reduction can be illustrated by the following diagram:

𝑿f(𝒙,θ)𝒁ϵΔRϵ(𝒁).\bm{X}\xrightarrow{\hskip 5.69054ptf(\bm{x},\theta)\hskip 5.69054pt}\bm{Z}\xrightarrow{\hskip 5.69054pt\epsilon\hskip 5.69054pt}\Delta R_{\epsilon}(\bm{Z}).bold_italic_X start_ARROW start_OVERACCENT italic_f ( bold_italic_x , italic_θ ) end_OVERACCENT → end_ARROW bold_italic_Z start_ARROW start_OVERACCENT italic_ϵ end_OVERACCENT → end_ARROW roman_Δ italic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( bold_italic_Z ) . (3.4.11)

Note that ΔRϵ\Delta R_{\epsilon}roman_Δ italic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT is monotonic in the scale of the features 𝒁\bm{Z}bold_italic_Z. To ensure fair comparison across different representations, it is essential to normalize the scale of the learned features. This can be achieved by either imposing the Frobenius norm of each class 𝒁k\bm{Z}_{k}bold_italic_Z start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT to scale with the number of features in 𝒁kd×Nk\bm{Z}_{k}\in\mathbb{R}^{d\times N_{k}}bold_italic_Z start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUPERSCRIPT italic_d × italic_N start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUPERSCRIPT, i.e., 𝒁kF2=Nk\|\bm{Z}_{k}\|_{F}^{2}=N_{k}∥ bold_italic_Z start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT = italic_N start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT, or by normalizing each feature to be on the unit sphere, i.e., 𝒛i𝕊d1\bm{z}_{i}\in\mathbb{S}^{d-1}bold_italic_z start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∈ blackboard_S start_POSTSUPERSCRIPT italic_d - 1 end_POSTSUPERSCRIPT, where Nk=tr(𝚷k)N_{k}=\mathrm{tr}(\bm{\Pi}_{k})italic_N start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT = roman_tr ( bold_Π start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) denotes the number of samples in the kkitalic_k-th class. This formulation offers a natural justification for the need for “batch normalization” in the practice of training deep neural networks [IS15].

Once the representations are comparable, the goal becomes to learn a set of features 𝒁=f(𝑿,θ)\bm{Z}=f(\bm{X},\theta)bold_italic_Z = italic_f ( bold_italic_X , italic_θ ) such that they maximize the reduction between the coding rate of all features and that of the sum of features w.r.t. their classes:

maxθ\displaystyle\max_{\theta}roman_max start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ΔRϵ(𝒁)Rϵ(𝒁)Rϵc(𝒁),\displaystyle\;\Delta R_{\epsilon}\big{(}\bm{Z}\big{)}\doteq R_{\epsilon}(\bm{Z})-R_{\epsilon}^{c}(\bm{Z}),roman_Δ italic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( bold_italic_Z ) ≐ italic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( bold_italic_Z ) - italic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT ( bold_italic_Z ) , (3.4.12)
s.t. 𝒁=f(𝑿,θ),𝒁kF2=Nk,k=1,,K.\displaystyle\ \ \,\bm{Z}=f(\bm{X},\theta),\ \|\bm{Z}_{k}\|_{F}^{2}=N_{k},\ k=1,\dots,K.bold_italic_Z = italic_f ( bold_italic_X , italic_θ ) , ∥ bold_italic_Z start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT = italic_N start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT , italic_k = 1 , … , italic_K .

We refer to this as the principle of maximal coding rate reduction (MCR2), a true embodiment of Aristotle’s famous quote:

The whole is greater than the sum of its parts.

To learn the best representation, we require that the whole is maximally greater than the sum of its parts. Let us examine the example shown in Figure 3.19 again. From a compression perspective, the representation on the right is the most compact one in the sense that the difference between the coding rate when all features are encoded as a single Gaussian (blue) and that when the features are properly clustered and encoded as two separate subspaces (green) is maximal.282828Intuitively, the ratio between the “volume” of the whole space spanned by all features and that actually occupied by the features is maximal.

Note that the above MCR2 principle is designed for supervised learning problems, where the group memberships (or class labels) are known. However, this principle can be naturally extended to unsupervised learning problems by introducing a membership matrix, which encodes the (potentially soft) assignment of each data point to latent groups or clusters. Specifically, let 𝚷={𝚷k}k=1KN×N\bm{\Pi}=\{\bm{\Pi}_{k}\}_{k=1}^{K}\subset\mathbb{R}^{N\times N}bold_Π = { bold_Π start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT } start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT ⊂ blackboard_R start_POSTSUPERSCRIPT italic_N × italic_N end_POSTSUPERSCRIPT be a set of diagonal matrices whose diagonal entries encode the membership of the NNitalic_N samples into KKitalic_K classes. That is, 𝚷\bm{\Pi}bold_Π lies in a simplex Ω{𝚷:𝚷k𝟎:k=1K𝚷k=𝑰N}\Omega\doteq\{\bm{\Pi}:\bm{\Pi}_{k}\geq\bm{0}:\sum_{k=1}^{K}\bm{\Pi}_{k}=\bm{I}_{N}\}roman_Ω ≐ { bold_Π : bold_Π start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ≥ bold_0 : ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT bold_Π start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT = bold_italic_I start_POSTSUBSCRIPT italic_N end_POSTSUBSCRIPT }. Then, we can define the average coding rate with respect to the partition 𝚷\bm{\Pi}bold_Π as

Rϵc(𝒁𝚷)k=1Ktr(𝚷k)2Nlogdet(𝑰+dtr(𝚷k)ϵ2𝒁𝚷k𝒁).\displaystyle R_{\epsilon}^{c}(\bm{Z}\mid\bm{\Pi})\doteq\sum_{k=1}^{K}\frac{\mathrm{tr}(\bm{\Pi}_{k})}{2N}\log\det\left(\bm{I}+\frac{d}{\mathrm{tr}(\bm{\Pi}_{k})\epsilon^{2}}\bm{Z}\bm{\Pi}_{k}\bm{Z}^{\top}\right).italic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT ( bold_italic_Z ∣ bold_Π ) ≐ ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT divide start_ARG roman_tr ( bold_Π start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) end_ARG start_ARG 2 italic_N end_ARG roman_log roman_det ( bold_italic_I + divide start_ARG italic_d end_ARG start_ARG roman_tr ( bold_Π start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) italic_ϵ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG bold_italic_Z bold_Π start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT bold_italic_Z start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) . (3.4.13)

When 𝒁\bm{Z}bold_italic_Z is given, Rϵc(𝒁|𝚷)R_{\epsilon}^{c}(\bm{Z}|\bm{\Pi})italic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT ( bold_italic_Z | bold_Π ) is a concave function of 𝚷\bm{\Pi}bold_Π. Then the MCR2 principle for unsupervised learning problems becomes as follows:

max𝚷,θ\displaystyle\max_{\bm{\Pi},\theta}roman_max start_POSTSUBSCRIPT bold_Π , italic_θ end_POSTSUBSCRIPT ΔRϵ(𝒁𝚷)Rϵ(𝒁)Rϵc(𝒁𝚷)\displaystyle\ \Delta R_{\epsilon}\big{(}\bm{Z}\mid\bm{\Pi})\doteq R_{\epsilon}(\bm{Z})-R_{\epsilon}^{c}(\bm{Z}\mid\bm{\Pi})roman_Δ italic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( bold_italic_Z ∣ bold_Π ) ≐ italic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( bold_italic_Z ) - italic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT ( bold_italic_Z ∣ bold_Π )
s.t.\displaystyle\mathrm{s.t.}roman_s . roman_t . 𝒁=f(𝑿,θ),𝒁𝚷kF2=Nk,k=1,,K,𝚷Ω.\displaystyle\ \ \ \bm{Z}=f(\bm{X},\theta),\ \|\bm{Z}\bm{\Pi}_{k}\|_{F}^{2}=N_{k},\ k=1,\dots,K,\ \bm{\Pi}\in\Omega.bold_italic_Z = italic_f ( bold_italic_X , italic_θ ) , ∥ bold_italic_Z bold_Π start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT = italic_N start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT , italic_k = 1 , … , italic_K , bold_Π ∈ roman_Ω . (3.4.14)

Compared to (3.4.12), the formulation here allows for the joint optimization of both the group memberships and the network parameters. In particular, when 𝚷\bm{\Pi}bold_Π is fixed to a group membership matrix that assigns NNitalic_N data points into KKitalic_K groups, Problem (3.4.2) can recover Problem (3.4.12).

Figure 3.22 : Local optimization landscape: According to Theorem • ‣ 3.7 , the global maximum of the rate reduction objective corresponds to a solution with mutually incoherent subspaces.
Figure 3.22: Local optimization landscape: According to Theorem 3.7, the global maximum of the rate reduction objective corresponds to a solution with mutually incoherent subspaces.

3.4.3 Optimization Properties of Coding Rate Reduction

In this subsection, we study the optimization properties of the MCR2 function by analyzing its optimal solutions and the structure of its optimization landscape. To get around the technical difficulty introduced by the neural networks, we consider a simplified version of Problem (3.4.12) as follows:

max𝒁Rϵ(𝒁)Rϵc(𝒁)s.t.𝒁kF2=Nk,k=1,,K.\displaystyle\max_{\bm{Z}}\ R_{\epsilon}(\bm{Z})-R_{\epsilon}^{c}(\bm{Z})\qquad\mathrm{s.t.}\quad\|\bm{Z}_{k}\|_{F}^{2}=N_{k},\ k=1,\dots,K.roman_max start_POSTSUBSCRIPT bold_italic_Z end_POSTSUBSCRIPT italic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( bold_italic_Z ) - italic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT ( bold_italic_Z ) roman_s . roman_t . ∥ bold_italic_Z start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT = italic_N start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT , italic_k = 1 , … , italic_K . (3.4.15)

In theory, the MCR2 principle (3.4.15) benefits from great generalizability and can be applied to representations 𝒁\bm{Z}bold_italic_Z of any distributions as long as the rates RϵR_{\epsilon}italic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT and RϵcR^{c}_{\epsilon}italic_R start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT for the distributions can be accurately and efficiently evaluated. The optimal representation 𝒁\bm{Z}^{\ast}bold_italic_Z start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT should have some interesting geometric and statistical properties. We here reveal nice properties of the optimal representation with the special case of subspaces, which have many important use cases in machine learning. When the desired representation for 𝒁\bm{Z}bold_italic_Z is multiple subspaces, the rates RϵR_{\epsilon}italic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT and RϵcR^{c}_{\epsilon}italic_R start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT in (3.4.15) are given by (3.4.7) and (3.4.9), respectively. At the maximal rate reduction, MCR2 achieves its optimal representations, denoted as 𝒁=[𝒁1,,𝒁K]\bm{Z}^{\ast}=[\bm{Z}_{1}^{*},\dots,\bm{Z}_{K}^{*}]bold_italic_Z start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT = [ bold_italic_Z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT , … , bold_italic_Z start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ] with rank(𝒁k)dk\operatorname{rank}{(\bm{Z}_{k}^{*})}\leq d_{k}roman_rank ( bold_italic_Z start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ) ≤ italic_d start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT. One can show that 𝒁\bm{Z}^{\ast}bold_italic_Z start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT has the following desired properties (see [YCY+20] for a formal statement and detailed proofs).

Theorem 3.7 (Characterization of Global Optimal Solutions).

Suppose 𝐙=[𝐙1,,𝐙K]\bm{Z}^{\ast}=[\bm{Z}_{1}^{*},\dots,\bm{Z}_{K}^{*}]bold_italic_Z start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT = [ bold_italic_Z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT , … , bold_italic_Z start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ] is a global optimal solution of Problem (3.4.15). The following statements hold:

  • Between-Class Discriminative: As long as the ambient space is adequately large (dk=1Kdkd\geq\sum_{k=1}^{K}d_{k}italic_d ≥ ∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT italic_d start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT), the subspaces are all orthogonal to each other, i.e., (𝒁k)𝒁l=𝟎(\bm{Z}_{k}^{*})^{\top}\bm{Z}_{l}^{*}=\bm{0}( bold_italic_Z start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_Z start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT = bold_0 for klk\not=litalic_k ≠ italic_l.

  • Maximally Diverse Representation: As long as the coding precision is adequately high, i.e., ϵ4<cmink{NkNd2dk2}\epsilon^{4}<c\cdot\min_{k}\left\{\frac{N_{k}}{N}\frac{d^{2}}{d_{k}^{2}}\right\}italic_ϵ start_POSTSUPERSCRIPT 4 end_POSTSUPERSCRIPT < italic_c ⋅ roman_min start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT { divide start_ARG italic_N start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_ARG start_ARG italic_N end_ARG divide start_ARG italic_d start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG start_ARG italic_d start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG }, where c>0c>0italic_c > 0 is a constant. Each subspace achieves its maximal dimension, i.e. rank(𝒁k)=dk\mathrm{rank}{(\bm{Z}_{k}^{*})}=d_{k}roman_rank ( bold_italic_Z start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ) = italic_d start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT. In addition, the largest dk1d_{k}-1italic_d start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT - 1 singular values of 𝒁k\bm{Z}_{k}^{*}bold_italic_Z start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT are equal.

This theorem indicates that the MCR2 principle promotes embedding of data into multiple independent subspaces (as illustrated in Figure 3.22), with features distributed isotropically in each subspace (except for possibly one dimension). Notably, this theorem also confirms that the features learned by the MCR2 principle exhibit the desired low-dimensional discriminative properties discussed in Section 3.4.1. In addition, among all such discriminative representations, it prefers the one with the highest dimensions in the ambient space. This is substantially different from the objective of information bottleneck (3.4.5).

Example 3.13 (Classification of Images on CIFAR-10).

We here present how the MCR2 objective helps learn better representations than the cross entropy (3.4.2) for image classification. Here we adopt the popular neural network architecture, the ResNet-18 [HZR+16a], to model the feature mapping 𝒛=f(𝒙,θ)\bm{z}=f(\bm{x},\theta)bold_italic_z = italic_f ( bold_italic_x , italic_θ ). We optimize the neural network parameters θ\thetaitalic_θ to maximize the coding rate reduction. We evaluate the performance with the CIFAR10 image classification dataset [KH+09].

(a) Evolution of R ϵ R_{\epsilon} italic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT , R ϵ c R^{c}_{\epsilon} italic_R start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT , Δ ​ R ϵ \Delta R_{\epsilon} roman_Δ italic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT during the training process.
(a) Evolution of RϵR_{\epsilon}italic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT, RϵcR^{c}_{\epsilon}italic_R start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT, ΔRϵ\Delta R_{\epsilon}roman_Δ italic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT during the training process.
(a) Evolution of R ϵ R_{\epsilon} italic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT , R ϵ c R^{c}_{\epsilon} italic_R start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT , Δ ​ R ϵ \Delta R_{\epsilon} roman_Δ italic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT during the training process.
(b) PCA: (red) overall data; (blue) individual classes.
Figure 3.23: Evolution of the rates of MCR2 in the training process, the principal components of learned features.
Figure 3.24 : Cosine similarity between learned features by using the MCR 2 objective ( left ) and CE loss ( right ).
Figure 3.24 : Cosine similarity between learned features by using the MCR 2 objective ( left ) and CE loss ( right ).
Figure 3.24: Cosine similarity between learned features by using the MCR2 objective (left) and CE loss (right).

Figure 3.23(a) illustrates how the two rates and their difference (for both training and test data) evolve over epochs of training: After an initial phase, RϵR_{\epsilon}italic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT gradually increases while RϵcR^{c}_{\epsilon}italic_R start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT decreases, indicating that features 𝒁\bm{Z}bold_italic_Z are expanding as a whole while each class 𝒁k\bm{Z}_{k}bold_italic_Z start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT is being compressed. Figure 3.23(b) shows the distribution of singular values per 𝒁k\bm{Z}_{k}bold_italic_Z start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT. Figure 3.24 shows the cosine similarities between the learned features sorted by class. We compare the similarities of the learned features by using the cross-entropy (3.4.2) and the MCR2 objective (3.4.12). From the plots, one can clearly see that the representations learned by using MCR2 loss are much more diverse than the ones learned by using cross-entropy loss. More details of this experiment can be found in [CYY+22]. \blacksquare

However, there has been an apparent lack of justification of the network architectures used in the above experiments. It is yet unclear why the network adopted here (the ResNet-18) is suitable for representing the map f(𝒙,θ)f(\bm{x},\theta)italic_f ( bold_italic_x , italic_θ ), let alone for interpreting the layer operators and parameters θ\thetaitalic_θ learned inside. In the next chapter, we will show how to derive network architectures and components entirely as a “white box” from the desired objective (say the rate reduction).

Regularized MCR2.

The above theorem characterizes properties of the global optima of the rate reduction objectives. What about other optima, such as local ones? Due to the constraints of the Frobenius norm, it is a difficult task to analyze Problem (3.4.15) from an optimization-theoretic perspective. Therefore, we consider the Lagrangian formulation of (3.4.15). This can be viewed as a tight relaxation or even an equivalent problem of (3.4.15) whose optimal solutions agree under specific settings of the regularization parameter; see [WLP+24, Proposition 1]. Specifically, the formulation we study, referred to henceforth as the regularized MCR2 problem, is as follows:

max𝒁Rϵ(𝒁)Rϵc(𝒁)λ2𝒁F2,\displaystyle\max_{\bm{Z}}\ R_{\epsilon}(\bm{Z})-R_{\epsilon}^{c}(\bm{Z})-\frac{\lambda}{2}\|\bm{Z}\|_{F}^{2},roman_max start_POSTSUBSCRIPT bold_italic_Z end_POSTSUBSCRIPT italic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT ( bold_italic_Z ) - italic_R start_POSTSUBSCRIPT italic_ϵ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_c end_POSTSUPERSCRIPT ( bold_italic_Z ) - divide start_ARG italic_λ end_ARG start_ARG 2 end_ARG ∥ bold_italic_Z ∥ start_POSTSUBSCRIPT italic_F end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT , (3.4.16)

where λ>0\lambda>0italic_λ > 0 is the regularization parameter. Although the program (3.4.16) is highly nonconcave and involves matrix inverses in its gradient computation, we can still explicitly characterize its local and global optima as follows.

Theorem 3.8 (Local and Global Optima).

Let NkN_{k}italic_N start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT denote the number of training samples in the kkitalic_k-th class for each k{1,,K}k\in\{1,\dots,K\}italic_k ∈ { 1 , … , italic_K }, Nmaxmax{N1,,NK}N_{\max}\doteq\max\{N_{1},\dots,N_{K}\}italic_N start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT ≐ roman_max { italic_N start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_N start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT }, α=d/(Nϵ2)\alpha=d/(N\epsilon^{2})italic_α = italic_d / ( italic_N italic_ϵ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ), and αk=d/(Nkϵ2)\alpha_{k}=d/(N_{k}\epsilon^{2})italic_α start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT = italic_d / ( italic_N start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT italic_ϵ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) for each k{1,,K}k\in\{1,\dots,K\}italic_k ∈ { 1 , … , italic_K }. Given a coding precision ϵ>0\epsilon>0italic_ϵ > 0, if the regularization parameter satisfies

λ(0,d(N/Nmax1)N(N/Nmax+1)ϵ2],\displaystyle\lambda\in\left(0,\frac{d(\sqrt{N/N_{\max}}-1)}{N(\sqrt{N/N_{\max}}+1)\epsilon^{2}}\right],italic_λ ∈ ( 0 , divide start_ARG italic_d ( square-root start_ARG italic_N / italic_N start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT end_ARG - 1 ) end_ARG start_ARG italic_N ( square-root start_ARG italic_N / italic_N start_POSTSUBSCRIPT roman_max end_POSTSUBSCRIPT end_ARG + 1 ) italic_ϵ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG ] , (3.4.17)

then the following statements hold:
(i) (
Local maximizers) 𝐙=[𝐙1,,𝐙K]\bm{Z}^{*}=\left[\bm{Z}_{1}^{*},\dots,\bm{Z}_{K}^{*}\right]bold_italic_Z start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT = [ bold_italic_Z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT , … , bold_italic_Z start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ] is a local maximizer of Problem (3.4.16) if and only if the kkitalic_k-th block admits the following decomposition

𝒁k=(ηk+ηk24λ2N/Nk2λαk)1/2𝑼k𝑽k,\displaystyle\bm{Z}_{k}^{*}=\left(\frac{\eta_{k}+\sqrt{\eta_{k}^{2}-4\lambda^{2}N/N_{k}}}{2\lambda\alpha_{k}}\right)^{1/2}\bm{U}_{k}\bm{V}_{k}^{\top},bold_italic_Z start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT = ( divide start_ARG italic_η start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT + square-root start_ARG italic_η start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT - 4 italic_λ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_N / italic_N start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_ARG end_ARG start_ARG 2 italic_λ italic_α start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_ARG ) start_POSTSUPERSCRIPT 1 / 2 end_POSTSUPERSCRIPT bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT bold_italic_V start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT , (3.4.18)

where (a) rk=rank(𝐙k)r_{k}=\mathrm{rank}(\bm{Z}_{k}^{*})italic_r start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT = roman_rank ( bold_italic_Z start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ) satisfies rk[0,min{Nk,d})r_{k}\in[0,\min\{N_{k},d\})italic_r start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ∈ [ 0 , roman_min { italic_N start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT , italic_d } ) and k=1Krkmin{N,d}\sum_{k=1}^{K}r_{k}\leq\min\{N,d\}∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT italic_r start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ≤ roman_min { italic_N , italic_d }, (b) 𝐔k𝒪d×rk\bm{U}_{k}\in\mathcal{O}^{d\times r_{k}}bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ∈ caligraphic_O start_POSTSUPERSCRIPT italic_d × italic_r start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUPERSCRIPT satisfies 𝐔k𝐔l=𝟎\bm{U}_{k}^{\top}\bm{U}_{l}=\bm{0}bold_italic_U start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_U start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT = bold_0 for all klk\neq litalic_k ≠ italic_l, 𝐕k𝒪Nk×rk\bm{V}_{k}\in\mathcal{O}^{N_{k}\times r_{k}}bold_italic_V start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ∈ caligraphic_O start_POSTSUPERSCRIPT italic_N start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT × italic_r start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUPERSCRIPT, and (c) ηk=(αkα)λ(N/Nk+1)\eta_{k}=(\alpha_{k}-\alpha)-\lambda(N/N_{k}+1)italic_η start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT = ( italic_α start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT - italic_α ) - italic_λ ( italic_N / italic_N start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT + 1 ) for each k{1,,K}k\in\{1,\dots,K\}italic_k ∈ { 1 , … , italic_K }.
(ii) (
Global maximizers) 𝐙=[𝐙1,,𝐙K]\bm{Z}^{*}=\left[\bm{Z}_{1}^{*},\dots,\bm{Z}_{K}^{*}\right]bold_italic_Z start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT = [ bold_italic_Z start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT , … , bold_italic_Z start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ] is a global maximizer of Problem (3.4.16) if and only if (a) it satisfies the above all conditions and k=1Krk=min{m,d}\sum_{k=1}^{K}r_{k}=\min\{m,d\}∑ start_POSTSUBSCRIPT italic_k = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_K end_POSTSUPERSCRIPT italic_r start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT = roman_min { italic_m , italic_d }, and (b) for all kl[K]k\neq l\in[K]italic_k ≠ italic_l ∈ [ italic_K ] satisfying Nk<NlN_{k}<N_{l}italic_N start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT < italic_N start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT and rl>0r_{l}>0italic_r start_POSTSUBSCRIPT italic_l end_POSTSUBSCRIPT > 0, we have rk=min{Nk,d}r_{k}=\min\{N_{k},d\}italic_r start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT = roman_min { italic_N start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT , italic_d }.

Figure 3.25 : Global optimization landscape: According to [ SQW15 , LSJ+16 ] , Theorems 3.8 and 3.9 , both global and local maxima of the (regularized) rate reduction objective correspond to a solution with mutually incoherent subspaces. All other critical points are strict saddle points.
Figure 3.25: Global optimization landscape: According to [SQW15, LSJ+16], Theorems 3.8 and 3.9, both global and local maxima of the (regularized) rate reduction objective correspond to a solution with mutually incoherent subspaces. All other critical points are strict saddle points.

This theorem explicitly characterizes the local and global optima of problem (3.4.16). Intuitively, this shows that the features represented by each local maximizer of Problem (3.4.16) are low-dimensional and discriminative. Although we have characterized the local and global optimal solutions in Theorem 3.8, it remains unknown whether these solutions can be efficiently computed using GD to solve the problem (3.4.16), since GD may get stuck at other critical points such as a saddle point. Fortunately, [SQW15, LSJ+16] showed that if a function is twice continuously differentiable and satisfies the strict saddle property, i.e., each critical point is either a local minimizer or a strict saddle point292929We say that a critical point is a strict saddle point of Problem (3.4.16) if it has a direction with strictly positive curvature [SQW15]. This includes classical saddle points with strictly positive curvature as well as local minimizers., GD converges to its local minimizer almost surely with random initialization. We investigate the global optimization landscape of the problem (3.4.16) by characterizing all its critical points as follows.

Theorem 3.9 (Benign Global Optimization Landscape).

Given a coding precision ϵ>0\epsilon>0italic_ϵ > 0, if the regularization parameter satisfies (3.4.17), it holds that any critical point 𝐙\bm{Z}bold_italic_Z of the problem (3.4.16) is either a local maximizer or a strict saddle point.

Together, the above two theorems show that the learned features associated with each local maximizer of the rate reduction objective—not just global maximizers—are structured as incoherent low-dimensional subspaces. Furthermore, the (regularized) rate reduction objective (3.4.12) has a very benign landscape with only local maxima and strict saddles as critical points, as illustrated in Figure 3.25. According to [SQW15, LSJ+16], Theorems 3.8 and 3.9 imply that low-dimensional and discriminative representations (LDRs) can be efficiently found by applying (stochastic) gradient descent to the rate reduction objective (3.4.12) from random initialization. These results also indirectly explain why in Figure 3.24, if the chosen network is expressive enough and trained well, the resulting representation typically gives an incoherent linear representation that likely corresponds to the globally optimal solution. Interested readers are referred to [WLP+24] for proofs.

3.5 Summary and Notes

The use of denoising and diffusion for sampling has a rich history. The first work which is clearly about a diffusion model is probably [SWM+15], but before this there are many works about denoising as a computational and statistical problem. The most relevant of these is probably [Hyv05], which explicitly uses the score function to denoise (as well as perform independent component analysis). The most popular follow-ups are basically co-occurring: [HJA20, SE19]. Since then, thousands of papers have built on diffusion models; we will revisit this topic in Chapter 5.

Many of these works use a different stochastic process than the simple linear combination (3.2.69). In fact, all works listed above emphasize the need to add independent Gaussian noise at the beginning of each step of the forward process. Theoretically-minded work actually uses Brownian motion or stochastic differential equations to formulate the forward process [SSK+21]. However, since linear combinations of Gaussians still result in Gaussians, the marginal distributions of such processes still take the form of (3.2.69). Most of our discussion requires only that the marginal distributions are what they are, and hence our overly simplistic model is actually quite enough for almost everything. In fact, the only time where marginal distributions are not enough is when we derive an expression for 𝔼[𝒙s𝒙t]\operatorname{\mathbb{E}}[\bm{x}_{s}\mid\bm{x}_{t}]blackboard_E [ bold_italic_x start_POSTSUBSCRIPT italic_s end_POSTSUBSCRIPT ∣ bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ] in terms of 𝔼[𝒙𝒙t]\operatorname{\mathbb{E}}[\bm{x}\mid\bm{x}_{t}]blackboard_E [ bold_italic_x ∣ bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT ]. Different (noising) processes give different such expressions, which can be used for sampling (and of course there are other ways to derive efficient samplers, such as the ever-popular DDPM sampler). The process in (3.2.69) is a bona fide stochastic process, however, whose “natural” denoising iteration takes the form of the popular DDIM algorithm [SME20]. (Even this equivalence is not trivial; we cite [DGG+25] as a justification.)

On top of the theoretical work [LY24] covered in Section 1.3.1, and the lineage of work that it builds on, which studies the sampling efficiency of diffusion models when the data has low-dimensional structure, there is a large body of work which studies the training efficiency of diffusion models when the data has low-dimensional structure. Specifically, [CHZ+23] and [WZZ+24] characterized the approximation and estimation error of denoisers when the data belongs to a mixture of low-rank Gaussians, showing that the number of training samples required to accurately learn the distribution scales with the intrinsic dimension of the data rather than the ambient distribution. There is considerable methodological work which attempts to utilize the low-dimensional structure of the data in order to do various things with diffusion models. We highlight three here: image editing [CZG+24], watermarking [LZQ24], and unlearning [CZL+25], though as always this is an inexhaustive list.

3.6 Exercises and Extensions

Exercise 3.1.

Please show that (3.2.4) is the optimal solution of Problem (3.2.3).

Exercise 3.2.

Consider random vectors 𝒙D\bm{x}\in\mathbb{R}^{D}bold_italic_x ∈ blackboard_R start_POSTSUPERSCRIPT italic_D end_POSTSUPERSCRIPT and 𝒚d\bm{y}\in\mathbb{R}^{d}bold_italic_y ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT, such that the pair (𝒙,𝒚)D+d(\bm{x},\bm{y})\in\mathbb{R}^{D+d}( bold_italic_x , bold_italic_y ) ∈ blackboard_R start_POSTSUPERSCRIPT italic_D + italic_d end_POSTSUPERSCRIPT is jointly Gaussian. This means that

[𝒙𝒚]𝒩([𝝁𝒙𝝁𝒚],[𝚺𝒙𝚺𝒙𝒚𝚺𝒙𝒚𝚺𝒚]),\begin{bmatrix}\bm{x}\\ \bm{y}\end{bmatrix}\sim\mathcal{N}\left(\begin{bmatrix}\bm{\mu}_{\bm{x}}\\ \bm{\mu}_{\bm{y}}\end{bmatrix},\begin{bmatrix}\bm{\Sigma}_{\bm{x}}&\bm{\Sigma}_{\bm{x}\bm{y}}\\ \bm{\Sigma}_{\bm{x}\bm{y}}^{\top}&\bm{\Sigma}_{\bm{y}}\end{bmatrix}\right),[ start_ARG start_ROW start_CELL bold_italic_x end_CELL end_ROW start_ROW start_CELL bold_italic_y end_CELL end_ROW end_ARG ] ∼ caligraphic_N ( [ start_ARG start_ROW start_CELL bold_italic_μ start_POSTSUBSCRIPT bold_italic_x end_POSTSUBSCRIPT end_CELL end_ROW start_ROW start_CELL bold_italic_μ start_POSTSUBSCRIPT bold_italic_y end_POSTSUBSCRIPT end_CELL end_ROW end_ARG ] , [ start_ARG start_ROW start_CELL bold_Σ start_POSTSUBSCRIPT bold_italic_x end_POSTSUBSCRIPT end_CELL start_CELL bold_Σ start_POSTSUBSCRIPT bold_italic_x bold_italic_y end_POSTSUBSCRIPT end_CELL end_ROW start_ROW start_CELL bold_Σ start_POSTSUBSCRIPT bold_italic_x bold_italic_y end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT end_CELL start_CELL bold_Σ start_POSTSUBSCRIPT bold_italic_y end_POSTSUBSCRIPT end_CELL end_ROW end_ARG ] ) ,

where the mean and covariance parameters are given by

𝝁𝒙=𝔼[𝒙],𝝁𝒚=𝔼[𝒚],[𝚺𝒙𝚺𝒙𝒚𝚺𝒙𝒚𝚺𝒚]=𝔼[[𝒙𝔼[𝒙]𝒚𝔼[𝒚]][𝒙𝔼[𝒙]𝒚𝔼[𝒚]]]\bm{\mu}_{\bm{x}}=\mathbb{E}[\bm{x}],\quad\bm{\mu}_{\bm{y}}=\mathbb{E}[\bm{y}],\quad\begin{bmatrix}\bm{\Sigma}_{\bm{x}}&\bm{\Sigma}_{\bm{x}\bm{y}}\\ \bm{\Sigma}_{\bm{x}\bm{y}}^{\top}&\bm{\Sigma}_{\bm{y}}\end{bmatrix}=\mathbb{E}\left[\begin{bmatrix}\bm{x}-\mathbb{E}[\bm{x}]\\ \bm{y}-\mathbb{E}[\bm{y}]\end{bmatrix}\begin{bmatrix}\bm{x}-\mathbb{E}[\bm{x}]\\ \bm{y}-\mathbb{E}[\bm{y}]\end{bmatrix}^{\top}\right]bold_italic_μ start_POSTSUBSCRIPT bold_italic_x end_POSTSUBSCRIPT = blackboard_E [ bold_italic_x ] , bold_italic_μ start_POSTSUBSCRIPT bold_italic_y end_POSTSUBSCRIPT = blackboard_E [ bold_italic_y ] , [ start_ARG start_ROW start_CELL bold_Σ start_POSTSUBSCRIPT bold_italic_x end_POSTSUBSCRIPT end_CELL start_CELL bold_Σ start_POSTSUBSCRIPT bold_italic_x bold_italic_y end_POSTSUBSCRIPT end_CELL end_ROW start_ROW start_CELL bold_Σ start_POSTSUBSCRIPT bold_italic_x bold_italic_y end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT end_CELL start_CELL bold_Σ start_POSTSUBSCRIPT bold_italic_y end_POSTSUBSCRIPT end_CELL end_ROW end_ARG ] = blackboard_E [ [ start_ARG start_ROW start_CELL bold_italic_x - blackboard_E [ bold_italic_x ] end_CELL end_ROW start_ROW start_CELL bold_italic_y - blackboard_E [ bold_italic_y ] end_CELL end_ROW end_ARG ] [ start_ARG start_ROW start_CELL bold_italic_x - blackboard_E [ bold_italic_x ] end_CELL end_ROW start_ROW start_CELL bold_italic_y - blackboard_E [ bold_italic_y ] end_CELL end_ROW end_ARG ] start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ]

Assume that 𝚺𝒚\bm{\Sigma}_{\bm{y}}bold_Σ start_POSTSUBSCRIPT bold_italic_y end_POSTSUBSCRIPT is positive definite (hence invertible); then positive semidefiniteness of the covariance matrix is equivalent to the Schur complement condition 𝚺𝒙𝚺𝒙𝒚𝚺𝒚1𝚺𝒙𝒚𝟎\bm{\Sigma}_{\bm{x}}-\bm{\Sigma}_{\bm{x}\bm{y}}\bm{\Sigma}_{\bm{y}}^{-1}\bm{\Sigma}_{\bm{x}\bm{y}}^{\top}\succeq\mathbf{0}bold_Σ start_POSTSUBSCRIPT bold_italic_x end_POSTSUBSCRIPT - bold_Σ start_POSTSUBSCRIPT bold_italic_x bold_italic_y end_POSTSUBSCRIPT bold_Σ start_POSTSUBSCRIPT bold_italic_y end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT bold_Σ start_POSTSUBSCRIPT bold_italic_x bold_italic_y end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ⪰ bold_0.

In this exercise, we will prove that the conditional distribution p𝒙𝒚p_{\bm{x}\mid\bm{y}}italic_p start_POSTSUBSCRIPT bold_italic_x ∣ bold_italic_y end_POSTSUBSCRIPT is Gaussian: namely,

p𝒙𝒚𝒩(𝝁𝒙+𝚺𝒙𝒚𝚺𝒚1(𝒚𝝁𝒚),𝚺𝒙𝚺𝒙𝒚𝚺𝒚1𝚺𝒙𝒚).p_{\bm{x}\mid\bm{y}}\sim\mathcal{N}\left(\bm{\mu}_{\bm{x}}+\bm{\Sigma}_{\bm{x}\bm{y}}\bm{\Sigma}_{\bm{y}}^{-1}(\bm{y}-\bm{\mu}_{\bm{y}}),\bm{\Sigma}_{\bm{x}}-\bm{\Sigma}_{\bm{x}\bm{y}}\bm{\Sigma}_{\bm{y}}^{-1}\bm{\Sigma}_{\bm{x}\bm{y}}^{\top}\right).italic_p start_POSTSUBSCRIPT bold_italic_x ∣ bold_italic_y end_POSTSUBSCRIPT ∼ caligraphic_N ( bold_italic_μ start_POSTSUBSCRIPT bold_italic_x end_POSTSUBSCRIPT + bold_Σ start_POSTSUBSCRIPT bold_italic_x bold_italic_y end_POSTSUBSCRIPT bold_Σ start_POSTSUBSCRIPT bold_italic_y end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( bold_italic_y - bold_italic_μ start_POSTSUBSCRIPT bold_italic_y end_POSTSUBSCRIPT ) , bold_Σ start_POSTSUBSCRIPT bold_italic_x end_POSTSUBSCRIPT - bold_Σ start_POSTSUBSCRIPT bold_italic_x bold_italic_y end_POSTSUBSCRIPT bold_Σ start_POSTSUBSCRIPT bold_italic_y end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT bold_Σ start_POSTSUBSCRIPT bold_italic_x bold_italic_y end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) . (3.6.1)

A direct path to prove this result manipulates the defining ratio of densities p𝒙,𝒚/p𝒚p_{\bm{x},\bm{y}}/p_{\bm{y}}italic_p start_POSTSUBSCRIPT bold_italic_x , bold_italic_y end_POSTSUBSCRIPT / italic_p start_POSTSUBSCRIPT bold_italic_y end_POSTSUBSCRIPT. We sketch an algebraically-concise argument of this form below.

  1. 1.

    Verify the following matrix identity for the covariance:

    [𝚺𝒙𝚺𝒙𝒚𝚺𝒙𝒚𝚺𝒚]=[𝑰D𝚺𝒙𝒚𝚺𝒚1𝟎𝑰d][𝚺𝒙𝚺𝒙𝒚𝚺𝒚1𝚺𝒙𝒚𝟎𝟎𝚺𝒚][𝑰D𝟎𝚺𝒚1𝚺𝒙𝒚𝑰d].\begin{bmatrix}\bm{\Sigma}_{\bm{x}}&\bm{\Sigma}_{\bm{x}\bm{y}}\\ \bm{\Sigma}_{\bm{x}\bm{y}}^{\top}&\bm{\Sigma}_{\bm{y}}\end{bmatrix}=\begin{bmatrix}\bm{I}_{D}&\bm{\Sigma}_{\bm{x}\bm{y}}\bm{\Sigma}_{\bm{y}}^{-1}\\ \mathbf{0}&\bm{I}_{d}\end{bmatrix}\begin{bmatrix}\bm{\Sigma}_{\bm{x}}-\bm{\Sigma}_{\bm{x}\bm{y}}\bm{\Sigma}_{\bm{y}}^{-1}\bm{\Sigma}_{\bm{x}\bm{y}}^{\top}&\mathbf{0}\\ \mathbf{0}&\bm{\Sigma}_{\bm{y}}\end{bmatrix}\begin{bmatrix}\bm{I}_{D}&\mathbf{0}\\ \bm{\Sigma}_{\bm{y}}^{-1}\bm{\Sigma}_{\bm{x}\bm{y}}^{\top}&\bm{I}_{d}\end{bmatrix}.[ start_ARG start_ROW start_CELL bold_Σ start_POSTSUBSCRIPT bold_italic_x end_POSTSUBSCRIPT end_CELL start_CELL bold_Σ start_POSTSUBSCRIPT bold_italic_x bold_italic_y end_POSTSUBSCRIPT end_CELL end_ROW start_ROW start_CELL bold_Σ start_POSTSUBSCRIPT bold_italic_x bold_italic_y end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT end_CELL start_CELL bold_Σ start_POSTSUBSCRIPT bold_italic_y end_POSTSUBSCRIPT end_CELL end_ROW end_ARG ] = [ start_ARG start_ROW start_CELL bold_italic_I start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT end_CELL start_CELL bold_Σ start_POSTSUBSCRIPT bold_italic_x bold_italic_y end_POSTSUBSCRIPT bold_Σ start_POSTSUBSCRIPT bold_italic_y end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT end_CELL end_ROW start_ROW start_CELL bold_0 end_CELL start_CELL bold_italic_I start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT end_CELL end_ROW end_ARG ] [ start_ARG start_ROW start_CELL bold_Σ start_POSTSUBSCRIPT bold_italic_x end_POSTSUBSCRIPT - bold_Σ start_POSTSUBSCRIPT bold_italic_x bold_italic_y end_POSTSUBSCRIPT bold_Σ start_POSTSUBSCRIPT bold_italic_y end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT bold_Σ start_POSTSUBSCRIPT bold_italic_x bold_italic_y end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT end_CELL start_CELL bold_0 end_CELL end_ROW start_ROW start_CELL bold_0 end_CELL start_CELL bold_Σ start_POSTSUBSCRIPT bold_italic_y end_POSTSUBSCRIPT end_CELL end_ROW end_ARG ] [ start_ARG start_ROW start_CELL bold_italic_I start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT end_CELL start_CELL bold_0 end_CELL end_ROW start_ROW start_CELL bold_Σ start_POSTSUBSCRIPT bold_italic_y end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT bold_Σ start_POSTSUBSCRIPT bold_italic_x bold_italic_y end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT end_CELL start_CELL bold_italic_I start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT end_CELL end_ROW end_ARG ] . (3.6.2)

    One arrives at this identity by performing two rounds of (block) Gaussian elimination on the covariance matrix.

  2. 2.

    Based on the previous identity, show that

    [𝚺𝒙𝚺𝒙𝒚𝚺𝒙𝒚𝚺𝒚]1=[𝑰D𝟎𝚺𝒚1𝚺𝒙𝒚𝑰d][(𝚺𝒙𝚺𝒙𝒚𝚺𝒚1𝚺𝒙𝒚)1𝟎𝟎𝚺𝒚1][𝑰D𝚺𝒙𝒚𝚺𝒚1𝟎𝑰d]\begin{bmatrix}\bm{\Sigma}_{\bm{x}}&\bm{\Sigma}_{\bm{x}\bm{y}}\\ \bm{\Sigma}_{\bm{x}\bm{y}}^{\top}&\bm{\Sigma}_{\bm{y}}\end{bmatrix}^{-1}=\begin{bmatrix}\bm{I}_{D}&\mathbf{0}\\ -\bm{\Sigma}_{\bm{y}}^{-1}\bm{\Sigma}_{\bm{x}\bm{y}}^{\top}&\bm{I}_{d}\end{bmatrix}\begin{bmatrix}\left(\bm{\Sigma}_{\bm{x}}-\bm{\Sigma}_{\bm{x}\bm{y}}\bm{\Sigma}_{\bm{y}}^{-1}\bm{\Sigma}_{\bm{x}\bm{y}}^{\top}\right)^{-1}&\mathbf{0}\\ \mathbf{0}&\bm{\Sigma}_{\bm{y}}^{-1}\end{bmatrix}\begin{bmatrix}\bm{I}_{D}&-\bm{\Sigma}_{\bm{x}\bm{y}}\bm{\Sigma}_{\bm{y}}^{-1}\\ \mathbf{0}&\bm{I}_{d}\end{bmatrix}[ start_ARG start_ROW start_CELL bold_Σ start_POSTSUBSCRIPT bold_italic_x end_POSTSUBSCRIPT end_CELL start_CELL bold_Σ start_POSTSUBSCRIPT bold_italic_x bold_italic_y end_POSTSUBSCRIPT end_CELL end_ROW start_ROW start_CELL bold_Σ start_POSTSUBSCRIPT bold_italic_x bold_italic_y end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT end_CELL start_CELL bold_Σ start_POSTSUBSCRIPT bold_italic_y end_POSTSUBSCRIPT end_CELL end_ROW end_ARG ] start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT = [ start_ARG start_ROW start_CELL bold_italic_I start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT end_CELL start_CELL bold_0 end_CELL end_ROW start_ROW start_CELL - bold_Σ start_POSTSUBSCRIPT bold_italic_y end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT bold_Σ start_POSTSUBSCRIPT bold_italic_x bold_italic_y end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT end_CELL start_CELL bold_italic_I start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT end_CELL end_ROW end_ARG ] [ start_ARG start_ROW start_CELL ( bold_Σ start_POSTSUBSCRIPT bold_italic_x end_POSTSUBSCRIPT - bold_Σ start_POSTSUBSCRIPT bold_italic_x bold_italic_y end_POSTSUBSCRIPT bold_Σ start_POSTSUBSCRIPT bold_italic_y end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT bold_Σ start_POSTSUBSCRIPT bold_italic_x bold_italic_y end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT end_CELL start_CELL bold_0 end_CELL end_ROW start_ROW start_CELL bold_0 end_CELL start_CELL bold_Σ start_POSTSUBSCRIPT bold_italic_y end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT end_CELL end_ROW end_ARG ] [ start_ARG start_ROW start_CELL bold_italic_I start_POSTSUBSCRIPT italic_D end_POSTSUBSCRIPT end_CELL start_CELL - bold_Σ start_POSTSUBSCRIPT bold_italic_x bold_italic_y end_POSTSUBSCRIPT bold_Σ start_POSTSUBSCRIPT bold_italic_y end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT end_CELL end_ROW start_ROW start_CELL bold_0 end_CELL start_CELL bold_italic_I start_POSTSUBSCRIPT italic_d end_POSTSUBSCRIPT end_CELL end_ROW end_ARG ] (3.6.3)

    whenever the relevant inverses are defined.303030In cases where the Schur complement term is not invertible, the same result holds with its inverse replaced by the Moore-Penrose pseudoinverse. In particular, the conditional distribution (3.6.1) becomes a degenerate Gaussian distribution. Conclude that

    [𝒙𝝁𝒙𝒚𝝁𝒚][𝚺𝒙𝚺𝒙𝒚𝚺𝒙𝒚𝚺𝒚]1[𝒙𝝁𝒙𝒚𝝁𝒚]\displaystyle\begin{bmatrix}\bm{x}-\bm{\mu}_{\bm{x}}\\ \bm{y}-\bm{\mu}_{\bm{y}}\end{bmatrix}^{\top}\begin{bmatrix}\bm{\Sigma}_{\bm{x}}&\bm{\Sigma}_{\bm{x}\bm{y}}\\ \bm{\Sigma}_{\bm{x}\bm{y}}^{\top}&\bm{\Sigma}_{\bm{y}}\end{bmatrix}^{-1}\begin{bmatrix}\bm{x}-\bm{\mu}_{\bm{x}}\\ \bm{y}-\bm{\mu}_{\bm{y}}\end{bmatrix}[ start_ARG start_ROW start_CELL bold_italic_x - bold_italic_μ start_POSTSUBSCRIPT bold_italic_x end_POSTSUBSCRIPT end_CELL end_ROW start_ROW start_CELL bold_italic_y - bold_italic_μ start_POSTSUBSCRIPT bold_italic_y end_POSTSUBSCRIPT end_CELL end_ROW end_ARG ] start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT [ start_ARG start_ROW start_CELL bold_Σ start_POSTSUBSCRIPT bold_italic_x end_POSTSUBSCRIPT end_CELL start_CELL bold_Σ start_POSTSUBSCRIPT bold_italic_x bold_italic_y end_POSTSUBSCRIPT end_CELL end_ROW start_ROW start_CELL bold_Σ start_POSTSUBSCRIPT bold_italic_x bold_italic_y end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT end_CELL start_CELL bold_Σ start_POSTSUBSCRIPT bold_italic_y end_POSTSUBSCRIPT end_CELL end_ROW end_ARG ] start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT [ start_ARG start_ROW start_CELL bold_italic_x - bold_italic_μ start_POSTSUBSCRIPT bold_italic_x end_POSTSUBSCRIPT end_CELL end_ROW start_ROW start_CELL bold_italic_y - bold_italic_μ start_POSTSUBSCRIPT bold_italic_y end_POSTSUBSCRIPT end_CELL end_ROW end_ARG ] (3.6.4)
    =[𝒙(𝝁𝒙+𝚺𝒙𝒚𝚺𝒚1(𝒚𝝁𝒚))𝒚𝝁𝒚][(𝚺𝒙𝚺𝒙𝒚𝚺𝒚1𝚺𝒙𝒚)1𝟎𝟎𝚺𝒚1][𝒙(𝝁𝒙+𝚺𝒙𝒚𝚺𝒚1(𝒚𝝁𝒚))𝒚𝝁𝒚].\displaystyle\qquad=\begin{bmatrix}\bm{x}-\left(\bm{\mu}_{\bm{x}}+\bm{\Sigma}_{\bm{x}\bm{y}}\bm{\Sigma}_{\bm{y}}^{-1}(\bm{y}-\bm{\mu}_{\bm{y}})\right)\\ \bm{y}-\bm{\mu}_{\bm{y}}\end{bmatrix}^{\top}\begin{bmatrix}\left(\bm{\Sigma}_{\bm{x}}-\bm{\Sigma}_{\bm{x}\bm{y}}\bm{\Sigma}_{\bm{y}}^{-1}\bm{\Sigma}_{\bm{x}\bm{y}}^{\top}\right)^{-1}&\mathbf{0}\\ \mathbf{0}&\bm{\Sigma}_{\bm{y}}^{-1}\end{bmatrix}\begin{bmatrix}\bm{x}-\left(\bm{\mu}_{\bm{x}}+\bm{\Sigma}_{\bm{x}\bm{y}}\bm{\Sigma}_{\bm{y}}^{-1}(\bm{y}-\bm{\mu}_{\bm{y}})\right)\\ \bm{y}-\bm{\mu}_{\bm{y}}\end{bmatrix}.= [ start_ARG start_ROW start_CELL bold_italic_x - ( bold_italic_μ start_POSTSUBSCRIPT bold_italic_x end_POSTSUBSCRIPT + bold_Σ start_POSTSUBSCRIPT bold_italic_x bold_italic_y end_POSTSUBSCRIPT bold_Σ start_POSTSUBSCRIPT bold_italic_y end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( bold_italic_y - bold_italic_μ start_POSTSUBSCRIPT bold_italic_y end_POSTSUBSCRIPT ) ) end_CELL end_ROW start_ROW start_CELL bold_italic_y - bold_italic_μ start_POSTSUBSCRIPT bold_italic_y end_POSTSUBSCRIPT end_CELL end_ROW end_ARG ] start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT [ start_ARG start_ROW start_CELL ( bold_Σ start_POSTSUBSCRIPT bold_italic_x end_POSTSUBSCRIPT - bold_Σ start_POSTSUBSCRIPT bold_italic_x bold_italic_y end_POSTSUBSCRIPT bold_Σ start_POSTSUBSCRIPT bold_italic_y end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT bold_Σ start_POSTSUBSCRIPT bold_italic_x bold_italic_y end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT end_CELL start_CELL bold_0 end_CELL end_ROW start_ROW start_CELL bold_0 end_CELL start_CELL bold_Σ start_POSTSUBSCRIPT bold_italic_y end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT end_CELL end_ROW end_ARG ] [ start_ARG start_ROW start_CELL bold_italic_x - ( bold_italic_μ start_POSTSUBSCRIPT bold_italic_x end_POSTSUBSCRIPT + bold_Σ start_POSTSUBSCRIPT bold_italic_x bold_italic_y end_POSTSUBSCRIPT bold_Σ start_POSTSUBSCRIPT bold_italic_y end_POSTSUBSCRIPT start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT ( bold_italic_y - bold_italic_μ start_POSTSUBSCRIPT bold_italic_y end_POSTSUBSCRIPT ) ) end_CELL end_ROW start_ROW start_CELL bold_italic_y - bold_italic_μ start_POSTSUBSCRIPT bold_italic_y end_POSTSUBSCRIPT end_CELL end_ROW end_ARG ] . (3.6.5)

    (Hint: To economize algebraic manipulations, note that the first and last matrices on the RHS of Equation 3.6.2 are transposes of one another.)

  3. 3.

    By dividing p𝒙,𝒚/p𝒚p_{\bm{x},\bm{y}}/p_{\bm{y}}italic_p start_POSTSUBSCRIPT bold_italic_x , bold_italic_y end_POSTSUBSCRIPT / italic_p start_POSTSUBSCRIPT bold_italic_y end_POSTSUBSCRIPT, prove Equation 3.6.1. (Hint: Using the previous identities, only minimal algebra should be necessary. For the normalizing constant, use Equation 3.6.3 to factor the determinant similarly.)

Exercise 3.3.

Show the Sherman-Morrison-Woodbury identity, i.e., for matrices 𝑨\bm{A}bold_italic_A, 𝑪\bm{C}bold_italic_C, 𝑼\bm{U}bold_italic_U, 𝑽\bm{V}bold_italic_V such that 𝑨\bm{A}bold_italic_A, 𝑪\bm{C}bold_italic_C, and 𝑨+𝑼𝑪𝑽\bm{A}+\bm{U}\bm{C}\bm{V}bold_italic_A + bold_italic_U bold_italic_C bold_italic_V are invertible,

(𝑨+𝑼𝑪𝑽)1=𝑨1𝑨1𝑼(𝑪1+𝑽𝑨1𝑼)1𝑽𝑨1(\bm{A}+\bm{U}\bm{C}\bm{V})^{-1}=\bm{A}^{-1}-\bm{A}^{-1}\bm{U}(\bm{C}^{-1}+\bm{V}\bm{A}^{-1}\bm{U})^{-1}\bm{V}\bm{A}^{-1}( bold_italic_A + bold_italic_U bold_italic_C bold_italic_V ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT = bold_italic_A start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT - bold_italic_A start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT bold_italic_U ( bold_italic_C start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT + bold_italic_V bold_italic_A start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT bold_italic_U ) start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT bold_italic_V bold_italic_A start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT (3.6.6)
Exercise 3.4.

Rederive the following, assuming 𝒙t\bm{x}_{t}bold_italic_x start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT follows the generalized noise model (3.2.69).

  • Tweedie’s formula: (3.2.70).

  • The DDIM iteration: (3.2.71).

  • The Bayes optimal denoiser for a Gaussian mixture model: (3.2.72).

Exercise 3.5.
  1. 1.

    Implement the formulae derived in Exercise 3.4, building a sampler for Gaussian mixtures.

  2. 2.

    Reproduce Figure 3.4 and Figure 3.7.

  3. 3.

    We now introduce a separate process called Flow Matching (FM), as follows:

    αt=1t,σt=t.\alpha_{t}=1-t,\qquad\sigma_{t}=t.italic_α start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = 1 - italic_t , italic_σ start_POSTSUBSCRIPT italic_t end_POSTSUBSCRIPT = italic_t . (3.6.7)

    Implement this process using the same framework, and test it for sampling in high dimensions. Which process seems to give better or more stable results?

Exercise 3.6.

Please show the following properties of the logdet()\log\det(\cdot)roman_log roman_det ( ⋅ ) function.

  1. 1.

    Show that

    f(𝑿)=logdet(𝑿)\displaystyle f(\bm{X})=\log\det\left(\bm{X}\right)italic_f ( bold_italic_X ) = roman_log roman_det ( bold_italic_X )

    is a concave function. (Hint: The function f(𝒙)f(\bm{x})italic_f ( bold_italic_x ) is convex if and only if the function f(𝒙+t𝒉)f(\bm{x}+t\bm{h})italic_f ( bold_italic_x + italic_t bold_italic_h ) for all 𝒙\bm{x}bold_italic_x and 𝒉\bm{h}bold_italic_h.)

  2. 2.

    Show that:

    logdet(𝑰+𝑿𝑿)=logdet(𝑰+𝑿𝑿)\displaystyle\log\det(\bm{I}+\bm{X}^{\top}\bm{X})=\log\det(\bm{I}+\bm{X}\bm{X}^{\top})roman_log roman_det ( bold_italic_I + bold_italic_X start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_X ) = roman_log roman_det ( bold_italic_I + bold_italic_X bold_italic_X start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT )
  3. 3.

    Let 𝑨n×n\bm{A}\in\mathbb{R}^{n\times n}bold_italic_A ∈ blackboard_R start_POSTSUPERSCRIPT italic_n × italic_n end_POSTSUPERSCRIPT be a positive definite matrix. Please show that:

    logdet(𝑨)=i=1nlog(λi),\displaystyle\log\det\left(\bm{A}\right)=\sum_{i=1}^{n}\log(\lambda_{i}),roman_log roman_det ( bold_italic_A ) = ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT roman_log ( italic_λ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) , (3.6.8)

    where λ1,λ2,,λn\lambda_{1},\lambda_{2},\dots,\lambda_{n}italic_λ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , italic_λ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , … , italic_λ start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT are the eigenvalues of 𝑨\bm{A}bold_italic_A.