Appendix A Optimization Methods

Since the building of all the universe is perfect and is created by the wisdom creator, nothing arises in the universe in which one cannot see the sense of some maximum or minimum.

  – L. Euler

In this chapter, we give a brief introduction to some of the most basic but important optimization algorithms used in this book. The purpose is only to help the reader apply these algorithms to problems studied in this book, not to gain a deep understanding about these algorithms. Hence, we will not provide a thorough justification for the algorithms introduced, in terms of performance guarantees.

A.1 Steepest Descent

Optimization is concerned with the question of how to find where a function, say L(θ)L(\theta)italic_L ( italic_θ ), reaches its minimum value. Mathematically, this is stated as a problem:

argminθΘ(θ),\operatorname*{arg\ min}_{\theta\in\Theta}\mathcal{L}(\theta),start_OPERATOR roman_arg roman_min end_OPERATOR start_POSTSUBSCRIPT italic_θ ∈ roman_Θ end_POSTSUBSCRIPT caligraphic_L ( italic_θ ) , (A.1.1)

where Θ\Thetaroman_Θ represents a domain to which the argument 𝒙\bm{x}bold_italic_x is confined. Often (and unleess otherwise mentioned, in this chapter) Θ\Thetaroman_Θ is simply n\mathbb{R}^{n}blackboard_R start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT. Without loss of generality, we assume that here the function (θ)\mathcal{L}(\theta)caligraphic_L ( italic_θ ) is smooth111In case the function \mathcal{L}caligraphic_L is not smooth, we replace its gradient with a so-called subgradient..

The efficiency of finding the (global) minima depends on what information we have about the function \mathcal{L}caligraphic_L. For most optimization problems considered in this book, the dimension of θ\thetaitalic_θ, say nnitalic_n, is very large. That makes computing or accessing local information about \ellroman_ℓ expensive. In particular, since the gradient \nabla\mathcal{L}∇ caligraphic_L has nnitalic_n entries, it is often reasonable to compute; however, the Hessian 2\nabla^{2}\mathcal{L}∇ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT caligraphic_L has n2n^{2}italic_n start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT entries which is often wildly impractical to compute (and the same goes for higher-order derivatives). Hence, it is typical to assume that we have the zeroth-order information, i.e., we are able to evaluate (θ)\mathcal{L}(\theta)caligraphic_L ( italic_θ ), and the first-order information, i.e., we are able to evaluate (θ)\nabla\mathcal{L}(\theta)∇ caligraphic_L ( italic_θ ). Optimization theorists may rephrase this as saying we have a “first-order oracle.” All optimization algorithms that we introduce in this section only use a first-order oracle.222We refer the readers to the book by [WM22] for a more systematic introduction to optimization algorithms in a high-dimensional space, including algorithms assuming higher-order oracles.

A.1.1 Vanilla Gradient Descent for Smooth Problems

The simplest and most widely used method for optimization is gradient descent (GD). It was first introduced by Cauchy in 1847. The idea is very simple: starting from an initial state, we iteratively take small steps such that each step reduces the value of the function (θ)\mathcal{L}(\theta)caligraphic_L ( italic_θ ).

Suppose that the current state is θ\thetaitalic_θ. We want to take a small step, say of distance hhitalic_h, in a direction, indicated by a vector 𝒗\bm{v}bold_italic_v, to reach a new state θ+h𝒗\theta+h\bm{v}italic_θ + italic_h bold_italic_v such that the value of the function decreases:

(θ+h𝒗)(θ).\mathcal{L}(\theta+h\bm{v})\leq\mathcal{L}(\theta).caligraphic_L ( italic_θ + italic_h bold_italic_v ) ≤ caligraphic_L ( italic_θ ) . (A.1.2)

To find such a direction 𝒗\bm{v}bold_italic_v, we can approximate (θ+h𝒗)\mathcal{L}(\theta+h\bm{v})caligraphic_L ( italic_θ + italic_h bold_italic_v ) through a Taylor expansion around h=0h=0italic_h = 0:

(θ+h𝒗)=(θ)+h(θ),𝒗+o(h),\mathcal{L}(\theta+h\bm{v})=\mathcal{L}(\theta)+h\langle\nabla\mathcal{L}(\theta),\bm{v}\rangle+o(h),caligraphic_L ( italic_θ + italic_h bold_italic_v ) = caligraphic_L ( italic_θ ) + italic_h ⟨ ∇ caligraphic_L ( italic_θ ) , bold_italic_v ⟩ + italic_o ( italic_h ) , (A.1.3)

where the inner product here (and in this chapter) will be the 2\ell^{2}roman_ℓ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT inner product, i.e., 𝒙,𝒚=𝒙𝒚\langle\bm{x},\bm{y}\rangle=\bm{x}^{\top}\bm{y}⟨ bold_italic_x , bold_italic_y ⟩ = bold_italic_x start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_y. To find the direction of steepest descent, we attempt to minimize this Taylor expansion among unit vectors 𝒗\bm{v}bold_italic_v. If (θ)=𝟎\nabla\mathcal{L}(\theta)=\bm{0}∇ caligraphic_L ( italic_θ ) = bold_0, then the second term above is 0 regardless of the value of 𝒗\bm{v}bold_italic_v, so we cannot attempt to make progress, i.e., the algorithm has converged. On the other hand, if (θ)𝟎\nabla\mathcal{L}(\theta)\neq\bm{0}∇ caligraphic_L ( italic_θ ) ≠ bold_0 then it holds

