Material related to this page, as well as additional exercises, can be found in LLA Chapter 9.2. Reviewing CalcBLUE2 Chapter 5 on the chain rule is recommended.
over x∈Rn, where we look for the x∈Rn that makes the value of the cost function f:Rn→R as small as possible. We saw that one way to find either a local or global minimum x∗ is gradient descent. Starting at an initial guess x(0), we iteratively update our guess via
where ∇f(x(k))∈Rn is the gradient of f evaluated at the current guess, and s>0 is a step size chosen large enough to make progress towards x∗, but not so big as to overshoot.
Today, we’ll focus our attention on optimization problems (1) for which the cost function takes the following special form
i.e., cost functions f that decompose into a sum of N “sub-costs” fi. Problems with cost functions of the form (3) are particularly common in machine learning.
For example, a typical problem setup in machine learning is as follows (we saw an example of this when we studied least squares for data-fitting). We are given a set of training data{(zi,yi)},i=1,…,N, comprised of “inputs” zi∈Rp and “outputs” yi∈Rp. Our goal is to find a set of weights x∈Rn which parametrize a model such that m(zi;x)≈yi on our training data. A common way of doing this is to minimize a loss function of the form
where each term ℓ(m(zi;x)−yi) is a term penalizing the difference between our model prediction m(zi;x) on input zi and the observed output yi. In this setting, the loss function (4) takes the form (3), with fi=N1ℓ(m(zi;x)−yi) the error between our prediction y^i=m(zi;x) and the true output yi.
A common choice for the “sub-loss” function is ℓ(e)=∥e∥2, leading to a least-squares regression problem, but note that most other choices of loss function are compatible with the following discussion.
Now suppose that we want to implement gradient descent (GD) on the loss function (4). Our first step is to compute the gradient ∇xloss((zi,yi);x). Because of the sum structure of (4), we have that:
i.e., the gradient of the loss function is the sum of the gradients of the “sub-losses” on each of the i=1,…,N data points.
Our task now is therefore to compute the gradient ∇xℓ(m(zi;x)−yi). This requires the multivariate chain rule, as fi(x)=ℓ(m(zi;x)−yi) is a composition of the functions ℓ(e),e=w−yi, and w=m(zi;x).
If we define g=g(f) and f=f(x), then we can rewrite (6) as dxdh=dfdg⋅dxdf. This is a useful way of writing things as we can “cancel” df on the RHS to check that our formula is correct.
Generalizing slightly, suppose now that f:Rn→R maps a vector x∈Rn to f(x)∈R. Then for h(x)=g(f(x)), we have:
which we see is a natural generalization of equation (6). It will be convenient for us later to define dxdf=∇xf(x)T and dxdh=∇xh(x)T. Again defining g=g(f) and f=f(x), we can rewrite (7) as dxdh=dfdg⋅dxdf, which looks exactly the same as before!
Now, let’s apply these ideas to computing the gradient of h(x)=ℓ(m(zi;x)−yi), where we’ll assume for now that m(zi;x),yi∈R. Applying (7), we get
Note that (13) is defined by a matrix-matrix multiplication of an m×p and p×n matrix, meaning dxdh∈Rm×n. The claim is that (i,j)th entry of of dxdh is the rate of change of hi(x)=gi(f(x)) with respect to xj. From (12) and (13), we have
(dxdh)i,j=ith row of dfdgdfdgi⋅jthcolumn of dxdf⎣⎡∂xj∂f1⋮∂xj∂fp⎦⎤=∂f1dgi⋅∂xj∂f1+⋯+∂fp∂gi⋅∂xj∂fp,
which is precisely the expression we were looking for. The “cancellation rule” tells us each term in the sum is computing the partial of ∂xj∂gi in the “fi” channel.
We can apply this formula recursively to our function class (10) to obtain the formula:
which is a fully general matrix chain rule. We’ll use (15) next to explore the key idea behind backpropagation, which has been a key technical enabler of contemporary deep learning.
where Xi is a Rpi×(ni+1) matrix with entries given by xi∈Rpi(ni+1), and σ is a pointwise nonlinearity σ(x)=(σ(x1),…,σ(xn)) called an activation function (more on these next lecture).
Applying our matrix chain rule to ℓ(m(x)−yi) (we won’t write zi to save space) we get the expression
Here, ∂m∂ℓ is a pL dimensional row vector, and ∂mi−1∂mi is a pi×pi−1 matrix.
In modern architectures, the layer dimensions, also called layer widths, pi can be very large (on the order of 100s of thousands or even millions), meaning the ∂mi−1∂mi matrices are very very large! Too large to store in memory actually.
Fortunately, since ∂m∂ℓ is a row vector, we can build ∂x∂ℓ by sequentially computing inner products. For example, if ∂mL−1∂mL=[a1⋯apL−1],
meaning we only ever need to store ∂mL∂ℓ and ai in memory at any given time, which is only 2pL numbers, as opposed to pL×pL−1 #s! Then once we’ve computed ∂mL∂ℓ∂mL−1∂mL, which is now a pL−1 dimensional row vector, we can continue our way down the chain.
What’s left to do is compute the partial derivatives! Let’s break down ∂x∂ℓ into partial derivatives with respect to a layer’s parameters xi. For layer L, we have:
Since xL appears in the last layer, it shows up right away in the first term above, which is the derivative of mL(mL−1;xL) with respect to xL (the 2nd argument). The second term,
which measures how mL changes with respect to changes in mL−1 caused by changes in xL is zero because mL−1 does not depend on xL at all! This is a key observation in the backpropagation algorithm!
Let’s proceed to compute the derivative with respect to the parameter xL−1:
where ∂mj∂ℓ will have been computed at the layer above. This is another key piece of backpropagation!
The only thing left to compute is ∂xj∂mj --- this is now just an exercise in calculus, so we’ll not work it out by hand. Please refer to Backpropagation#Finding the derivative of the error and for further information if you are interested.
We apply our chain rule (with w=Xj[Oj−11]) to get
∂xj∂mj=∂xj∂σ(Xj[Oj−11])=∂w∂σ⋅∂xj∂w.
Now for σ(w)=⎣⎡σ(w1)⋮σ(wpj−1)⎦⎤,∂w∂σ=⎣⎡σ′(w1)⋱σ′(wpj−1)⎦⎤. Next, we need to find ∂xj∂w=∂xj∂(Xj[Oj−11]). This can be computed using multi linear algebra (tensors). We won’t work it out, but note that it can be found efficiently.