Math Insight

Special cases of the multivariable chain rule

The general statement of the multivariable chain rule is the following.

Chain Rule: For differentiable functions $\vc{g}: \R^m \rightarrow \R^k$ and $\vc{f}: \R^k \rightarrow \R^n$ (confused?), the derivative matrix of the composition $\vc{h}=\vc{f} \circ \vc{g}$ (i.e,. $\vc{h}(\vc{x}) = \vc{f}(\vc{g}(\vc{x}))$) at the point $\vc{a}$ is the product of the derivative matrices for $\vc{f}$ and $\vc{g}$: \begin{gather} D\vc{h}(\vc{a})= D(\vc{f} \circ \vc{g})(\vc{a}) = D{\vc{f}}\bigl(\vc{g}(\vc{a})\bigr) D{\vc{g}}(\vc{a}). \label{general_chain_rule} \end{gather}

In this form, the multivariable chain rule looks similar to the one-variable chain rule: $$\diff{}{x}(f \circ g)(x) = \diff{}{x}f(g(x)) = f'(g(x))g'(x).$$ The biggest difference in the multivariable case is that the ordinary derivative has been replaced with the derivative matrix. One important fact to remember is that the matrix of partial derivatives for $\vc{f}$ is evaluated at $\vc{g}(\vc{a})$, not at $\vc{a}$.

Using the above general form may be the easiest way to learn the chain rule. If you are comfortable forming derivative matrices, multiplying matrices, and using the one-variable chain rule, then using the chain rule \eqref{general_chain_rule} doesn't require memorizing a series of formulas and determining which formula applies to a given problem.

On the other hand, having a few special case formulas available can save some work. For a given type of problem, you can form the matrices and calculate their product once to obtain a formula valid for that particular type of problem. Then, for each example of that problem type, you can plug the particular functions into the special case formula and get to the final result more quickly.

In the following, we derive formulas for a few special cases. In each case, the outer function $f$ is a scalar-valued function. We look at two groups of special cases: when $\vc{g}$ is a function of one variable and when $\vc{g}$ is a function of two variables.

$\vc{g}$ is a function of one variable

When the inner function $\vc{g}$ is a function of one variable, $\vc{g}: \R \to \R^n$, and $f$ is a scalar-valued function, $f: \R^n \to \R$, then the composition $h(t)=f(\vc{g}(t))$ is just a scalar-valued function of a single variable $h: \R \to \R$. Its derivative is just a single number $h'(t)$. We show how to express this single number as a dot product of two vectors.

Since $\vc{g}$ is a function fo a single variable, we can view $\vc{g}$ as parametrizing a curve. We can write the derivative of the parametrized curve as a vector \begin{align*} D\vc{g}(t) = \left[ \begin{array}{c} \diff{g_1}{t}(t)\\ \diff{g_2}{t}(t)\\ \vdots\\ \diff{g_n}{t}(t) \end{array} \right] = (g_1'(t),g_2'(t),\ldots,g_n'(t)) =\vc{g}'(t). \end{align*}.

Since we are considering the case where $f$ is a scalar-valued function, its derivative matrix can be viewed as the gradient vector: $$\nabla f(\vc{x}) = \left(\pdiff{f}{x_1}(\vc{x}), \pdiff{f}{x_2}(\vc{x}), \cdots, \pdiff{f}{x_n}(\vc{x}) \right).$$ The matrix product of the general chain rule $Df(\vc{g}(t))Dg(t)$ can then be viewed as a dot product between the gradient vector $\nabla f(\vc{x})$ and the vector $\vc{g}'(t)$. We can write the chain rule as \begin{align} h'(t) = \nabla f(\vc{g}(t)) \cdot \vc{g}'(t). \label{chainruledotproduct} \end{align}

If $\vc{g}(t)$ is a one-dimensional function $x=g(t)$ and $f(x)$ is a function of a single variable, then we are back to the single-variable chain rule and equation \eqref{chainruledotproduct} becomes $ h'(t) = f'(g(t))g'(t).$ If we wrote $g(t)=x(t)$, we could also write this chain rule as $\diff{h}{t} = \diff{f}{x} \diff{x}{t}$, where we neglect to write the arguments of each function.

If, on the other hand, $\vc{g}(t)$ is a two-dimensional function $(x,y)=\vc{g}(t)=(g_1(t),g_2(t))$ and $f(\vc{x})$ is a function of two variables, $f(x,y)$, then we can multiply out the dot product of equation \eqref{chainruledotproduct} to write it as \begin{align} h'(t) = \pdiff{f}{x}(\vc{g}(t))g_1'(t) + \pdiff{f}{y}(\vc{g}(t))g_2'(t). \tag{2a} \end{align} This special case is exactly equation (4) chain rule introduction.

If we write $\vc{g}(t)=(g_1(t),g_2(t))=(x(t),y(t))$ and its derivative as $\vc{g}'(t)=\left(\diff{x}{t},\diff{y}{t}\right)$, then we can write this formula in a way that some people fine easier to memorize: \begin{align} \diff{h}{t} = \pdiff{f}{x}\diff{x}{t} + \pdiff{f}{y}\diff{y}{t}. \tag{2a'} \end{align} This formula looks especially simple since we didn't write the arguments of each function.