argmin𝒗d𝒗2=1[(θ)+h(θ),𝒗]=argmin𝒗d𝒗2=1(θ),𝒗=(θ)(θ)2,\operatorname*{arg\ min}_{\begin{subarray}{c}\bm{v}\in\mathbb{R}^{d}\\ \|\bm{v}\|_{2}=1\end{subarray}}[\mathcal{L}(\theta)+h\langle\nabla\mathcal{L}(\theta),\bm{v}\rangle]=\operatorname*{arg\ min}_{\begin{subarray}{c}\bm{v}\in\mathbb{R}^{d}\\ \|\bm{v}\|_{2}=1\end{subarray}}\langle\nabla\mathcal{L}(\theta),\bm{v}\rangle=-\frac{\nabla\mathcal{L}(\theta)}{\|\nabla\mathcal{L}(\theta)\|_{2}},start_OPERATOR roman_arg roman_min end_OPERATOR start_POSTSUBSCRIPT start_ARG start_ROW start_CELL bold_italic_v ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT end_CELL end_ROW start_ROW start_CELL ∥ bold_italic_v ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT = 1 end_CELL end_ROW end_ARG end_POSTSUBSCRIPT [ caligraphic_L ( italic_θ ) + italic_h ⟨ ∇ caligraphic_L ( italic_θ ) , bold_italic_v ⟩ ] = start_OPERATOR roman_arg roman_min end_OPERATOR start_POSTSUBSCRIPT start_ARG start_ROW start_CELL bold_italic_v ∈ blackboard_R start_POSTSUPERSCRIPT italic_d end_POSTSUPERSCRIPT end_CELL end_ROW start_ROW start_CELL ∥ bold_italic_v ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT = 1 end_CELL end_ROW end_ARG end_POSTSUBSCRIPT ⟨ ∇ caligraphic_L ( italic_θ ) , bold_italic_v ⟩ = - divide start_ARG ∇ caligraphic_L ( italic_θ ) end_ARG start_ARG ∥ ∇ caligraphic_L ( italic_θ ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_ARG , (A.1.4)

In words, this means that the value of (θ+h𝒗)\mathcal{L}(\theta+h\bm{v})caligraphic_L ( italic_θ + italic_h bold_italic_v ) decreases the fastest along the direction 𝒗=(θ)/(θ)2\bm{v}=-\nabla\mathcal{L}(\theta)/\|\nabla\mathcal{L}(\theta)\|_{2}bold_italic_v = - ∇ caligraphic_L ( italic_θ ) / ∥ ∇ caligraphic_L ( italic_θ ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT, for small enough hhitalic_h. This leads to the gradient descent method: From the current state θk\theta_{k}italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT (k=0,1,k=0,1,\ldotsitalic_k = 0 , 1 , …), we take a step of size hhitalic_h in the direction of the negative gradient to reach the next iterate,

θk+1=θkh(θk).\theta_{k+1}=\theta_{k}-h\nabla\mathcal{L}(\theta_{k}).italic_θ start_POSTSUBSCRIPT italic_k + 1 end_POSTSUBSCRIPT = italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT - italic_h ∇ caligraphic_L ( italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) . (A.1.5)

The step size hhitalic_h is also called the learning rate in machine learning contexts.

Step-Size Selection

The remaining question is what the step size hhitalic_h should be? If we choose hhitalic_h to be too small, the value of the function may decrease very slowly, as shown by the plot in the middle in Figure A.1. If hhitalic_h is too large, the value might not even decrease at all, as shown by the plot on the right in Figure A.1.

Figure A.1 : The effect of step size h h italic_h on the convergence of the gradient descent method.
Figure A.1 : The effect of step size h h italic_h on the convergence of the gradient descent method.
Figure A.1 : The effect of step size h h italic_h on the convergence of the gradient descent method.
Figure A.1: The effect of step size hhitalic_h on the convergence of the gradient descent method.

So the step size hhitalic_h should be chosen based on the landscape of the function (θk)\mathcal{L}(\theta_{k})caligraphic_L ( italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ). Ideally, to choose the best step size hhitalic_h, we can solve the following optimization problem over a single variable hhitalic_h:

h=argminh0(θkh(θk)).h=\operatorname*{arg\ min}_{h\geq 0}\mathcal{L}(\theta_{k}-h\nabla\mathcal{L}(\theta_{k})).italic_h = start_OPERATOR roman_arg roman_min end_OPERATOR start_POSTSUBSCRIPT italic_h ≥ 0 end_POSTSUBSCRIPT caligraphic_L ( italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT - italic_h ∇ caligraphic_L ( italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) ) . (A.1.6)

This method of choosing the step size is called line search. However, when the function L(θk)L(\theta_{k})italic_L ( italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) is complicated, which is usually the case for training a deep neural network, this one-dimensional optimization is very difficult to solve at each iteration of gradient descent.

Then how should we choose a proper step size hhitalic_h? One common and classical approach is to try to obtain a good approximation of the local landscape around the current state θ\thetaitalic_θ based on some knowledge about the overall landscape of the function (θ)\mathcal{L}(\theta)caligraphic_L ( italic_θ ).

θ\thetaitalic_θ(θ)\mathcal{L}(\theta)caligraphic_L ( italic_θ )(θ)\mathcal{L}(\theta)caligraphic_L ( italic_θ )u(θ)u(\theta)italic_u ( italic_θ )θ0\theta_{0}italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPTθ1\theta_{1}italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT
Figure A.2: Majorization-minimization scheme and intuition. A function :Θ\mathcal{L}\colon\Theta\to\mathbb{R}caligraphic_L : roman_Θ → blackboard_R has a global upper bound u:Θu\colon\Theta\to\mathbb{R}italic_u : roman_Θ → blackboard_R which meets LLitalic_L at at least one point θ0\theta_{0}italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT. Then, finding the θ1\theta_{1}italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT which minimizes uuitalic_u will improve the value of \mathcal{L}caligraphic_L from (θ0)\mathcal{L}(\theta_{0})caligraphic_L ( italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ). Note that similar results can be shown about local upper bounds.

Common conditions for the landscape of (θ)\mathcal{L}(\theta)caligraphic_L ( italic_θ ) include:

  • α\alphaitalic_α-Strong Convexity. Recall that \mathcal{L}caligraphic_L is α\alphaitalic_α-strongly convex if its graph lies above a global quadratic lower bound of slope α\alphaitalic_α, i.e.,

    (θ)lθ0,α(θ)(θ0)+(θ0),θθ0+α2θθ022\mathcal{L}(\theta)\geq l_{\theta_{0},\alpha}(\theta)\doteq\mathcal{L}(\theta_{0})+\langle\nabla\mathcal{L}(\theta_{0}),\theta-\theta_{0}\rangle+\frac{\alpha}{2}\|\theta-\theta_{0}\|_{2}^{2}caligraphic_L ( italic_θ ) ≥ italic_l start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_α end_POSTSUBSCRIPT ( italic_θ ) ≐ caligraphic_L ( italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) + ⟨ ∇ caligraphic_L ( italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) , italic_θ - italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ⟩ + divide start_ARG italic_α end_ARG start_ARG 2 end_ARG ∥ italic_θ - italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT (A.1.7)

    for any “base point” θ0\theta_{0}italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT. We say that \mathcal{L}caligraphic_L is convex if it is 0-strongly convex, i.e., its graph lies above its tangents. It is easy to show (proof as exercise) that strongly convex functions have unique global minima. Another important fact (proof as exercise) is that α\alphaitalic_α-strongly convex twice-differentiable functions \mathcal{L}caligraphic_L have (symmetric) Hessians 2\nabla^{2}\mathcal{L}∇ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT caligraphic_L whose minimum eigenvalue is α\geq\alpha≥ italic_α. For α>0\alpha>0italic_α > 0 this implies the Hessian is symmetric positive definite, and for α=0\alpha=0italic_α = 0 (i.e., \mathcal{L}caligraphic_L is convex) this implies that the Hessian is symmetric positive semidefinite.

  • β\betaitalic_β-Lipschitz Gradient (also called β\betaitalic_β-Smoothness). Recall that \mathcal{L}caligraphic_L has β\betaitalic_β-Lipschitz gradient if \nabla\mathcal{L}∇ caligraphic_L exists and is β\betaitalic_β-Lipschitz, i.e.,

    (θ)(θ0)2βθθ02.\|\nabla\mathcal{L}(\theta)-\nabla\mathcal{L}(\theta_{0})\|_{2}\leq\beta\|\theta-\theta_{0}\|_{2}.∥ ∇ caligraphic_L ( italic_θ ) - ∇ caligraphic_L ( italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ≤ italic_β ∥ italic_θ - italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT . (A.1.8)

    for any “base point” θ0\theta_{0}italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT. It is easy to show (proof as exercise) that this is equivalent to \mathcal{L}caligraphic_L having a global quadratic upper bound of slope β\betaitalic_β, i.e.,

    (θ)uθ0,β(θ)(θ0)+(θ0),θθ0+β2θθ022.\mathcal{L}(\theta)\leq u_{\theta_{0},\beta}(\theta)\doteq\mathcal{L}(\theta_{0})+\langle\nabla\mathcal{L}(\theta_{0}),\theta-\theta_{0}\rangle+\frac{\beta}{2}\|\theta-\theta_{0}\|_{2}^{2}.caligraphic_L ( italic_θ ) ≤ italic_u start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_β end_POSTSUBSCRIPT ( italic_θ ) ≐ caligraphic_L ( italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) + ⟨ ∇ caligraphic_L ( italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) , italic_θ - italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ⟩ + divide start_ARG italic_β end_ARG start_ARG 2 end_ARG ∥ italic_θ - italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT . (A.1.9)

    for any “base point” θ0\theta_{0}italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT. Another important fact (proof as exercise) is that convex β\betaitalic_β-Lipschitz gradient twice-differentiable functions have (symmetric) Hessians 2\nabla^{2}\mathcal{L}∇ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT caligraphic_L whose largest eigenvalue is β\leq\beta≤ italic_β.

First, let us suppose that \mathcal{L}caligraphic_L has β\betaitalic_β-Lipschitz gradient (but is not necessarily even convex). We will use this occasion to introduce a common optimization theme: to minimize \mathcal{L}caligraphic_L, we can minimize an upper bound on LLitalic_L, which is justified by the following lemma visualized in Figure A.2.

Lemma A.1 (Majorization-Minimization).

Suppose that u:Θu\colon\Theta\to\mathbb{R}italic_u : roman_Θ → blackboard_R is a global upper bound on \mathcal{L}caligraphic_L, namely (θ)u(θ)\mathcal{L}(\theta)\leq u(\theta)caligraphic_L ( italic_θ ) ≤ italic_u ( italic_θ ) for all θΘ\theta\in\Thetaitalic_θ ∈ roman_Θ. Suppose that they meet with equality at θ0\theta_{0}italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT, i.e., (θ0)=u(θ0)\mathcal{L}(\theta_{0})=u(\theta_{0})caligraphic_L ( italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) = italic_u ( italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ). Then

θ1argminθΘu(θ)(θ1)u(θ1)u(θ0)=(θ0).\theta_{1}\in\operatorname*{arg\ min}_{\theta\in\Theta}u(\theta)\implies\mathcal{L}(\theta_{1})\leq u(\theta_{1})\leq u(\theta_{0})=\mathcal{L}(\theta_{0}).italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ∈ start_OPERATOR roman_arg roman_min end_OPERATOR start_POSTSUBSCRIPT italic_θ ∈ roman_Θ end_POSTSUBSCRIPT italic_u ( italic_θ ) ⟹ caligraphic_L ( italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) ≤ italic_u ( italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) ≤ italic_u ( italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) = caligraphic_L ( italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) . (A.1.10)

We will use this lemma to show that we can use the Lipschitz gradient property to ensure that each gradient step cannot worsen the value of \mathcal{L}caligraphic_L. Indeed, at every base point θ0\theta_{0}italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT, we have that uθ0,βu_{\theta_{0},\beta}italic_u start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_β end_POSTSUBSCRIPT is a global upper bound on \mathcal{L}caligraphic_L, and uθ0,β(θ0)=(θ0)u_{\theta_{0},\beta}(\theta_{0})=\mathcal{L}(\theta_{0})italic_u start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_β end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) = caligraphic_L ( italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ). Hence by Lemma A.1

if θ minimizes uθ0,β then(θ)uθ0,β(θ)uθ0,β(θ0)=(θ0).\text{if $\theta$ minimizes $u_{\theta_{0},\beta}$ then}\quad\mathcal{L}(\theta)\leq u_{\theta_{0},\beta}(\theta)\leq u_{\theta_{0},\beta}(\theta_{0})=\mathcal{L}(\theta_{0}).if italic_θ minimizes italic_u start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_β end_POSTSUBSCRIPT then caligraphic_L ( italic_θ ) ≤ italic_u start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_β end_POSTSUBSCRIPT ( italic_θ ) ≤ italic_u start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_β end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) = caligraphic_L ( italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) . (A.1.11)

This motivates us, when finding an update to obtain θk+1\theta_{k+1}italic_θ start_POSTSUBSCRIPT italic_k + 1 end_POSTSUBSCRIPT from θk\theta_{k}italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT, we can instead minimize the upper bound uθk,βu_{\theta_{k},\beta}italic_u start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT , italic_β end_POSTSUBSCRIPT over θ\thetaitalic_θ and set that to be θk+1\theta_{k+1}italic_θ start_POSTSUBSCRIPT italic_k + 1 end_POSTSUBSCRIPT. By minimizing uθk,βu_{\theta_{k},\beta}italic_u start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT , italic_β end_POSTSUBSCRIPT (proof as exercise) we get

θk+1=θk1β(θk)(θk+1)(θk).\theta_{k+1}=\theta_{k}-\frac{1}{\beta}\nabla\mathcal{L}(\theta_{k})\implies\mathcal{L}(\theta_{k+1})\leq\mathcal{L}(\theta_{k}).italic_θ start_POSTSUBSCRIPT italic_k + 1 end_POSTSUBSCRIPT = italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT - divide start_ARG 1 end_ARG start_ARG italic_β end_ARG ∇ caligraphic_L ( italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) ⟹ caligraphic_L ( italic_θ start_POSTSUBSCRIPT italic_k + 1 end_POSTSUBSCRIPT ) ≤ caligraphic_L ( italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) . (A.1.12)

This implies that a step size h=1/βh=1/\betaitalic_h = 1 / italic_β is a usable learning rate, but it does not provide a convergence rate or certify that L(θk)L(\theta_{k})italic_L ( italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) actually converges to minθ(θ)\min_{\theta}\mathcal{L}(\theta)roman_min start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT caligraphic_L ( italic_θ ). This requires a little more rigor, which we now pursue.

Now, let us suppose that \mathcal{L}caligraphic_L is α\alphaitalic_α-strongly convex, has β\betaitalic_β-Lipschitz gradient, and has global optimum θ\theta^{\star}italic_θ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT. We will show that θk\theta_{k}italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT will converge directly to the unique global optimum θ\theta^{\star}italic_θ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT, which is a very strong form of convergence. In particular, we will bound θθk2\|\theta^{\star}-\theta_{k}\|_{2}∥ italic_θ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT - italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT using both strong convexity and Lipschitzness of the gradient of \mathcal{L}caligraphic_L, i.e., taking a look at the neighborhood around θk\theta_{k}italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT:333In this proof the β\betaitalic_β-Lipschitz Gradient invocation step is a little non-trivial. We also leave this step as an exercise, with the hint to plug in θ=θ0h(θ0)\theta=\theta_{0}-h\nabla\mathcal{L}(\theta_{0})italic_θ = italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT - italic_h ∇ caligraphic_L ( italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) into the Lipschitz gradient identity.

θθk+122\displaystyle\|\theta^{\star}-\theta_{k+1}\|_{2}^{2}∥ italic_θ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT - italic_θ start_POSTSUBSCRIPT italic_k + 1 end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT θθk+h(θk)22\displaystyle\leq\|\theta^{\star}-\theta_{k}+h\nabla\mathcal{L}(\theta_{k})\|_{2}^{2}≤ ∥ italic_θ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT - italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT + italic_h ∇ caligraphic_L ( italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT (A.1.13)
=θθk22+2h(θk),θθk+h2(θk)22\displaystyle=\|\theta^{\star}-\theta_{k}\|_{2}^{2}+2h\langle\nabla\mathcal{L}(\theta_{k}),\theta^{\star}-\theta_{k}\rangle+h^{2}\|\nabla\mathcal{L}(\theta_{k})\|_{2}^{2}= ∥ italic_θ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT - italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + 2 italic_h ⟨ ∇ caligraphic_L ( italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) , italic_θ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT - italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ⟩ + italic_h start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ∥ ∇ caligraphic_L ( italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT (A.1.14)
θθk22+2h((θ)(θk)α2θθk22)+h2(θk)22(α-SC)\displaystyle\leq\|\theta^{\star}-\theta_{k}\|_{2}^{2}+2h\left(\mathcal{L}(\theta^{\star})-\mathcal{L}(\theta_{k})-\frac{\alpha}{2}\|\theta^{\star}-\theta_{k}\|_{2}^{2}\right)+h^{2}\|\nabla\mathcal{L}(\theta_{k})\|_{2}^{2}\quad\text{($\alpha$-SC)}≤ ∥ italic_θ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT - italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + 2 italic_h ( caligraphic_L ( italic_θ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) - caligraphic_L ( italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) - divide start_ARG italic_α end_ARG start_ARG 2 end_ARG ∥ italic_θ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT - italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) + italic_h start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ∥ ∇ caligraphic_L ( italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ( italic_α -SC) (A.1.15)
=(1αh)θθk22+2h((θ)L(θk))+h2(θk)22\displaystyle=\left(1-\alpha h\right)\|\theta^{\star}-\theta_{k}\|_{2}^{2}+2h(\mathcal{L}(\theta^{\star})-L(\theta_{k}))+h^{2}\|\nabla\mathcal{L}(\theta_{k})\|_{2}^{2}= ( 1 - italic_α italic_h ) ∥ italic_θ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT - italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + 2 italic_h ( caligraphic_L ( italic_θ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) - italic_L ( italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) ) + italic_h start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ∥ ∇ caligraphic_L ( italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT (A.1.16)
(1αh)θθk22+2h((θk)(θ))+2h2β((θk)(θ))(β-LG)\displaystyle\leq\left(1-\alpha h\right)\|\theta^{\star}-\theta_{k}\|_{2}^{2}+2h(\mathcal{L}(\theta^{k})-\mathcal{L}(\theta^{\star}))+2h^{2}\beta(\mathcal{L}(\theta_{k})-\mathcal{L}(\theta^{\star}))\quad\text{($\beta$-LG)}≤ ( 1 - italic_α italic_h ) ∥ italic_θ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT - italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + 2 italic_h ( caligraphic_L ( italic_θ start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ) - caligraphic_L ( italic_θ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) ) + 2 italic_h start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT italic_β ( caligraphic_L ( italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) - caligraphic_L ( italic_θ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) ) ( italic_β -LG) (A.1.17)
=(1αh)θθk222h(1βh)((θk)(θ)).\displaystyle=\left(1-\alpha h\right)\|\theta^{\star}-\theta_{k}\|_{2}^{2}-2h(1-\beta h)(\mathcal{L}(\theta_{k})-\mathcal{L}(\theta^{\star})).= ( 1 - italic_α italic_h ) ∥ italic_θ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT - italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT - 2 italic_h ( 1 - italic_β italic_h ) ( caligraphic_L ( italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) - caligraphic_L ( italic_θ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) ) . (A.1.18)

In order to ensure that the gradient descent iteration makes progress we must pick the step size so that 1βh01-\beta h\geq 01 - italic_β italic_h ≥ 0, i.e., h1/βh\leq 1/\betaitalic_h ≤ 1 / italic_β. If such a setting occurs, then

θθk+122\displaystyle\|\theta^{\star}-\theta_{k+1}\|_{2}^{2}∥ italic_θ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT - italic_θ start_POSTSUBSCRIPT italic_k + 1 end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT (1αh)θθk22(1αh)2θθk122\displaystyle\leq(1-\alpha h)\|\theta^{\star}-\theta_{k}\|_{2}^{2}\leq(1-\alpha h)^{2}\|\theta^{\star}-\theta_{k-1}\|_{2}^{2}\leq\cdots≤ ( 1 - italic_α italic_h ) ∥ italic_θ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT - italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ≤ ( 1 - italic_α italic_h ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ∥ italic_θ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT - italic_θ start_POSTSUBSCRIPT italic_k - 1 end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ≤ ⋯ (A.1.19)
(1αh)k+1θθ022.\displaystyle\leq(1-\alpha h)^{k+1}\|\theta^{\star}-\theta_{0}\|_{2}^{2}.≤ ( 1 - italic_α italic_h ) start_POSTSUPERSCRIPT italic_k + 1 end_POSTSUPERSCRIPT ∥ italic_θ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT - italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT . (A.1.20)

In order to minimize the right-hand side, we can set h=1/βh=1/\betaitalic_h = 1 / italic_β, which obtains

θθk+122(1α/β)k+1θθ022,\|\theta^{\star}-\theta_{k+1}\|_{2}^{2}\leq(1-\alpha/\beta)^{k+1}\|\theta^{\star}-\theta_{0}\|_{2}^{2},∥ italic_θ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT - italic_θ start_POSTSUBSCRIPT italic_k + 1 end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ≤ ( 1 - italic_α / italic_β ) start_POSTSUPERSCRIPT italic_k + 1 end_POSTSUPERSCRIPT ∥ italic_θ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT - italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT , (A.1.21)

showing convergence to global optimum with exponentially decaying error. Notice that here we used a convergence rate to obtain a favorable step size of h=1/βh=1/\betaitalic_h = 1 / italic_β. This motif will re-occur in this section.

We end this section with a caveat: learning a global optimum is (usually) impractically hard. Under certain conditions, we can ensure that the gradient descent iterates converge to a local optimum. Also, under more relaxed conditions, we can ensure local convergence, i.e., that the iterates converge to a (global or local) optimum if the sequence is initialized close enough to the optimum.

A.1.2 Preconditioned Gradient Descent for Badly-Conditioned Problems

Figure A.3 : The negative gradient − ∇ ℒ λ -\nabla\mathcal{L}_{\lambda} - ∇ caligraphic_L start_POSTSUBSCRIPT italic_λ end_POSTSUBSCRIPT and pre-conditioned (Newton’s method step) vector field − [ ∇ 2 ℒ λ ] − 1 ​ [ ∇ ℒ λ ] -[\nabla^{2}\mathcal{L}_{\lambda}]^{-1}[\nabla\mathcal{L}_{\lambda}] - [ ∇ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT caligraphic_L start_POSTSUBSCRIPT italic_λ end_POSTSUBSCRIPT ] start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT [ ∇ caligraphic_L start_POSTSUBSCRIPT italic_λ end_POSTSUBSCRIPT ] where λ = 19 \lambda=19 italic_λ = 19 . There is a section of the space where following the negative gradient vector field makes very little progress towards finding the minimum, but in all cases following the Newton’s method vector field achieves equal speed of progress towards the optimum since the gradient is whitened. Since the Hessian here is diagonal, adaptive learning rate algorithms (e.g. Adam, as will be discussed later in the section) can make similar progress as Newton’s method, but a non-axis-aligned Hessian may even prevent Adam from succeeding quickly.
Figure A.3: The negative gradient λ-\nabla\mathcal{L}_{\lambda}- ∇ caligraphic_L start_POSTSUBSCRIPT italic_λ end_POSTSUBSCRIPT and pre-conditioned (Newton’s method step) vector field [2λ]1[λ]-[\nabla^{2}\mathcal{L}_{\lambda}]^{-1}[\nabla\mathcal{L}_{\lambda}]- [ ∇ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT caligraphic_L start_POSTSUBSCRIPT italic_λ end_POSTSUBSCRIPT ] start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT [ ∇ caligraphic_L start_POSTSUBSCRIPT italic_λ end_POSTSUBSCRIPT ] where λ=19\lambda=19italic_λ = 19. There is a section of the space where following the negative gradient vector field makes very little progress towards finding the minimum, but in all cases following the Newton’s method vector field achieves equal speed of progress towards the optimum since the gradient is whitened. Since the Hessian here is diagonal, adaptive learning rate algorithms (e.g. Adam, as will be discussed later in the section) can make similar progress as Newton’s method, but a non-axis-aligned Hessian may even prevent Adam from succeeding quickly.

Newton’s Method

There are some smooth problems and strongly convex problems on which gradient descent nonetheless does quite poorly. For example, let λ0\lambda\geq 0italic_λ ≥ 0 and let λ:2\mathcal{L}_{\lambda}\colon\mathbb{R}^{2}\to\mathbb{R}caligraphic_L start_POSTSUBSCRIPT italic_λ end_POSTSUBSCRIPT : blackboard_R start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT → blackboard_R of the form

λ(θ)=λ([θ1θ2])12{(1+λ)θ12+θ22}=12θ[1+λ001]θ.\mathcal{L}_{\lambda}(\theta)=\mathcal{L}_{\lambda}\left(\begin{bmatrix}\theta_{1}\\ \theta_{2}\end{bmatrix}\right)\doteq\frac{1}{2}\left\{(1+\lambda)\theta_{1}^{2}+\theta_{2}^{2}\right\}=\frac{1}{2}\theta^{\top}\begin{bmatrix}1+\lambda&0\\ 0&1\end{bmatrix}\theta.caligraphic_L start_POSTSUBSCRIPT italic_λ end_POSTSUBSCRIPT ( italic_θ ) = caligraphic_L start_POSTSUBSCRIPT italic_λ end_POSTSUBSCRIPT ( [ start_ARG start_ROW start_CELL italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_CELL end_ROW start_ROW start_CELL italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_CELL end_ROW end_ARG ] ) ≐ divide start_ARG 1 end_ARG start_ARG 2 end_ARG { ( 1 + italic_λ ) italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + italic_θ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT } = divide start_ARG 1 end_ARG start_ARG 2 end_ARG italic_θ start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT [ start_ARG start_ROW start_CELL 1 + italic_λ end_CELL start_CELL 0 end_CELL end_ROW start_ROW start_CELL 0 end_CELL start_CELL 1 end_CELL end_ROW end_ARG ] italic_θ . (A.1.22)

This problem is 111-strongly convex and has (1+λ)(1+\lambda)( 1 + italic_λ )-Lipschitz gradient. The convergence rate is then geometric with rate 11/(1+λ)1-1/(1+\lambda)1 - 1 / ( 1 + italic_λ ). For large λ\lambdaitalic_λ, this is still not very fast. In this section, we will introduce a class of optimization problems which can successfully optimize such badly-conditioned functions.

The key lies in the objective’s curvature, which is given by the Hessian. Suppose that (as a counterfactual) we had a second-order oracle which would allow us to compute (θ)\mathcal{L}(\theta)caligraphic_L ( italic_θ ), (θ)\nabla\mathcal{L}(\theta)∇ caligraphic_L ( italic_θ ), and 2(θ)\nabla^{2}\mathcal{L}(\theta)∇ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT caligraphic_L ( italic_θ ). Then, instead of picking a descent direction 𝒗\bm{v}bold_italic_v to optimize the first-order Taylor expansion around θ\thetaitalic_θ, we could optimize the second-order Taylor expansion instead. Intuitively this would allow us to incorporate curvature information into the update.

Let us carry out this computation. The second-order Taylor expansion of (θ+h𝒗)\mathcal{L}(\theta+h\bm{v})caligraphic_L ( italic_θ + italic_h bold_italic_v ) around h=0h=0italic_h = 0 is

(θ+h𝒗)=(θ)+h(θ),𝒗+12h2[2(θ)]𝒗,𝒗+o(h2).\mathcal{L}(\theta+h\bm{v})=\mathcal{L}(\theta)+h\langle\nabla\mathcal{L}(\theta),\bm{v}\rangle+\frac{1}{2}h^{2}\langle[\nabla^{2}\mathcal{L}(\theta)]\bm{v},\bm{v}\rangle+o(h^{2}).caligraphic_L ( italic_θ + italic_h bold_italic_v ) = caligraphic_L ( italic_θ ) + italic_h ⟨ ∇ caligraphic_L ( italic_θ ) , bold_italic_v ⟩ + divide start_ARG 1 end_ARG start_ARG 2 end_ARG italic_h start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ⟨ [ ∇ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT caligraphic_L ( italic_θ ) ] bold_italic_v , bold_italic_v ⟩ + italic_o ( italic_h start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) . (A.1.23)

Then we can compute the descent direction:

argmin𝒗n𝒗2=1[(θ)+h(θ),𝒗+12h2[2(θ)]𝒗,𝒗]\displaystyle\operatorname*{arg\ min}_{\begin{subarray}{c}\bm{v}\in\mathbb{R}^{n}\\ \|\bm{v}\|_{2}=1\end{subarray}}\left[\mathcal{L}(\theta)+h\langle\nabla\mathcal{L}(\theta),\bm{v}\rangle+\frac{1}{2}h^{2}\langle[\nabla^{2}\mathcal{L}(\theta)]\bm{v},\bm{v}\rangle\right]start_OPERATOR roman_arg roman_min end_OPERATOR start_POSTSUBSCRIPT start_ARG start_ROW start_CELL bold_italic_v ∈ blackboard_R start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT end_CELL end_ROW start_ROW start_CELL ∥ bold_italic_v ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT = 1 end_CELL end_ROW end_ARG end_POSTSUBSCRIPT [ caligraphic_L ( italic_θ ) + italic_h ⟨ ∇ caligraphic_L ( italic_θ ) , bold_italic_v ⟩ + divide start_ARG 1 end_ARG start_ARG 2 end_ARG italic_h start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ⟨ [ ∇ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT caligraphic_L ( italic_θ ) ] bold_italic_v , bold_italic_v ⟩ ] =argmin𝒗n𝒗2=1[(θ),𝒗+12h[2(θ)]𝒗,𝒗].\displaystyle=\operatorname*{arg\ min}_{\begin{subarray}{c}\bm{v}\in\mathbb{R}^{n}\\ \|\bm{v}\|_{2}=1\end{subarray}}\left[\langle\nabla\mathcal{L}(\theta),\bm{v}\rangle+\frac{1}{2}h\langle[\nabla^{2}\mathcal{L}(\theta)]\bm{v},\bm{v}\rangle\right].= start_OPERATOR roman_arg roman_min end_OPERATOR start_POSTSUBSCRIPT start_ARG start_ROW start_CELL bold_italic_v ∈ blackboard_R start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT end_CELL end_ROW start_ROW start_CELL ∥ bold_italic_v ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT = 1 end_CELL end_ROW end_ARG end_POSTSUBSCRIPT [ ⟨ ∇ caligraphic_L ( italic_θ ) , bold_italic_v ⟩ + divide start_ARG 1 end_ARG start_ARG 2 end_ARG italic_h ⟨ [ ∇ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT caligraphic_L ( italic_θ ) ] bold_italic_v , bold_italic_v ⟩ ] . (A.1.24)

This optimization problem is a little difficult to solve because of the constraint 𝒗2=1\|\bm{v}\|_{2}=1∥ bold_italic_v ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT = 1. But in practice we never normalize the descent direction 𝒗\bm{v}bold_italic_v and use the step size hhitalic_h to control the size of the update. So let us just solve the above problem over all vectors 𝒗n\bm{v}\in\mathbb{R}^{n}bold_italic_v ∈ blackboard_R start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT:444If 2(θ)\nabla^{2}\mathcal{L}(\theta)∇ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT caligraphic_L ( italic_θ ) is not invertible, then we can replace [2(θ)]1[\nabla^{2}\mathcal{L}(\theta)]^{-1}[ ∇ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT caligraphic_L ( italic_θ ) ] start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT with the Moore-Penrose pseudoinverse of 2(θ)\nabla^{2}\mathcal{L}(\theta)∇ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT caligraphic_L ( italic_θ ).

argmin𝒗n[(θ),𝒗+12h[2(θ)]𝒗,𝒗]=1h[2(θ)]1[(θ)].\operatorname*{arg\ min}_{\bm{v}\in\mathbb{R}^{n}}\left[\langle\nabla\mathcal{L}(\theta),\bm{v}\rangle+\frac{1}{2}h\langle[\nabla^{2}\mathcal{L}(\theta)]\bm{v},\bm{v}\rangle\right]=-\frac{1}{h}[\nabla^{2}\mathcal{L}(\theta)]^{-1}[\nabla\mathcal{L}(\theta)].start_OPERATOR roman_arg roman_min end_OPERATOR start_POSTSUBSCRIPT bold_italic_v ∈ blackboard_R start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT end_POSTSUBSCRIPT [ ⟨ ∇ caligraphic_L ( italic_θ ) , bold_italic_v ⟩ + divide start_ARG 1 end_ARG start_ARG 2 end_ARG italic_h ⟨ [ ∇ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT caligraphic_L ( italic_θ ) ] bold_italic_v , bold_italic_v ⟩ ] = - divide start_ARG 1 end_ARG start_ARG italic_h end_ARG [ ∇ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT caligraphic_L ( italic_θ ) ] start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT [ ∇ caligraphic_L ( italic_θ ) ] . (A.1.25)

We can thus use the steepest descent iteration

θk+1=θk[2(θk)]1[(θk)],\theta_{k+1}=\theta_{k}-[\nabla^{2}\mathcal{L}(\theta_{k})]^{-1}[\nabla\mathcal{L}(\theta_{k})],italic_θ start_POSTSUBSCRIPT italic_k + 1 end_POSTSUBSCRIPT = italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT - [ ∇ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT caligraphic_L ( italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) ] start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT [ ∇ caligraphic_L ( italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) ] , (A.1.26)

(this is the celebrated Newton’s method), or

θk+1=θkh[2(θk)]1[(θk)],\theta_{k+1}=\theta_{k}-h[\nabla^{2}\mathcal{L}(\theta_{k})]^{-1}[\nabla\mathcal{L}(\theta_{k})],italic_θ start_POSTSUBSCRIPT italic_k + 1 end_POSTSUBSCRIPT = italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT - italic_h [ ∇ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT caligraphic_L ( italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) ] start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT [ ∇ caligraphic_L ( italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) ] , (A.1.27)

(which is called underdamped Newton’s method). Since the second-order quadratic λ\mathcal{L}_{\lambda}caligraphic_L start_POSTSUBSCRIPT italic_λ end_POSTSUBSCRIPT is equal to its second-order Taylor expansion, if we run Newton’s method for one step, we will achieve the global minimum in one step no matter how large λ\lambdaitalic_λ is. Figure A.3 gives some intuition about poorly conditioned functions and the gradient steps versus Newton’s steps.

PGD

In practice, we do not have a second-order oracle which allows us to compute 2(θ)\nabla^{2}\mathcal{L}(\theta)∇ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT caligraphic_L ( italic_θ ). Instead, we can attempt to learn an approximation to it alongside the parameter update θk+1\theta_{k+1}italic_θ start_POSTSUBSCRIPT italic_k + 1 end_POSTSUBSCRIPT from θk\theta_{k}italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT.

How do we learn an approximation to it? We shall find some equations which the Hessian’s inverse satisfies and then try to update our approximation so that it satisfies the equations. Namely, taking the Taylor series of (θ+δθ)\nabla\mathcal{L}(\theta+\delta_{\theta})∇ caligraphic_L ( italic_θ + italic_δ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ) around point θ\thetaitalic_θ, we obtain

L(θ+δθ)(θ)δ𝒈=[2(θ)]δθ+o(δθ2).\underbrace{\nabla L(\theta+\delta_{\theta})-\nabla\mathcal{L}(\theta)}_{\doteq\delta_{\bm{g}}}=[\nabla^{2}\mathcal{L}(\theta)]\delta_{\theta}+o(\|\delta_{\theta}\|_{2}).under⏟ start_ARG ∇ italic_L ( italic_θ + italic_δ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ) - ∇ caligraphic_L ( italic_θ ) end_ARG start_POSTSUBSCRIPT ≐ italic_δ start_POSTSUBSCRIPT bold_italic_g end_POSTSUBSCRIPT end_POSTSUBSCRIPT = [ ∇ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT caligraphic_L ( italic_θ ) ] italic_δ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT + italic_o ( ∥ italic_δ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ) . (A.1.28)

In this case we have

δ𝒈[2(θ)]δθδθ[2(θ)]1δ𝒈\delta_{\bm{g}}\approx[\nabla^{2}\mathcal{L}(\theta)]\delta_{\theta}\implies\delta_{\theta}\approx[\nabla^{2}\mathcal{L}(\theta)]^{-1}\delta_{\bm{g}}italic_δ start_POSTSUBSCRIPT bold_italic_g end_POSTSUBSCRIPT ≈ [ ∇ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT caligraphic_L ( italic_θ ) ] italic_δ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ⟹ italic_δ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ≈ [ ∇ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT caligraphic_L ( italic_θ ) ] start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT italic_δ start_POSTSUBSCRIPT bold_italic_g end_POSTSUBSCRIPT (A.1.29)

We can now try to learn a symmetric positive semidefinite pre-conditioner Pn×nP\in\mathbb{R}^{n\times n}italic_P ∈ blackboard_R start_POSTSUPERSCRIPT italic_n × italic_n end_POSTSUPERSCRIPT such that

δθPδ𝒈,\delta_{\theta}\approx P\delta_{\bm{g}},italic_δ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ≈ italic_P italic_δ start_POSTSUBSCRIPT bold_italic_g end_POSTSUBSCRIPT , (A.1.30)

updating it at each iteration along with θk\theta_{k}italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT. Namely, we have the PSGD iteration

Pk\displaystyle P_{k}italic_P start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT =PreconditionerUpdate(Pk1;θk,(θk))\displaystyle=\mathrm{PreconditionerUpdate}(P_{k-1};\theta_{k},\nabla\mathcal{L}(\theta_{k}))= roman_PreconditionerUpdate ( italic_P start_POSTSUBSCRIPT italic_k - 1 end_POSTSUBSCRIPT ; italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT , ∇ caligraphic_L ( italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) ) (A.1.31)
θk+1\displaystyle\theta_{k+1}italic_θ start_POSTSUBSCRIPT italic_k + 1 end_POSTSUBSCRIPT =θkhPk(θk).\displaystyle=\theta_{k}-hP_{k}\nabla\mathcal{L}(\theta_{k}).= italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT - italic_h italic_P start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ∇ caligraphic_L ( italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) . (A.1.32)

This update has two problems: how can we even use PPitalic_P (since we already said we cannot store an n×nn\times nitalic_n × italic_n matrix) and how can we update PPitalic_P at each iteration? The answers are very related; we can never materialize PPitalic_P in computer memory, but we can represent it using a low-rank factorization (or comparable methods such as Kronecker factorization which is particularly suited to the form of deep neural networks). Then the preconditioner update step is designed to exploit the structure of the preconditioner representation.

We end this subsection with a caveat: in deep learning, for example, \mathcal{L}caligraphic_L is not a convex function and so Newton’s method (and approximations to it) do not make sense. In this case we look at the geometric intuition of Newton’s method on convex functions, say from Figure A.3: the inverse Hessian whitens the gradients. Thus instead of a Hessian-approximating preconditioner, we can adjust the above procedures to learn a more general whitening transformation for the gradient. This is the idea behind the original proposal of PSGD [Li17], which contains more information about how to store and update the preconditioner, and more modern optimizers like Muon [LSY+25].

A.1.3 Proximal Gradient Descent for Non-Smooth Problems

Even in very toy problems, however, such as LASSO or dictionary learning, the problem is not strongly convex but rather just convex, and the objective is no longer just smooth but rather the sum of a smooth function and a non-smooth regularizer (such as the 1\ell^{1}roman_ℓ start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT norm). Such problems are solved by proximal optimization algorithms, which generalize steepest descent to non-smooth objectives.

Formally, let us say that

(θ)𝒮(θ)+(θ)\mathcal{L}(\theta)\doteq\mathcal{S}(\theta)+\mathcal{R}(\theta)caligraphic_L ( italic_θ ) ≐ caligraphic_S ( italic_θ ) + caligraphic_R ( italic_θ ) (A.1.33)

where 𝒮\mathcal{S}caligraphic_S is smooth, say with β\betaitalic_β-Lipschitz gradient, and \mathcal{R}caligraphic_R is non-smooth (i.e., rough). The proximal gradient algorithm generalizes the steepest descent algorithm, by using the majorization-minimization framework (i.e., Lemma A.1) with a different global upper bound. Namely, we construct such an upper bound by asking: what if we take the Lipschitz gradient upper bound of SSitalic_S but leave RRitalic_R alone? Namely, we have

(θ1)=𝒮(θ1)+(θ1)uθ0,β(θ1)𝒮(θ0)+𝒮(θ0),θ1θ0+β2θ1θ022+(θ1).\mathcal{L}(\theta_{1})=\mathcal{S}(\theta_{1})+\mathcal{R}(\theta_{1})\leq u_{\theta_{0},\beta}(\theta_{1})\doteq\mathcal{S}(\theta_{0})+\langle\nabla\mathcal{S}(\theta_{0}),\theta_{1}-\theta_{0}\rangle+\frac{\beta}{2}\|\theta_{1}-\theta_{0}\|_{2}^{2}+\mathcal{R}(\theta_{1}).caligraphic_L ( italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) = caligraphic_S ( italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) + caligraphic_R ( italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) ≤ italic_u start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_β end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) ≐ caligraphic_S ( italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) + ⟨ ∇ caligraphic_S ( italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) , italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT - italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ⟩ + divide start_ARG italic_β end_ARG start_ARG 2 end_ARG ∥ italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT - italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + caligraphic_R ( italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) . (A.1.34)

Note that (proof as exercise)

argminθ1uθ0,β(θ1)=argminθ1[β2θ1(θ01β𝒮(θ0))22+(θ1)].\operatorname*{arg\ min}_{\theta_{1}}u_{\theta_{0},\beta}(\theta_{1})=\operatorname*{arg\ min}_{\theta_{1}}\left[\frac{\beta}{2}\left\|\theta_{1}-\left(\theta_{0}-\frac{1}{\beta}\nabla\mathcal{S}(\theta_{0})\right)\right\|_{2}^{2}+\mathcal{R}(\theta_{1})\right].start_OPERATOR roman_arg roman_min end_OPERATOR start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT italic_u start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_β end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) = start_OPERATOR roman_arg roman_min end_OPERATOR start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ divide start_ARG italic_β end_ARG start_ARG 2 end_ARG ∥ italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT - ( italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT - divide start_ARG 1 end_ARG start_ARG italic_β end_ARG ∇ caligraphic_S ( italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + caligraphic_R ( italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ) ] . (A.1.35)

Now if we try to minimize the upper bound uθ0,βu_{\theta_{0},\beta}italic_u start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_β end_POSTSUBSCRIPT, we are picking a θ1\theta_{1}italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT that:

  • is close to the gradient update θ01β𝒮(θ0)\theta_{0}-\frac{1}{\beta}\nabla\mathcal{S}(\theta_{0})italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT - divide start_ARG 1 end_ARG start_ARG italic_β end_ARG ∇ caligraphic_S ( italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT );

  • has a small value of the regularizer (θ1)\mathcal{R}(\theta_{1})caligraphic_R ( italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT )

and trades off these properties according to the smoothness parameter β\betaitalic_β. Accordingly, let us define the proximal operator

proxh,(θ)argminθ1[12hθ1θ22+(θ)].\operatorname{prox}_{h,\mathcal{R}}(\theta)\doteq\operatorname*{arg\ min}_{\theta_{1}}\left[\frac{1}{2h}\|\theta_{1}-\theta\|_{2}^{2}+\mathcal{R}(\theta)\right].roman_prox start_POSTSUBSCRIPT italic_h , caligraphic_R end_POSTSUBSCRIPT ( italic_θ ) ≐ start_OPERATOR roman_arg roman_min end_OPERATOR start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ divide start_ARG 1 end_ARG start_ARG 2 italic_h end_ARG ∥ italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT - italic_θ ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + caligraphic_R ( italic_θ ) ] . (A.1.36)

Then, we can define the proximal gradient descent iteration which, at each iteration, minimizes the upper bound uθk,h1u_{\theta_{k},h^{-1}}italic_u start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT , italic_h start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT end_POSTSUBSCRIPT, i.e.,

θk+1=proxh,(θkh𝒮(θk)).\theta_{k+1}=\operatorname{prox}_{h,\mathcal{R}}(\theta_{k}-h\nabla\mathcal{S}(\theta_{k})).italic_θ start_POSTSUBSCRIPT italic_k + 1 end_POSTSUBSCRIPT = roman_prox start_POSTSUBSCRIPT italic_h , caligraphic_R end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT - italic_h ∇ caligraphic_S ( italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) ) . (A.1.37)

Convergence proofs are possible when h1/βh\leq 1/\betaitalic_h ≤ 1 / italic_β, but we do not give any in this section.

One remaining question is: how can we compute the proximal operator? At first glance, it seems like we have traded one intractable minimization problem for another. Since we have not made any assumption on \mathcal{R}caligraphic_R so far, the framework works even when \mathcal{R}caligraphic_R is a very complex function (such as a neural network loss), which would require us to solve a neural network training problem in order to compute a single proximal operator. However, in practice, for simple regularizers \mathcal{R}caligraphic_R such as those we use in this manuscript, there exist proximal operators which are easy to compute or even in closed-form. We give a few below (the proofs are an exercise).

Example A.1.

Let ΓΘ\Gamma\subseteq\Thetaroman_Γ ⊆ roman_Θ be a set, and let χΓ\chi_{\Gamma}italic_χ start_POSTSUBSCRIPT roman_Γ end_POSTSUBSCRIPT be the characteristic function on Γ\Gammaroman_Γ, i.e.,

χΓ(θ){0,ifθΓ+,ifθΓ.\chi_{\Gamma}(\theta)\doteq\begin{cases}0,&\text{if}\ \theta\in\Gamma\\ +\infty,&\text{if}\ \theta\notin\Gamma.\end{cases}italic_χ start_POSTSUBSCRIPT roman_Γ end_POSTSUBSCRIPT ( italic_θ ) ≐ { start_ROW start_CELL 0 , end_CELL start_CELL if italic_θ ∈ roman_Γ end_CELL end_ROW start_ROW start_CELL + ∞ , end_CELL start_CELL if italic_θ ∉ roman_Γ . end_CELL end_ROW (A.1.38)

Then the proximal operator of χΓ\chi_{\Gamma}italic_χ start_POSTSUBSCRIPT roman_Γ end_POSTSUBSCRIPT is a projection, i.e.,

proxh,χΓ(θ)=argminθ1Γ12θ1θ22=argminθ1Γθ1θ2.\operatorname{prox}_{h,\chi_{\Gamma}}(\theta)=\operatorname*{arg\ min}_{\theta_{1}\in\Gamma}\frac{1}{2}\|\theta_{1}-\theta\|_{2}^{2}=\operatorname*{arg\ min}_{\theta_{1}\in\Gamma}\|\theta_{1}-\theta\|_{2}.roman_prox start_POSTSUBSCRIPT italic_h , italic_χ start_POSTSUBSCRIPT roman_Γ end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_θ ) = start_OPERATOR roman_arg roman_min end_OPERATOR start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ∈ roman_Γ end_POSTSUBSCRIPT divide start_ARG 1 end_ARG start_ARG 2 end_ARG ∥ italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT - italic_θ ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT = start_OPERATOR roman_arg roman_min end_OPERATOR start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ∈ roman_Γ end_POSTSUBSCRIPT ∥ italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT - italic_θ ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT . (A.1.39)

\blacksquare

Example A.2.

The 1\ell^{1}roman_ℓ start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT norm has a proximal operator which performs soft thresholding:

Sh(θ)proxh,λ1(θ)=argminθ1[12hθ1θ22+λθ1]S_{h}(\theta)\doteq\operatorname{prox}_{h,\lambda\|\cdot\|_{1}}(\theta)=\operatorname*{arg\ min}_{\theta_{1}}\left[\frac{1}{2h}\|\theta_{1}-\theta\|_{2}^{2}+\lambda\|\theta\|_{1}\right]italic_S start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ( italic_θ ) ≐ roman_prox start_POSTSUBSCRIPT italic_h , italic_λ ∥ ⋅ ∥ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_θ ) = start_OPERATOR roman_arg roman_min end_OPERATOR start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_POSTSUBSCRIPT [ divide start_ARG 1 end_ARG start_ARG 2 italic_h end_ARG ∥ italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT - italic_θ ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + italic_λ ∥ italic_θ ∥ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ] (A.1.40)

then Sh(θ)S_{h}(\theta)italic_S start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ( italic_θ ) is defined by

Sh(θ)i={θihλ,ifθihλ0,ifθi[hλ,hλ]θi+hλ,ifθihλ={max{|θi|hλ,0}sign(θi),if|θi|hλ0,if|θi|<hλ.S_{h}(\theta)_{i}=\begin{cases}\theta_{i}-h\lambda,&\text{if}\ \theta_{i}\geq h\lambda\\ 0,&\text{if}\ \theta_{i}\in[-h\lambda,h\lambda]\\ \theta_{i}+h\lambda,&\text{if}\ \theta_{i}\leq-h\lambda\end{cases}=\begin{cases}\max\{\lvert\theta_{i}\rvert-h\lambda,0\}\operatorname{sign}(\theta_{i}),&\text{if}\ \lvert\theta_{i}\rvert\geq h\lambda\\ 0,&\text{if}\ \lvert\theta_{i}\rvert<h\lambda.\end{cases}italic_S start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ( italic_θ ) start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = { start_ROW start_CELL italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT - italic_h italic_λ , end_CELL start_CELL if italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ≥ italic_h italic_λ end_CELL end_ROW start_ROW start_CELL 0 , end_CELL start_CELL if italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ∈ [ - italic_h italic_λ , italic_h italic_λ ] end_CELL end_ROW start_ROW start_CELL italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT + italic_h italic_λ , end_CELL start_CELL if italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ≤ - italic_h italic_λ end_CELL end_ROW = { start_ROW start_CELL roman_max { | italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | - italic_h italic_λ , 0 } roman_sign ( italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) , end_CELL start_CELL if | italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | ≥ italic_h italic_λ end_CELL end_ROW start_ROW start_CELL 0 , end_CELL start_CELL if | italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT | < italic_h italic_λ . end_CELL end_ROW (A.1.41)

The proximal gradient operation with the smooth part of the objective being least-squares and the non-smooth part being the 1\ell^{1}roman_ℓ start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT norm (hence using this soft thresholding proximal operator) is called the Iterative Shrinkage-Thresholding Algorithm (ISTA). \blacksquare

Example A.3.

In Chapter 4 we use a proximal operator corresponding to the 1\ell^{1}roman_ℓ start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT norm plus the characteristic function for the positive orthant +n{𝒙n:xi0i}\mathbb{R}_{+}^{n}\doteq\{\bm{x}\in\mathbb{R}^{n}\colon x_{i}\geq 0\ \forall i\}blackboard_R start_POSTSUBSCRIPT + end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT ≐ { bold_italic_x ∈ blackboard_R start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT : italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ≥ 0 ∀ italic_i }, namely

Th(θ)proxh,λ1+χ+n(θ)=argminθ1+n[12hθ1θ22+λθ1],T_{h}(\theta)\doteq\operatorname{prox}_{h,\lambda\|\cdot\|_{1}+\chi_{\mathbb{R}_{+}^{n}}}(\theta)=\operatorname*{arg\ min}_{\theta_{1}\in\mathbb{R}_{+}^{n}}\left[\frac{1}{2h}\|\theta_{1}-\theta\|_{2}^{2}+\lambda\|\theta\|_{1}\right],italic_T start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ( italic_θ ) ≐ roman_prox start_POSTSUBSCRIPT italic_h , italic_λ ∥ ⋅ ∥ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT + italic_χ start_POSTSUBSCRIPT blackboard_R start_POSTSUBSCRIPT + end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT end_POSTSUBSCRIPT end_POSTSUBSCRIPT ( italic_θ ) = start_OPERATOR roman_arg roman_min end_OPERATOR start_POSTSUBSCRIPT italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ∈ blackboard_R start_POSTSUBSCRIPT + end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT end_POSTSUBSCRIPT [ divide start_ARG 1 end_ARG start_ARG 2 italic_h end_ARG ∥ italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT - italic_θ ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT + italic_λ ∥ italic_θ ∥ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT ] , (A.1.42)

then ThT_{h}italic_T start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT is defined as

Th(θ)imax{θihλ,0}.T_{h}(\theta)_{i}\doteq\max\{\theta_{i}-h\lambda,0\}.italic_T start_POSTSUBSCRIPT italic_h end_POSTSUBSCRIPT ( italic_θ ) start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ≐ roman_max { italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT - italic_h italic_λ , 0 } . (A.1.43)

This proximal operator yields the non-negative ISTA that is invoked in Chapter 4 and beyond. \blacksquare

A.1.4 Stochastic Gradient Descent for Large-Scale Problems

In deep learning, the objective function \mathcal{L}caligraphic_L usually cannot be computed exactly, and instead at each optimization step it is estimated using finite samples (say, using a mini-batch). A common way to model this situation is to define a stochastic loss function ω(θ)\mathcal{L}_{\omega}(\theta)caligraphic_L start_POSTSUBSCRIPT italic_ω end_POSTSUBSCRIPT ( italic_θ ) where ω\omegaitalic_ω is some “source of randomness”. For example, ω\omegaitalic_ω could contain the indices of the samples in a batch over which to compute the loss. Then, we would like to minimize (θ)𝔼ω[ω(θ)]\mathcal{L}(\theta)\doteq\operatorname{\mathbb{E}}_{\omega}[\mathcal{L}_{\omega}(\theta)]caligraphic_L ( italic_θ ) ≐ blackboard_E start_POSTSUBSCRIPT italic_ω end_POSTSUBSCRIPT [ caligraphic_L start_POSTSUBSCRIPT italic_ω end_POSTSUBSCRIPT ( italic_θ ) ] over θ\thetaitalic_θ, given access to a stochastic first-order oracle: given θ\thetaitalic_θ, we can sample ω\omegaitalic_ω and compute ω(θ)\mathcal{L}_{\omega}(\theta)caligraphic_L start_POSTSUBSCRIPT italic_ω end_POSTSUBSCRIPT ( italic_θ ) and θω(θ)\nabla_{\theta}\mathcal{L}_{\omega}(\theta)∇ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT caligraphic_L start_POSTSUBSCRIPT italic_ω end_POSTSUBSCRIPT ( italic_θ ). This minimization problem is called a stochastic optimization problem.

The basic first-order stochastic algorithm is stochastic gradient descent: at each iteration kkitalic_k we sample ωk\omega_{k}italic_ω start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT, define kωk\mathcal{L}_{k}\doteq\mathcal{L}_{\omega_{k}}caligraphic_L start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ≐ caligraphic_L start_POSTSUBSCRIPT italic_ω start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT end_POSTSUBSCRIPT, and perform a gradient step on k\mathcal{L}_{k}caligraphic_L start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT, i.e.,

θk+1=θkhk(θk).\theta_{k+1}=\theta_{k}-h\nabla\mathcal{L}_{k}(\theta_{k}).italic_θ start_POSTSUBSCRIPT italic_k + 1 end_POSTSUBSCRIPT = italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT - italic_h ∇ caligraphic_L start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) . (A.1.44)

However, even for very simple problems we cannot expect the same type of convergence as we obtained in gradient descent. For example, suppose that there are mmitalic_m possible values for ω{1,,m}\omega\in\{1,\dots,m\}italic_ω ∈ { 1 , … , italic_m } which it takes with equal probability, and there are mmitalic_m possible targets ξ1,,ξm\xi_{1},\dots,\xi_{m}italic_ξ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , … , italic_ξ start_POSTSUBSCRIPT italic_m end_POSTSUBSCRIPT, such that the loss function ω\mathcal{L}_{\omega}caligraphic_L start_POSTSUBSCRIPT italic_ω end_POSTSUBSCRIPT is

ω(θ)12θξω22.\mathcal{L}_{\omega}(\theta)\doteq\frac{1}{2}\|\theta-\xi_{\omega}\|_{2}^{2}.caligraphic_L start_POSTSUBSCRIPT italic_ω end_POSTSUBSCRIPT ( italic_θ ) ≐ divide start_ARG 1 end_ARG start_ARG 2 end_ARG ∥ italic_θ - italic_ξ start_POSTSUBSCRIPT italic_ω end_POSTSUBSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT . (A.1.45)

Then argminθ𝔼[ω(θ)]=1mi=1ξi\operatorname*{arg\ min}_{\theta}\operatorname{\mathbb{E}}[\mathcal{L}_{\omega}(\theta)]=\frac{1}{m}\sum_{i=1}\xi_{i}start_OPERATOR roman_arg roman_min end_OPERATOR start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT blackboard_E [ caligraphic_L start_POSTSUBSCRIPT italic_ω end_POSTSUBSCRIPT ( italic_θ ) ] = divide start_ARG 1 end_ARG start_ARG italic_m end_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT italic_ξ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT, but stochastic gradient descent can “pinball” around the global optimum value, and not converge, as visualized in Figure A.4.

Figure A.4 : Stochastic gradient descent may not converge, even for very benign objectives, but Nesterov gradient converges. For even simple quadratic objectives, stochastic gradient descent iterates may pinball around the global optimum, whereas Nesterov gradients align to point to the optimal value.
Figure A.4: Stochastic gradient descent may not converge, even for very benign objectives, but Nesterov gradient converges. For even simple quadratic objectives, stochastic gradient descent iterates may pinball around the global optimum, whereas Nesterov gradients align to point to the optimal value.

In order to fix this, we can either average the parameters θk\theta_{k}italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT or average the gradients k(θk)\nabla\mathcal{L}_{k}(\theta_{k})∇ caligraphic_L start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) over time. If we average the parameters θk\theta_{k}italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT, then (using Figure A.4 as a mental model) the issue of pinballing is straightforwardly not possible, since the average iterate will grow closer to the center. As such, most theoretical convergence proofs consider the convergence of the average iterate 1ki=0kθi\frac{1}{k}\sum_{i=0}^{k}\theta_{i}divide start_ARG 1 end_ARG start_ARG italic_k end_ARG ∑ start_POSTSUBSCRIPT italic_i = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT italic_θ start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT to the global minimum. If we average the gradients, we will eventually learn an average gradient 1ki=0kk(θk)\frac{1}{k}\sum_{i=0}^{k}\nabla\mathcal{L}_{k}(\theta_{k})divide start_ARG 1 end_ARG start_ARG italic_k end_ARG ∑ start_POSTSUBSCRIPT italic_i = 0 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ∇ caligraphic_L start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) which does not change much at each iteration and therefore does not pinball.

In practice, instead of using an arithmetic average, we take an exponentially moving average (EMA) of the parameters (this is called Polyak momentum) or of the gradients (this is called Nesterov momentum). Nesterov momentum is more popular and we will study it here.

A stochastic gradient descent iteration with Nesterov momentum is as follows:

𝒈k\displaystyle\bm{g}_{k}bold_italic_g start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT =β𝒈k1+(1β)k(θk)\displaystyle=\beta\bm{g}_{k-1}+(1-\beta)\nabla\mathcal{L}_{k}(\theta_{k})= italic_β bold_italic_g start_POSTSUBSCRIPT italic_k - 1 end_POSTSUBSCRIPT + ( 1 - italic_β ) ∇ caligraphic_L start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) (A.1.46)
θk+1\displaystyle\theta_{k+1}italic_θ start_POSTSUBSCRIPT italic_k + 1 end_POSTSUBSCRIPT =θkh𝒈k.\displaystyle=\theta_{k}-h\bm{g}_{k}.= italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT - italic_h bold_italic_g start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT . (A.1.47)

We do not go through a convergence proof (see Chapter 7 of [GG23] for an example). However, Nesterov momentum handles our toy case in Figure A.4 easily (see the right-hand figure): it stops pinballing and eventually converges to the global optimum.

We end with a caveat: one can show that Polyak momentum and Nesterov momentum are equivalent, for certain choices of parameter settings. Then it is also possible to show that a decaying learning rate schedule (i.e., the learning rate hhitalic_h depends on the iteration kkitalic_k, and its limit is hk0h_{k}\to 0italic_h start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT → 0 as kk\to\inftyitalic_k → ∞) with plain SGD (or PSGD) can mimic the effect of momentum. Namely, [DCM+23] shows that if the SGD algorithm lasts KKitalic_K iterations, the gradient norms are bounded (θk)2G\|\nabla\mathcal{L}(\theta_{k})\|_{2}\leq G∥ ∇ caligraphic_L ( italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ≤ italic_G, and we define Dθ0θ2D\doteq\|\theta_{0}-\theta^{\star}\|_{2}italic_D ≐ ∥ italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT - italic_θ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT, then plain SGD iterates θk\theta_{k}italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT satisfy the rate 𝔼[(θk)(θ)]DG/K\operatorname{\mathbb{E}}[\mathcal{L}(\theta_{k})-\mathcal{L}(\theta^{\star})]\leq DG/\sqrt{K}blackboard_E [ caligraphic_L ( italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) - caligraphic_L ( italic_θ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) ] ≤ italic_D italic_G / square-root start_ARG italic_K end_ARG — but only so long as the learning rate hk=(D/[GK])(1k/K)h_{k}=(D/[G\sqrt{K}])(1-k/K)italic_h start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT = ( italic_D / [ italic_G square-root start_ARG italic_K end_ARG ] ) ( 1 - italic_k / italic_K ) decays linearly with time. This matches learning rate schedules used in practice. Indeed, surprisingly, such a theory of convex optimization can predict many empirical phenomena in deep networks [SHT+25], despite deep learning optimization being highly non-convex and non-smooth in the worst case. It is so far unclear why this is the case.

A.1.5 Putting Everything Together: Adam

The gradient descent scheme proposes an iteration of the form

θk+1=θk+h𝒗k,\theta_{k+1}=\theta_{k}+h\bm{v}_{k},italic_θ start_POSTSUBSCRIPT italic_k + 1 end_POSTSUBSCRIPT = italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT + italic_h bold_italic_v start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT , (A.1.48)

where (recall) 𝒗k\bm{v}_{k}bold_italic_v start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT is chosen to be (proportional to) the steepest descent vector in the Euclidean norm:

𝒗k=(θk)(θk)2argmin𝒗n𝒗2=1(θk),𝒗.\bm{v}_{k}=-\frac{\nabla\mathcal{L}(\theta_{k})}{\|\nabla\mathcal{L}(\theta_{k})\|_{2}}\in\operatorname*{arg\ min}_{\begin{subarray}{c}\bm{v}\in\mathbb{R}^{n}\\ \|\bm{v}\|_{2}=1\end{subarray}}\langle\nabla\mathcal{L}(\theta_{k}),\bm{v}\rangle.bold_italic_v start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT = - divide start_ARG ∇ caligraphic_L ( italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) end_ARG start_ARG ∥ ∇ caligraphic_L ( italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT end_ARG ∈ start_OPERATOR roman_arg roman_min end_OPERATOR start_POSTSUBSCRIPT start_ARG start_ROW start_CELL bold_italic_v ∈ blackboard_R start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT end_CELL end_ROW start_ROW start_CELL ∥ bold_italic_v ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT = 1 end_CELL end_ROW end_ARG end_POSTSUBSCRIPT ⟨ ∇ caligraphic_L ( italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) , bold_italic_v ⟩ . (A.1.49)

However, in the context of deep learning optimization, there is absolutely nothing which implies that we have to use the Euclidean norm; indeed the “natural geometry” of the space of parameters is not well-respected by the Euclidean norm, since small changes in the parameter space can lead to very large differences in the output space, for a particular fixed input to the network. If we were instead to use a generic norm \|\cdot\|∥ ⋅ ∥ on the parameter space n\mathbb{R}^{n}blackboard_R start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT, we would get some other quantity corresponding to the so-called dual norm:

𝒗kargmin𝒗n𝒗=1(θk),𝒗.\bm{v}_{k}\in\operatorname*{arg\ min}_{\begin{subarray}{c}\bm{v}\in\mathbb{R}^{n}\\ \|\bm{v}\|=1\end{subarray}}\langle\nabla\mathcal{L}(\theta_{k}),\bm{v}\rangle.bold_italic_v start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ∈ start_OPERATOR roman_arg roman_min end_OPERATOR start_POSTSUBSCRIPT start_ARG start_ROW start_CELL bold_italic_v ∈ blackboard_R start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT end_CELL end_ROW start_ROW start_CELL ∥ bold_italic_v ∥ = 1 end_CELL end_ROW end_ARG end_POSTSUBSCRIPT ⟨ ∇ caligraphic_L ( italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) , bold_italic_v ⟩ . (A.1.50)

For instance, if we were to use the \ell^{\infty}roman_ℓ start_POSTSUPERSCRIPT ∞ end_POSTSUPERSCRIPT-norm, it is possible to show that

𝒗k=sign((θk))argmin𝒗n𝒗=1(θk),𝒗,\bm{v}_{k}=-\operatorname{sign}(\nabla\mathcal{L}(\theta_{k}))\in\operatorname*{arg\ min}_{\begin{subarray}{c}\bm{v}\in\mathbb{R}^{n}\\ \|\bm{v}\|_{\infty}=1\end{subarray}}\langle\nabla\mathcal{L}(\theta_{k}),\bm{v}\rangle,bold_italic_v start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT = - roman_sign ( ∇ caligraphic_L ( italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) ) ∈ start_OPERATOR roman_arg roman_min end_OPERATOR start_POSTSUBSCRIPT start_ARG start_ROW start_CELL bold_italic_v ∈ blackboard_R start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT end_CELL end_ROW start_ROW start_CELL ∥ bold_italic_v ∥ start_POSTSUBSCRIPT ∞ end_POSTSUBSCRIPT = 1 end_CELL end_ROW end_ARG end_POSTSUBSCRIPT ⟨ ∇ caligraphic_L ( italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) , bold_italic_v ⟩ , (A.1.51)

where sign(𝒙)i=sign(xi){1,0,1}\operatorname{sign}(\bm{x})_{i}=\operatorname{sign}(x_{i})\in\{-1,0,1\}roman_sign ( bold_italic_x ) start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT = roman_sign ( italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT ) ∈ { - 1 , 0 , 1 }. Thus if we were so-inclined, we could use the so-called sign-gradient descent:

θk+1=θkhsign((θk)).\theta_{k+1}=\theta_{k}-h\operatorname{sign}(\nabla\mathcal{L}(\theta_{k})).italic_θ start_POSTSUBSCRIPT italic_k + 1 end_POSTSUBSCRIPT = italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT - italic_h roman_sign ( ∇ caligraphic_L ( italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) ) . (A.1.52)

From sign-gradient descent, we can derive the famous Adam optimization algorithm. Note that for a scalar xx\in\mathbb{R}italic_x ∈ blackboard_R we can write

sign(x)=x|x|=xx2.\operatorname{sign}(x)=\frac{x}{\lvert x\rvert}=\frac{x}{\sqrt{x^{2}}}.roman_sign ( italic_x ) = divide start_ARG italic_x end_ARG start_ARG | italic_x | end_ARG = divide start_ARG italic_x end_ARG start_ARG square-root start_ARG italic_x start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_ARG end_ARG . (A.1.53)

Similarly, for a vector 𝒙n\bm{x}\in\mathbb{R}^{n}bold_italic_x ∈ blackboard_R start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT we write (where \mathbin{\mathchoice{\raisebox{1.3pt}{$\displaystyle\mathchoice{\scalebox{0.8}{$\displaystyle\odot$}}{\scalebox{0.8}{$\textstyle\odot$}}{\scalebox{0.8}{$\scriptstyle\odot$}}{\scalebox{0.8}{$\scriptscriptstyle\odot$}}$}}{\raisebox{1.3pt}{$\mathchoice{\scalebox{0.8}{$\displaystyle\odot$}}{\scalebox{0.8}{$\textstyle\odot$}}{\scalebox{0.8}{$\scriptstyle\odot$}}{\scalebox{0.8}{$\scriptscriptstyle\odot$}}$}}{\raisebox{0.75pt}{$\scriptstyle\mathchoice{\scalebox{0.8}{$\displaystyle\odot$}}{\scalebox{0.8}{$\textstyle\odot$}}{\scalebox{0.8}{$\scriptstyle\odot$}}{\scalebox{0.8}{$\scriptscriptstyle\odot$}}$}}{\raisebox{0.6pt}{$\scriptscriptstyle\mathchoice{\scalebox{0.8}{$\displaystyle\odot$}}{\scalebox{0.8}{$\textstyle\odot$}}{\scalebox{0.8}{$\scriptstyle\odot$}}{\scalebox{0.8}{$\scriptscriptstyle\odot$}}$}}} and \mathbin{\mathchoice{\raisebox{1.3pt}{$\displaystyle\mathchoice{\scalebox{0.8}{$\displaystyle\oslash$}}{\scalebox{0.8}{$\textstyle\oslash$}}{\scalebox{0.8}{$\scriptstyle\oslash$}}{\scalebox{0.8}{$\scriptscriptstyle\oslash$}}$}}{\raisebox{1.3pt}{$\mathchoice{\scalebox{0.8}{$\displaystyle\oslash$}}{\scalebox{0.8}{$\textstyle\oslash$}}{\scalebox{0.8}{$\scriptstyle\oslash$}}{\scalebox{0.8}{$\scriptscriptstyle\oslash$}}$}}{\raisebox{0.75pt}{$\scriptstyle\mathchoice{\scalebox{0.8}{$\displaystyle\oslash$}}{\scalebox{0.8}{$\textstyle\oslash$}}{\scalebox{0.8}{$\scriptstyle\oslash$}}{\scalebox{0.8}{$\scriptscriptstyle\oslash$}}$}}{\raisebox{0.6pt}{$\scriptscriptstyle\mathchoice{\scalebox{0.8}{$\displaystyle\oslash$}}{\scalebox{0.8}{$\textstyle\oslash$}}{\scalebox{0.8}{$\scriptstyle\oslash$}}{\scalebox{0.8}{$\scriptscriptstyle\oslash$}}$}}} are element-wise multiplication and division)

sign(𝒙)=𝒙[𝒙2](1/2).\operatorname{sign}(\bm{x})=\bm{x}\mathbin{\mathchoice{\raisebox{1.3pt}{$\displaystyle\mathchoice{\scalebox{0.8}{$\displaystyle\oslash$}}{\scalebox{0.8}{$\textstyle\oslash$}}{\scalebox{0.8}{$\scriptstyle\oslash$}}{\scalebox{0.8}{$\scriptscriptstyle\oslash$}}$}}{\raisebox{1.3pt}{$\mathchoice{\scalebox{0.8}{$\displaystyle\oslash$}}{\scalebox{0.8}{$\textstyle\oslash$}}{\scalebox{0.8}{$\scriptstyle\oslash$}}{\scalebox{0.8}{$\scriptscriptstyle\oslash$}}$}}{\raisebox{0.75pt}{$\scriptstyle\mathchoice{\scalebox{0.8}{$\displaystyle\oslash$}}{\scalebox{0.8}{$\textstyle\oslash$}}{\scalebox{0.8}{$\scriptstyle\oslash$}}{\scalebox{0.8}{$\scriptscriptstyle\oslash$}}$}}{\raisebox{0.6pt}{$\scriptscriptstyle\mathchoice{\scalebox{0.8}{$\displaystyle\oslash$}}{\scalebox{0.8}{$\textstyle\oslash$}}{\scalebox{0.8}{$\scriptstyle\oslash$}}{\scalebox{0.8}{$\scriptscriptstyle\oslash$}}$}}}[\bm{x}^{\mathbin{\mathchoice{\raisebox{1.3pt}{$\displaystyle\mathchoice{\scalebox{0.8}{$\displaystyle\odot$}}{\scalebox{0.8}{$\textstyle\odot$}}{\scalebox{0.8}{$\scriptstyle\odot$}}{\scalebox{0.8}{$\scriptscriptstyle\odot$}}$}}{\raisebox{1.3pt}{$\mathchoice{\scalebox{0.8}{$\displaystyle\odot$}}{\scalebox{0.8}{$\textstyle\odot$}}{\scalebox{0.8}{$\scriptstyle\odot$}}{\scalebox{0.8}{$\scriptscriptstyle\odot$}}$}}{\raisebox{0.75pt}{$\scriptstyle\mathchoice{\scalebox{0.8}{$\displaystyle\odot$}}{\scalebox{0.8}{$\textstyle\odot$}}{\scalebox{0.8}{$\scriptstyle\odot$}}{\scalebox{0.8}{$\scriptscriptstyle\odot$}}$}}{\raisebox{0.6pt}{$\scriptscriptstyle\mathchoice{\scalebox{0.8}{$\displaystyle\odot$}}{\scalebox{0.8}{$\textstyle\odot$}}{\scalebox{0.8}{$\scriptstyle\odot$}}{\scalebox{0.8}{$\scriptscriptstyle\odot$}}$}}}2}]^{\mathbin{\mathchoice{\raisebox{1.3pt}{$\displaystyle\mathchoice{\scalebox{0.8}{$\displaystyle\odot$}}{\scalebox{0.8}{$\textstyle\odot$}}{\scalebox{0.8}{$\scriptstyle\odot$}}{\scalebox{0.8}{$\scriptscriptstyle\odot$}}$}}{\raisebox{1.3pt}{$\mathchoice{\scalebox{0.8}{$\displaystyle\odot$}}{\scalebox{0.8}{$\textstyle\odot$}}{\scalebox{0.8}{$\scriptstyle\odot$}}{\scalebox{0.8}{$\scriptscriptstyle\odot$}}$}}{\raisebox{0.75pt}{$\scriptstyle\mathchoice{\scalebox{0.8}{$\displaystyle\odot$}}{\scalebox{0.8}{$\textstyle\odot$}}{\scalebox{0.8}{$\scriptstyle\odot$}}{\scalebox{0.8}{$\scriptscriptstyle\odot$}}$}}{\raisebox{0.6pt}{$\scriptscriptstyle\mathchoice{\scalebox{0.8}{$\displaystyle\odot$}}{\scalebox{0.8}{$\textstyle\odot$}}{\scalebox{0.8}{$\scriptstyle\odot$}}{\scalebox{0.8}{$\scriptscriptstyle\odot$}}$}}}(1/2)}.roman_sign ( bold_italic_x ) = bold_italic_x ⊘ [ bold_italic_x start_POSTSUPERSCRIPT ⊙ 2 end_POSTSUPERSCRIPT ] start_POSTSUPERSCRIPT ⊙ ( 1 / 2 ) end_POSTSUPERSCRIPT . (A.1.54)

Using this representation we can write (A.1.52) as

θk+1=θkh([(θk)][(θk)2]12).\theta_{k+1}=\theta_{k}-h([\nabla\mathcal{L}(\theta_{k})]\mathbin{\mathchoice{\raisebox{1.3pt}{$\displaystyle\mathchoice{\scalebox{0.8}{$\displaystyle\oslash$}}{\scalebox{0.8}{$\textstyle\oslash$}}{\scalebox{0.8}{$\scriptstyle\oslash$}}{\scalebox{0.8}{$\scriptscriptstyle\oslash$}}$}}{\raisebox{1.3pt}{$\mathchoice{\scalebox{0.8}{$\displaystyle\oslash$}}{\scalebox{0.8}{$\textstyle\oslash$}}{\scalebox{0.8}{$\scriptstyle\oslash$}}{\scalebox{0.8}{$\scriptscriptstyle\oslash$}}$}}{\raisebox{0.75pt}{$\scriptstyle\mathchoice{\scalebox{0.8}{$\displaystyle\oslash$}}{\scalebox{0.8}{$\textstyle\oslash$}}{\scalebox{0.8}{$\scriptstyle\oslash$}}{\scalebox{0.8}{$\scriptscriptstyle\oslash$}}$}}{\raisebox{0.6pt}{$\scriptscriptstyle\mathchoice{\scalebox{0.8}{$\displaystyle\oslash$}}{\scalebox{0.8}{$\textstyle\oslash$}}{\scalebox{0.8}{$\scriptstyle\oslash$}}{\scalebox{0.8}{$\scriptscriptstyle\oslash$}}$}}}[\nabla\mathcal{L}(\theta_{k})^{\mathbin{\mathchoice{\raisebox{1.3pt}{$\displaystyle\mathchoice{\scalebox{0.8}{$\displaystyle\odot$}}{\scalebox{0.8}{$\textstyle\odot$}}{\scalebox{0.8}{$\scriptstyle\odot$}}{\scalebox{0.8}{$\scriptscriptstyle\odot$}}$}}{\raisebox{1.3pt}{$\mathchoice{\scalebox{0.8}{$\displaystyle\odot$}}{\scalebox{0.8}{$\textstyle\odot$}}{\scalebox{0.8}{$\scriptstyle\odot$}}{\scalebox{0.8}{$\scriptscriptstyle\odot$}}$}}{\raisebox{0.75pt}{$\scriptstyle\mathchoice{\scalebox{0.8}{$\displaystyle\odot$}}{\scalebox{0.8}{$\textstyle\odot$}}{\scalebox{0.8}{$\scriptstyle\odot$}}{\scalebox{0.8}{$\scriptscriptstyle\odot$}}$}}{\raisebox{0.6pt}{$\scriptscriptstyle\mathchoice{\scalebox{0.8}{$\displaystyle\odot$}}{\scalebox{0.8}{$\textstyle\odot$}}{\scalebox{0.8}{$\scriptstyle\odot$}}{\scalebox{0.8}{$\scriptscriptstyle\odot$}}$}}}2}]^{\mathbin{\mathchoice{\raisebox{1.3pt}{$\displaystyle\mathchoice{\scalebox{0.8}{$\displaystyle\odot$}}{\scalebox{0.8}{$\textstyle\odot$}}{\scalebox{0.8}{$\scriptstyle\odot$}}{\scalebox{0.8}{$\scriptscriptstyle\odot$}}$}}{\raisebox{1.3pt}{$\mathchoice{\scalebox{0.8}{$\displaystyle\odot$}}{\scalebox{0.8}{$\textstyle\odot$}}{\scalebox{0.8}{$\scriptstyle\odot$}}{\scalebox{0.8}{$\scriptscriptstyle\odot$}}$}}{\raisebox{0.75pt}{$\scriptstyle\mathchoice{\scalebox{0.8}{$\displaystyle\odot$}}{\scalebox{0.8}{$\textstyle\odot$}}{\scalebox{0.8}{$\scriptstyle\odot$}}{\scalebox{0.8}{$\scriptscriptstyle\odot$}}$}}{\raisebox{0.6pt}{$\scriptscriptstyle\mathchoice{\scalebox{0.8}{$\displaystyle\odot$}}{\scalebox{0.8}{$\textstyle\odot$}}{\scalebox{0.8}{$\scriptstyle\odot$}}{\scalebox{0.8}{$\scriptscriptstyle\odot$}}$}}}\frac{1}{2}}).italic_θ start_POSTSUBSCRIPT italic_k + 1 end_POSTSUBSCRIPT = italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT - italic_h ( [ ∇ caligraphic_L ( italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) ] ⊘ [ ∇ caligraphic_L ( italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) start_POSTSUPERSCRIPT ⊙ 2 end_POSTSUPERSCRIPT ] start_POSTSUPERSCRIPT ⊙ divide start_ARG 1 end_ARG start_ARG 2 end_ARG end_POSTSUPERSCRIPT ) . (A.1.55)

Now consider the stochastic regime where we are optimizing a different loss LkL_{k}italic_L start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT at each iteration. In SGD, we “tracked” (i.e., took an average of) the gradients using Nesterov momentum. Here, we can track both the gradient and the squared gradient using momentum, i.e.,

𝒈k\displaystyle\bm{g}_{k}bold_italic_g start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT =β1𝒈k1+(1β1)k(θk)\displaystyle=\beta^{1}\bm{g}_{k-1}+(1-\beta^{1})\nabla\mathcal{L}_{k}(\theta_{k})= italic_β start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT bold_italic_g start_POSTSUBSCRIPT italic_k - 1 end_POSTSUBSCRIPT + ( 1 - italic_β start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT ) ∇ caligraphic_L start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) (A.1.56)
𝒔k\displaystyle\bm{s}_{k}bold_italic_s start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT =β2𝒔k1+(1β2)[k(θk)]2\displaystyle=\beta^{2}\bm{s}_{k-1}+(1-\beta^{2})[\nabla\mathcal{L}_{k}(\theta_{k})]^{\mathbin{\mathchoice{\raisebox{1.3pt}{$\displaystyle\mathchoice{\scalebox{0.8}{$\displaystyle\odot$}}{\scalebox{0.8}{$\textstyle\odot$}}{\scalebox{0.8}{$\scriptstyle\odot$}}{\scalebox{0.8}{$\scriptscriptstyle\odot$}}$}}{\raisebox{1.3pt}{$\mathchoice{\scalebox{0.8}{$\displaystyle\odot$}}{\scalebox{0.8}{$\textstyle\odot$}}{\scalebox{0.8}{$\scriptstyle\odot$}}{\scalebox{0.8}{$\scriptscriptstyle\odot$}}$}}{\raisebox{0.75pt}{$\scriptstyle\mathchoice{\scalebox{0.8}{$\displaystyle\odot$}}{\scalebox{0.8}{$\textstyle\odot$}}{\scalebox{0.8}{$\scriptstyle\odot$}}{\scalebox{0.8}{$\scriptscriptstyle\odot$}}$}}{\raisebox{0.6pt}{$\scriptscriptstyle\mathchoice{\scalebox{0.8}{$\displaystyle\odot$}}{\scalebox{0.8}{$\textstyle\odot$}}{\scalebox{0.8}{$\scriptstyle\odot$}}{\scalebox{0.8}{$\scriptscriptstyle\odot$}}$}}}2}= italic_β start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT bold_italic_s start_POSTSUBSCRIPT italic_k - 1 end_POSTSUBSCRIPT + ( 1 - italic_β start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ) [ ∇ caligraphic_L start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ( italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) ] start_POSTSUPERSCRIPT ⊙ 2 end_POSTSUPERSCRIPT (A.1.57)
θk+1\displaystyle\theta_{k+1}italic_θ start_POSTSUBSCRIPT italic_k + 1 end_POSTSUBSCRIPT =θkh𝒈k𝒔k12,\displaystyle=\theta_{k}-h\bm{g}_{k}\mathbin{\mathchoice{\raisebox{1.3pt}{$\displaystyle\mathchoice{\scalebox{0.8}{$\displaystyle\oslash$}}{\scalebox{0.8}{$\textstyle\oslash$}}{\scalebox{0.8}{$\scriptstyle\oslash$}}{\scalebox{0.8}{$\scriptscriptstyle\oslash$}}$}}{\raisebox{1.3pt}{$\mathchoice{\scalebox{0.8}{$\displaystyle\oslash$}}{\scalebox{0.8}{$\textstyle\oslash$}}{\scalebox{0.8}{$\scriptstyle\oslash$}}{\scalebox{0.8}{$\scriptscriptstyle\oslash$}}$}}{\raisebox{0.75pt}{$\scriptstyle\mathchoice{\scalebox{0.8}{$\displaystyle\oslash$}}{\scalebox{0.8}{$\textstyle\oslash$}}{\scalebox{0.8}{$\scriptstyle\oslash$}}{\scalebox{0.8}{$\scriptscriptstyle\oslash$}}$}}{\raisebox{0.6pt}{$\scriptscriptstyle\mathchoice{\scalebox{0.8}{$\displaystyle\oslash$}}{\scalebox{0.8}{$\textstyle\oslash$}}{\scalebox{0.8}{$\scriptstyle\oslash$}}{\scalebox{0.8}{$\scriptscriptstyle\oslash$}}$}}}\bm{s}_{k}^{\mathbin{\mathchoice{\raisebox{1.3pt}{$\displaystyle\mathchoice{\scalebox{0.8}{$\displaystyle\odot$}}{\scalebox{0.8}{$\textstyle\odot$}}{\scalebox{0.8}{$\scriptstyle\odot$}}{\scalebox{0.8}{$\scriptscriptstyle\odot$}}$}}{\raisebox{1.3pt}{$\mathchoice{\scalebox{0.8}{$\displaystyle\odot$}}{\scalebox{0.8}{$\textstyle\odot$}}{\scalebox{0.8}{$\scriptstyle\odot$}}{\scalebox{0.8}{$\scriptscriptstyle\odot$}}$}}{\raisebox{0.75pt}{$\scriptstyle\mathchoice{\scalebox{0.8}{$\displaystyle\odot$}}{\scalebox{0.8}{$\textstyle\odot$}}{\scalebox{0.8}{$\scriptstyle\odot$}}{\scalebox{0.8}{$\scriptscriptstyle\odot$}}$}}{\raisebox{0.6pt}{$\scriptscriptstyle\mathchoice{\scalebox{0.8}{$\displaystyle\odot$}}{\scalebox{0.8}{$\textstyle\odot$}}{\scalebox{0.8}{$\scriptstyle\odot$}}{\scalebox{0.8}{$\scriptscriptstyle\odot$}}$}}}\frac{1}{2}},= italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT - italic_h bold_italic_g start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ⊘ bold_italic_s start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊙ divide start_ARG 1 end_ARG start_ARG 2 end_ARG end_POSTSUPERSCRIPT , (A.1.58)

where βi[0,1]\beta^{i}\in[0,1]italic_β start_POSTSUPERSCRIPT italic_i end_POSTSUPERSCRIPT ∈ [ 0 , 1 ] are the momentum parameters. The algorithm presented by this iteration is the celebrated Adam optimizer,555In order to avoid division-by-zero errors, we divide by 𝒔k(1/2)+ε𝟏n\bm{s}_{k}^{\mathbin{\mathchoice{\raisebox{1.3pt}{$\displaystyle\mathchoice{\scalebox{0.8}{$\displaystyle\odot$}}{\scalebox{0.8}{$\textstyle\odot$}}{\scalebox{0.8}{$\scriptstyle\odot$}}{\scalebox{0.8}{$\scriptscriptstyle\odot$}}$}}{\raisebox{1.3pt}{$\mathchoice{\scalebox{0.8}{$\displaystyle\odot$}}{\scalebox{0.8}{$\textstyle\odot$}}{\scalebox{0.8}{$\scriptstyle\odot$}}{\scalebox{0.8}{$\scriptscriptstyle\odot$}}$}}{\raisebox{0.75pt}{$\scriptstyle\mathchoice{\scalebox{0.8}{$\displaystyle\odot$}}{\scalebox{0.8}{$\textstyle\odot$}}{\scalebox{0.8}{$\scriptstyle\odot$}}{\scalebox{0.8}{$\scriptscriptstyle\odot$}}$}}{\raisebox{0.6pt}{$\scriptscriptstyle\mathchoice{\scalebox{0.8}{$\displaystyle\odot$}}{\scalebox{0.8}{$\textstyle\odot$}}{\scalebox{0.8}{$\scriptstyle\odot$}}{\scalebox{0.8}{$\scriptscriptstyle\odot$}}$}}}(1/2)}+\varepsilon\bm{1}_{n}bold_italic_s start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊙ ( 1 / 2 ) end_POSTSUPERSCRIPT + italic_ε bold_1 start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT where ε\varepsilonitalic_ε is small, say on the order of 10810^{-8}10 start_POSTSUPERSCRIPT - 8 end_POSTSUPERSCRIPT. which is the most-used optimizer in deep learning. While convergence proofs of Adam are more involved, it falls out of the same steepest descent principle we used so far, and so we should expect that given a small enough learning rate, each update should improve the loss.

Another way to view Adam, which partially explains its empirical success, is that it dynamically updates the learning rates for each parameter based on the squared gradients. In particular, notice that we can write

θk+1=θkηk𝒈kwhereηk=h𝒔k(12)\theta_{k+1}=\theta_{k}-\eta_{k}\mathbin{\mathchoice{\raisebox{1.3pt}{$\displaystyle\mathchoice{\scalebox{0.8}{$\displaystyle\odot$}}{\scalebox{0.8}{$\textstyle\odot$}}{\scalebox{0.8}{$\scriptstyle\odot$}}{\scalebox{0.8}{$\scriptscriptstyle\odot$}}$}}{\raisebox{1.3pt}{$\mathchoice{\scalebox{0.8}{$\displaystyle\odot$}}{\scalebox{0.8}{$\textstyle\odot$}}{\scalebox{0.8}{$\scriptstyle\odot$}}{\scalebox{0.8}{$\scriptscriptstyle\odot$}}$}}{\raisebox{0.75pt}{$\scriptstyle\mathchoice{\scalebox{0.8}{$\displaystyle\odot$}}{\scalebox{0.8}{$\textstyle\odot$}}{\scalebox{0.8}{$\scriptstyle\odot$}}{\scalebox{0.8}{$\scriptscriptstyle\odot$}}$}}{\raisebox{0.6pt}{$\scriptscriptstyle\mathchoice{\scalebox{0.8}{$\displaystyle\odot$}}{\scalebox{0.8}{$\textstyle\odot$}}{\scalebox{0.8}{$\scriptstyle\odot$}}{\scalebox{0.8}{$\scriptscriptstyle\odot$}}$}}}\bm{g}_{k}\qquad\text{where}\qquad\eta_{k}=h\bm{s}_{k}^{\mathbin{\mathchoice{\raisebox{1.3pt}{$\displaystyle\mathchoice{\scalebox{0.8}{$\displaystyle\odot$}}{\scalebox{0.8}{$\textstyle\odot$}}{\scalebox{0.8}{$\scriptstyle\odot$}}{\scalebox{0.8}{$\scriptscriptstyle\odot$}}$}}{\raisebox{1.3pt}{$\mathchoice{\scalebox{0.8}{$\displaystyle\odot$}}{\scalebox{0.8}{$\textstyle\odot$}}{\scalebox{0.8}{$\scriptstyle\odot$}}{\scalebox{0.8}{$\scriptscriptstyle\odot$}}$}}{\raisebox{0.75pt}{$\scriptstyle\mathchoice{\scalebox{0.8}{$\displaystyle\odot$}}{\scalebox{0.8}{$\textstyle\odot$}}{\scalebox{0.8}{$\scriptstyle\odot$}}{\scalebox{0.8}{$\scriptscriptstyle\odot$}}$}}{\raisebox{0.6pt}{$\scriptscriptstyle\mathchoice{\scalebox{0.8}{$\displaystyle\odot$}}{\scalebox{0.8}{$\textstyle\odot$}}{\scalebox{0.8}{$\scriptstyle\odot$}}{\scalebox{0.8}{$\scriptscriptstyle\odot$}}$}}}(-\frac{1}{2})}italic_θ start_POSTSUBSCRIPT italic_k + 1 end_POSTSUBSCRIPT = italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT - italic_η start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ⊙ bold_italic_g start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT where italic_η start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT = italic_h bold_italic_s start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊙ ( - divide start_ARG 1 end_ARG start_ARG 2 end_ARG ) end_POSTSUPERSCRIPT (A.1.59)

where ηk\eta_{k}italic_η start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT is the parameter-wise adaptively set learning rate. This scheme is called adaptive because if the gradient of a particular parameter is large up to iteration kkitalic_k, then the learning rate for this parameter becomes smaller to compensate, and vice versa, as can be seen from the above equation.

A.2 Computing Gradients via Automatic Differentiation

Above, we discussed several optimization algorithms for deep networks which assumed access to a first-order oracle, i.e., a device which would allow us to compute (θ)\mathcal{L}(\theta)caligraphic_L ( italic_θ ) and (θ)\nabla\mathcal{L}(\theta)∇ caligraphic_L ( italic_θ ). For simple functions \mathcal{L}caligraphic_L, it is possible to do this by hand. However, for deep neural networks, this quickly becomes tedious, and hinders rapid experimentation. Hence we require a general algorithm which would allow us to efficiently compute the gradients of arbitrary (sub)differentiable network architectures.

In this section, we introduce the basics of automatic differentiation (AD or autodiff), which is a computationally efficient way to compute gradients and Jacobians of general functions f:mnf:\mathbb{R}^{m}\to\mathbb{R}^{n}italic_f : blackboard_R start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT → blackboard_R start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT. We will show how this leads to the backpropagation algorithm for computing gradients of loss functions involving neural networks. A summary of the structure of this section is as follows:

  1. 1.

    We introduce differentials, a convenient formalism for calculating and organizing the derivatives of functions between high-dimensional parameter spaces (which may themselves be products of other spaces involving matrices, tensors, etc.).

  2. 2.

    We describe the basics of forward-mode and reverse-mode automatic differentiation, which involves considerations that are important for efficient computation of gradients/Jacobians for different kinds of functions arising in machine learning applications.

  3. 3.

    We describe backpropagation in the special case of a loss function applied to a stack-of-layers neural network as an instantiation of reverse-mode automatic differentiation.

Our treatment will err on the mathematical side, to give the reader a deep understanding of the underlying mathematics. The reader should ensure to couple this understanding with a strong grasp of practical aspects of automatic differentiation for deep learning, for example as manifested in the outstanding tutorial of [Kar22a].

A.2.1 Differentials

A full accounting of this subsection is given in the excellent guide [BEJ25]. To motivate differentials, let us first consider the simple example of a differentiable function :\mathcal{L}\colon\mathbb{R}\to\mathbb{R}caligraphic_L : blackboard_R → blackboard_R acting on a parameter θ\thetaitalic_θ. We can write

(θ)(θ0)=(θ0)(θθ0)+o(|θθ0|).\mathcal{L}(\theta)-\mathcal{L}(\theta_{0})=\mathcal{L}^{\prime}(\theta_{0})\cdot(\theta-\theta_{0})+o(\lvert\theta-\theta_{0}\rvert).caligraphic_L ( italic_θ ) - caligraphic_L ( italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) = caligraphic_L start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) ⋅ ( italic_θ - italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) + italic_o ( | italic_θ - italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT | ) . (A.2.1)

If we take δθθθ0\delta\theta\doteq\theta-\theta_{0}italic_δ italic_θ ≐ italic_θ - italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT and δ(θ0+δθ)(θ0)\delta\mathcal{L}\doteq\mathcal{L}(\theta_{0}+\delta\theta)-\mathcal{L}(\theta_{0})italic_δ caligraphic_L ≐ caligraphic_L ( italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT + italic_δ italic_θ ) - caligraphic_L ( italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ), we can write

δ=(θ0)δθ+o(|δθ|).\delta\mathcal{L}=\mathcal{L}^{\prime}(\theta_{0})\cdot\delta\theta+o(\lvert\delta\theta\rvert).italic_δ caligraphic_L = caligraphic_L start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) ⋅ italic_δ italic_θ + italic_o ( | italic_δ italic_θ | ) . (A.2.2)

We will (non-rigorously) define dθ\mathrm{d}\thetaroman_d italic_θ and d\mathrm{d}\mathcal{L}roman_d caligraphic_L, i.e., the differentials of θ\thetaitalic_θ and \mathcal{L}caligraphic_L, to be infinitesimally small changes in θ\thetaitalic_θ and \mathcal{L}caligraphic_L. Think of them as what one gets when δθ\delta\thetaitalic_δ italic_θ (and therefore δ\delta\mathcal{L}italic_δ caligraphic_L) are extremely small. The goal of differential calculus, in some sense, is to study the relationships between the differentials dθ\mathrm{d}\thetaroman_d italic_θ and d\mathrm{d}\mathcal{L}roman_d caligraphic_L, namely, seeing how small changes in the input of a function change the output. We should note that the differential dθ\mathrm{d}\thetaroman_d italic_θ is the same shape as θ\thetaitalic_θ, and the differential d\mathrm{d}\mathcal{L}roman_d caligraphic_L is the same shape as \mathcal{L}caligraphic_L. In particular, we can write

d=(θ)dθ,\mathrm{d}\mathcal{L}=\mathcal{L}^{\prime}(\theta)\cdot\mathrm{d}\theta,roman_d caligraphic_L = caligraphic_L start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_θ ) ⋅ roman_d italic_θ , (A.2.3)

whereby we have that all higher powers of |dθ|\lvert\mathrm{d}\theta\rvert| roman_d italic_θ |, such as (dθ)2(\mathrm{d}\theta)^{2}( roman_d italic_θ ) start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT, are 0.

Let’s see how this works for a higher dimensions, i.e., :n\mathcal{L}\colon\mathbb{R}^{n}\to\mathbb{R}caligraphic_L : blackboard_R start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT → blackboard_R. Then we still have

d=(θ)dθ\mathrm{d}\mathcal{L}=\mathcal{L}^{\prime}(\theta)\cdot\mathrm{d}\thetaroman_d caligraphic_L = caligraphic_L start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_θ ) ⋅ roman_d italic_θ (A.2.4)

