8th February 2009, 02:14 pm

In What is automatic differentiation, and why does it work?, I gave a semantic model that explains what automatic differentiation (AD) accomplishes. Correct implementations then flowed from that model, by applying the principle of type class morphisms. (An instance’s interpretation is the interpretation’s instance).

I’ve had a nagging discomfort about the role of the chain rule in AD, with an intuition that the chain rule can carry a more central role the the specification and implementation. This post gives a variation on the previous AD post that carries the chain rule further into the reasining and implementation, leading to simpler correctness proofs and a nearly unaltered implementation.

Finally, as a bonus, I’ll show how GHC rewrite rules enable an even simpler and more modular implementation.

I’ve included some optional content, including exercises. You can see my answers to the exercises by examining the HTML.

As before, I’ll start with a limited form of differentiation that works for functions of a scalar (1D) domain, where one can identify derivative values with regular values:

deriv :: Num a => (a -> a) -> (a -> a) -- simplification

The development below extends to higher-order derivatives and higher-dimensional domains.

The chain rule

At the heart of AD is the chain rule:

deriv (g . f) x == deriv g (f x) * deriv f x

Equivalently,

deriv (g . f) == (deriv g . f) * deriv f

where this (*) is on functions: (*) = liftA2 (*) == h k x -> h x * k x .

The traditional forward AD formulation is based on the chain rule but is not as symmetric as the chain rule. In the function compositions g . f considered, g is always simple, while f may be arbitrarily complex. (I think the reverse true for reverse-mode AD.) What might we find if we delay introducing this asymmetry?

A direct implementation of the chain rule

The chain rule applies to functions and their derivatives, so let’s formulate a direct implementation. Start with a type for holding two functions:

data FD a = FD (a -> a) (a -> a)

FD is used to hold functions and their derivatives:

toFD :: (a -> a) -> FD a toFD f = FD f (deriv f)

We do not have an implementation of deriv , so toFD here is part of the specification only, not the implementation.

Now we can specify a composition operator on FD :

(~.~) :: FD a -> FD a -> FD a

We’ll want (~.~) to represent composition of functions, where we have access to the derivatives as well. That is, (~.~) must satisfy:

toFD (g . f) == toFD g ~.~ toFD f

The implementation and its correctness follow from the chain rule:

