Tricking GHC into evaluating recursive functions at compile time

Here is a trick I came up with for a project of mine. Suppose you have a GADT like this very simple one:

data T a where TInt :: Int -> T Int TPair :: T a -> T b -> T (a,b)

and a function which does something with it:

sumT :: T a -> Int sumT (TInt n) = n sumT (TPair l r) = sumT l + sumT r

Now, let’s use the two:

term = TPair (TPair (TInt 1) (TInt 2)) (TInt 3) foo = sumT term

Since foo is constant, we would expect GHC to evaluate it at compile time and just bind it to 6 in the compiled code, right?

Wrong! For this to happen, GHC would have to inline sumT . But sumT is a recursive function and GHC never inlines those because it might get into an infinite loop otherwise. This means that it won’t optimise foo at all which was absolutely unacceptable in my program. I spent about two days fiddling with inline pragmas, rewrite rules and other unpleasant things until I found a satisfactory solution.

My first attempt was to only inline sumT if it is applied to a constructor. We could try adding a couple of rewrite rules.

"sumT/TInt" forall n. sumT (TInt n) = n "sumT/TPair" forall l r. sumT (TPair l r) = sumT l + sumT r

Alas, this doesn’t work most of time. Basically, trying to match on non-trivial constructors in rewrite rules is never a good idea. We could introduce “virtual” constructors, use them everywhere instead of the real ones and match on them.

tInt :: Int -> T Int {-# NOINLINE CONLIKE tInt #-} tInt = TInt tPair :: T a -> T b -> T (a,b) {-# NOINLINE CONLIKE tPair #-} tPair = TPair "sumT/tInt" forall n. sumT (tInt n) = n "sumT/tPair" forall l r. sumT (tPair l r) = sumT l + sumT r

This works much better but, unfortunately, still fails in my program. There, I make extensive use of type families so the Core generated by GHC has casts all over the place. Casts make rule matching highly unreliable because rules don’t ignore them (a particularly ugly wart that I keep running into). So what to do?

The solution I came up with requires adding a unit component to every recursive constructor.

data T a where TInt :: Int -> T Int TPair :: () -> T a -> T b -> T (a,b)

Where we previously wrote TPair , we will now write TPair () . In fact, let’s provide a convenience function for that:

tPair :: T a -> T b -> T (a,b) tPair = TPair ()

Now, we define a non-recursive version of sumT which is parametrised with a function it is supposed to apply to the pair components.

sumT_cont :: (forall a. () -> T a -> Int) -> T a -> Int {-# INLINE sumT_cont #-} sumT_cont cont (TInt n) = n sumT_cont cont (TPair u l r) = cont u l + cont u r

Note that since sumT_cont isn’t recursive it can be freely inlined. Note also that we pass the unit value from the constructor to cont . This is absolutely essential.

The actual recursive sum is defined via sumT_cont . Of course, it has to be parametrised with () (which it ignores).

sumT' :: () -> T a -> Int sumT' _ = sumT_cont sumT' sumT :: T a -> Int {-# INLINE sumT #-} sumT = sumT' ()

The final missing piece that makes the whole thing work is this simple rewrite rule:

"sumT'" sumT' () = sumT_cont sumT'

It “inlines” sumT' but only if it is applied to () . Why is this useful? Let’s see what happens if we apply sumT to a term which GHC knows nothing about:

sumT x = {inline sumT} sumT' () x = {apply rule "sumT'"} sumT_cont sumT' x = {inline sumT_cont} case x of TInt n -> n TPair u l r -> sumT' u l + sumT' u r

Rewriting sumT' to sumT_cont sumT’ again would be a disaster as it would put us into an infinite rewriting loop. This is precisely the reason why GHC won’t inline recursive functions. But our rule doesn’t match here because u is not guaranteed to be () !

So what happens if we apply sumT to a term that is at least partially static?

sumT (tPair (tInt 1) y) = {inline sumT and tPair} sumT' () (TPair () (TInt 1) y) = {apply rule "sumT'"} sumT_cont sumT' (TPair () (TInt 1) y) = {inline sumT_cont} case TPair () (TInt 1) y of TInt n -> n TPair u l r -> sumT' u l + sumT' u r = {eliminate case} sumT' () (TInt 1) + sumT' () y = {apply rule "sumT'" twice} sumT_cont sumT' (TInt 1) + sumT_cont sumT' y = {inline sumT_cont, eliminate case} 1 + case y of TInt n -> n TPair u l r -> sumT' u l + sumT' u r

This looks good! In effect, GHC executed sumT for the statically known portion of the term at compile time and deferred the rest to run time. This worked because when it eliminated the case on TPair it bound u in the case alternative to () . This allowed it to apply the "sumT'" rule again and thus to get rid of the TInt constructor in the left component. The right component is unknown, though, so rewriting stops there. In general, after “inlining” (via the rewrite rule) sumT' once, GHC will only apply the rule again if it eliminates the case, thus binding u to () . This, in turn, is only possible if the head of the term is a known constructor so GHC will continue rewriting and inlining until it consumes all known constructors but will not get into an infinite loop. For foo from my first example, which is fully constant, it will perform the entire computation at compile time and reduce it to 6.

A word of warning: it is possible to get GHC into an infinite loop with this approach by constructing infinite but statically known terms. For instance, we could apply the same technique to this type.

data U = UInt Int | UPair U U

But now, this term gets us into trouble:

x = UPair (UInt 1) x

This technique works best with GADTs like T that do not admit infinite terms.