for some notion of a derivative (θ)\mathcal{L}^{\prime}(\theta)caligraphic_L start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_θ ). Since θ\thetaitalic_θ (hence dθ\mathrm{d}\thetaroman_d italic_θ) is a column vector here and \mathcal{L}caligraphic_L (hence d\mathrm{d}\mathcal{L}roman_d caligraphic_L) is a scalar, we must have that (θ)\mathcal{L}^{\prime}(\theta)caligraphic_L start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_θ ) is a row vector. In this case, (θ)\mathcal{L}^{\prime}(\theta)caligraphic_L start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_θ ) is the Jacobian of \mathcal{L}caligraphic_L w.r.t. θ\thetaitalic_θ. Here notice that we have set all higher powers and products of coordinates of dθ\mathrm{d}\thetaroman_d italic_θ to 0. In sum,

All products and powers 2\geq 2≥ 2 of differentials are equal to 0.

Now consider a higher-order tensor function :m×np×q\mathcal{L}\colon\mathbb{R}^{m\times n}\to\mathbb{R}^{p\times q}caligraphic_L : blackboard_R start_POSTSUPERSCRIPT italic_m × italic_n end_POSTSUPERSCRIPT → blackboard_R start_POSTSUPERSCRIPT italic_p × italic_q end_POSTSUPERSCRIPT. Then our basic linearization equation is insufficient for this case: d=(θ)dθ\mathrm{d}\mathcal{L}=\mathcal{L}^{\prime}(\theta)\cdot\mathrm{d}\thetaroman_d caligraphic_L = caligraphic_L start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_θ ) ⋅ roman_d italic_θ does not make sense since θ\thetaitalic_θ is an m×nm\times nitalic_m × italic_n matrix but d\mathrm{d}\mathcal{L}roman_d caligraphic_L is a p×qp\times qitalic_p × italic_q matrix, so there is no possible vector or matrix shape for (θ)\mathcal{L}^{\prime}(\theta)caligraphic_L start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_θ ) that works in general (as no matrix can multiply a m×nm\times nitalic_m × italic_n matrix to form a p×qp\times qitalic_p × italic_q matrix unless m=pm=pitalic_m = italic_p). So we must have a slightly more advanced interpretation.

