Teaching machine learning, I have found that many students are unprepared for the level of vector calculus required, particularly when it comes to doing backprop calculations, which require the chain rule. Here I attempt to review the chain rule for computing gradients, and related concepts such as derivatives and Jacobians, in a cohesive way.
Review: single-variable chain rule
In single-variable calculus, the chain rule is often written
where is some function of , which is itself a function of . Students can remember this easily by "canceling terms" (although they are reminded that this is not technically correct).
We would like to extend this type of relationship to vector-valued functions of several variables. But first, we need to reinterpret the derivative.
The derivative as a linear map
The classical derivative is a number, denoted , which corresponds to the instantaneous rate of change of at if you zoom in infinitely close to . This is conceptually equivalent to drawing a tangent line to at ; the slope of the tangent line is the derivative . The tangent line can be written as the graph of the function
This leads to a more general notion of a derivative as the best local linear approximation of the change of a function with respect to its input. This is typically called the total derivative or differential of a function, or sometimes just derivative. For a function where and are normed vector spaces, we say that is the total derivative of at if it is linear and satisfies
assuming such a map exists. In the special case of , we abuse notation by writing for both the scalar and for the linear map, but it holds that
so we recover .
Jacobian and gradient
For vector-valued functions , the numerical implementation of the total derivative turns out to be the Jacobian matrix, which contains all the partial derivatives:
In the special case of scalar-valued functions , we can define the gradient by
Note that if we interpret as a column vector, i.e. an matrix, then
Or, more generally, for ,
This may seem like a coincidence if your only knowledge of the transpose operator is that it flips a matrix across the diagonal. The conceptual interpretation of transposition (operating on a vector) is that it takes a vector and produces a linear functional (also called linear form), i.e. a linear function that acts on other vectors to produce a scalar. Namely, for any vector we can define the linear functional . There is essentially no difference between and ; while they are not literally the same "type" of object, they act the same way, and there is a one-to-one correspondence between them.
From the equations above, we see that the gradient is the vector corresponding to the directional derivative functional
which quantifies how changes along the direction when starting from . Note that the directional derivative – considered as a function of the direction – coincides with the total derivative of when is scalar-valued.
Generalization: gradients in Hilbert spaces
In fact, the gradient can be defined more generally in a Hilbert space, which is a vector space equipped with an inner product that generalizes the dot product from . Any Hilbert space is also a normed vector space with norm .
The Riesz representation theorem guarantees that, for any linear functional , there exists a unique vector such that . Thus, the gradient of can be defined as the unique vector satisfying
The multivariable chain rule
The chain rule generalizes nicely to the total derivative. Let and , then their composition has derivative
Observe that the order of composition of the derivatives matches the order of composition of the original functions. That is, first applies then , and similarly first applies then .
This can also be expressed in terms of Jacobians:
Note that if and , and we abuse notation by writing for the Jacobian, the above equation becomes
just as before.
Example: ordinary least squares
A prototypical example in ML classes is ordinary least squares:
As we know, the solution can be found by identifying critical points of , i.e. where . The gradient can be found by expanding the square and applying a couple of common matrix calculus identities:
so we arrive at the normal equations .
However, one can also use the chain rule to obtain . Let's denote and . We know , so , and it should be apparent that , as this is the best linear approximation, so