A note on the multivariable chain rule

Teaching machine learning, I have found that many students are unprepared for the level of vector calculus required, particularly when it comes to doing backprop calculations, which require the chain rule. Here I attempt to review the chain rule for computing gradients, and related concepts such as derivatives and Jacobians, in a cohesive way.

Review: single-variable chain rule

In single-variable calculus, the chain rule is often written

dzdx=dzdydydx

where z is some function of y, which is itself a function of x. Students can remember this easily by "canceling terms" (although they are reminded that this is not technically correct).

We would like to extend this type of relationship to vector-valued functions of several variables. But first, we need to reinterpret the derivative.

The derivative as a linear map

The classical derivative is a number, denoted f(x), which corresponds to the instantaneous rate of change of f at x if you zoom in infinitely close to x. This is conceptually equivalent to drawing a tangent line to f at x; the slope of the tangent line is the derivative f(x). The tangent line can be written as the graph of the function

f¯(x+Δ)=f(x)+f(x)Δ

This leads to a more general notion of a derivative as the best local linear approximation of the change of a function with respect to its input. This is typically called the total derivative or differential of a function, or sometimes just derivative. For a function f:XY where X and Y are normed vector spaces, we say that f(x):XY is the total derivative of f at x if it is linear and satisfies

limΔ0f(x+Δ)f(x)f(x)(Δ)YΔX=0.

assuming such a map exists. In the special case of f:RR, we abuse notation by writing f(x) for both the scalar and for the linear map, but it holds that

limΔ0|f(x+Δ)f(x)f(x)Δ||Δ|=limΔ0|f(x+Δ)f(x)Δf(x)|=0

so we recover f(x)(Δ)=f(x)Δ.

Jacobian and gradient

For vector-valued functions f:RmRn, the numerical implementation of the total derivative turns out to be the Jacobian matrix Jf(x)Rn×m, which contains all the partial derivatives:

f(x)(Δ)=Jf(x)Δwhere[Jf(x)]ij=fixj|x

In the special case of scalar-valued functions f:RmR, we can define the gradient f(x)Rm by

[f(x)]i=fxi|x

Note that if we interpret as a column vector, i.e. an m×1 matrix, then

Jf(x)=f(x)

Or, more generally, for f:RmRn,

Jf(x)=[f1(x)fn(x)]

This may seem like a coincidence if your only knowledge of the transpose operator is that it flips a matrix across the diagonal. The conceptual interpretation of transposition (operating on a vector) is that it takes a vector and produces a linear functional (also called linear form), i.e. a linear function that acts on other vectors to produce a scalar. Namely, for any vector v we can define the linear functional v(w)=vw=iviwi. There is essentially no difference between v and v; while they are not literally the same "type" of object, they act the same way, and there is a one-to-one correspondence between them.

From the equations above, we see that the gradient is the vector corresponding to the directional derivative functional

f(x)(v)=f(x)v=ivifxi=vf(x)

which quantifies how f changes along the direction v when starting from x. Note that the directional derivative – considered as a function of the direction – coincides with the total derivative of f when f is scalar-valued.

Generalization: gradients in Hilbert spaces

In fact, the gradient can be defined more generally in a Hilbert space H, which is a vector space equipped with an inner product ,H that generalizes the dot product from Rn. Any Hilbert space is also a normed vector space with norm vH=v,vH.

The Riesz representation theorem guarantees that, for any linear functional :HR, there exists a unique vector vH such that (x)=x,vH. Thus, the gradient of f:HR can be defined as the unique vector f(x)H satisfying

f(x)(v)=v,f(x)H

The multivariable chain rule

The chain rule generalizes nicely to the total derivative. Let f:XY and g:YZ, then their composition gf:XZ has derivative

(gf)(x)=g(f(x))f(x)

Observe that the order of composition of the derivatives matches the order of composition of the original functions. That is, gf first applies f then g, and similarly (gf) first applies f then g.

This can also be expressed in terms of Jacobians:

Jgf(x)=Jg(f(x))Jf(x)

Note that if y=f(x) and z=g(y), and we abuse notation by writing dd for the Jacobian, the above equation becomes

dzdx=dzdydydx

just as before.

Example: ordinary least squares

A prototypical example in ML classes is ordinary least squares:

minwL(w)whereL(w)=Xwy22

As we know, the solution can be found by identifying critical points of , i.e. w where L(w)=0. The gradient can be found by expanding the square and applying a couple of common matrix calculus identities:

L(w)=w((Xwy)(Xwy))=w(wXXw2yXw+yy)=2XXw2Xy

so we arrive at the normal equations XXw=Xy.

However, one can also use the chain rule to obtain L. Let's denote z=f(w)=Xwy and L=g(z)=z22. We know g(z)=2z, so dLdz=Jg(z)=2z, and it should be apparent that dzdw=Jf=X, as this is the best linear approximation, so

dLdw=dLdzdzdw=2(Xwy)X

Then

L(w)=(dLdw)=2X(Xwy)

which is exactly what we got before.