[Haskell] Typeful symbolic differentiation of compiled functions

Jacques Carette wrote on LtU on Wed, 11/24/2004 ] One quick (cryptic) example: the same difficulties in being able to ] express partial evaluation in a typed setting occurs in a CAS ] [computer algebra system]. Of course I mean to have a partial ] evaluator written in a language X for language X, and have the partial ] evaluator 'never go wrong'. Cheating by encoding language X as an ] algebraic datastructure in X is counter-productive as it entails huge ] amounts of useless reflection/ reification. One really wants to be ] able to deal with object-level terms simply and directly. But of ] course, that way lies the land of paradoxes (in set theory, type ] theory, logic) ] ] And while I am at it: consider symbolic differentiation. If I call ] that 'function' diff, and have things like diff(sin(x),x) == cos(x), ] what is the type of diff? More interestingly, what if I have D(\x -> ] sin(x) ) == \x -> cos(x) What is the type of D ? Is it implementable ] in Ocaml or Haskell? [Answer: as far as I know, it is not. But that ] is because as far as I can tell, D can't even exist in System F. You ] can't have something like D operating on opaque lambda terms.]. But ] both Maple and Mathematica can. And I can write that in LISP or Scheme ] too. In this message, we develop the `symbolic' differentiator for a subset of Haskell functions (which covers arithmetics and a bit of trigonometry). We can write test1f x = x * x + fromInteger 1 test1 = test1f (2.0::Float) test2f = diff_fn test1f test2 = test2f (3.0::Float) We can evaluate our functions _numerically_ -- and differentiate them _symbolically_. Partial derivatives are supported as well. To answer Jacques Carette's question: the type of the derivative operator (which is just a regular function) is diff_fn :: (Num b, D b) => (forall a. D a => a -> a) -> b -> b where the class D includes Floats. One can add exact reals and other similar things. The key insight is that Haskell98 supports a sort of a reflection -- or, to be precise, type-directed partial evaluation and hence term reconstructions. The very types that are assumed of great hindrance to computer algebra and reflective systems turn out indispensable in being able to operate on even *compiled* terms _symbolically_. We must point out that we specifically do _not_ represent our terms as algebraic datatypes. Our terms are regular Haskell terms, and can be compiled! That is in stark contrast with Scheme, for example: although Scheme may permit term reconstruction under notable restrictions, that ability is not present in the compiled code. In general, we cannot take a _compiled_ function Float->Float and compute its derivative symbolically, yielding another Float->Float function. Incidentally, R5RS does not guarantee the success of type-directed partial evaluation even in the interpreted code. Jacques Carette has mentioned `useless reflection/reification'. The paper `Tag Elimination and Jones-Optimality' by Walid Taha, Henning Makholm and John Hughes has introduced a novel tag elimination analysis as a way to remove all interpretative overhead. In this message, we do _not_ use that technique. We exploit a different idea, whose roots can be traced back to Forth. It is remarkable how Haskell allows that technique. Other features of our approach are: an extensible differentiation rule database; emulation of GADT with type classes. This message is the complete code. > {-# OPTIONS -fglasgow-exts #-} > -- We only need existentials. In the rest, it is Haskell98! > -- Tested with GHC 6.2.1 and 6.3.20041106-snapshot > > module Diff where > import Prelude hiding ((+), (-), (*), (/), (^), sin, cos, fromInteger) > import qualified Prelude First we declare the domain of `differentiable' (by us) functions > class D a where > (+):: a -> a -> a > (*):: a -> a -> a > (-):: a -> a -> a > (/):: a -> a -> a > (^):: a -> Int -> a > sin:: a -> a > cos:: a -> a > fromInteger:: Integer -> a and inject floats into that domain > instance D Float where > (+) = (Prelude.+) > (-) = (Prelude.-) > (*) = (Prelude.*) > (/) = (Prelude./) > (^) = (Prelude.^) > sin = Prelude.sin > cos = Prelude.cos > fromInteger = Prelude.fromInteger For symbolic manipulation, we need a representation for (reconstructed) terms > -- Here, reflect is the tag eliminator -- or `compiler' > class Term t a | t -> a where > reflect :: t -> a -> a We should point out that the terms are fully typeful. > newtype Const a = Const a deriving Show > data Var a = Var deriving Show > data Add x y = Add x y deriving Show > data Sub x y = Sub x y deriving Show > data Mul x y = Mul x y deriving Show > data Div x y = Div x y deriving Show > data Pow x = Pow x Int deriving Show > newtype Sin x = Sin x deriving Show > newtype Cos x = Cos x deriving Show We can now describe the grammar of our term representation in the following straightforward way: > instance Term (Const a) a where reflect (Const a) = const a > > instance Term (Var a) a where reflect _ = id > > instance (D a, Term x a, Term y a) => Term (Add x y) a > where > reflect (Add x y) = \a -> (reflect x a) + (reflect y a) > > instance (D a, Term x a) => Term (Sin x) a > where > reflect (Sin x) = sin . reflect x The other instances are given in the Appendix. This is the straightforward emulation of GADT. The function `reflect' removes the `tags' after the symbolic differentiation. Actually, `Sin' is a newtype constructor, so there is no run-time tag to eliminate in this case. We must stress that there is no `reify' function. One may say it is built into Haskell already. We only need to declare the datatype for the reified code > data Code a = forall t. (Show t, Term t a, DiffRules t a) => Code t > instance Show a => Show (Code a) where show (Code t) = show t > reflect_code (Code c) = reflect c inject the reified code in the D domain > instance (Num a, D a) => D (Code a) where > Code x + Code y = Code $ Add x y > Code x - Code y = Code $ Sub x y > Code x * Code y = Code $ Mul x y > Code x / Code y = Code $ Div x y > (Code x) ^ n = Code $ Pow x n > sin (Code x) = Code $ Sin x > cos (Code x) = Code $ Cos x > fromInteger n = Code $ Const (fromInteger n) and we're done with the first part: We can define a function > test1f x = x * x + fromInteger 1 > test1 = test1f (2.0::Float) we can even compile it. At any point, we can reify it > test1c = test1f (Code Var :: Code Float) and reflect it back: > test1f' = reflect_code test1c > test1' = test1f' (2.0::Float) *Diff> test1 5.0 *Diff> test1' 5.0 *Diff> test1c Add (Mul Var Var) (Const 1.0) The differentiation part is quite straightforward. We declare a class for differentiation rules > class (Term t a,D a) => DiffRules t a | t -> a where > diff :: t -> Code a The rules are the instances of the class DiffRules > instance (Num a, D a) => DiffRules (Const a) a where > diff _ = Code $ Const 0 > > instance (Num a, D a) => DiffRules (Var a) a where > diff _ = Code $ Const 1 > > instance (Show x, Show y, DiffRules x a, DiffRules y a) > => DiffRules (Mul x y) a where > diff (Mul x y) = case (diff x,diff y) of > (Code x'::Code a,Code y') -> > Code $ Add (Mul (x::x) y') (Mul x' (y::y)) > > > instance (Num a, Show x, DiffRules x a) > => DiffRules (Sin x) a where > diff (Sin x) = case diff x of > (Code x'::Code a) -> > Code $ Mul x' (Cos x) The other instances are in the Appendix. The approach is scalable -- we may add more rules later, in other modules. And that's about it: > diff_code (Code c) = diff c > > diff_fn :: (Num b, D b) => (forall a. D a => a -> a) -> b -> b > diff_fn f = > let code = f (Code Var) > in reflect_code $ diff_code code the differentiation operator could not be any simpler. We can try > test2f = diff_fn test1f > test2 = test2f (3.0::Float) we can even see the differentiation result, symbolically: *Diff> diff_code test1c Add (Add (Mul Var (Const 1.0)) (Mul (Const 1.0) Var)) (Const 0.0) True, simplifications are direly needed. Well, the full computer algebra system is a little bit too big to be developed over one evening. Besides, I wanted to go home three hours ago. Here's a slightly more complex example: > test5f x = sin (fromInteger 5*x) + cos(fromInteger 1/x) > test5c = test5f (Code Var :: Code Float) > > test5 = test5f (pi::Float) > test5d = diff_code test5c > > test6 = diff_fn test5f (pi::Float) One can evaluate the function test5f numerically, differentiate it symbolically, check the result of differentiation -- and evaluate it numerically right away. We can even do partial derivatives: > test3f x y = (x*y + ((fromInteger 5)*(x^2))) / y > > test3c1 = test3f (Code Var :: Code Float) (fromInteger 10) > > test4x y = diff_fn (\x -> test3f x (fromInteger y)) > test4y x = diff_fn (test3f (fromInteger x)) -- *Diff> test4x 1 (2::Float) -- partial derivative with respect to x -- 21.0 -- *Diff> test4y 5 (5::Float) -- partial derivative with respect to y -- -5.0 Appendix: > instance (D a, Term x a, Term y a) => Term (Sub x y) a > where > reflect (Sub x y) = \a -> (reflect x a) - (reflect y a) > > instance (D a, Term x a, Term y a) => Term (Mul x y) a > where > reflect (Mul x y) = \a -> (reflect x a) * (reflect y a) > > instance (D a, Term x a, Term y a) => Term (Div x y) a > where > reflect (Div x y) = \a -> (reflect x a) / (reflect y a) > > instance (D a, Term x a) => Term (Pow x) a > where > reflect (Pow x n) = (^ n) . reflect x > > instance (D a, Term x a) => Term (Cos x) a > where > reflect (Cos x) = cos . reflect x > instance (Show x, Show y, DiffRules x a, DiffRules y a) > => DiffRules (Add x y) a where > diff (Add x y) = case (diff x,diff y) of > (Code x'::Code a,Code y') -> > Code $ Add x' y' > > instance (Show x, Show y, DiffRules x a, DiffRules y a) > => DiffRules (Sub x y) a where > diff (Sub x y) = case (diff x,diff y) of > (Code x'::Code a,Code y') -> > Code $ Sub x' y' > > instance (Num a, Show x, Show y, DiffRules x a, DiffRules y a) > => DiffRules (Div x y) a where > diff (Div x y) = case (diff x,diff y) of > (Code x'::Code a,Code y') -> > Code $ > Div (Sub (Mul x' y) (Mul x y')) > (Pow y 2) > > instance (Num a, Show x, DiffRules x a) > => DiffRules (Pow x) a where > diff (Pow x n) = case diff x of > (Code x'::Code a) -> > Code $ Mul (Const (fromInteger $ toInteger n)) > (Mul x' (Pow x (n Prelude.- 1))) > instance (Num a, Show x, DiffRules x a) > => DiffRules (Cos x) a where > diff (Cos x) = case diff x of > (Code x'::Code a) -> > Code $ Mul x' (Sub (Const 0) (Sin x))