Namely, we consider (θ)\mathcal{L}^{\prime}(\theta)caligraphic_L start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_θ ) as a linear transformation whose input is θ\thetaitalic_θ-space and whose output is \mathcal{L}caligraphic_L-space, which takes in a small change in θ\thetaitalic_θ and outputs the corresponding small change in \mathcal{L}caligraphic_L. Namely, we can write

d=(θ)[dθ].\mathrm{d}\mathcal{L}=\mathcal{L}^{\prime}(\theta)[\mathrm{d}\theta].roman_d caligraphic_L = caligraphic_L start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_θ ) [ roman_d italic_θ ] . (A.2.5)

In the previous cases, (θ)\mathcal{L}^{\prime}(\theta)caligraphic_L start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_θ ) was first a linear operator \mathbb{R}\to\mathbb{R}blackboard_R → blackboard_R whose action was to multiply its input by the scalar derivative of \mathcal{L}caligraphic_L with respect to θ\thetaitalic_θ, and then a linear operator n\mathbb{R}^{n}\to\mathbb{R}blackboard_R start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT → blackboard_R whose action was to multiply its input by the Jacobian derivative of \mathcal{L}caligraphic_L with respect to θ\thetaitalic_θ. In general (θ)\mathcal{L}^{\prime}(\theta)caligraphic_L start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_θ ) is the “derivative” of \mathcal{L}caligraphic_L w.r.t. θ\thetaitalic_θ. Think of \mathcal{L}^{\prime}caligraphic_L start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT as a generalized version of the Jacobian of \mathcal{L}caligraphic_L. As such, it follows some simple derivative rules, most crucially the chain rule.

Theorem A.1 (Differential Chain Rule).

Suppose =fg\mathcal{L}=f\circ gcaligraphic_L = italic_f ∘ italic_g where ffitalic_f and ggitalic_g are differentiable. Then

d=f(g(θ))g(θ)[dθ],\mathrm{d}\mathcal{L}=f^{\prime}(g(\theta))g^{\prime}(\theta)[\mathrm{d}\theta],roman_d caligraphic_L = italic_f start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_g ( italic_θ ) ) italic_g start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_θ ) [ roman_d italic_θ ] , (A.2.6)

where (as usual) multiplication indicates composition of linear operators. In particular,

(θ)=f(g(θ))g(θ)\mathcal{L}^{\prime}(\theta)=f^{\prime}(g(\theta))g^{\prime}(\theta)caligraphic_L start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_θ ) = italic_f start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_g ( italic_θ ) ) italic_g start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_θ ) (A.2.7)

in the sense of equality of linear operators.

It is productive to think of the multivariate chain rule in functorial terms: composition of functions gets ‘turned into’ matrix multiplication of Jacobians (composition of linear operators!). We illustrate the power of this result and this perspective through several examples.

Example A.4.

Consider the function f(𝑿)=𝑾𝑿+𝒃𝟏f(\bm{X})=\bm{W}\bm{X}+\bm{b}\bm{1}^{\top}italic_f ( bold_italic_X ) = bold_italic_W bold_italic_X + bold_italic_b bold_1 start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT. Then

df=f(𝑿+d𝑿)f(𝑿)=[𝑾(𝑿+d𝑿)+𝒃𝟏][𝑾𝑿+𝒃𝟏]=𝑾d𝑿.\mathrm{d}f=f(\bm{X}+\mathrm{d}\bm{X})-f(\bm{X})=[\bm{W}(\bm{X}+\mathrm{d}\bm{X})+\bm{b}\bm{1}^{\top}]-[\bm{W}\bm{X}+\bm{b}\bm{1}^{\top}]=\bm{W}\mathrm{d}\bm{X}.roman_d italic_f = italic_f ( bold_italic_X + roman_d bold_italic_X ) - italic_f ( bold_italic_X ) = [ bold_italic_W ( bold_italic_X + roman_d bold_italic_X ) + bold_italic_b bold_1 start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ] - [ bold_italic_W bold_italic_X + bold_italic_b bold_1 start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ] = bold_italic_W roman_d bold_italic_X . (A.2.8)

Thus the derivative of an affine function w.r.t. its input is

f(𝑿)[d𝑿]=𝑾d𝑿f(𝑿)=𝑾.f^{\prime}(\bm{X})[\mathrm{d}\bm{X}]=\bm{W}\mathrm{d}\bm{X}\implies f^{\prime}(\bm{X})=\bm{W}.italic_f start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( bold_italic_X ) [ roman_d bold_italic_X ] = bold_italic_W roman_d bold_italic_X ⟹ italic_f start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( bold_italic_X ) = bold_italic_W . (A.2.9)

Notice that ff^{\prime}italic_f start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT is constant. On the other hand, consider the function g(𝑾,𝒃)=𝑾𝑿+𝒃𝟏g(\bm{W},\bm{b})=\bm{W}\bm{X}+\bm{b}\bm{1}^{\top}italic_g ( bold_italic_W , bold_italic_b ) = bold_italic_W bold_italic_X + bold_italic_b bold_1 start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT. Then

dg\displaystyle\mathrm{d}groman_d italic_g =g(𝑾+d𝑾,𝒃+d𝒃)g(𝑾,𝒃)=[(𝑾+d𝑾)𝑿+(𝒃+d𝒃)𝟏][𝑾𝑿+𝒃]\displaystyle=g(\bm{W}+\mathrm{d}\bm{W},\bm{b}+\mathrm{d}\bm{b})-g(\bm{W},\bm{b})=[(\bm{W}+\mathrm{d}\bm{W})\bm{X}+(\bm{b}+\mathrm{d}\bm{b})\bm{1}^{\top}]-[\bm{W}\bm{X}+\bm{b}]= italic_g ( bold_italic_W + roman_d bold_italic_W , bold_italic_b + roman_d bold_italic_b ) - italic_g ( bold_italic_W , bold_italic_b ) = [ ( bold_italic_W + roman_d bold_italic_W ) bold_italic_X + ( bold_italic_b + roman_d bold_italic_b ) bold_1 start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ] - [ bold_italic_W bold_italic_X + bold_italic_b ] (A.2.10)
=(d𝑾)𝑿+(d𝒃)𝟏=g(𝑾,𝑩)[d𝑾,d𝒃].\displaystyle=(\mathrm{d}\bm{W})\bm{X}+(\mathrm{d}\bm{b})\bm{1}^{\top}=g^{\prime}(\bm{W},\bm{B})[\mathrm{d}\bm{W},\mathrm{d}\bm{b}].= ( roman_d bold_italic_W ) bold_italic_X + ( roman_d bold_italic_b ) bold_1 start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT = italic_g start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( bold_italic_W , bold_italic_B ) [ roman_d bold_italic_W , roman_d bold_italic_b ] . (A.2.11)

Notice that this derivative is constant in 𝑾,𝒃\bm{W},\bm{b}bold_italic_W , bold_italic_b (which makes sense since ggitalic_g itself is linear) and linear in the differential inputs d𝑾,d𝒃\mathrm{d}\bm{W},\mathrm{d}\bm{b}roman_d bold_italic_W , roman_d bold_italic_b. \blacksquare

Example A.5.

Consider the function f=ghf=ghitalic_f = italic_g italic_h where g,hg,hitalic_g , italic_h are differentiable functions whose outputs can multiply together. Then f=pvf=p\circ vitalic_f = italic_p ∘ italic_v where v=(g,h)v=(g,h)italic_v = ( italic_g , italic_h ) and p(a,b)=abp(a,b)=abitalic_p ( italic_a , italic_b ) = italic_a italic_b. Applying the chain rule we have

df=p(v(x))v(x)[dx].\mathrm{d}f=p^{\prime}(v(x))v^{\prime}(x)[\mathrm{d}x].roman_d italic_f = italic_p start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_v ( italic_x ) ) italic_v start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_x ) [ roman_d italic_x ] . (A.2.12)

To compute v(x)v^{\prime}(x)italic_v start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_x ) we can compute

dv=v(x)[dx]=v(x+dx)v(x)=[g(x+dx)g(x)h(x+dx)h(x)]=[g(x)[dx]h(x)[dx]].\mathrm{d}v=v^{\prime}(x)[\mathrm{d}x]=v(x+\mathrm{d}x)-v(x)=\begin{bmatrix}g(x+\mathrm{d}x)-g(x)\\ h(x+\mathrm{d}x)-h(x)\end{bmatrix}=\begin{bmatrix}g^{\prime}(x)[\mathrm{d}x]\\ h^{\prime}(x)[\mathrm{d}x]\end{bmatrix}.roman_d italic_v = italic_v start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_x ) [ roman_d italic_x ] = italic_v ( italic_x + roman_d italic_x ) - italic_v ( italic_x ) = [ start_ARG start_ROW start_CELL italic_g ( italic_x + roman_d italic_x ) - italic_g ( italic_x ) end_CELL end_ROW start_ROW start_CELL italic_h ( italic_x + roman_d italic_x ) - italic_h ( italic_x ) end_CELL end_ROW end_ARG ] = [ start_ARG start_ROW start_CELL italic_g start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_x ) [ roman_d italic_x ] end_CELL end_ROW start_ROW start_CELL italic_h start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_x ) [ roman_d italic_x ] end_CELL end_ROW end_ARG ] . (A.2.13)

To compute pp^{\prime}italic_p start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT we can compute

dp\displaystyle\mathrm{d}proman_d italic_p =p(a,b)[da,db]=p(a+da,b+db)p(a,b)=(a+da)(b+db)ab\displaystyle=p^{\prime}(a,b)[\mathrm{d}a,\mathrm{d}b]=p(a+\mathrm{d}a,b+\mathrm{d}b)-p(a,b)=(a+\mathrm{d}a)(b+\mathrm{d}b)-ab= italic_p start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_a , italic_b ) [ roman_d italic_a , roman_d italic_b ] = italic_p ( italic_a + roman_d italic_a , italic_b + roman_d italic_b ) - italic_p ( italic_a , italic_b ) = ( italic_a + roman_d italic_a ) ( italic_b + roman_d italic_b ) - italic_a italic_b (A.2.14)
=(da)b+a(db)+(da)(db)=(da)b+a(db),\displaystyle=(\mathrm{d}a)b+a(\mathrm{d}b)+(\mathrm{d}a)(\mathrm{d}b)=(\mathrm{d}a)b+a(\mathrm{d}b),= ( roman_d italic_a ) italic_b + italic_a ( roman_d italic_b ) + ( roman_d italic_a ) ( roman_d italic_b ) = ( roman_d italic_a ) italic_b + italic_a ( roman_d italic_b ) , (A.2.15)

where (recall) the product of the differentials da\mathrm{d}aroman_d italic_a and db\mathrm{d}broman_d italic_b is set to 0. Therefore

p(a,b)[da,db]=(da)b+(db)a.p^{\prime}(a,b)[\mathrm{d}a,\mathrm{d}b]=(\mathrm{d}a)b+(\mathrm{d}b)a.italic_p start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_a , italic_b ) [ roman_d italic_a , roman_d italic_b ] = ( roman_d italic_a ) italic_b + ( roman_d italic_b ) italic_a . (A.2.16)

Putting these together, we find

f(x)[dx]\displaystyle f^{\prime}(x)[\mathrm{d}x]italic_f start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_x ) [ roman_d italic_x ] =p(v(x))v(x)[dx]=p(g(x),h(x))[g(x)[dx],h(x)[dx]]\displaystyle=p^{\prime}(v(x))v^{\prime}(x)[\mathrm{d}x]=p^{\prime}(g(x),h(x))[g^{\prime}(x)[\mathrm{d}x],h^{\prime}(x)[\mathrm{d}x]]= italic_p start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_v ( italic_x ) ) italic_v start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_x ) [ roman_d italic_x ] = italic_p start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_g ( italic_x ) , italic_h ( italic_x ) ) [ italic_g start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_x ) [ roman_d italic_x ] , italic_h start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_x ) [ roman_d italic_x ] ] (A.2.17)
=(g(x)[dx])h(x)+g(x)(h(x)[dx]).\displaystyle=(g^{\prime}(x)[\mathrm{d}x])h(x)+g(x)(h^{\prime}(x)[\mathrm{d}x]).= ( italic_g start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_x ) [ roman_d italic_x ] ) italic_h ( italic_x ) + italic_g ( italic_x ) ( italic_h start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_x ) [ roman_d italic_x ] ) . (A.2.18)

This gives

f(x)[dx]=(g(x)[dx])h(x)+g(x)(h(x)[dx]).f^{\prime}(x)[\mathrm{d}x]=(g^{\prime}(x)[\mathrm{d}x])h(x)+g(x)(h^{\prime}(x)[\mathrm{d}x]).italic_f start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_x ) [ roman_d italic_x ] = ( italic_g start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_x ) [ roman_d italic_x ] ) italic_h ( italic_x ) + italic_g ( italic_x ) ( italic_h start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_x ) [ roman_d italic_x ] ) . (A.2.19)

If for example we say that f,g,h:f,g,h\colon\mathbb{R}\to\mathbb{R}italic_f , italic_g , italic_h : blackboard_R → blackboard_R then everything commutes so

f(x)[dx]=(g(x)h(x)+g(x)h(x))[dx]f(x)=g(x)h(x)+g(x)h(x)f^{\prime}(x)[\mathrm{d}x]=(g^{\prime}(x)h(x)+g(x)h^{\prime}(x))[\mathrm{d}x]\implies f^{\prime}(x)=g^{\prime}(x)h(x)+g(x)h^{\prime}(x)italic_f start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_x ) [ roman_d italic_x ] = ( italic_g start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_x ) italic_h ( italic_x ) + italic_g ( italic_x ) italic_h start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_x ) ) [ roman_d italic_x ] ⟹ italic_f start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_x ) = italic_g start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_x ) italic_h ( italic_x ) + italic_g ( italic_x ) italic_h start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_x ) (A.2.20)