We can write a similar expression for three-dimensional $\vc{g}(t)=(g_1(t),g_2(t),g_3(t))=(x(t),y(t),z(t))$ and $f(x,y,z)$: \begin{align} h'(t) = \pdiff{f}{x}(\vc{g}(t))g_1'(t) + \pdiff{f}{y}(\vc{g}(t))g_2'(t) + \pdiff{f}{z}(\vc{g}(t))g_3'(t). \tag{2b} \end{align} The simplified verision of the 3D formula is \begin{align} \diff{h}{t} = \pdiff{f}{x}\diff{x}{t} + \pdiff{f}{y}\diff{y}{t} + \pdiff{f}{z}\diff{z}{t}. \tag{2b'} \end{align} (Some people even write such a formula as $\diff{f}{t} = \pdiff{f}{x}\diff{x}{t} + \pdiff{f}{y}\diff{y}{t} + \pdiff{f}{z}\diff{z}{t}$ where $f$ is viewed as both a function of $t$ and as a function of $\vc{x}=(x,y,z)$.)

$\vc{g}$ is a function of two variables

If $\vc{g}(s,t)$ is a vector-valued function of two variables, $\vc{g}: \R^2 \to \R^n$, then we can no longer write its derivative as a vector. Instead, the derivative $D\vc{g}(s,t)$ will be a matrix of partial derivatives with two columns, i.e., an $n \times 2$ matrix. Since the derivative of $f: \R^n \to \R$ is a $1 \times n$ matrix, the derivative of the composition $h(s,t)=g(\vc{g}(s,t))$ will be a $1 \times 2$ matrix: \begin{align*} Dh(s,t) = \left[\pdiff{h}{s}(s,t) \quad \pdiff{h}{t}(s,t)\right]. \end{align*} In a similar manner to the above procedure, we can write down component formulas for both $\pdiff{h}{s}$ and $\pdiff{h}{t}$, depending on the dimension $n$.

If $g$ is a scalar-valued function, $x(s,t)=g(s,t)$, then its derivative is the $1 \times 2$ matrix \begin{align*} Dg(s,t) = \left[\pdiff{g}{s}(s,t) \quad \pdiff{g}{t}(s,t)\right] = \left[\pdiff{x}{s} \quad \pdiff{x}{t}\right], \end{align*} where in the second shortcut form, we neglect function arguments. In this case, $f(x)$ must be a function of a single variable and its derivative is the scalar $f'(x)$. By the chain rule \eqref{general_chain_rule}, the derivative of $h$ is \begin{align*} \left[\pdiff{h}{s}(s,t) \quad \pdiff{h}{t}(s,t)\right] = f'(g(s,t)) \left[\pdiff{g}{s}(s,t) \quad \pdiff{g}{t}(s,t)\right]. \end{align*} Writing each component separately, we can write this special case of the chain rule as \begin{align} \pdiff{h}{s}(s,t) &= f'(g(s,t))\pdiff{g}{s}(s,t)\notag\\ \pdiff{h}{t}(s,t) &= f'(g(s,t))\pdiff{g}{t}(s,t)\tag{3a} \end{align} or writing $g$ as $x$, we can write the simplified form as \begin{align} \pdiff{h}{s} &= \diff{f}{x}\pdiff{x}{s}\notag\\ \pdiff{h}{t} &= \diff{f}{x}\pdiff{x}{t}.\tag{3a'} \end{align}

If $\vc{g}$ is a two-dimensional function $(x(s,t),y(s,t))=\vc{g}(s,t)=(g_1(s,t),g_2(s,t))$ and $f$ is a function of two variables, $f(x,y)$, then $D\vc{g}(s,t)$ is a $2 \times 2$ matrix and $Df(x,y)$ is a $1 \times 2$ matrix. The chain rule \eqref{general_chain_rule} can be written as \begin{align*} \left[\pdiff{h}{s}(s,t) \quad \pdiff{h}{t}(s,t)\right] = \left[\pdiff{f}{x}(\vc{g}(s,t))\quad \pdiff{f}{y}(\vc{g}(s,t))\right] \left[\begin{array}{cc}\pdiff{g_1}{s}(s,t) & \pdiff{g_1}{t}(s,t)\\ \pdiff{g_2}{s}(s,t) & \pdiff{g_2}{t}(s,t) \end{array}\right]. \end{align*} We multiply the right two matrices and write each component separately to write the result as \begin{align} \pdiff{h}{s}(s,t) &= \pdiff{f}{x}(\vc{g}(s,t))\pdiff{g_1}{s}(s,t) +\pdiff{f}{y}(\vc{g}(s,t))\pdiff{g_2}{s}(s,t)\notag\\ \pdiff{h}{t}(s,t) &= \pdiff{f}{x}(\vc{g}(s,t))\pdiff{g_1}{t}(s,t) +\pdiff{f}{y}(\vc{g}(s,t))\pdiff{g_2}{t}(s,t). \label{chainrule22}\tag{} \end{align} We can simplify by writing $(g_1,g_2)$ as $(x,y)$: \begin{align} \pdiff{h}{s} &= \pdiff{f}{x}\pdiff{x}{s} +\pdiff{f}{y}\pdiff{y}{s}\notag\\ \pdiff{h}{t} &= \pdiff{f}{x}\pdiff{x}{t} +\pdiff{f}{y}\pdiff{y}{t}.\tag{3b'} \end{align}

Lastly, in three-dimensions, with functions $\vc{g}: \R^2 \to \R^3$ and $f: \R^3 \to \R$, we can write the chain rule in matrix form as \begin{align*} \left[\pdiff{h}{s}(s,t) \quad \pdiff{h}{t}(s,t)\right] = \left[\pdiff{f}{x}(\vc{g}(s,t))\quad \pdiff{f}{y}(\vc{g}(s,t)) \quad \pdiff{f}{z}(\vc{g}(s,t))\right] \left[\begin{array}{cc}\pdiff{g_1}{s}(s,t) & \pdiff{g_1}{t}(s,t)\\ \pdiff{g_2}{s}(s,t) & \pdiff{g_2}{t}(s,t)\\ \pdiff{g_3}{s}(s,t) & \pdiff{g_3}{t}(s,t) \end{array}\right], \end{align*} and in component form as \begin{align} \pdiff{h}{s}(s,t) &= \pdiff{f}{x}(\vc{g}(s,t))\pdiff{g_1}{s}(s,t) +\pdiff{f}{y}(\vc{g}(s,t))\pdiff{g_2}{s}(s,t) +\pdiff{f}{z}(\vc{g}(s,t))\pdiff{g_3}{s}(s,t)\notag\\ \pdiff{h}{t}(s,t) &= \pdiff{f}{x}(\vc{g}(s,t))\pdiff{g_1}{t}(s,t) +\pdiff{f}{y}(\vc{g}(s,t))\pdiff{g_2}{t}(s,t) +\pdiff{f}{z}(\vc{g}(s,t))\pdiff{g_3}{t}(s,t).\tag{3c} \end{align} Writing $\vc{g}(s,t)=(g_1(s,t),g_2(s,t),g_3(s,t))$ as $(x(s,t),y(s,t),z(s,t))$, we obtain the simplified form \begin{align} \pdiff{h}{s} &= \pdiff{f}{x}\pdiff{x}{s} +\pdiff{f}{y}\pdiff{y}{s}+\pdiff{f}{z}\pdiff{z}{s}\notag\\ \pdiff{h}{t} &= \pdiff{f}{x}\pdiff{x}{t} +\pdiff{f}{y}\pdiff{y}{t}+\pdiff{f}{z}\pdiff{z}{t}.\tag{3c'} \end{align}