FD g g' ~.~ FD f f' = FD (g . f) ((g' . f) * f')

Exercise: Fill in the proof that (~.~) satisfies its specification.

toFD (g . f) == {- def of toFD -} FD (g . f) (deriv (g . f)) == {- chain rule -} FD (g . f) ((deriv g . f) * deriv f) == {- def of (~.~) -} FD g (deriv g) ~.~ FD f (deriv f) == {- def of toFD -} toFD g ~.~ toFD f

(Exercise solutions are in the post’s HTML.)

From function to values

The FD type and its composition operator implement the chain rule quite directly. However, they are not suitable for AD, which operates on the values (range) of a function and of its derivative.

Let’s start over with the usual AD value representation:

data D a = D a a

Previously, I defined this ideal construction function:

toD :: (a -> a) -> a -> D a toD f x = D (f x) (deriv f x) -- or toD == liftA2 D f (deriv f)

Instead, let’s now define toD in terms of toFD . First, how do the FD and D representations relate?

fdAt :: FD a -> (a -> D a) fdAt (FD f f') = liftA2 D f f'

Then we can define toD :

toD = fdAt . toFD

Exercise: Show that these definitions of toD are equivalent.

Expanding, toD f == fdAt (toFD f) == fdAt (FD f (deriv f)) == liftA2 D f (deriv f) == x -> D (f x) (deriv f x)

Again, toD isn’t executable with this definition, because toFD isn’t (because deriv isn’t). As before, toD must be eliminated in our journey from specification to implementation.

Optional: We can also define an odd sort of inverse for toD : fromD :: D a -> a -> (a -> a) fromD must satisfy, for all x , toD (fromD d x) x == d It’s more convenient to relate flipped version of toD and fromD : toD' :: a -> (a -> a) -> D a fromD' :: a -> D a -> (a -> a) toD' = flip toD fromD' = flip fromD Then fromD must satisfy, for all x , toD' x . fromD' x == id Exercise: Give a simple definition for fromD and show that it’s correct (satisfies its specification). Define fromD (D a a’) x = t -> a + a’ * (t – x) Then toD (fromD (D a a’) x) x == D (fromD (D a a’) x) (deriv (fromD (D a a’)) x) == D (( t -> a + a’ * (t – x)) x) (deriv ( t -> a + a’ * (t – x)) x) == D a (deriv ( t -> a + a’ * (t – x)) x) == D a (( t -> a’) x) == D a a’

A general, value-friendly chain rule

In What is AD …?, I defined correctness of the numeric class instances D by saying that toD must be a type class morphism for each of the numeric classes it implements. For example, let’s take the sin method. The other unary methods will work just like it. The morphism property:

toD (sin u) == sin (toD u)

Because of numeric overloading on functions, this property is equivalent to a more explicit one:

toD (sin . u) == sin . toD u

The sin on the left is on numbers, and the sin on the right is on D a .

Let’s suppose we have a function adiff (for automatic differentiation) such that for all g and f ,

toD (g . f) == adiff g . toD f -- specification of adiff adiff :: Num => (a -> a) -> (D a -> D a)

Then our goal would become

adiff sin . toD u == sin . toD u

and a correct definition of sin would be immediate, as would be the other definitions:

sin = adiff sin sqrt = adiff sqrt ...

Note that the adiff specification above implies that for all g ,

toD g == adiff g . toD id

Exercise: Show that a necessary and sufficient definition for adiff satisfying its specification is adiff g (D a a') = D (g a) (deriv g a * a') Derive this definition of adiff from its specification.

Proof details: First, sufficient. LHS: toD (g . f) == {- toD def -} liftA2 D (g . f) (deriv (g . f)) == {- chain rule -} liftA2 D (g . f) ((deriv g . f) * deriv f) == {- liftA2 on functions -} x -> D (g (f x)) (deriv g (f x) * deriv f x) RHS: adiff g . toD f == {- toD def -} adiff g . x -> D (f x) (deriv f x) == {- (.) def -} x -> adiff g (D (f x) (deriv f x)) The definition then falls out mechanically (with [second order pattern matching][]). Next, necessary. adiff g dd@(D a a’) == {- toD/fromD -} adiff g (toD (fromD dd x) x) == {- adiff spec -} toD (g . fromD dd x) x == {- toD def -} D (g (fromD dd x x)) (deriv (g . fromD dd x) x) == {- chain rule -} D (g (fromD dd x x)) (deriv g (fromD dd x x) * deriv (fromD dd x) x) == {- fromD spec -} D (g a) (deriv g a * a’)

The adiff function satisfies a more symmetric property as well. It distributes over composition:

adiff (h . g) == adiff h . adiff g

Exercise: Prove it this property from the specification.

Here’s a proof from the definition. Start from the right-hand side: (adiff h . adiff g) (D a a’) == {- (.) def -} adiff h (adiff g (D a a’)) == {- adiff def -} adiff h (D (g a) (deriv g a * a’)) == {- adiff def -} D (h (g a)) (deriv h (g a) * (deriv g a * a’)) == {- associativity of (*) -} D (h (g a)) ((deriv h (g a) * deriv g a) * a’) == {- chain rule -} D (h (g a)) (deriv (h . g) a * a’) == {- (.) def -} D ((h . g) a) (deriv (h . g) a * a’) == {- adiff def -} adiff (h . g) (D a a’) Alternatively, prove from the specification alone: For all `h`, `g`, `f`, adiff (h . g) . toD f == {- adiff spec -} toD ((h . g) . f) == {- (.) associativity -} toD (h . (g . f)) == {- adiff spec -} adiff h . toD (g . f) == {- adiff spec -} adiff h . (adiff g . toD f) == {- (.) associativity -} (adiff h . adiff g) . toD f So `adiff (h . g)` and `(adiff h . adiff g)` agree on `toD f x` for all `f` and `x`. Given an arbitrary `D` value `dd`, choose an arbitrary `x`. Then adiff (h . g) dd == adiff (h . g) (toD (fromD dd x) x) == (adiff h . adiff g) (toD (fromD dd x) x) == (adiff h . adiff g) dd

Moreover, adiff maps the identity to the identity: adiff id = id .

Exercise: Show that for any definition of adiff , if for all g , toD g == adiff g . toD id and if for all h and g adiff (h . g) == adiff h . adiff g then our adiff specification holds, i.e., for g and f , adiff g . toD f == toD (g . f)

toD (g . f) == {- first assumption -} adiff (g . f) . toD id == {- second assumption -} (adiff g . adiff f) . toD id == {- associativity of (.) -} adiff g . (adiff f . toD id) == {- first assumption -} adiff g . toD f

Back to an implementation

We’re still not quite done, since adiff depends on deriv , which doesn’t have an implementation. Let’s separate out the problematic deriv by refactoring adiff :

adiff g = g >-< deriv g

where

infix 0 >-< (>-<) :: Num a => (a -> a) -> (a -> a) -> (D a -> D a) (g >-< g') (D a a') = D (g a) (g' a * a')

After inlining this definition of adiff , the method definitions are

sin = sin >-< deriv sin sqrt = sqrt >-< deriv sqrt ...

Every remaining use of deriv is applied to a function whose derivative is known, so we can replace each use.

sin = sin >-< cos sqrt = sqrt >-< recip (2 * sqrt) ...

Now we have an executable implementation again. These method definitions and the definition of (>-<) are exactly as in What is automatic differentiation, and why does it work?.

Fun with rules

Let’s back up to our more elegant method definitions using adiff :

sin = adiff sin sqrt = adiff sqrt ...

We made these definitions executable in spite of their appeal to the non-executable deriv by (a) refactoring adiff to split the deriv from the residual function (>-<) , (b) inlining adiff , and (c) rewriting applications of deriv with known derivative rules.

Now let’s get GHC to do these steps for us.

List the derivatives of known functions:

{-# RULES "deriv negate" deriv negate = -1 "deriv abs" deriv abs = signum "deriv signum" deriv signum = 0 "deriv recip" deriv recip = - sqr recip "deriv exp" deriv exp = exp "deriv log" deriv log = recip "deriv sqrt" deriv sqrt = recip (2 * sqrt) "deriv sin" deriv sin = cos "deriv cos" deriv cos = - sin "deriv asin" deriv asin = recip (sqrt (1-sqr)) "deriv acos" deriv acos = recip (- sqrt (1-sqr)) "deriv atan" deriv atan = recip (1+sqr) "deriv sinh" deriv sinh = cosh "deriv cosh" deriv cosh = sinh "deriv asinh" deriv asinh = recip (sqrt (1+sqr)) "deriv acosh" deriv acosh = recip (- sqrt (sqr-1)) "deriv atanh" deriv atanh = recip (1-sqr) #-}

Notice that these definitions are simpler and more modular than the standard differentiation rules, as they do not have the chain rule mixed in. For instance, compare (a) deriv sin = cos , (b) deriv (sin u) == cos u * deriv u , and (c) deriv (sin u) x == cos u x * deriv u x .

Now we can use the incredibly simple adiff -based definitions of our methods, e.g., asin = adiff asin .

The definition of adiff must get inlined so as to reveal the deriv applications, which then get rewritten according to the rules. Fortunately, the adiff definition is tiny, which encourages its inlining. We could add an INLINE pragma as a reminder. GHC requires that a definition must be given deriv , even it all uses are rewritten away, so use the following

deriv = error "deriv: undefined. Missing rewrite rule?"

In What is automatic differentiation, and why does it work?, I gave a semantic model that explains what automatic differentiation (AD) accomplishes. Correct implementations then flowed from that model, by...