which is the familiar product rule. \blacksquare

Example A.6.

Consider the function f(𝑨)=𝑨𝑨𝑩𝑨f(\bm{A})=\bm{A}^{\top}\bm{A}\bm{B}\bm{A}italic_f ( bold_italic_A ) = bold_italic_A start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_A bold_italic_B bold_italic_A where 𝑨\bm{A}bold_italic_A is a matrix and 𝑩\bm{B}bold_italic_B is a constant matrix. Then, letting f=ghf=ghitalic_f = italic_g italic_h where g(𝑨)=𝑨𝑨g(\bm{A})=\bm{A}^{\top}\bm{A}italic_g ( bold_italic_A ) = bold_italic_A start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_A and h(𝑨)=𝑩𝑨h(\bm{A})=\bm{B}\bm{A}italic_h ( bold_italic_A ) = bold_italic_B bold_italic_A, we can use the product rule to obtain

f(𝑨)[d𝑨]\displaystyle f^{\prime}(\bm{A})[\mathrm{d}\bm{A}]italic_f start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( bold_italic_A ) [ roman_d bold_italic_A ] =(g(𝑨)[d𝑨])h(𝑨)+g(𝑨)(h(𝑨)[d𝑨])\displaystyle=(g^{\prime}(\bm{A})[\mathrm{d}\bm{A}])h(\bm{A})+g(\bm{A})(h^{\prime}(\bm{A})[\mathrm{d}\bm{A}])= ( italic_g start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( bold_italic_A ) [ roman_d bold_italic_A ] ) italic_h ( bold_italic_A ) + italic_g ( bold_italic_A ) ( italic_h start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( bold_italic_A ) [ roman_d bold_italic_A ] ) (A.2.21)
=((d𝑨)𝑨+𝑨(d𝑨))𝑩𝑨+𝑨𝑨𝑩(d𝑨).\displaystyle=((\mathrm{d}\bm{A})^{\top}\bm{A}+\bm{A}^{\top}(\mathrm{d}\bm{A}))\bm{B}\bm{A}+\bm{A}^{\top}\bm{A}\bm{B}(\mathrm{d}\bm{A}).= ( ( roman_d bold_italic_A ) start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_A + bold_italic_A start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ( roman_d bold_italic_A ) ) bold_italic_B bold_italic_A + bold_italic_A start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_A bold_italic_B ( roman_d bold_italic_A ) . (A.2.22)

\blacksquare

Example A.7.

Consider the function f:m×n×km×nf\colon\mathbb{R}^{m\times n\times k}\to\mathbb{R}^{m\times n}italic_f : blackboard_R start_POSTSUPERSCRIPT italic_m × italic_n × italic_k end_POSTSUPERSCRIPT → blackboard_R start_POSTSUPERSCRIPT italic_m × italic_n end_POSTSUPERSCRIPT given by

f(𝑨)ij=t=1kAijt.f(\bm{A})_{ij}=\sum_{t=1}^{k}A_{ijt}.italic_f ( bold_italic_A ) start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT = ∑ start_POSTSUBSCRIPT italic_t = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT italic_A start_POSTSUBSCRIPT italic_i italic_j italic_t end_POSTSUBSCRIPT . (A.2.23)

We cannot write a (matrix-valued) Jacobian or gradient for this function. But we can compute its differential just fine:

dfij=[f(𝑨+d𝑨)f(𝑨)]ij=t=1kd𝑨ijt=𝟏k(d𝑨)ij.\mathrm{d}f_{ij}=[f(\bm{A}+\mathrm{d}\bm{A})-f(\bm{A})]_{ij}=\sum_{t=1}^{k}\mathrm{d}\bm{A}_{ijt}=\bm{1}_{k}^{\top}(\mathrm{d}\bm{A})_{ij}.roman_d italic_f start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT = [ italic_f ( bold_italic_A + roman_d bold_italic_A ) - italic_f ( bold_italic_A ) ] start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT = ∑ start_POSTSUBSCRIPT italic_t = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT roman_d bold_italic_A start_POSTSUBSCRIPT italic_i italic_j italic_t end_POSTSUBSCRIPT = bold_1 start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ( roman_d bold_italic_A ) start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT . (A.2.24)

So

(f(𝑨)[d𝑨])ij=𝟏k(d𝑨)ij,(f^{\prime}(\bm{A})[\mathrm{d}\bm{A}])_{ij}=\bm{1}_{k}^{\top}(\mathrm{d}\bm{A})_{ij},( italic_f start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( bold_italic_A ) [ roman_d bold_italic_A ] ) start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT = bold_1 start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ( roman_d bold_italic_A ) start_POSTSUBSCRIPT italic_i italic_j end_POSTSUBSCRIPT , (A.2.25)

which represents a higher-order tensor multiplication operation that is nonetheless well-defined. \blacksquare

This gives us all the technology we need to compute differentials of everything. The last thing we cover in this section is a method to compute gradients using the differential. Namely, for a function \mathcal{L}caligraphic_L whose output is a scalar, the gradient \nabla\mathcal{L}∇ caligraphic_L is defined as

d=(θ)[dθ]=(θ),dθ,\mathrm{d}\mathcal{L}=\mathcal{L}^{\prime}(\theta)[\mathrm{d}\theta]=\langle\nabla\mathcal{L}(\theta),\mathrm{d}\theta\rangle,roman_d caligraphic_L = caligraphic_L start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_θ ) [ roman_d italic_θ ] = ⟨ ∇ caligraphic_L ( italic_θ ) , roman_d italic_θ ⟩ , (A.2.26)

where the inner product here is the “standard” inner product for the specified objects (i.e., for vectors it’s the 2\ell^{2}roman_ℓ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT inner product, whereas for matrices it’s the Frobenius inner product, and for higher-order tensors it’s the analogous sum-of-coordinates inner product). This definition is the correct generalization of the ‘familiar’ example of the gradient of a function from n\mathbb{R}^{n}blackboard_R start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT to \mathbb{R}blackboard_R as the vector of partial derivatives—a version of Taylor’s theorem for general functions f:mnf:\mathbb{R}^{m}\to\mathbb{R}^{n}italic_f : blackboard_R start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT → blackboard_R start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT makes this connection rigorous. So one way to compute the gradient \nabla\mathcal{L}∇ caligraphic_L is to compute the differential d\mathrm{d}\mathcal{L}roman_d caligraphic_L and rewrite it in the form something,dθ\langle\text{something},\mathrm{d}\theta\rangle⟨ something , roman_d italic_θ ⟩, then that “something” is the gradient.

A.2.2 Automatic Differentiation

The main idea of AD is to compute the chain rule efficiently. The basic problem we need to cope with is the following. In the optimization section of the appendix, we considered that the parameter space Θ\Thetaroman_Θ was an abstract Euclidean space like n\mathbb{R}^{n}blackboard_R start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT. In practice the parameters are really some collection of vectors, matrices, and higher-order objects: Θ=m×n×n×r×q×p×r×q×\Theta=\mathbb{R}^{m\times n}\times\mathbb{R}^{n}\times\mathbb{R}^{r\times q\times p}\times\mathbb{R}^{r\times q}\times\cdotsroman_Θ = blackboard_R start_POSTSUPERSCRIPT italic_m × italic_n end_POSTSUPERSCRIPT × blackboard_R start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT × blackboard_R start_POSTSUPERSCRIPT italic_r × italic_q × italic_p end_POSTSUPERSCRIPT × blackboard_R start_POSTSUPERSCRIPT italic_r × italic_q end_POSTSUPERSCRIPT × ⋯. While in theory this is the same thing as a large parameter space n\mathbb{R}^{n^{\prime}}blackboard_R start_POSTSUPERSCRIPT italic_n start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT for some (very large) nn^{\prime}italic_n start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT, computationally efficient algorithms for differentiation must treat these two spaces differently. Forward and reverse mode automatic differentiation are two different schemes for performing this computation.

Let us do a simple example to start. Let \mathcal{L}caligraphic_L be defined by =abc\mathcal{L}=a\circ b\circ ccaligraphic_L = italic_a ∘ italic_b ∘ italic_c where a,b,ca,b,citalic_a , italic_b , italic_c are differentiable. Then the chain rule gives

(θ)=a(b(c(θ)))b(c(θ))c(θ).\mathcal{L}^{\prime}(\theta)=a^{\prime}(b(c(\theta)))b^{\prime}(c(\theta))c^{\prime}(\theta).caligraphic_L start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_θ ) = italic_a start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_b ( italic_c ( italic_θ ) ) ) italic_b start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_c ( italic_θ ) ) italic_c start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_θ ) . (A.2.27)

To compute (θ)\mathcal{L}(\theta)caligraphic_L ( italic_θ ), we first compute c(θ)c(\theta)italic_c ( italic_θ ) then b(c(θ))b(c(\theta))italic_b ( italic_c ( italic_θ ) ) then a(b(c(θ)))a(b(c(\theta)))italic_a ( italic_b ( italic_c ( italic_θ ) ) ), and store them all. There are two ways to compute (θ)\mathcal{L}^{\prime}(\theta)caligraphic_L start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_θ ). The forward-mode AD will compute

c(θ)b(c(θ))c(θ)a(b(c(θ)))b(c(θ))c(θ)c^{\prime}(\theta)\implies b^{\prime}(c(\theta))c^{\prime}(\theta)\implies a^{\prime}(b(c(\theta)))b^{\prime}(c(\theta))c^{\prime}(\theta)italic_c start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_θ ) ⟹ italic_b start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_c ( italic_θ ) ) italic_c start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_θ ) ⟹ italic_a start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_b ( italic_c ( italic_θ ) ) ) italic_b start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_c ( italic_θ ) ) italic_c start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_θ ) (A.2.28)

i.e., computing the derivatives “from the bottom-up”. The reverse mode AD will compute

a(b(c(θ)))a(b(c(θ)))b(c(θ))a(b(c(θ)))b(c(θ))c(θ),a^{\prime}(b(c(\theta)))\implies a^{\prime}(b(c(\theta)))b^{\prime}(c(\theta))\implies a^{\prime}(b(c(\theta)))b^{\prime}(c(\theta))c^{\prime}(\theta),italic_a start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_b ( italic_c ( italic_θ ) ) ) ⟹ italic_a start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_b ( italic_c ( italic_θ ) ) ) italic_b start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_c ( italic_θ ) ) ⟹ italic_a start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_b ( italic_c ( italic_θ ) ) ) italic_b start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_c ( italic_θ ) ) italic_c start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_θ ) , (A.2.29)

i.e., computing the derivatives “from the top down”. To see why this matters, suppose that f:psf\colon\mathbb{R}^{p}\to\mathbb{R}^{s}italic_f : blackboard_R start_POSTSUPERSCRIPT italic_p end_POSTSUPERSCRIPT → blackboard_R start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT is given by f=abcf=a\circ b\circ citalic_f = italic_a ∘ italic_b ∘ italic_c where a:rsa\colon\mathbb{R}^{r}\to\mathbb{R}^{s}italic_a : blackboard_R start_POSTSUPERSCRIPT italic_r end_POSTSUPERSCRIPT → blackboard_R start_POSTSUPERSCRIPT italic_s end_POSTSUPERSCRIPT, b:qrb\colon\mathbb{R}^{q}\to\mathbb{R}^{r}italic_b : blackboard_R start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT → blackboard_R start_POSTSUPERSCRIPT italic_r end_POSTSUPERSCRIPT, c:pqc\colon\mathbb{R}^{p}\to\mathbb{R}^{q}italic_c : blackboard_R start_POSTSUPERSCRIPT italic_p end_POSTSUPERSCRIPT → blackboard_R start_POSTSUPERSCRIPT italic_q end_POSTSUPERSCRIPT. Then the chain rule is:

f(𝒙)=a(b(c(𝒙)))b(c(𝒙))c(𝒙)f^{\prime}(\bm{x})=a^{\prime}(b(c(\bm{x})))b^{\prime}(c(\bm{x}))c^{\prime}(\bm{x})italic_f start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( bold_italic_x ) = italic_a start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_b ( italic_c ( bold_italic_x ) ) ) italic_b start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_c ( bold_italic_x ) ) italic_c start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( bold_italic_x ) (A.2.30)

where (recall) ff^{\prime}italic_f start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT is the derivative, in this case the Jacobian (since the input and output of each function are both vectors). Assuming that computing each Jacobian is trivial and the only cost is multiplying the Jacobians together, forward-mode AD has the following computational costs (assuming that multiplying Am×n,Bn×kA\in\mathbb{R}^{m\times n},B\in\mathbb{R}^{n\times k}italic_A ∈ blackboard_R start_POSTSUPERSCRIPT italic_m × italic_n end_POSTSUPERSCRIPT , italic_B ∈ blackboard_R start_POSTSUPERSCRIPT italic_n × italic_k end_POSTSUPERSCRIPT takes 𝒪(mnk)\mathcal{O}(mnk)caligraphic_O ( italic_m italic_n italic_k ) time):

computingc(𝒙)q×ptakes negligible time\displaystyle\text{computing}\ c^{\prime}(\bm{x})\in\mathbb{R}^{q\times p}\ \text{takes negligible time}computing italic_c start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( bold_italic_x ) ∈ blackboard_R start_POSTSUPERSCRIPT italic_q × italic_p end_POSTSUPERSCRIPT takes negligible time (A.2.31)
computingb(c(𝒙))c(x)r×ptakes 𝒪(pqr) time\displaystyle\text{computing}\ b^{\prime}(c(\bm{x}))c^{\prime}(x)\in\mathbb{R}^{r\times p}\ \text{takes $\mathcal{O}(pqr)$ time}computing italic_b start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_c ( bold_italic_x ) ) italic_c start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_x ) ∈ blackboard_R start_POSTSUPERSCRIPT italic_r × italic_p end_POSTSUPERSCRIPT takes caligraphic_O ( italic_p italic_q italic_r ) time (A.2.32)
computinga(b(c(𝒙)))b(c(𝒙))c(𝒙)s×ptakes 𝒪(pqr+prs) time.\displaystyle\text{computing}\ a^{\prime}(b(c(\bm{x})))b^{\prime}(c(\bm{x}))c^{\prime}(\bm{x})\in\mathbb{R}^{s\times p}\ \text{takes $\mathcal{O}(pqr+prs)$ time.}computing italic_a start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_b ( italic_c ( bold_italic_x ) ) ) italic_b start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_c ( bold_italic_x ) ) italic_c start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( bold_italic_x ) ∈ blackboard_R start_POSTSUPERSCRIPT italic_s × italic_p end_POSTSUPERSCRIPT takes caligraphic_O ( italic_p italic_q italic_r + italic_p italic_r italic_s ) time. (A.2.33)

Meanwhile, doing reverse-mode AD has the following computational costs:

computinga(b(c(𝒙)))s×rtakes negligible time\displaystyle\text{computing}\ a^{\prime}(b(c(\bm{x})))\in\mathbb{R}^{s\times r}\ \text{takes negligible time}computing italic_a start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_b ( italic_c ( bold_italic_x ) ) ) ∈ blackboard_R start_POSTSUPERSCRIPT italic_s × italic_r end_POSTSUPERSCRIPT takes negligible time (A.2.34)
computinga(b(c(𝒙)))b(c(𝒙))s×qtakes 𝒪(qrs) time\displaystyle\text{computing}\ a^{\prime}(b(c(\bm{x})))b^{\prime}(c(\bm{x}))\in\mathbb{R}^{s\times q}\ \text{takes $\mathcal{O}(qrs)$ time}computing italic_a start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_b ( italic_c ( bold_italic_x ) ) ) italic_b start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_c ( bold_italic_x ) ) ∈ blackboard_R start_POSTSUPERSCRIPT italic_s × italic_q end_POSTSUPERSCRIPT takes caligraphic_O ( italic_q italic_r italic_s ) time (A.2.35)
computinga(b(c(𝒙)))b(c(𝒙))c(𝒙)s×ptakes 𝒪(qrs+pqs) time.\displaystyle\text{computing}\ a^{\prime}(b(c(\bm{x})))b^{\prime}(c(\bm{x}))c^{\prime}(\bm{x})\in\mathbb{R}^{s\times p}\ \text{takes $\mathcal{O}(qrs+pqs)$ time.}computing italic_a start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_b ( italic_c ( bold_italic_x ) ) ) italic_b start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( italic_c ( bold_italic_x ) ) italic_c start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ( bold_italic_x ) ∈ blackboard_R start_POSTSUPERSCRIPT italic_s × italic_p end_POSTSUPERSCRIPT takes caligraphic_O ( italic_q italic_r italic_s + italic_p italic_q italic_s ) time. (A.2.36)

In other words, the forward-mode AD takes 𝒪(p(qr+rs))\mathcal{O}(p(qr+rs))caligraphic_O ( italic_p ( italic_q italic_r + italic_r italic_s ) ) time, and the reverse-mode AD takes 𝒪(s(pq+qr))\mathcal{O}(s(pq+qr))caligraphic_O ( italic_s ( italic_p italic_q + italic_q italic_r ) ) time. These take a different amount of time!

More generally, suppose that f=fLf1f=f^{L}\circ\cdots\circ f^{1}italic_f = italic_f start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT ∘ ⋯ ∘ italic_f start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT where each f:d1df^{\ell}\colon\mathbb{R}^{d^{\ell-1}}\to\mathbb{R}^{d^{\ell}}italic_f start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT : blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUPERSCRIPT roman_ℓ - 1 end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT → blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT, so that f:d0dLf\colon\mathbb{R}^{d^{0}}\to\mathbb{R}^{d^{L}}italic_f : blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT → blackboard_R start_POSTSUPERSCRIPT italic_d start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT end_POSTSUPERSCRIPT. Then the forward-mode AD takes 𝒪(d0(=2Ld1d))\mathcal{O}(d^{0}(\sum_{\ell=2}^{L}d^{\ell-1}d^{\ell}))caligraphic_O ( italic_d start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT ( ∑ start_POSTSUBSCRIPT roman_ℓ = 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT italic_d start_POSTSUPERSCRIPT roman_ℓ - 1 end_POSTSUPERSCRIPT italic_d start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ) ) time while the reverse-mode AD takes 𝒪(dL(=1L1d1d))\mathcal{O}(d^{L}(\sum_{\ell=1}^{L-1}d^{\ell-1}d^{\ell}))caligraphic_O ( italic_d start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT ( ∑ start_POSTSUBSCRIPT roman_ℓ = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L - 1 end_POSTSUPERSCRIPT italic_d start_POSTSUPERSCRIPT roman_ℓ - 1 end_POSTSUPERSCRIPT italic_d start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ) ) time. From the above rates, we see that all else equal:

  • If the function to optimize has more outputs than inputs (i.e., dL>d0d^{L}>d^{0}italic_d start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT > italic_d start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT), use forward-mode AD.

  • If the function to optimize has more inputs than outputs (i.e., d0>dLd^{0}>d^{L}italic_d start_POSTSUPERSCRIPT 0 end_POSTSUPERSCRIPT > italic_d start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT), use reverse-mode AD.

In a neural network, we compute the gradient of a loss function :Θ\mathcal{L}\colon\Theta\to\mathbb{R}caligraphic_L : roman_Θ → blackboard_R, where the parameter space Θ\Thetaroman_Θ is usually very high-dimensional. So in practice we always use reverse-mode AD for training neural networks. Reverse-mode AD, in the context of training neural networks, is called backpropagation.

A.2.3 Back Propagation

In this section, we will discuss algorithmic backpropagation using a simple yet completely practical example. Suppose that we fix an input-label pair (𝑿,𝒚)(\bm{X},\bm{y})( bold_italic_X , bold_italic_y ), and fix a network architecture fθ=fθLfθ1fθembf_{\theta}=f_{\theta}^{L}\circ\cdots\circ f_{\theta}^{1}\circ f_{\theta}^{\mathrm{emb}}italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT = italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT ∘ ⋯ ∘ italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT ∘ italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_emb end_POSTSUPERSCRIPT where θ=(θemb,θ1,,θm,θhead)\theta=(\theta^{\mathrm{emb}},\theta^{1},\dots,\theta^{m},\theta^{\mathrm{head}})italic_θ = ( italic_θ start_POSTSUPERSCRIPT roman_emb end_POSTSUPERSCRIPT , italic_θ start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT , … , italic_θ start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT , italic_θ start_POSTSUPERSCRIPT roman_head end_POSTSUPERSCRIPT ) and task-specific head hθheadh_{\theta^{\mathrm{head}}}italic_h start_POSTSUBSCRIPT italic_θ start_POSTSUPERSCRIPT roman_head end_POSTSUPERSCRIPT end_POSTSUBSCRIPT, and write

𝒁θ1(𝑿)fθemb(𝑿),\displaystyle\bm{Z}_{\theta}^{1}(\bm{X})\doteq f_{\theta}^{\mathrm{emb}}(\bm{X}),bold_italic_Z start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT ( bold_italic_X ) ≐ italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_emb end_POSTSUPERSCRIPT ( bold_italic_X ) , (A.2.37)
𝒁θ+1(𝑿)fθ(𝒁θ(𝑿)),{1,,L},\displaystyle\bm{Z}_{\theta}^{\ell+1}(\bm{X})\doteq f_{\theta}^{\ell}(\bm{Z}_{\theta}^{\ell}(\bm{X})),\quad\forall\ell\in\{1,\dots,L\},bold_italic_Z start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ + 1 end_POSTSUPERSCRIPT ( bold_italic_X ) ≐ italic_f start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ( bold_italic_Z start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ( bold_italic_X ) ) , ∀ roman_ℓ ∈ { 1 , … , italic_L } , (A.2.38)
𝒚^θhead(𝑿)hθ(𝒁θL+1(𝑿)).\displaystyle\hat{\bm{y}}_{\theta^{\mathrm{head}}}(\bm{X})\doteq h_{\theta}(\bm{Z}_{\theta}^{L+1}(\bm{X})).over^ start_ARG bold_italic_y end_ARG start_POSTSUBSCRIPT italic_θ start_POSTSUPERSCRIPT roman_head end_POSTSUPERSCRIPT end_POSTSUBSCRIPT ( bold_italic_X ) ≐ italic_h start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_Z start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_L + 1 end_POSTSUPERSCRIPT ( bold_italic_X ) ) . (A.2.39)

Then, we can define the loss on this one term by

(θ)𝖫(𝒚,𝒚^θ(𝑿)),\mathcal{L}(\theta)\doteq\mathsf{L}(\bm{y},\hat{\bm{y}}_{\theta}(\bm{X})),caligraphic_L ( italic_θ ) ≐ sansserif_L ( bold_italic_y , over^ start_ARG bold_italic_y end_ARG start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT ( bold_italic_X ) ) , (A.2.40)

where 𝖫\mathsf{L}sansserif_L is a differentiable function of its second argument. Then the backward-mode AD instructs us to compute the derivatives in the order θhead,θL,,θ1,θemb\theta^{\mathrm{head}},\theta^{L},\dots,\theta^{1},\theta^{\mathrm{emb}}italic_θ start_POSTSUPERSCRIPT roman_head end_POSTSUPERSCRIPT , italic_θ start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT , … , italic_θ start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT , italic_θ start_POSTSUPERSCRIPT roman_emb end_POSTSUPERSCRIPT.

To carry out this computation, let us make some changes to the notation.

  • First, let us change the notation to emphasize the dependency structure between the variables. Namely,

    𝒁1femb(𝑿,θemb)\displaystyle\bm{Z}^{1}\doteq f^{\mathrm{emb}}(\bm{X},\theta^{\mathrm{emb}})bold_italic_Z start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT ≐ italic_f start_POSTSUPERSCRIPT roman_emb end_POSTSUPERSCRIPT ( bold_italic_X , italic_θ start_POSTSUPERSCRIPT roman_emb end_POSTSUPERSCRIPT ) (A.2.41)
    𝒁+1f(𝒁,θ),{1,,L},\displaystyle\bm{Z}^{\ell+1}\doteq f^{\ell}(\bm{Z}^{\ell},\theta^{\ell}),\qquad\forall\ell\in\{1,\dots,L\},bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ + 1 end_POSTSUPERSCRIPT ≐ italic_f start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ( bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT , italic_θ start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ) , ∀ roman_ℓ ∈ { 1 , … , italic_L } , (A.2.42)
    𝒚^h(𝒁L+1,θhead),\displaystyle\hat{\bm{y}}\doteq h(\bm{Z}^{L+1},\theta^{\mathrm{head}}),over^ start_ARG bold_italic_y end_ARG ≐ italic_h ( bold_italic_Z start_POSTSUPERSCRIPT italic_L + 1 end_POSTSUPERSCRIPT , italic_θ start_POSTSUPERSCRIPT roman_head end_POSTSUPERSCRIPT ) , (A.2.43)
    𝖫(𝒚,𝒚^).\displaystyle\mathcal{L}\doteq\mathsf{L}(\bm{y},\hat{\bm{y}}).caligraphic_L ≐ sansserif_L ( bold_italic_y , over^ start_ARG bold_italic_y end_ARG ) . (A.2.44)
  • Then, instead of having the derivative be ff^{\prime}italic_f start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT, we explicitly notate the independent variable and write the derivative as dfdθ1\frac{\mathrm{d}f}{\mathrm{d}\theta^{1}}divide start_ARG roman_d italic_f end_ARG start_ARG roman_d italic_θ start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT end_ARG, for example. This is because there are many variables in our model and we only care about one at a time.

We can start by computing the appropriate differentials. First, for θhead\theta^{\mathrm{head}}italic_θ start_POSTSUPERSCRIPT roman_head end_POSTSUPERSCRIPT we have

d\displaystyle\mathrm{d}\mathcal{L}roman_d caligraphic_L =d𝖫\displaystyle=\mathrm{d}\mathsf{L}= roman_d sansserif_L (A.2.45)
=d𝖫d𝒚^d𝒚^\displaystyle=\frac{\mathrm{d}\mathsf{L}}{\mathrm{d}\hat{\bm{y}}}\cdot\mathrm{d}\hat{\bm{y}}= divide start_ARG roman_d sansserif_L end_ARG start_ARG roman_d over^ start_ARG bold_italic_y end_ARG end_ARG ⋅ roman_d over^ start_ARG bold_italic_y end_ARG (A.2.46)
=d𝖫d𝒚^d(h(𝒁L+1,θhead))\displaystyle=\frac{\mathrm{d}\mathsf{L}}{\mathrm{d}\hat{\bm{y}}}\cdot\mathrm{d}{(h(\bm{Z}^{L+1},\theta^{\mathrm{head}}))}= divide start_ARG roman_d sansserif_L end_ARG start_ARG roman_d over^ start_ARG bold_italic_y end_ARG end_ARG ⋅ roman_d ( italic_h ( bold_italic_Z start_POSTSUPERSCRIPT italic_L + 1 end_POSTSUPERSCRIPT , italic_θ start_POSTSUPERSCRIPT roman_head end_POSTSUPERSCRIPT ) ) (A.2.47)
=d𝖫d𝒚^[dhd𝒁L+1d𝒁L+1+dhdθheaddθhead].\displaystyle=\frac{\mathrm{d}\mathsf{L}}{\mathrm{d}\hat{\bm{y}}}\left[\frac{\mathrm{d}h}{\mathrm{d}\bm{Z}^{L+1}}\cdot\mathrm{d}\bm{Z}^{L+1}+\frac{\mathrm{d}h}{\mathrm{d}\theta^{\mathrm{head}}}\cdot\mathrm{d}\theta^{\mathrm{head}}\right].= divide start_ARG roman_d sansserif_L end_ARG start_ARG roman_d over^ start_ARG bold_italic_y end_ARG end_ARG [ divide start_ARG roman_d italic_h end_ARG start_ARG roman_d bold_italic_Z start_POSTSUPERSCRIPT italic_L + 1 end_POSTSUPERSCRIPT end_ARG ⋅ roman_d bold_italic_Z start_POSTSUPERSCRIPT italic_L + 1 end_POSTSUPERSCRIPT + divide start_ARG roman_d italic_h end_ARG start_ARG roman_d italic_θ start_POSTSUPERSCRIPT roman_head end_POSTSUPERSCRIPT end_ARG ⋅ roman_d italic_θ start_POSTSUPERSCRIPT roman_head end_POSTSUPERSCRIPT ] . (A.2.48)

Now since 𝒁L+1\bm{Z}^{L+1}bold_italic_Z start_POSTSUPERSCRIPT italic_L + 1 end_POSTSUPERSCRIPT does not depend on θhead\theta^{\mathrm{head}}italic_θ start_POSTSUPERSCRIPT roman_head end_POSTSUPERSCRIPT, we have d𝒁L+1=0\mathrm{d}\bm{Z}^{L+1}=0roman_d bold_italic_Z start_POSTSUPERSCRIPT italic_L + 1 end_POSTSUPERSCRIPT = 0, so in the end it holds (using the fact that the gradient is the transpose of the derivative for a function 𝖫\mathsf{L}sansserif_L whose codomain is \mathbb{R}blackboard_R):

d\displaystyle\mathrm{d}\mathcal{L}roman_d caligraphic_L =d𝖫d𝒚^dhdθheaddθhead\displaystyle=\frac{\mathrm{d}\mathsf{L}}{\mathrm{d}\hat{\bm{y}}}\cdot\frac{\mathrm{d}h}{\mathrm{d}\theta^{\mathrm{head}}}\cdot\mathrm{d}\theta^{\mathrm{head}}= divide start_ARG roman_d sansserif_L end_ARG start_ARG roman_d over^ start_ARG bold_italic_y end_ARG end_ARG ⋅ divide start_ARG roman_d italic_h end_ARG start_ARG roman_d italic_θ start_POSTSUPERSCRIPT roman_head end_POSTSUPERSCRIPT end_ARG ⋅ roman_d italic_θ start_POSTSUPERSCRIPT roman_head end_POSTSUPERSCRIPT (A.2.49)
=[𝒚^𝖫]dhdθheaddθhead\displaystyle=[\nabla_{\hat{\bm{y}}}\mathsf{L}]^{\top}\frac{\mathrm{d}h}{\mathrm{d}\theta^{\mathrm{head}}}\cdot\mathrm{d}\theta^{\mathrm{head}}= [ ∇ start_POSTSUBSCRIPT over^ start_ARG bold_italic_y end_ARG end_POSTSUBSCRIPT sansserif_L ] start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT divide start_ARG roman_d italic_h end_ARG start_ARG roman_d italic_θ start_POSTSUPERSCRIPT roman_head end_POSTSUPERSCRIPT end_ARG ⋅ roman_d italic_θ start_POSTSUPERSCRIPT roman_head end_POSTSUPERSCRIPT (A.2.50)
=𝒚^𝖫,dhdθheaddθhead\displaystyle=\left\langle\nabla_{\hat{\bm{y}}}\mathsf{L},\frac{\mathrm{d}h}{\mathrm{d}\theta^{\mathrm{head}}}\cdot\mathrm{d}\theta^{\mathrm{head}}\right\rangle= ⟨ ∇ start_POSTSUBSCRIPT over^ start_ARG bold_italic_y end_ARG end_POSTSUBSCRIPT sansserif_L , divide start_ARG roman_d italic_h end_ARG start_ARG roman_d italic_θ start_POSTSUPERSCRIPT roman_head end_POSTSUPERSCRIPT end_ARG ⋅ roman_d italic_θ start_POSTSUPERSCRIPT roman_head end_POSTSUPERSCRIPT ⟩ (A.2.51)
=(dhdθhead)𝒚^𝖫,dθhead\displaystyle=\left\langle\left(\frac{\mathrm{d}h}{\mathrm{d}\theta^{\mathrm{head}}}\right)^{*}\nabla_{\hat{\bm{y}}}\mathsf{L},\mathrm{d}\theta^{\mathrm{head}}\right\rangle= ⟨ ( divide start_ARG roman_d italic_h end_ARG start_ARG roman_d italic_θ start_POSTSUPERSCRIPT roman_head end_POSTSUPERSCRIPT end_ARG ) start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ∇ start_POSTSUBSCRIPT over^ start_ARG bold_italic_y end_ARG end_POSTSUBSCRIPT sansserif_L , roman_d italic_θ start_POSTSUPERSCRIPT roman_head end_POSTSUPERSCRIPT ⟩ (A.2.52)
=θhead,dθhead.\displaystyle=\langle\nabla_{\theta^{\mathrm{head}}}\mathcal{L},\mathrm{d}\theta^{\mathrm{head}}\rangle.= ⟨ ∇ start_POSTSUBSCRIPT italic_θ start_POSTSUPERSCRIPT roman_head end_POSTSUPERSCRIPT end_POSTSUBSCRIPT caligraphic_L , roman_d italic_θ start_POSTSUPERSCRIPT roman_head end_POSTSUPERSCRIPT ⟩ . (A.2.53)

Thus to compute θhead\nabla_{\theta^{\mathrm{head}}}\mathcal{L}∇ start_POSTSUBSCRIPT italic_θ start_POSTSUPERSCRIPT roman_head end_POSTSUPERSCRIPT end_POSTSUBSCRIPT caligraphic_L, we compute the gradient 𝒚^𝖫\nabla_{\hat{\bm{y}}}\mathsf{L}∇ start_POSTSUBSCRIPT over^ start_ARG bold_italic_y end_ARG end_POSTSUBSCRIPT sansserif_L and the adjoint666The adjoint is like a generalized transpose for more general linear transformations. Particularly, for a given pair of inner product spaces and linear transformation TTitalic_T between those spaces, the adjoint TT^{*}italic_T start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT is defined by the identity T𝐱,𝐲=𝐱,T𝐲\langle T\bm{x},\bm{y}\rangle=\langle\bm{x},T^{*}\bm{y}\rangle⟨ italic_T bold_italic_x , bold_italic_y ⟩ = ⟨ bold_italic_x , italic_T start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT bold_italic_y ⟩. In finite dimensions (i.e., all cases relevant to this book) the adjoint exists and is unique. of the derivative dhdθhead\frac{\mathrm{d}h}{\mathrm{d}\theta^{\mathrm{head}}}divide start_ARG roman_d italic_h end_ARG start_ARG roman_d italic_θ start_POSTSUPERSCRIPT roman_head end_POSTSUPERSCRIPT end_ARG and multiply (i.e., apply the adjoint linear transformation to the gradient). In practice, both derivatives can be computed by hand, but many modern computational frameworks can automatically define the derivatives (and/or their adjoints) given code for the “forward pass,” i.e., the loss function computation. While extending this automatic derivative definition to as many functions as possible is an area of active research, the resource [BEJ25] describes one basic approach to do it in some detail. By the way, backpropagation is also called the adjoint method for this reason — i.e., that we use adjoints derivatives to compute the gradient.

Now let us compute the differentials w.r.t. some θ\theta^{\ell}italic_θ start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT:

d\displaystyle\mathrm{d}\mathcal{L}roman_d caligraphic_L =dd𝒁+1d𝒁+1\displaystyle=\frac{\mathrm{d}\mathcal{L}}{\mathrm{d}\bm{Z}^{\ell+1}}\cdot\mathrm{d}\bm{Z}^{\ell+1}= divide start_ARG roman_d caligraphic_L end_ARG start_ARG roman_d bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ + 1 end_POSTSUPERSCRIPT end_ARG ⋅ roman_d bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ + 1 end_POSTSUPERSCRIPT (A.2.54)
=dd𝒁+1d(f(𝒁,θ))\displaystyle=\frac{\mathrm{d}\mathcal{L}}{\mathrm{d}\bm{Z}^{\ell+1}}\cdot\mathrm{d}{(f^{\ell}(\bm{Z}^{\ell},\theta^{\ell}))}= divide start_ARG roman_d caligraphic_L end_ARG start_ARG roman_d bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ + 1 end_POSTSUPERSCRIPT end_ARG ⋅ roman_d ( italic_f start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ( bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT , italic_θ start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ) ) (A.2.55)
=dd𝒁+1[dfd𝒁d𝒁+dfdθdθ]\displaystyle=\frac{\mathrm{d}\mathcal{L}}{\mathrm{d}\bm{Z}^{\ell+1}}\left[\frac{\mathrm{d}f^{\ell}}{\mathrm{d}\bm{Z}^{\ell}}\cdot\mathrm{d}\bm{Z}^{\ell}+\frac{\mathrm{d}f^{\ell}}{\mathrm{d}\theta^{\ell}}\cdot\mathrm{d}\theta^{\ell}\right]= divide start_ARG roman_d caligraphic_L end_ARG start_ARG roman_d bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ + 1 end_POSTSUPERSCRIPT end_ARG [ divide start_ARG roman_d italic_f start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT end_ARG start_ARG roman_d bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT end_ARG ⋅ roman_d bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT + divide start_ARG roman_d italic_f start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT end_ARG start_ARG roman_d italic_θ start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT end_ARG ⋅ roman_d italic_θ start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ] (A.2.56)
=dd𝒁+1dfdθdθ(b/c 𝒁 isn’t fn. of θ so d𝒁=0)\displaystyle=\frac{\mathrm{d}\mathcal{L}}{\mathrm{d}\bm{Z}^{\ell+1}}\cdot\frac{\mathrm{d}f^{\ell}}{\mathrm{d}\theta^{\ell}}\cdot\mathrm{d}\theta^{\ell}\qquad\text{(b/c~{}$\bm{Z}^{\ell}$ isn't fn.~{}of $\theta^{\ell}$ so $\mathrm{d}\bm{Z}^{\ell}=0$)}= divide start_ARG roman_d caligraphic_L end_ARG start_ARG roman_d bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ + 1 end_POSTSUPERSCRIPT end_ARG ⋅ divide start_ARG roman_d italic_f start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT end_ARG start_ARG roman_d italic_θ start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT end_ARG ⋅ roman_d italic_θ start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT (b/c bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT isn’t fn. of italic_θ start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT so roman_d bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT = 0 ) (A.2.57)
=[𝒁+1]dfdθdθ\displaystyle=[\nabla_{\bm{Z}^{\ell+1}}\mathcal{L}]^{\top}\frac{\mathrm{d}f^{\ell}}{\mathrm{d}\theta^{\ell}}\cdot\mathrm{d}\theta^{\ell}= [ ∇ start_POSTSUBSCRIPT bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ + 1 end_POSTSUPERSCRIPT end_POSTSUBSCRIPT caligraphic_L ] start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT divide start_ARG roman_d italic_f start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT end_ARG start_ARG roman_d italic_θ start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT end_ARG ⋅ roman_d italic_θ start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT (A.2.58)
=𝒁+1,dfdθdθ\displaystyle=\left\langle\nabla_{\bm{Z}^{\ell+1}}\mathcal{L},\frac{\mathrm{d}f^{\ell}}{\mathrm{d}\theta^{\ell}}\cdot\mathrm{d}\theta^{\ell}\right\rangle= ⟨ ∇ start_POSTSUBSCRIPT bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ + 1 end_POSTSUPERSCRIPT end_POSTSUBSCRIPT caligraphic_L , divide start_ARG roman_d italic_f start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT end_ARG start_ARG roman_d italic_θ start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT end_ARG ⋅ roman_d italic_θ start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ⟩ (A.2.59)
=(dfdθ)𝒁+1,dθ\displaystyle=\left\langle\left(\frac{\mathrm{d}f^{\ell}}{\mathrm{d}\theta^{\ell}}\right)^{*}\nabla_{\bm{Z}^{\ell+1}}\mathcal{L},\mathrm{d}\theta^{\ell}\right\rangle= ⟨ ( divide start_ARG roman_d italic_f start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT end_ARG start_ARG roman_d italic_θ start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT end_ARG ) start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ∇ start_POSTSUBSCRIPT bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ + 1 end_POSTSUPERSCRIPT end_POSTSUBSCRIPT caligraphic_L , roman_d italic_θ start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ⟩ (A.2.60)
=θ,dθ.\displaystyle=\langle\nabla_{\theta^{\ell}}\mathcal{L},\mathrm{d}\theta^{\ell}\rangle.= ⟨ ∇ start_POSTSUBSCRIPT italic_θ start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT caligraphic_L , roman_d italic_θ start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ⟩ . (A.2.61)

Thus to compute θ\nabla_{\theta^{\ell}}\mathcal{L}∇ start_POSTSUBSCRIPT italic_θ start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT caligraphic_L we compute 𝒁+1\nabla_{\bm{Z}^{\ell+1}}\mathcal{L}∇ start_POSTSUBSCRIPT bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ + 1 end_POSTSUPERSCRIPT end_POSTSUBSCRIPT caligraphic_L then apply the adjoint derivative (dfdθ)\left(\frac{\mathrm{d}f^{\ell}}{\mathrm{d}\theta^{\ell}}\right)^{*}( divide start_ARG roman_d italic_f start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT end_ARG start_ARG roman_d italic_θ start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT end_ARG ) start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT to it. Since θ\nabla_{\theta^{\ell}}\mathcal{L}∇ start_POSTSUBSCRIPT italic_θ start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT caligraphic_L depends on 𝒁+1\nabla_{\bm{Z}^{\ell+1}}\mathcal{L}∇ start_POSTSUBSCRIPT bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ + 1 end_POSTSUPERSCRIPT end_POSTSUBSCRIPT caligraphic_L, we also want to be able to compute the gradients w.r.t. 𝒁\bm{Z}^{\ell}bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT. This can be computed in the exact same way:

d\displaystyle\mathrm{d}\mathcal{L}roman_d caligraphic_L =dd𝒁+1d𝒁+1\displaystyle=\frac{\mathrm{d}\mathcal{L}}{\mathrm{d}\bm{Z}^{\ell+1}}\cdot\mathrm{d}\bm{Z}^{\ell+1}= divide start_ARG roman_d caligraphic_L end_ARG start_ARG roman_d bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ + 1 end_POSTSUPERSCRIPT end_ARG ⋅ roman_d bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ + 1 end_POSTSUPERSCRIPT (A.2.62)
=dd𝒁+1d(f(𝒁,θ))\displaystyle=\frac{\mathrm{d}\mathcal{L}}{\mathrm{d}\bm{Z}^{\ell+1}}\cdot\mathrm{d}{(f^{\ell}(\bm{Z}^{\ell},\theta^{\ell}))}= divide start_ARG roman_d caligraphic_L end_ARG start_ARG roman_d bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ + 1 end_POSTSUPERSCRIPT end_ARG ⋅ roman_d ( italic_f start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ( bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT , italic_θ start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ) ) (A.2.63)
=dd𝒁+1[dfd𝒁d𝒁+dfdθdθ]\displaystyle=\frac{\mathrm{d}\mathcal{L}}{\mathrm{d}\bm{Z}^{\ell+1}}\left[\frac{\mathrm{d}f^{\ell}}{\mathrm{d}\bm{Z}^{\ell}}\cdot\mathrm{d}\bm{Z}^{\ell}+\frac{\mathrm{d}f^{\ell}}{\mathrm{d}\theta^{\ell}}\cdot\mathrm{d}\theta^{\ell}\right]= divide start_ARG roman_d caligraphic_L end_ARG start_ARG roman_d bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ + 1 end_POSTSUPERSCRIPT end_ARG [ divide start_ARG roman_d italic_f start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT end_ARG start_ARG roman_d bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT end_ARG ⋅ roman_d bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT + divide start_ARG roman_d italic_f start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT end_ARG start_ARG roman_d italic_θ start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT end_ARG ⋅ roman_d italic_θ start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ] (A.2.64)
=dd𝒁+1dfd𝒁d𝒁(b/c θ isn’t fn. of 𝒁 so dθ=0 this time)\displaystyle=\frac{\mathrm{d}\mathcal{L}}{\mathrm{d}\bm{Z}^{\ell+1}}\cdot\frac{\mathrm{d}f^{\ell}}{\mathrm{d}\bm{Z}^{\ell}}\cdot\mathrm{d}\bm{Z}^{\ell}\qquad\text{(b/c~{}~{}$\theta^{\ell}$ isn't fn.~{}of $\bm{Z}^{\ell}$ so $\mathrm{d}\theta^{\ell}=0$ this time)}= divide start_ARG roman_d caligraphic_L end_ARG start_ARG roman_d bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ + 1 end_POSTSUPERSCRIPT end_ARG ⋅ divide start_ARG roman_d italic_f start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT end_ARG start_ARG roman_d bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT end_ARG ⋅ roman_d bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT (b/c italic_θ start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT isn’t fn. of bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT so roman_d italic_θ start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT = 0 this time) (A.2.65)
=(dfd𝒁)𝒁+1,d𝒁(same machine)\displaystyle=\left\langle\left(\frac{\mathrm{d}f^{\ell}}{\mathrm{d}\bm{Z}^{\ell}}\right)^{*}\nabla_{\bm{Z}^{\ell+1}}\mathcal{L},\mathrm{d}\bm{Z}^{\ell}\right\rangle\qquad\text{(same machine)}= ⟨ ( divide start_ARG roman_d italic_f start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT end_ARG start_ARG roman_d bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT end_ARG ) start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ∇ start_POSTSUBSCRIPT bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ + 1 end_POSTSUPERSCRIPT end_POSTSUBSCRIPT caligraphic_L , roman_d bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ⟩ (same machine) (A.2.66)
=𝒁,d𝒁.\displaystyle=\langle\nabla_{\bm{Z}^{\ell}}\mathcal{L},\mathrm{d}\bm{Z}^{\ell}\rangle.= ⟨ ∇ start_POSTSUBSCRIPT bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT caligraphic_L , roman_d bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ⟩ . (A.2.67)

Thus to compute 𝒁\nabla_{\bm{Z}^{\ell}}\mathcal{L}∇ start_POSTSUBSCRIPT bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT caligraphic_L we compute 𝒁+1\nabla_{\bm{Z}^{\ell+1}}\mathcal{L}∇ start_POSTSUBSCRIPT bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ + 1 end_POSTSUPERSCRIPT end_POSTSUBSCRIPT caligraphic_L then apply the adjoint derivative (dfd𝒁)\left(\frac{\mathrm{d}f^{\ell}}{\mathrm{d}\bm{Z}^{\ell}}\right)^{*}( divide start_ARG roman_d italic_f start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT end_ARG start_ARG roman_d bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT end_ARG ) start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT to it. So we have a recursion to compute 𝒁\nabla_{\bm{Z}^{\ell}}\mathcal{L}∇ start_POSTSUBSCRIPT bold_italic_Z start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT end_POSTSUBSCRIPT caligraphic_L for all \ellroman_ℓ, with base case 𝒁L+1\nabla_{\bm{Z}^{L+1}}\mathcal{L}∇ start_POSTSUBSCRIPT bold_italic_Z start_POSTSUPERSCRIPT italic_L + 1 end_POSTSUPERSCRIPT end_POSTSUBSCRIPT caligraphic_L, which is given by

d\displaystyle\mathrm{d}\mathcal{L}roman_d caligraphic_L =d𝖫\displaystyle=\mathrm{d}\mathsf{L}= roman_d sansserif_L (A.2.68)
=d𝖫d𝒚^d𝒚^\displaystyle=\frac{\mathrm{d}\mathsf{L}}{\mathrm{d}\hat{\bm{y}}}\cdot\mathrm{d}\hat{\bm{y}}= divide start_ARG roman_d sansserif_L end_ARG start_ARG roman_d over^ start_ARG bold_italic_y end_ARG end_ARG ⋅ roman_d over^ start_ARG bold_italic_y end_ARG (A.2.69)
=d𝖫d𝒚^dh(𝒁L+1,θhead)\displaystyle=\frac{\mathrm{d}\mathsf{L}}{\mathrm{d}\hat{\bm{y}}}\cdot\mathrm{d}{h(\bm{Z}^{L+1},\theta^{\mathrm{head}})}= divide start_ARG roman_d sansserif_L end_ARG start_ARG roman_d over^ start_ARG bold_italic_y end_ARG end_ARG ⋅ roman_d italic_h ( bold_italic_Z start_POSTSUPERSCRIPT italic_L + 1 end_POSTSUPERSCRIPT , italic_θ start_POSTSUPERSCRIPT roman_head end_POSTSUPERSCRIPT ) (A.2.70)
=d𝖫d𝒚^[dhd𝒁L+1d𝒁L+1+dhdθheaddθhead]\displaystyle=\frac{\mathrm{d}\mathsf{L}}{\mathrm{d}\hat{\bm{y}}}\left[\frac{\mathrm{d}h}{\mathrm{d}\bm{Z}^{L+1}}\cdot\mathrm{d}\bm{Z}^{L+1}+\frac{\mathrm{d}h}{\mathrm{d}\theta^{\mathrm{head}}}\cdot\mathrm{d}\theta^{\mathrm{head}}\right]= divide start_ARG roman_d sansserif_L end_ARG start_ARG roman_d over^ start_ARG bold_italic_y end_ARG end_ARG [ divide start_ARG roman_d italic_h end_ARG start_ARG roman_d bold_italic_Z start_POSTSUPERSCRIPT italic_L + 1 end_POSTSUPERSCRIPT end_ARG ⋅ roman_d bold_italic_Z start_POSTSUPERSCRIPT italic_L + 1 end_POSTSUPERSCRIPT + divide start_ARG roman_d italic_h end_ARG start_ARG roman_d italic_θ start_POSTSUPERSCRIPT roman_head end_POSTSUPERSCRIPT end_ARG ⋅ roman_d italic_θ start_POSTSUPERSCRIPT roman_head end_POSTSUPERSCRIPT ] (A.2.71)
=d𝖫d𝒚^dhd𝒁L+1d𝒁L+1\displaystyle=\frac{\mathrm{d}\mathsf{L}}{\mathrm{d}\hat{\bm{y}}}\cdot\frac{\mathrm{d}h}{\mathrm{d}\bm{Z}^{L+1}}\cdot\mathrm{d}\bm{Z}^{L+1}= divide start_ARG roman_d sansserif_L end_ARG start_ARG roman_d over^ start_ARG bold_italic_y end_ARG end_ARG ⋅ divide start_ARG roman_d italic_h end_ARG start_ARG roman_d bold_italic_Z start_POSTSUPERSCRIPT italic_L + 1 end_POSTSUPERSCRIPT end_ARG ⋅ roman_d bold_italic_Z start_POSTSUPERSCRIPT italic_L + 1 end_POSTSUPERSCRIPT (A.2.72)
=(dhd𝒁L+1)𝒚^𝖫,d𝒁L+1\displaystyle=\left\langle\left(\frac{\mathrm{d}h}{\mathrm{d}\bm{Z}^{L+1}}\right)^{*}\nabla_{\hat{\bm{y}}}\mathsf{L},\mathrm{d}\bm{Z}^{L+1}\right\rangle= ⟨ ( divide start_ARG roman_d italic_h end_ARG start_ARG roman_d bold_italic_Z start_POSTSUPERSCRIPT italic_L + 1 end_POSTSUPERSCRIPT end_ARG ) start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ∇ start_POSTSUBSCRIPT over^ start_ARG bold_italic_y end_ARG end_POSTSUBSCRIPT sansserif_L , roman_d bold_italic_Z start_POSTSUPERSCRIPT italic_L + 1 end_POSTSUPERSCRIPT ⟩ (A.2.73)
=𝒁L+1,d𝒁L+1.\displaystyle=\langle\nabla_{\bm{Z}^{L+1}}\mathcal{L},\mathrm{d}\bm{Z}^{L+1}\rangle.= ⟨ ∇ start_POSTSUBSCRIPT bold_italic_Z start_POSTSUPERSCRIPT italic_L + 1 end_POSTSUPERSCRIPT end_POSTSUBSCRIPT caligraphic_L , roman_d bold_italic_Z start_POSTSUPERSCRIPT italic_L + 1 end_POSTSUPERSCRIPT ⟩ . (A.2.74)

Thus we have the recursion:

𝒁L+1\displaystyle\nabla_{\bm{Z}^{L+1}}\mathcal{L}∇ start_POSTSUBSCRIPT bold_italic_Z start_POSTSUPERSCRIPT italic_L + 1 end_POSTSUPERSCRIPT end_POSTSUBSCRIPT caligraphic_L =(dhd𝒁L+1)𝒚^𝖫\displaystyle=\left(\frac{\mathrm{d}h}{\mathrm{d}\bm{Z}^{L+1}}\right)^{*}\nabla_{\hat{\bm{y}}}\mathsf{L}= ( divide start_ARG roman_d italic_h end_ARG start_ARG roman_d bold_italic_Z start_POSTSUPERSCRIPT italic_L + 1 end_POSTSUPERSCRIPT end_ARG ) start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ∇ start_POSTSUBSCRIPT over^ start_ARG bold_italic_y end_ARG end_POSTSUBSCRIPT sansserif_L (A.2.75)
𝒁L\displaystyle\nabla_{\bm{Z}^{L}}\mathcal{L}∇ start_POSTSUBSCRIPT bold_italic_Z start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT end_POSTSUBSCRIPT caligraphic_L =(dfLd𝒁L)𝒁L+1\displaystyle=\left(\frac{\mathrm{d}f^{L}}{\mathrm{d}\bm{Z}^{L}}\right)^{*}\nabla_{\bm{Z}^{L+1}}\mathcal{L}= ( divide start_ARG roman_d italic_f start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT end_ARG start_ARG roman_d bold_italic_Z start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT end_ARG ) start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ∇ start_POSTSUBSCRIPT bold_italic_Z start_POSTSUPERSCRIPT italic_L + 1 end_POSTSUPERSCRIPT end_POSTSUBSCRIPT caligraphic_L (A.2.76)
θL\displaystyle\nabla_{\theta^{L}}\mathcal{L}∇ start_POSTSUBSCRIPT italic_θ start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT end_POSTSUBSCRIPT caligraphic_L =(dfLdθL)𝒁L+1\displaystyle=\left(\frac{\mathrm{d}f^{L}}{\mathrm{d}\theta^{L}}\right)^{*}\nabla_{\bm{Z}^{L+1}}\mathcal{L}= ( divide start_ARG roman_d italic_f start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT end_ARG start_ARG roman_d italic_θ start_POSTSUPERSCRIPT italic_L end_POSTSUPERSCRIPT end_ARG ) start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ∇ start_POSTSUBSCRIPT bold_italic_Z start_POSTSUPERSCRIPT italic_L + 1 end_POSTSUPERSCRIPT end_POSTSUBSCRIPT caligraphic_L (A.2.77)
\displaystyle\vdots (A.2.78)
𝒁1\displaystyle\nabla_{\bm{Z}^{1}}\mathcal{L}∇ start_POSTSUBSCRIPT bold_italic_Z start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT end_POSTSUBSCRIPT caligraphic_L =(df1d𝒁1)𝒁2\displaystyle=\left(\frac{\mathrm{d}f^{1}}{\mathrm{d}\bm{Z}^{1}}\right)^{*}\nabla_{\bm{Z}^{2}}\mathcal{L}= ( divide start_ARG roman_d italic_f start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT end_ARG start_ARG roman_d bold_italic_Z start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT end_ARG ) start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ∇ start_POSTSUBSCRIPT bold_italic_Z start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_POSTSUBSCRIPT caligraphic_L (A.2.79)
θ1\displaystyle\nabla_{\theta^{1}}\mathcal{L}∇ start_POSTSUBSCRIPT italic_θ start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT end_POSTSUBSCRIPT caligraphic_L =(df1dθ1)𝒁2\displaystyle=\left(\frac{\mathrm{d}f^{1}}{\mathrm{d}\theta^{1}}\right)^{*}\nabla_{\bm{Z}^{2}}\mathcal{L}= ( divide start_ARG roman_d italic_f start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT end_ARG start_ARG roman_d italic_θ start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT end_ARG ) start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ∇ start_POSTSUBSCRIPT bold_italic_Z start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT end_POSTSUBSCRIPT caligraphic_L (A.2.80)
θemb\displaystyle\nabla_{\theta^{\mathrm{emb}}}\mathcal{L}∇ start_POSTSUBSCRIPT italic_θ start_POSTSUPERSCRIPT roman_emb end_POSTSUPERSCRIPT end_POSTSUBSCRIPT caligraphic_L =(dfembdθemb)𝒁1.\displaystyle=\left(\frac{\mathrm{d}f^{\mathrm{emb}}}{\mathrm{d}\theta^{\mathrm{emb}}}\right)^{*}\nabla_{\bm{Z}^{1}}\mathcal{L}.= ( divide start_ARG roman_d italic_f start_POSTSUPERSCRIPT roman_emb end_POSTSUPERSCRIPT end_ARG start_ARG roman_d italic_θ start_POSTSUPERSCRIPT roman_emb end_POSTSUPERSCRIPT end_ARG ) start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT ∇ start_POSTSUBSCRIPT bold_italic_Z start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT end_POSTSUBSCRIPT caligraphic_L . (A.2.81)

This gives us a computationally efficient algorithm to find all gradients in the whole network.

We’ll finish this section by computing the adjoint derivative for a simple layer.

Example A.8.

Consider the “linear” (affine) layer ff^{\ell}italic_f start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT

f(𝒁,𝑾,𝒃)𝑾𝒁+𝒃𝟏=[𝑾𝒃].f^{\ell}(\bm{Z},\bm{W}^{\ell},\bm{b}^{\ell})\doteq\bm{W}^{\ell}\bm{Z}+\bm{b}^{\ell}\bm{1}^{\top}=\begin{bmatrix}\bm{W}^{\ell}&\bm{b}^{\ell}\end{bmatrix}.italic_f start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ( bold_italic_Z , bold_italic_W start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT , bold_italic_b start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ) ≐ bold_italic_W start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT bold_italic_Z + bold_italic_b start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT bold_1 start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT = [ start_ARG start_ROW start_CELL bold_italic_W start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT end_CELL start_CELL bold_italic_b start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT end_CELL end_ROW end_ARG ] . (A.2.82)

We can compute the differential w.r.t. both parameters as

df\displaystyle\mathrm{d}f^{\ell}roman_d italic_f start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT =[(𝑾+d𝑾)𝒁+(𝒃+d𝒃)𝟏][𝑾𝒁+𝒃]\displaystyle=[(\bm{W}^{\ell}+\mathrm{d}\bm{W}^{\ell})\bm{Z}+(\bm{b}^{\ell}+\mathrm{d}\bm{b}^{\ell})\bm{1}^{\top}]-[\bm{W}^{\ell}\bm{Z}+\bm{b}^{\ell}]= [ ( bold_italic_W start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT + roman_d bold_italic_W start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ) bold_italic_Z + ( bold_italic_b start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT + roman_d bold_italic_b start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ) bold_1 start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ] - [ bold_italic_W start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT bold_italic_Z + bold_italic_b start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ] (A.2.83)
=(d𝑾)𝒁+(d𝒃)𝟏.\displaystyle=(\mathrm{d}\bm{W}^{\ell})\bm{Z}+(\mathrm{d}\bm{b}^{\ell})\bm{1}^{\top}.= ( roman_d bold_italic_W start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ) bold_italic_Z + ( roman_d bold_italic_b start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ) bold_1 start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT . (A.2.84)

Thus the derivative of this transformation is

dfd(𝑾,𝒃)[d𝑾,d𝒃]=(d𝑾)𝒁+(d𝒃)𝟏,\frac{\mathrm{d}f^{\ell}}{\mathrm{d}(\bm{W}^{\ell},\bm{b}^{\ell})}[\mathrm{d}\bm{W}^{\ell},\mathrm{d}\bm{b}^{\ell}]=(\mathrm{d}\bm{W}^{\ell})\bm{Z}+(\mathrm{d}\bm{b}^{\ell})\bm{1}^{\top},divide start_ARG roman_d italic_f start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT end_ARG start_ARG roman_d ( bold_italic_W start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT , bold_italic_b start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ) end_ARG [ roman_d bold_italic_W start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT , roman_d bold_italic_b start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ] = ( roman_d bold_italic_W start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ) bold_italic_Z + ( roman_d bold_italic_b start_POSTSUPERSCRIPT roman_ℓ end_POSTSUPERSCRIPT ) bold_1 start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT , (A.2.85)

namely, representing the following linear transformation from m×d×m\mathbb{R}^{m\times d}\times\mathbb{R}^{m}blackboard_R start_POSTSUPERSCRIPT italic_m × italic_d end_POSTSUPERSCRIPT × blackboard_R start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT to m×n\mathbb{R}^{m\times n}blackboard_R start_POSTSUPERSCRIPT italic_m × italic_n end_POSTSUPERSCRIPT:

T[𝑨,𝒖]=𝑨𝒁+𝒖𝟏.T[\bm{A},\bm{u}]=\bm{A}\bm{Z}+\bm{u}\bm{1}^{\top}.italic_T [ bold_italic_A , bold_italic_u ] = bold_italic_A bold_italic_Z + bold_italic_u bold_1 start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT . (A.2.86)

We calculate the adjoint T:m×nm×d×mT^{*}\colon\mathbb{R}^{m\times n}\to\mathbb{R}^{m\times d}\times\mathbb{R}^{m}italic_T start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT : blackboard_R start_POSTSUPERSCRIPT italic_m × italic_n end_POSTSUPERSCRIPT → blackboard_R start_POSTSUPERSCRIPT italic_m × italic_d end_POSTSUPERSCRIPT × blackboard_R start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT w.r.t. the sum-over-coordinates (Frobenius) inner product by the following procedure:

T[𝑨,𝒖],𝑩m×n\displaystyle\langle T[\bm{A},\bm{u}],\bm{B}\rangle_{\mathbb{R}^{m\times n}}⟨ italic_T [ bold_italic_A , bold_italic_u ] , bold_italic_B ⟩ start_POSTSUBSCRIPT blackboard_R start_POSTSUPERSCRIPT italic_m × italic_n end_POSTSUPERSCRIPT end_POSTSUBSCRIPT =tr((𝑨𝒁+𝒖𝟏)𝑩)\displaystyle=\operatorname{tr}((\bm{A}\bm{Z}+\bm{u}\bm{1}^{\top})\bm{B}^{\top})= roman_tr ( ( bold_italic_A bold_italic_Z + bold_italic_u bold_1 start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) bold_italic_B start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) (A.2.87)
=tr(𝑨𝒁𝑩+𝒖𝟏𝑩)\displaystyle=\operatorname{tr}(\bm{A}\bm{Z}\bm{B}^{\top}+\bm{u}\bm{1}^{\top}\bm{B}^{\top})= roman_tr ( bold_italic_A bold_italic_Z bold_italic_B start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT + bold_italic_u bold_1 start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_B start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) (A.2.88)
=tr(𝑨𝒁𝑩)+tr(𝒖𝟏𝑩)\displaystyle=\operatorname{tr}(\bm{A}\bm{Z}\bm{B}^{\top})+\operatorname{tr}(\bm{u}\bm{1}^{\top}\bm{B}^{\top})= roman_tr ( bold_italic_A bold_italic_Z bold_italic_B start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) + roman_tr ( bold_italic_u bold_1 start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_B start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) (A.2.89)
=tr(𝑩𝒁𝑨)+tr(𝟏𝑩𝒖)\displaystyle=\operatorname{tr}(\bm{B}\bm{Z}^{\top}\bm{A}^{\top})+\operatorname{tr}(\bm{1}^{\top}\bm{B}^{\top}\bm{u})= roman_tr ( bold_italic_B bold_italic_Z start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_A start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT ) + roman_tr ( bold_1 start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_B start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT bold_italic_u ) (A.2.90)
=𝑩𝒁,𝑨m×d+𝑩𝟏,𝒖m\displaystyle=\langle\bm{B}\bm{Z}^{\top},\bm{A}\rangle_{\mathbb{R}^{m\times d}}+\langle\bm{B}\bm{1},\bm{u}\rangle_{\mathbb{R}^{m}}= ⟨ bold_italic_B bold_italic_Z start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT , bold_italic_A ⟩ start_POSTSUBSCRIPT blackboard_R start_POSTSUPERSCRIPT italic_m × italic_d end_POSTSUPERSCRIPT end_POSTSUBSCRIPT + ⟨ bold_italic_B bold_1 , bold_italic_u ⟩ start_POSTSUBSCRIPT blackboard_R start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT end_POSTSUBSCRIPT (A.2.91)
=𝑩(𝒁,𝟏),(𝑨,𝒖)m×d×m\displaystyle=\langle\bm{B}(\bm{Z}^{\top},\bm{1}),(\bm{A},\bm{u})\rangle_{\mathbb{R}^{m\times d}\times\mathbb{R}^{m}}= ⟨ bold_italic_B ( bold_italic_Z start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT , bold_1 ) , ( bold_italic_A , bold_italic_u ) ⟩ start_POSTSUBSCRIPT blackboard_R start_POSTSUPERSCRIPT italic_m × italic_d end_POSTSUPERSCRIPT × blackboard_R start_POSTSUPERSCRIPT italic_m end_POSTSUPERSCRIPT end_POSTSUBSCRIPT (A.2.92)

So T𝑩=𝑩(𝒁,𝟏)T^{*}\bm{B}=\bm{B}(\bm{Z}^{\top},\bm{1})italic_T start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT bold_italic_B = bold_italic_B ( bold_italic_Z start_POSTSUPERSCRIPT ⊤ end_POSTSUPERSCRIPT , bold_1 ). \blacksquare

Note that as a simple application of chain rule, both backpropagation and automatic differentiation work over general “computational graphs”, i.e., compositions of (simple) functions. We give all examples as neural network layers because this is the most common example in practice.

A.3 Game Theory and Minimax Optimization

In certain cases, such as in Chapter 5, a learning problem cannot be reduced to a single optimization problem but rather represents multiple potentially opposing components of the system try to each minimize their own objective. Examples of this paradigm include distribution learning via generative adversarial networks (GAN) and closed-loop transcription (CTRL). We will denote such a system as a two-player game, where we have two “players” (i.e., components) called Player 1 and Player 2 trying to minimize their objectives 1\mathcal{L}^{1}caligraphic_L start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT and 2\mathcal{L}^{2}caligraphic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT respectively. Player 1 picks parameters θΘ\theta\in\Thetaitalic_θ ∈ roman_Θ and Player 2 picks parameters ηH\eta\in\mathrm{H}italic_η ∈ roman_H. In this book we consider the special case of zero-sum games, i.e., defining a common objective \mathcal{L}caligraphic_L such that =1=2\mathcal{L}=-\mathcal{L}^{1}=\mathcal{L}^{2}caligraphic_L = - caligraphic_L start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT = caligraphic_L start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT.

Our first, very preliminary example is as follows. Suppose that there exists functions u(θ)u(\theta)italic_u ( italic_θ ) and v(η)v(\eta)italic_v ( italic_η ) such that

(θ,η)=u(θ)+v(η).\mathcal{L}(\theta,\eta)=-u(\theta)+v(\eta).caligraphic_L ( italic_θ , italic_η ) = - italic_u ( italic_θ ) + italic_v ( italic_η ) . (A.3.1)

Then both players’ objectives are independent of the other player, and the players should try to achieve their respective optima:

θargminθΘu(θ),ηargminηHv(η).\theta^{\star}\in\operatorname*{arg\ min}_{\theta\in\Theta}u(\theta),\qquad\eta^{\star}\in\operatorname*{arg\ min}_{\eta\in\mathrm{H}}v(\eta).italic_θ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ∈ start_OPERATOR roman_arg roman_min end_OPERATOR start_POSTSUBSCRIPT italic_θ ∈ roman_Θ end_POSTSUBSCRIPT italic_u ( italic_θ ) , italic_η start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ∈ start_OPERATOR roman_arg roman_min end_OPERATOR start_POSTSUBSCRIPT italic_η ∈ roman_H end_POSTSUBSCRIPT italic_v ( italic_η ) . (A.3.2)

The pair (θ,η)(\theta^{\star},\eta^{\star})( italic_θ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT , italic_η start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) is a straightforward special case of an equilibrium: a situation where neither player will want to move, given the chance, since moving will end up making their own situation worse. However, not all games are so trivial; many have more complicated objectives and information structures.

In this book, the relevant game-theoretic formalism is a Stackelberg game (variously called sequential game). In this formalism, one player (without loss of generality Player 1, and also described as a leader) picks their parameters before the other (i.e., Player 2, also described as a follower), and the follower can use the full knowledge of the leader’s choice to make their own choice. The correct notion of equilibrium for a Stackelberg game is a Stackelberg equilibrium. To explain this equilibrium, note that since Player 2 (i.e., the follower) can choose η\etaitalic_η reactively to the choice θ1\theta_{1}italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT made by Player 1 (i.e., the leader), Player 2 would of course choose the η\etaitalic_η which minimizes (θ1,)\mathcal{L}(\theta_{1},\cdot)caligraphic_L ( italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT , ⋅ ). But of course a rational Player 1 would realize this, and so pick a θ1\theta_{1}italic_θ start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT such that the worst-case η\etaitalic_η picked by Player 2 according to this rule is not too bad. More formally, let 𝒮(θ)argminηH(θ,η)\mathcal{S}(\theta)\doteq\operatorname*{arg\ min}_{\eta\in\mathrm{H}}\mathcal{L}(\theta,\eta)caligraphic_S ( italic_θ ) ≐ start_OPERATOR roman_arg roman_min end_OPERATOR start_POSTSUBSCRIPT italic_η ∈ roman_H end_POSTSUBSCRIPT caligraphic_L ( italic_θ , italic_η ) be the set of η\etaitalic_η minimizing (θ,η)\mathcal{L}(\theta,\eta)caligraphic_L ( italic_θ , italic_η ), i.e., the set of all η\etaitalic_η which Player 2 is liable to play given that Player 1 has played θ\thetaitalic_θ. Then (θ,η)(\theta^{\star},\eta^{\star})( italic_θ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT , italic_η start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) is a Stackelberg equilibrium if

θargmaxθΘminη𝒮(θ)(θ,η),ηargminηH(θ,η).\theta^{\star}\in\operatorname*{arg\ max}_{\theta\in\Theta}\min_{\eta\in\mathcal{S}(\theta)}\mathcal{L}(\theta,\eta),\qquad\eta^{\star}\in\operatorname*{arg\ min}_{\eta\in\mathrm{H}}\mathcal{L}(\theta^{\star},\eta).italic_θ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ∈ start_OPERATOR roman_arg roman_max end_OPERATOR start_POSTSUBSCRIPT italic_θ ∈ roman_Θ end_POSTSUBSCRIPT roman_min start_POSTSUBSCRIPT italic_η ∈ caligraphic_S ( italic_θ ) end_POSTSUBSCRIPT caligraphic_L ( italic_θ , italic_η ) , italic_η start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ∈ start_OPERATOR roman_arg roman_min end_OPERATOR start_POSTSUBSCRIPT italic_η ∈ roman_H end_POSTSUBSCRIPT caligraphic_L ( italic_θ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT , italic_η ) . (A.3.3)

Actually (proof as exercise), one can show that in the context of two-player zero-sum Stackelberg games, (θ,η)(\theta^{\star},\eta^{\star})( italic_θ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT , italic_η start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) is a Stackelberg equilibrium if and only if

θargmaxθΘminηH(θ,η),ηargminηH(θ,η),\theta^{\star}\in\operatorname*{arg\ max}_{\theta\in\Theta}\min_{\eta\in\mathrm{H}}\mathcal{L}(\theta,\eta),\qquad\eta^{\star}\in\operatorname*{arg\ min}_{\eta\in\mathrm{H}}\mathcal{L}(\theta^{\star},\eta),italic_θ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ∈ start_OPERATOR roman_arg roman_max end_OPERATOR start_POSTSUBSCRIPT italic_θ ∈ roman_Θ end_POSTSUBSCRIPT roman_min start_POSTSUBSCRIPT italic_η ∈ roman_H end_POSTSUBSCRIPT caligraphic_L ( italic_θ , italic_η ) , italic_η start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ∈ start_OPERATOR roman_arg roman_min end_OPERATOR start_POSTSUBSCRIPT italic_η ∈ roman_H end_POSTSUBSCRIPT caligraphic_L ( italic_θ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT , italic_η ) , (A.3.4)

(note that the notation 𝒮(θ)\mathcal{S}(\theta)caligraphic_S ( italic_θ ) is not used nor needed).

Proof.

Note that

𝒮(θ)=argminηH(θ,η),\mathcal{S}(\theta)=\operatorname*{arg\ min}_{\eta\in\mathrm{H}}\mathcal{L}(\theta,\eta),caligraphic_S ( italic_θ ) = start_OPERATOR roman_arg roman_min end_OPERATOR start_POSTSUBSCRIPT italic_η ∈ roman_H end_POSTSUBSCRIPT caligraphic_L ( italic_θ , italic_η ) , (A.3.5)

so it holds

minη𝒮(θ)(θ,η)=minηargminηH(θ,η)(θ,η)=minηH(θ,η).\min_{\eta\in\mathcal{S}(\theta)}\mathcal{L}(\theta,\eta)=\min_{\eta\in\operatorname*{arg\ min}_{\eta^{\prime}\in\mathrm{H}}\mathcal{L}(\theta,\eta^{\prime})}\mathcal{L}(\theta,\eta)=\min_{\eta\in\mathrm{H}}\mathcal{L}(\theta,\eta).roman_min start_POSTSUBSCRIPT italic_η ∈ caligraphic_S ( italic_θ ) end_POSTSUBSCRIPT caligraphic_L ( italic_θ , italic_η ) = roman_min start_POSTSUBSCRIPT italic_η ∈ start_OPERATOR roman_arg roman_min end_OPERATOR start_POSTSUBSCRIPT italic_η start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ∈ roman_H end_POSTSUBSCRIPT caligraphic_L ( italic_θ , italic_η start_POSTSUPERSCRIPT ′ end_POSTSUPERSCRIPT ) end_POSTSUBSCRIPT caligraphic_L ( italic_θ , italic_η ) = roman_min start_POSTSUBSCRIPT italic_η ∈ roman_H end_POSTSUBSCRIPT caligraphic_L ( italic_θ , italic_η ) . (A.3.6)

In the rest of the section, we will briefly discuss some algorithmic approaches to learn Stackelberg equilibria. The intuition you should have is that learning an equilibrium is like letting the different parts of the system automatically figure out tradeoffs between the different objectives they want to optimize.

We end this section with a caveat: in two-player zero-sum games, if it holds that

maxθΘminηH(θ,η)=minηHmaxθΘ(θ,η)\max_{\theta\in\Theta}\min_{\eta\in\mathrm{H}}\mathcal{L}(\theta,\eta)=\min_{\eta\in\mathrm{H}}\max_{\theta\in\Theta}\mathcal{L}(\theta,\eta)roman_max start_POSTSUBSCRIPT italic_θ ∈ roman_Θ end_POSTSUBSCRIPT roman_min start_POSTSUBSCRIPT italic_η ∈ roman_H end_POSTSUBSCRIPT caligraphic_L ( italic_θ , italic_η ) = roman_min start_POSTSUBSCRIPT italic_η ∈ roman_H end_POSTSUBSCRIPT roman_max start_POSTSUBSCRIPT italic_θ ∈ roman_Θ end_POSTSUBSCRIPT caligraphic_L ( italic_θ , italic_η ) (A.3.7)

then every Stackelberg equilibrium is a saddle point,777Famously called a Nash equilibrium. i.e.,

θargmaxθΘ(θ,η),ηargminηH(θ,η),\theta^{\star}\in\operatorname*{arg\ max}_{\theta\in\Theta}\mathcal{L}(\theta,\eta^{\star}),\qquad\eta^{\star}\in\operatorname*{arg\ min}_{\eta\in\mathrm{H}}\mathcal{L}(\theta^{\star},\eta),italic_θ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ∈ start_OPERATOR roman_arg roman_max end_OPERATOR start_POSTSUBSCRIPT italic_θ ∈ roman_Θ end_POSTSUBSCRIPT caligraphic_L ( italic_θ , italic_η start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) , italic_η start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ∈ start_OPERATOR roman_arg roman_min end_OPERATOR start_POSTSUBSCRIPT italic_η ∈ roman_H end_POSTSUBSCRIPT caligraphic_L ( italic_θ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT , italic_η ) , (A.3.8)

and vice versa, and furthermore each Stackelberg equilibrium has the (same) objective value

maxθΘminηH(θ,η).\max_{\theta\in\Theta}\min_{\eta\in\mathrm{H}}\mathcal{L}(\theta,\eta).roman_max start_POSTSUBSCRIPT italic_θ ∈ roman_Θ end_POSTSUBSCRIPT roman_min start_POSTSUBSCRIPT italic_η ∈ roman_H end_POSTSUBSCRIPT caligraphic_L ( italic_θ , italic_η ) .
Proof.

Suppose that indeed

maxθΘminηH(θ,η)=minηHmaxθΘ(θ,η).\max_{\theta\in\Theta}\min_{\eta\in\mathrm{H}}\mathcal{L}(\theta,\eta)=\min_{\eta\in\mathrm{H}}\max_{\theta\in\Theta}\mathcal{L}(\theta,\eta).roman_max start_POSTSUBSCRIPT italic_θ ∈ roman_Θ end_POSTSUBSCRIPT roman_min start_POSTSUBSCRIPT italic_η ∈ roman_H end_POSTSUBSCRIPT caligraphic_L ( italic_θ , italic_η ) = roman_min start_POSTSUBSCRIPT italic_η ∈ roman_H end_POSTSUBSCRIPT roman_max start_POSTSUBSCRIPT italic_θ ∈ roman_Θ end_POSTSUBSCRIPT caligraphic_L ( italic_θ , italic_η ) . (A.3.9)

First suppose that (θ,η)(\theta^{\star},\eta^{\star})( italic_θ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT , italic_η start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) is a saddle point. We will show it is a Stackelberg equilibrium. By definition we have

minηH(θ,η)=(θ,η)=maxθΘ(θ,η).\min_{\eta\in\mathrm{H}}\mathcal{L}(\theta^{\star},\eta)=\mathcal{L}(\theta^{\star},\eta^{\star})=\max_{\theta\in\Theta}\mathcal{L}(\theta,\eta^{\star}).roman_min start_POSTSUBSCRIPT italic_η ∈ roman_H end_POSTSUBSCRIPT caligraphic_L ( italic_θ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT , italic_η ) = caligraphic_L ( italic_θ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT , italic_η start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) = roman_max start_POSTSUBSCRIPT italic_θ ∈ roman_Θ end_POSTSUBSCRIPT caligraphic_L ( italic_θ , italic_η start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) . (A.3.10)

It then holds for any θ\thetaitalic_θ and η\etaitalic_η that

minηH(θ,η)(θ,η)(θ,η).\min_{\eta\in\mathrm{H}}\mathcal{L}(\theta,\eta)\leq\mathcal{L}(\theta,\eta^{\star})\leq\mathcal{L}(\theta^{\star},\eta^{\star}).roman_min start_POSTSUBSCRIPT italic_η ∈ roman_H end_POSTSUBSCRIPT caligraphic_L ( italic_θ , italic_η ) ≤ caligraphic_L ( italic_θ , italic_η start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) ≤ caligraphic_L ( italic_θ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT , italic_η start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) . (A.3.11)

Therefore

maxθΘ(θ,η)(θ,η).\max_{\theta\in\Theta}\mathcal{L}(\theta,\eta)\leq\mathcal{L}(\theta^{\star},\eta^{\star}).roman_max start_POSTSUBSCRIPT italic_θ ∈ roman_Θ end_POSTSUBSCRIPT caligraphic_L ( italic_θ , italic_η ) ≤ caligraphic_L ( italic_θ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT , italic_η start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) . (A.3.12)

Completely symmetrically,

(θ,η)(θ,η)maxθΘ(θ,η)(θ,η)minηHmaxθΘ(θ,η).\mathcal{L}(\theta^{\star},\eta^{\star})\leq\mathcal{L}(\theta^{\star},\eta)\leq\max_{\theta\in\Theta}\mathcal{L}(\theta^{\star},\eta)\implies\mathcal{L}(\theta^{\star},\eta^{\star})\leq\min_{\eta\in\mathrm{H}}\max_{\theta\in\Theta}\mathcal{L}(\theta,\eta).caligraphic_L ( italic_θ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT , italic_η start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) ≤ caligraphic_L ( italic_θ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT , italic_η ) ≤ roman_max start_POSTSUBSCRIPT italic_θ ∈ roman_Θ end_POSTSUBSCRIPT caligraphic_L ( italic_θ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT , italic_η ) ⟹ caligraphic_L ( italic_θ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT , italic_η start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) ≤ roman_min start_POSTSUBSCRIPT italic_η ∈ roman_H end_POSTSUBSCRIPT roman_max start_POSTSUBSCRIPT italic_θ ∈ roman_Θ end_POSTSUBSCRIPT caligraphic_L ( italic_θ , italic_η ) . (A.3.13)

Therefore since maxmin=minmax\max\min=\min\maxroman_max roman_min = roman_min roman_max we have

maxθΘminηH(θ,η)=(θ,η)=minηHmaxθΘ(θ,η).\max_{\theta\in\Theta}\min_{\eta\in\mathrm{H}}\mathcal{L}(\theta,\eta)=\mathcal{L}(\theta^{\star},\eta^{\star})=\min_{\eta\in\mathrm{H}}\max_{\theta\in\Theta}\mathcal{L}(\theta,\eta).roman_max start_POSTSUBSCRIPT italic_θ ∈ roman_Θ end_POSTSUBSCRIPT roman_min start_POSTSUBSCRIPT italic_η ∈ roman_H end_POSTSUBSCRIPT caligraphic_L ( italic_θ , italic_η ) = caligraphic_L ( italic_θ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT , italic_η start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) = roman_min start_POSTSUBSCRIPT italic_η ∈ roman_H end_POSTSUBSCRIPT roman_max start_POSTSUBSCRIPT italic_θ ∈ roman_Θ end_POSTSUBSCRIPT caligraphic_L ( italic_θ , italic_η ) . (A.3.14)

In particular, it holds that

θargmaxθΘminηH(θ,η).\theta^{\star}\in\operatorname*{arg\ max}_{\theta\in\Theta}\min_{\eta\in\mathrm{H}}\mathcal{L}(\theta,\eta).italic_θ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ∈ start_OPERATOR roman_arg roman_max end_OPERATOR start_POSTSUBSCRIPT italic_θ ∈ roman_Θ end_POSTSUBSCRIPT roman_min start_POSTSUBSCRIPT italic_η ∈ roman_H end_POSTSUBSCRIPT caligraphic_L ( italic_θ , italic_η ) . (A.3.15)

From the saddle point condition we have ηargminηH(θ,η)\eta^{\star}\in\operatorname*{arg\ min}_{\eta\in\mathrm{H}}\mathcal{L}(\theta^{\star},\eta)italic_η start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ∈ start_OPERATOR roman_arg roman_min end_OPERATOR start_POSTSUBSCRIPT italic_η ∈ roman_H end_POSTSUBSCRIPT caligraphic_L ( italic_θ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT , italic_η ). So (θ,η)(\theta^{\star},\eta^{\star})( italic_θ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT , italic_η start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) is a Stackelberg equilibrium. Furthermore we have also proved that all saddle points obey (A.3.14).

Now let (θ,η)(\theta^{\star},\eta^{\star})( italic_θ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT , italic_η start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) be a Stackelberg equilibrium. We claim that it is a saddle point, which completes the proof. By the definition of minimax equilibrium,

maxθΘminηH(θ,η)=(θ,η).\max_{\theta\in\Theta}\min_{\eta\in\mathrm{H}}\mathcal{L}(\theta,\eta)=\mathcal{L}(\theta^{\star},\eta^{\star}).roman_max start_POSTSUBSCRIPT italic_θ ∈ roman_Θ end_POSTSUBSCRIPT roman_min start_POSTSUBSCRIPT italic_η ∈ roman_H end_POSTSUBSCRIPT caligraphic_L ( italic_θ , italic_η ) = caligraphic_L ( italic_θ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT , italic_η start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) . (A.3.16)

Then by the minmax=maxmin\min\max=\max\minroman_min roman_max = roman_max roman_min assumption we have

maxθΘminηH(θ,η)=(θ,η)=minηHmaxθΘ(θ,η).\max_{\theta\in\Theta}\min_{\eta\in\mathrm{H}}\mathcal{L}(\theta,\eta)=\mathcal{L}(\theta^{\star},\eta^{\star})=\min_{\eta\in\mathrm{H}}\max_{\theta\in\Theta}\mathcal{L}(\theta,\eta).roman_max start_POSTSUBSCRIPT italic_θ ∈ roman_Θ end_POSTSUBSCRIPT roman_min start_POSTSUBSCRIPT italic_η ∈ roman_H end_POSTSUBSCRIPT caligraphic_L ( italic_θ , italic_η ) = caligraphic_L ( italic_θ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT , italic_η start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) = roman_min start_POSTSUBSCRIPT italic_η ∈ roman_H end_POSTSUBSCRIPT roman_max start_POSTSUBSCRIPT italic_θ ∈ roman_Θ end_POSTSUBSCRIPT caligraphic_L ( italic_θ , italic_η ) . (A.3.17)

This proves that all minimax equilibria have the desired objective value (θ,η)\mathcal{L}(\theta^{\star},\eta^{\star})caligraphic_L ( italic_θ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT , italic_η start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ). Now we want to show that

θargmaxθΘ(θ,η),ηargminηH(θ,η).\theta^{\star}\in\operatorname*{arg\ max}_{\theta\in\Theta}\mathcal{L}(\theta,\eta^{\star}),\qquad\eta^{\star}\in\operatorname*{arg\ min}_{\eta\in\mathrm{H}}\mathcal{L}(\theta^{\star},\eta).italic_θ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ∈ start_OPERATOR roman_arg roman_max end_OPERATOR start_POSTSUBSCRIPT italic_θ ∈ roman_Θ end_POSTSUBSCRIPT caligraphic_L ( italic_θ , italic_η start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) , italic_η start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ∈ start_OPERATOR roman_arg roman_min end_OPERATOR start_POSTSUBSCRIPT italic_η ∈ roman_H end_POSTSUBSCRIPT caligraphic_L ( italic_θ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT , italic_η ) . (A.3.18)

Indeed the latter assertion holds by definition of the minimax equilibrium, so only the former need be proved. Namely, we will show that

maxθΘ(θ,η)=(θ,η).\max_{\theta\in\Theta}\mathcal{L}(\theta,\eta^{\star})=\mathcal{L}(\theta^{\star},\eta^{\star}).roman_max start_POSTSUBSCRIPT italic_θ ∈ roman_Θ end_POSTSUBSCRIPT caligraphic_L ( italic_θ , italic_η start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) = caligraphic_L ( italic_θ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT , italic_η start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) . (A.3.19)

To show this note that by definition of the max\maxroman_max

(θ,η)maxθΘ(θ,η),\mathcal{L}(\theta^{\star},\eta^{\star})\leq\max_{\theta\in\Theta}\mathcal{L}(\theta,\eta^{\star}),caligraphic_L ( italic_θ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT , italic_η start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) ≤ roman_max start_POSTSUBSCRIPT italic_θ ∈ roman_Θ end_POSTSUBSCRIPT caligraphic_L ( italic_θ , italic_η start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) , (A.3.20)

meanwhile we have minηH(θ,η)(θ,η)\min_{\eta\in\mathrm{H}}\mathcal{L}(\theta,\eta)\leq\mathcal{L}(\theta,\eta^{\star})roman_min start_POSTSUBSCRIPT italic_η ∈ roman_H end_POSTSUBSCRIPT caligraphic_L ( italic_θ , italic_η ) ≤ caligraphic_L ( italic_θ , italic_η start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) so

(θ,η)=maxθΘminηH(θ,η)maxθΘ(θ,η).\mathcal{L}(\theta^{\star},\eta^{\star})=\max_{\theta\in\Theta}\min_{\eta\in\mathrm{H}}\mathcal{L}(\theta,\eta)\leq\max_{\theta\in\Theta}\mathcal{L}(\theta,\eta^{\star}).caligraphic_L ( italic_θ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT , italic_η start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) = roman_max start_POSTSUBSCRIPT italic_θ ∈ roman_Θ end_POSTSUBSCRIPT roman_min start_POSTSUBSCRIPT italic_η ∈ roman_H end_POSTSUBSCRIPT caligraphic_L ( italic_θ , italic_η ) ≤ roman_max start_POSTSUBSCRIPT italic_θ ∈ roman_Θ end_POSTSUBSCRIPT caligraphic_L ( italic_θ , italic_η start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) . (A.3.21)

Therefore it holds

(θ,η)=maxθΘ(θ,η),\mathcal{L}(\theta^{\star},\eta^{\star})=\max_{\theta\in\Theta}\mathcal{L}(\theta,\eta^{\star}),caligraphic_L ( italic_θ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT , italic_η start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) = roman_max start_POSTSUBSCRIPT italic_θ ∈ roman_Θ end_POSTSUBSCRIPT caligraphic_L ( italic_θ , italic_η start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) , (A.3.22)

and the proof is complete. ∎

Conditions under which the minmax=maxmin\min\max=\max\minroman_min roman_max = roman_max roman_min equality holds are given by so-called minimax theorems; the most famous of these is a theorem of von Neumann. However, in the cases we think about, this property usually does not hold.

A.3.1 Learning Stackelberg Equilibria

How can we learn Stackelberg equilibria via GDA? In general this is clearly impossible, since learning Stackelberg equilibria via GDA is obviously at least as hard as computing a global minimizer of a loss function (say by setting the shared objective (θ,η)\mathcal{L}(\theta,\eta)caligraphic_L ( italic_θ , italic_η ) to only be a function of η\etaitalic_η). As such, we can achieve two types of convergence guarantees:

  • When \mathcal{L}caligraphic_L is (strongly) concave in the first argument θ\thetaitalic_θ and (strongly) convex in the second argument η\etaitalic_η (as well as having Lipschitz gradients in both arguments), we can achieve exponentially fast convergence to a Stackelberg equilibrium.

  • When \mathcal{L}caligraphic_L is not concave or convex in either argument, we can achieve local convergence guarantees: namely, if we initialize the parameter values near a (local) Stackelberg equilibrium and the optimization geometry is good then we can learn that equilibrium efficiently.

The former situation is exactly analogous to the case of single-player optimization, where we proved that gradient descent converges exponentially fast for strongly convex objectives which have Lipschitz gradient. The latter situation is also analogous to the case of single-player optimization, although we did not cover it in depth due to technical difficulty; indeed there exist local convergence guarantees for nonconvex objectives which have locally nice geometry.

The algorithm in these two cases is the same algorithm, called Gradient Descent-Ascent (GDA). To motivate GDA, suppose we are trying to learn θ\theta^{\star}italic_θ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT. We could do gradient ascent on the function θminηH(θ,η)\theta\mapsto\min_{\eta\in\mathrm{H}}\mathcal{L}(\theta,\eta)italic_θ ↦ roman_min start_POSTSUBSCRIPT italic_η ∈ roman_H end_POSTSUBSCRIPT caligraphic_L ( italic_θ , italic_η ). But then we would need to take the derivative in θ\thetaitalic_θ of this function. To see how to do this, suppose that \mathcal{L}caligraphic_L is strongly convex in η\etaitalic_η so that there is one minimizer η(θ)\eta^{\star}(\theta)italic_η start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ( italic_θ ) of (θ,)\mathcal{L}(\theta,\cdot)caligraphic_L ( italic_θ , ⋅ ). Then, Danskin’s theorem says that

θ[minηH(θ,η)]=θ(θ,stop_grad(η(θ))),\nabla_{\theta}\left[\min_{\eta\in\mathrm{H}}\mathcal{L}(\theta,\eta)\right]=\nabla_{\theta}\mathcal{L}(\theta,\texttt{stop\_grad}(\eta^{\star}(\theta))),∇ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT [ roman_min start_POSTSUBSCRIPT italic_η ∈ roman_H end_POSTSUBSCRIPT caligraphic_L ( italic_θ , italic_η ) ] = ∇ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT caligraphic_L ( italic_θ , stop_grad ( italic_η start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ( italic_θ ) ) ) , (A.3.23)

where the gradient is only with respect to the first argument (i.e., not a total derivative which would require computing the Jacobian of η(θ)\eta^{\star}(\theta)italic_η start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ( italic_θ ) with respect to θ\thetaitalic_θ), indicated by the stop-gradient operator.888In the case that \mathcal{L}caligraphic_L is not strongly convex in η\etaitalic_η but rather just convex, Danskin’s theorem can be stated in terms of subdifferentials. If \mathcal{L}caligraphic_L is not convex at all in η\etaitalic_η, then this derivative may not be well-defined, but one can obtain (local) convergence guarantees for the resulting algorithm anyways. Hence we use Danskin’s theorem as a motivation and not a justification for our algorithms. Danskin’s thoerems can be generalized into more relaxed circumstances by the so-called envelope theorems. In order to take the derivative in θ\thetaitalic_θ, we need to set up a secondary process to also optimize η\etaitalic_η to obtain an approximation for η(θ)\eta^{\star}(\theta)italic_η start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ( italic_θ ). We can do this through the following algorithmic template:

ηk+1=ηk+1T+1;ηk+1t+1=ηk+1thη(θk,ηk+1t),t[T];ηk+11=ηk\displaystyle\eta_{k+1}=\eta_{k+1}^{T+1};\qquad\eta_{k+1}^{t+1}=\eta_{k+1}^{t}-h\nabla_{\eta}\mathcal{L}(\theta_{k},\eta_{k+1}^{t}),\quad\forall t\in[T];\qquad\eta_{k+1}^{1}=\eta_{k}italic_η start_POSTSUBSCRIPT italic_k + 1 end_POSTSUBSCRIPT = italic_η start_POSTSUBSCRIPT italic_k + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_T + 1 end_POSTSUPERSCRIPT ; italic_η start_POSTSUBSCRIPT italic_k + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t + 1 end_POSTSUPERSCRIPT = italic_η start_POSTSUBSCRIPT italic_k + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT - italic_h ∇ start_POSTSUBSCRIPT italic_η end_POSTSUBSCRIPT caligraphic_L ( italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT , italic_η start_POSTSUBSCRIPT italic_k + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_t end_POSTSUPERSCRIPT ) , ∀ italic_t ∈ [ italic_T ] ; italic_η start_POSTSUBSCRIPT italic_k + 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 1 end_POSTSUPERSCRIPT = italic_η start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT (A.3.24)
θk+1=θk+hθ(θk,ηk+1).\displaystyle\theta_{k+1}=\theta_{k}+h\nabla_{\theta}\mathcal{L}(\theta_{k},\eta_{k+1}).italic_θ start_POSTSUBSCRIPT italic_k + 1 end_POSTSUBSCRIPT = italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT + italic_h ∇ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT caligraphic_L ( italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT , italic_η start_POSTSUBSCRIPT italic_k + 1 end_POSTSUBSCRIPT ) . (A.3.25)

That is, we take TTitalic_T steps of gradient descent to update ηk\eta_{k}italic_η start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT (hopefully to the minimizer of (θk,)\mathcal{L}(\theta_{k},\cdot)caligraphic_L ( italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT , ⋅ )), and then take a gradient ascent step to update θk\theta_{k}italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT. As a bonus, on top of estimating θKargmaxθminη(θ,η)\theta_{K}\approx\operatorname*{arg\ max}_{\theta}\min_{\eta}\mathcal{L}(\theta,\eta)italic_θ start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ≈ start_OPERATOR roman_arg roman_max end_OPERATOR start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT roman_min start_POSTSUBSCRIPT italic_η end_POSTSUBSCRIPT caligraphic_L ( italic_θ , italic_η ), we also learn an ηKargminη(θK,η)\eta_{K}\approx\operatorname*{arg\ min}_{\eta}\mathcal{L}(\theta_{K},\eta)italic_η start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT ≈ start_OPERATOR roman_arg roman_min end_OPERATOR start_POSTSUBSCRIPT italic_η end_POSTSUBSCRIPT caligraphic_L ( italic_θ start_POSTSUBSCRIPT italic_K end_POSTSUBSCRIPT , italic_η ) — this is an approximate Stackelberg equilibrium.

This method is often not done in practice, as it requires T+1T+1italic_T + 1 total gradient descent iterations to update θ\thetaitalic_θ once. Instead, we use the so-called (simultaneous) Gradient Descent-Ascent (GDA) iteration

θk+1\displaystyle\theta_{k+1}italic_θ start_POSTSUBSCRIPT italic_k + 1 end_POSTSUBSCRIPT =θk+hθ(θk,ηk)\displaystyle=\theta_{k}+h\nabla_{\theta}\mathcal{L}(\theta_{k},\eta_{k})= italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT + italic_h ∇ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT caligraphic_L ( italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT , italic_η start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) (A.3.26)
ηk+1\displaystyle\eta_{k+1}italic_η start_POSTSUBSCRIPT italic_k + 1 end_POSTSUBSCRIPT =ηkThη(θk,ηk),\displaystyle=\eta_{k}-Th\nabla_{\eta}\mathcal{L}(\theta_{k},\eta_{k}),= italic_η start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT - italic_T italic_h ∇ start_POSTSUBSCRIPT italic_η end_POSTSUBSCRIPT caligraphic_L ( italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT , italic_η start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) , (A.3.27)

which can be implemented efficiently via a single gradient step on (θ,η)(\theta,\eta)( italic_θ , italic_η ). The crucial idea here is, to make our method close to the inefficient iteration above, we use an η\etaitalic_η update which is TTitalic_T times faster than the θ\thetaitalic_θ update (these can be seen as nearly the same by taking a linearization of the dynamics).

It is crucial to pick TTitalic_T sensibly. How can we do that? In the sequel, we discuss two configurations of TTitalic_T which lead to convergence of GDA to a Stackelberg equilibrium under different assumptions.

Convergence of One-Timescale GDA to Stackelberg Equilibrium

If T=1T=1italic_T = 1 (i.e., named one-timescale because both θ\thetaitalic_θ and η\etaitalic_η updates are of the same scale), then the GDA algorithm becomes

θk+1=θk+hθ(θk,ηk),ηk+1=ηkhη(θk,ηk).\theta_{k+1}=\theta_{k}+h\nabla_{\theta}\mathcal{L}(\theta_{k},\eta_{k}),\qquad\eta_{k+1}=\eta_{k}-h\nabla_{\eta}\mathcal{L}(\theta_{k},\eta_{k}).italic_θ start_POSTSUBSCRIPT italic_k + 1 end_POSTSUBSCRIPT = italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT + italic_h ∇ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT caligraphic_L ( italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT , italic_η start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) , italic_η start_POSTSUBSCRIPT italic_k + 1 end_POSTSUBSCRIPT = italic_η start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT - italic_h ∇ start_POSTSUBSCRIPT italic_η end_POSTSUBSCRIPT caligraphic_L ( italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT , italic_η start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) . (A.3.28)

If \mathcal{L}caligraphic_L has Lipschitz gradients in θ\thetaitalic_θ and η\etaitalic_η, and is strongly concave in θ\thetaitalic_θ and strongly convex in η\etaitalic_η, and is coercive (i.e.,limθ2,η2(θ,η)=\lim_{\|\theta\|_{2},\|\eta\|_{2}\to\infty}\mathcal{L}(\theta,\eta)=\inftyroman_lim start_POSTSUBSCRIPT ∥ italic_θ ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , ∥ italic_η ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT → ∞ end_POSTSUBSCRIPT caligraphic_L ( italic_θ , italic_η ) = ∞), then all saddle points are Stackelberg equilibria and vice versa. The work [ZAK24] shows that GDA (again, with T=1T=1italic_T = 1) with sufficiently small step size hhitalic_h converges to a saddle point (hence Stackelberg equilibrium) exponentially fast if one exists, analogously to gradient descent for strongly convex functions with Lipschitz gradient.999We do not provide the step size or exponential base since they are complicated functions of the strong convexity/concavity/Lipschitz constant of the gradient. Of course, the paper [ZAK24] provides the precise parameter values. To our knowledge, this flavor of results constitute the only known rigorous justification for single-timescale GDA.

Local Convergence of Two-Timescale GDA to Stackelberg Equilibrium

Strong convexity/concavity is a global property, and none of the games we look into in this book have objectives which are globally strongly concave/strongly convex. In this case, the best we can hope for is local convergence to Stackelberg equilibria: if the parameters are initialized close to a Stackelberg equilibrium, then GDA can converge onto it, given an appropriate step size hhitalic_h and timescale TTitalic_T.

In fact, our results also hold for a version of the local Stackelberg equilibrium called the differential Stackelberg equilibrium, which was introduced in [FCR19] (though we use the precise definition in [LFD+22]), and which we define as follows. A point (θ,η)(\theta^{\star},\eta^{\star})( italic_θ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT , italic_η start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) is a differential Stackelberg equilibrium if:

  • η(θ,η)=𝟎\nabla_{\eta}\mathcal{L}(\theta^{\star},\eta^{\star})=\bm{0}∇ start_POSTSUBSCRIPT italic_η end_POSTSUBSCRIPT caligraphic_L ( italic_θ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT , italic_η start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) = bold_0;

  • η2(θ,η)\nabla_{\eta}^{2}\mathcal{L}(\theta^{\star},\eta^{\star})∇ start_POSTSUBSCRIPT italic_η end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT caligraphic_L ( italic_θ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT , italic_η start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) is symmetric positive definite;

  • θ(θ,η)=𝟎\nabla_{\theta}\mathcal{L}(\theta^{\star},\eta^{\star})=\bm{0}∇ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT caligraphic_L ( italic_θ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT , italic_η start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) = bold_0;

  • (θ2+[ddθη][η2]1[ddηθ])(θ,η)(\nabla_{\theta}^{2}\mathcal{L}+[\frac{\mathrm{d}}{\mathrm{d}\theta}\nabla_{\eta}\mathcal{L}][\nabla_{\eta}^{2}\mathcal{L}]^{-1}[\frac{\mathrm{d}}{\mathrm{d}\eta}\nabla_{\theta}\mathcal{L}])(\theta^{\star},\eta^{\star})( ∇ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT caligraphic_L + [ divide start_ARG roman_d end_ARG start_ARG roman_d italic_θ end_ARG ∇ start_POSTSUBSCRIPT italic_η end_POSTSUBSCRIPT caligraphic_L ] [ ∇ start_POSTSUBSCRIPT italic_η end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT caligraphic_L ] start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT [ divide start_ARG roman_d end_ARG start_ARG roman_d italic_η end_ARG ∇ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT caligraphic_L ] ) ( italic_θ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT , italic_η start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) is symmetric negative definite.

Notice that the last condition asks for the (total) Hessian

2(θ,η)=[θ2(θ,η)ddηθ2(θ,η)ddθηθ2(θ,η)η2(θ)]\nabla^{2}\mathcal{L}(\theta,\eta)=\begin{bmatrix}\nabla_{\theta}^{2}\mathcal{L}(\theta,\eta)&\frac{\mathrm{d}}{\mathrm{d}\eta}\nabla_{\theta}^{2}\mathcal{L}(\theta,\eta)\\ \frac{\mathrm{d}}{\mathrm{d}\theta}\nabla_{\eta\theta}^{2}\mathcal{L}(\theta,\eta)&\nabla_{\eta}^{2}\mathcal{L}(\theta)\end{bmatrix}∇ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT caligraphic_L ( italic_θ , italic_η ) = [ start_ARG start_ROW start_CELL ∇ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT caligraphic_L ( italic_θ , italic_η ) end_CELL start_CELL divide start_ARG roman_d end_ARG start_ARG roman_d italic_η end_ARG ∇ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT caligraphic_L ( italic_θ , italic_η ) end_CELL end_ROW start_ROW start_CELL divide start_ARG roman_d end_ARG start_ARG roman_d italic_θ end_ARG ∇ start_POSTSUBSCRIPT italic_η italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT caligraphic_L ( italic_θ , italic_η ) end_CELL start_CELL ∇ start_POSTSUBSCRIPT italic_η end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT caligraphic_L ( italic_θ ) end_CELL end_ROW end_ARG ] (A.3.29)

or, equivalently, its Schur complement to be negative definite. If we look at the computation of θ[minηH(θ,η)]\nabla_{\theta}[\min_{\eta\in\mathrm{H}}\mathcal{L}(\theta,\eta)]∇ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT [ roman_min start_POSTSUBSCRIPT italic_η ∈ roman_H end_POSTSUBSCRIPT caligraphic_L ( italic_θ , italic_η ) ] furnished by Danskin’s theorem, the last two criteria are actually constraints on the gradient and Hessian of the function θminηH(θ,η)\theta\mapsto\min_{\eta\in\mathrm{H}}\mathcal{L}(\theta,\eta)italic_θ ↦ roman_min start_POSTSUBSCRIPT italic_η ∈ roman_H end_POSTSUBSCRIPT caligraphic_L ( italic_θ , italic_η ), ensuring that the gradient is 0 and the Hessian is negative semidefinite. This intuition tells us that we can expect that each Stackelberg equilibrium is a differential Stackelberg equilibrium; [FCR20] confirms this rigorously (up to some technical conditions).

Analogously to the notion of strict local optimum in single-player optimization (where we require 2(θ)\nabla^{2}\mathcal{L}(\theta^{\star})∇ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT caligraphic_L ( italic_θ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) to be positive semidefinite), the definition of differential Stackelberg equilibrium implies that (θ,)\mathcal{L}(\theta,\cdot)caligraphic_L ( italic_θ , ⋅ ) is locally (strictly) convex in a neighborhood of the equilibrium, and that minηH(,η)\min_{\eta\in\mathrm{H}}\mathcal{L}(\cdot,\eta)roman_min start_POSTSUBSCRIPT italic_η ∈ roman_H end_POSTSUBSCRIPT caligraphic_L ( ⋅ , italic_η ) is locally (strictly) concave in the same region.

In this context, we present the result from [LFD+22]. Let (θ,η)(\theta^{\star},\eta^{\star})( italic_θ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT , italic_η start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) be a differential Stackelberg equilibrium. Suppose that \mathcal{L}caligraphic_L has Lipschitz gradients, i.e.,

max{θ2(θ,η)2,η2(θ,η)2,ddηθ(θ,η)2}β.\max\left\{\|\nabla_{\theta}^{2}\mathcal{L}(\theta^{\star},\eta^{\star})\|_{2},\left\|\nabla_{\eta}^{2}\mathcal{L}(\theta^{\star},\eta^{\star})\right\|_{2},\left\|\frac{\mathrm{d}}{\mathrm{d}\eta}\nabla_{\theta}\mathcal{L}(\theta^{\star},\eta^{\star})\right\|_{2}\right\}\leq\beta.roman_max { ∥ ∇ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT caligraphic_L ( italic_θ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT , italic_η start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , ∥ ∇ start_POSTSUBSCRIPT italic_η end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT caligraphic_L ( italic_θ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT , italic_η start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , ∥ divide start_ARG roman_d end_ARG start_ARG roman_d italic_η end_ARG ∇ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT caligraphic_L ( italic_θ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT , italic_η start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT } ≤ italic_β . (A.3.30)

Further define the local strong convexity/concavity parameters of (θ,)\mathcal{L}(\theta,\cdot)caligraphic_L ( italic_θ , ⋅ ) and minη(,η)\min_{\eta}\mathcal{L}(\cdot,\eta)roman_min start_POSTSUBSCRIPT italic_η end_POSTSUBSCRIPT caligraphic_L ( ⋅ , italic_η ) respectively as

μη=λmin(η2(θ,η)),μθ=min{β,(θ2+[ddθη][η2]1[ddηθ])(θ,η)}.\mu_{\eta}=\lambda_{\min}(\nabla_{\eta}^{2}\mathcal{L}(\theta^{\star},\eta^{\star})),\qquad\mu_{\theta}=\min\left\{\beta,-\left(\nabla_{\theta}^{2}\mathcal{L}+\left[\frac{\mathrm{d}}{\mathrm{d}\theta}\nabla_{\eta}\mathcal{L}\right]\left[\nabla_{\eta}^{2}\mathcal{L}\right]^{-1}\left[\frac{\mathrm{d}}{\mathrm{d}\eta}\nabla_{\theta}\mathcal{L}\right]\right)(\theta^{\star},\eta^{\star})\right\}.italic_μ start_POSTSUBSCRIPT italic_η end_POSTSUBSCRIPT = italic_λ start_POSTSUBSCRIPT roman_min end_POSTSUBSCRIPT ( ∇ start_POSTSUBSCRIPT italic_η end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT caligraphic_L ( italic_θ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT , italic_η start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) ) , italic_μ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT = roman_min { italic_β , - ( ∇ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT caligraphic_L + [ divide start_ARG roman_d end_ARG start_ARG roman_d italic_θ end_ARG ∇ start_POSTSUBSCRIPT italic_η end_POSTSUBSCRIPT caligraphic_L ] [ ∇ start_POSTSUBSCRIPT italic_η end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT caligraphic_L ] start_POSTSUPERSCRIPT - 1 end_POSTSUPERSCRIPT [ divide start_ARG roman_d end_ARG start_ARG roman_d italic_η end_ARG ∇ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT caligraphic_L ] ) ( italic_θ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT , italic_η start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) } . (A.3.31)

Then define the local condition numbers as

κη=L/μη,κθ=L/μθ.\kappa_{\eta}=L/\mu_{\eta},\qquad\kappa_{\theta}=L/\mu_{\theta}.italic_κ start_POSTSUBSCRIPT italic_η end_POSTSUBSCRIPT = italic_L / italic_μ start_POSTSUBSCRIPT italic_η end_POSTSUBSCRIPT , italic_κ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT = italic_L / italic_μ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT . (A.3.32)

The paper [LFD+22] says that if we take step size h=14βTh=\frac{1}{4\beta T}italic_h = divide start_ARG 1 end_ARG start_ARG 4 italic_β italic_T end_ARG and take T2κT\geq 2\kappaitalic_T ≥ 2 italic_κ, so that the algorithm is

θk+1\displaystyle\theta_{k+1}italic_θ start_POSTSUBSCRIPT italic_k + 1 end_POSTSUBSCRIPT =θk+14βTθ(θk,ηk),\displaystyle=\theta_{k}+\frac{1}{4\beta T}\nabla_{\theta}\mathcal{L}(\theta_{k},\eta_{k}),= italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT + divide start_ARG 1 end_ARG start_ARG 4 italic_β italic_T end_ARG ∇ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT caligraphic_L ( italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT , italic_η start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) , (A.3.33)
ηk+1\displaystyle\eta_{k+1}italic_η start_POSTSUBSCRIPT italic_k + 1 end_POSTSUBSCRIPT =ηk14βη(θk,ηk),\displaystyle=\eta_{k}-\frac{1}{4\beta}\nabla_{\eta}\mathcal{L}(\theta_{k},\eta_{k}),= italic_η start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT - divide start_ARG 1 end_ARG start_ARG 4 italic_β end_ARG ∇ start_POSTSUBSCRIPT italic_η end_POSTSUBSCRIPT caligraphic_L ( italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT , italic_η start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) , (A.3.34)

the total Hessian 2(θ,η)\nabla^{2}\mathcal{L}(\theta^{\star},\eta^{\star})∇ start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT caligraphic_L ( italic_θ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT , italic_η start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) is diagonalizable, and we initialize (θ0,η0)(\theta_{0},\eta_{0})( italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_η start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) close enough to (θ,η)(\theta^{\star},\eta^{\star})( italic_θ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT , italic_η start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ), then there are positive constants c0,c1>0c_{0},c_{1}>0italic_c start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_c start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT > 0 such that

(θk,ηk)(θ,η)2c0(1c1Tκθ)k(θ0,η0)(θ,η)2,\|(\theta_{k},\eta_{k})-(\theta^{\star},\eta^{\star})\|_{2}\leq c_{0}\left(1-\frac{c_{1}}{T\kappa_{\theta}}\right)^{k}\|(\theta_{0},\eta_{0})-(\theta^{\star},\eta^{\star})\|_{2},∥ ( italic_θ start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT , italic_η start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT ) - ( italic_θ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT , italic_η start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT ≤ italic_c start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ( 1 - divide start_ARG italic_c start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_ARG start_ARG italic_T italic_κ start_POSTSUBSCRIPT italic_θ end_POSTSUBSCRIPT end_ARG ) start_POSTSUPERSCRIPT italic_k end_POSTSUPERSCRIPT ∥ ( italic_θ start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT , italic_η start_POSTSUBSCRIPT 0 end_POSTSUBSCRIPT ) - ( italic_θ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT , italic_η start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT , (A.3.35)

implying exponential convergence to the differential Stackelberg equilibrium.

A.3.2 Practical Considerations when Learning Stackelberg Equilibria

In practice, we do not know how to initialize parameters close to a (differential) Stackelberg equilibrium. Due to symmetries within the objective, including those induced by overparameterization of the neural networks being trained, one can (heuristically) expect that most initializations are close to a Stackelberg equilibrium. Also, we do not know how to compute the step size hhitalic_h or the timescale TTitalic_T, since they are dependent on properties of the loss \mathcal{L}caligraphic_L at the equilibrium. In practice, there are some common approaches:

  • Take T=1T=1italic_T = 1 (equal step-sizes), and use updates for θ\thetaitalic_θ and η\etaitalic_η that are derived from a learning-rate-adaptive optimizer like Adam (as opposed to vanilla GD). Here, you hope (but do not know) that the optimizer can adjust the learning rates to learn a good equilibrium.

  • Take TTitalic_T to be some constant like T=106T=10^{6}italic_T = 10 start_POSTSUPERSCRIPT 6 end_POSTSUPERSCRIPT which implies that η\etaitalic_η equilibrates 10610^{6}10 start_POSTSUPERSCRIPT 6 end_POSTSUPERSCRIPT times as fast as θ\thetaitalic_θ. Here you can also use Adam-style updates, and hope that it fixes the time scale.

  • Let TTitalic_T depend on the iteration kkitalic_k, and let TkT_{k}\to\inftyitalic_T start_POSTSUBSCRIPT italic_k end_POSTSUBSCRIPT → ∞ as kk\to\inftyitalic_k → ∞. This schedule was studied (also in the case of noise) by Borkar in [Bor97].

For example, you can use this while training CTRL-style models (see Chapter 5), where the encoder is Player 111 and the decoder is Player 222. Some theory about CTRL is given in Theorem 5.1.

A.4 Exercises

Exercise A.1.

We have shown that for a smooth function ffitalic_f, gradient descent converges linearly to the global optimum if it is strongly convex. However, in general nonconvex optimization, we do not have convexity, let alone strong convexity. Fortunately, in some cases, ffitalic_f satisfies the so-called μ\muitalic_μ-Polyak-Lojasiewicz (PL) inequality, i.e., there exists a constant μ>0\mu>0italic_μ > 0 such that for all θ\thetaitalic_θ,

12f(θ)22μ(f(θ)f(θ)),\displaystyle\frac{1}{2}\|\nabla f(\theta)\|_{2}^{2}\geq\mu\left(f(\theta)-f(\theta^{\star})\right),divide start_ARG 1 end_ARG start_ARG 2 end_ARG ∥ ∇ italic_f ( italic_θ ) ∥ start_POSTSUBSCRIPT 2 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT 2 end_POSTSUPERSCRIPT ≥ italic_μ ( italic_f ( italic_θ ) - italic_f ( italic_θ start_POSTSUPERSCRIPT ⋆ end_POSTSUPERSCRIPT ) ) ,

where θ\theta^{*}italic_θ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT is a minimizer of ffitalic_f.

Please show that under the PL inequality and the assumption that ffitalic_f is β\betaitalic_β-smooth, gradient descent (A.1.12) converges linearly to θ\theta^{*}italic_θ start_POSTSUPERSCRIPT ∗ end_POSTSUPERSCRIPT.

Exercise A.2.

Compute the differential and adjoint derivative of the softmax function, defined as follows.

softmax([x1xn])=1i=1nexi[x1xn].\operatorname{\mathrm{softmax}}\left(\begin{bmatrix}x_{1}\\ \vdots\\ x_{n}\end{bmatrix}\right)=\frac{1}{\sum_{i=1}^{n}e^{x_{i}}}\begin{bmatrix}x_{1}\\ \vdots\\ x_{n}\end{bmatrix}.roman_softmax ( [ start_ARG start_ROW start_CELL italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_CELL end_ROW start_ROW start_CELL ⋮ end_CELL end_ROW start_ROW start_CELL italic_x start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT end_CELL end_ROW end_ARG ] ) = divide start_ARG 1 end_ARG start_ARG ∑ start_POSTSUBSCRIPT italic_i = 1 end_POSTSUBSCRIPT start_POSTSUPERSCRIPT italic_n end_POSTSUPERSCRIPT italic_e start_POSTSUPERSCRIPT italic_x start_POSTSUBSCRIPT italic_i end_POSTSUBSCRIPT end_POSTSUPERSCRIPT end_ARG [ start_ARG start_ROW start_CELL italic_x start_POSTSUBSCRIPT 1 end_POSTSUBSCRIPT end_CELL end_ROW start_ROW start_CELL ⋮ end_CELL end_ROW start_ROW start_CELL italic_x start_POSTSUBSCRIPT italic_n end_POSTSUBSCRIPT end_CELL end_ROW end_ARG ] . (A.4.1)
Exercise A.3.

Carry through the backpropagation computation for a LLitalic_L-layer MLP, as defined in Section 7.2.3.

Exercise A.4.

Carry through the backpropagation computation for a LLitalic_L-layer transformer, as defined in Section 7.2.3.

Exercise A.5.

Carry through the backpropagation computation for an autoencoder with LLitalic_L encoder layers and LLitalic_L decoder layers (without necessarily specifying an architecture).