Traversing syntax trees

When working with syntax trees (such as in a type theory interpreter) you often want to apply some operation to all subtrees of a node, or to all nodes of a certain type. Of course you can do this easily by writing a recursive function. But then you would need to have a case for every constructor, and there can be many constructors.

Instead of writing a big recursive function for each operation, it is often easier to use a traversal function. Which is what this post is about. In particular, I will describe my favorite way to handle such traversal, in the hope that it is useful to others as well.

As a running example we will use the following data type, which represents expressions in a simple lambda calculus

data Exp = Var ! Int | Lam Exp | App Exp Exp | Global String deriving Show example 1 :: Exp example 1 = Lam $ Var 0 example 2 :: Exp example 2 = Lam $ Lam $ Var 1 example 3 :: Exp example 3 = Lam $ Lam $ Lam $ App ( Var 2 ) ( App ( Var 1 ) ( Var 0 ))

Now, what do I mean by a traversal function? The base library comes with the Traversable class, but that doesn't quite fit our purposes, because that class is designed for containers that can contain any type a. But expressions can only contain other sub-expressions. Instead we need a monomorphic variant of traverse for our expression type:

traverseExp :: Applicative f => ( Exp -> f Exp ) -> ( Exp -> f Exp )

The idea is that traverseExp applies a given function to all direct children of an expression.

The uniplate package defines a similar function, descendM . But it has two problems: 1) descendM has a Monad constraint instead of Applicative , and 2) the class actually requires you to implement a uniplate method, which is more annoying to do.

The ever intimidating lens package has a closer match in plate . But aside from the terrible name, that function also lacks a way to keep track of bound variables.

For a language with binders, like the lambda calculus, many operations need to know which variables are bound. In particular, when working with de Bruijn indices, it is necessary to keep track of the number of bound variables. To do that we define

type Depth = Int traverseExpD :: Applicative f => ( Depth -> Exp -> f Exp ) -> ( Depth -> Exp -> f Exp ) traverseExpD _ _ ( Var i ) = pure ( Var i ) traverseExpD f d ( Lam x ) = Lam <$> f ( d + 1 ) x traverseExpD f d ( App x y ) = App <$> f d x <*> f d y traverseExpD _ _ ( Global x ) = pure ( Global x )

Once we have written this function, other traversals can be defined in terms of traverseExpD

traverseExp :: Applicative f => ( Exp -> f Exp ) -> ( Exp -> f Exp ) traverseExp f = traverseExpD ( const f ) 0

And map and fold are just traversals with a specific applicative functor, Identity and Const a respectively. Recent versions of GHC are smart enough to know that it is safe to coerce from a traversal function to a mapping or folding one.

mapExpD :: ( Depth -> Exp -> Exp ) -> ( Depth -> Exp -> Exp ) mapExpD = coerce ( traverseExpD :: ( Depth -> Exp -> Identity Exp ) -> ( Depth -> Exp -> Identity Exp )) mapExp :: ( Exp -> Exp ) -> ( Exp -> Exp ) mapExp = coerce ( traverseExp :: ( Exp -> Identity Exp ) -> ( Exp -> Identity Exp )) foldExpD :: forall a . Monoid a => ( Depth -> Exp -> a ) -> ( Depth -> Exp -> a ) foldExpD = coerce ( traverseExpD :: ( Depth -> Exp -> Const a Exp ) -> ( Depth -> Exp -> Const a Exp )) foldExp :: forall a . Monoid a => ( Exp -> a ) -> ( Exp -> a ) foldExp = coerce ( traverseExp :: ( Exp -> Const a Exp ) -> ( Exp -> Const a Exp ))

After doing all this work, it is easy to answer questions like "how often is a variable used?"

varCount :: Depth -> Exp -> Sum Int varCount i ( Var j ) | i == j = Sum 1 varCount i x = foldExpD varCount i x

or "what is the set of all free variables?"

freeVars :: Depth -> Exp -> Set Int freeVars d ( Var i ) | i < d = Set . empty | otherwise = Set . singleton ( i - d ) freeVars d x = foldExpD freeVars d x

Or to perform (silly) operations like changing all globals to lower case

lowerCase :: Exp -> Exp lowerCase ( Global x ) = Global ( map toLower x ) lowerCase x = mapExp lowerCase x

These functions follows a common pattern of specifying how a particular constructor, in this case Var or Global , is handled, while for all other constructors traversing over the child expressions.

As another example, consider substitution, a very important operation on syntax trees. In its most general form, we can combine substitution with raising expressions to a larger context (also called weakening). And we should also consider leaving the innermost, bound, variables alone. This means that there are three possibilities for what to do with a variable.

substRaiseByAt :: [ Exp ] -> Int -> Depth -> Exp -> Exp substRaiseByAt ss r d ( Var i ) | i < d = Var i | i - d < length ss = raiseBy d ( ss !! ( i - d )) | otherwise = Var ( i - length ss + r ) substRaiseByAt ss r d x = mapExpD ( substRaiseByAt ss r ) d x

Similarly to varCount , we use mapExpD to handle all constructors besides variables. Plain substitution and raising are just special cases.

substRaiseBy :: [ Exp ] -> Int -> Exp -> Exp substRaiseBy ss r = substRaiseByAt ss r 0 raiseBy :: Int -> Exp -> Exp raiseBy r = substRaiseBy [ ] r subst :: [ Exp ] -> Exp -> Exp subst ss = substRaiseBy ss 0

λ> raiseBy 2 ( App ( Var 1 ) ( Var 2 )) App ( Var 3 ) ( Var 4 ) λ> subst [ Global "x" ] ( App ( Var 0 ) ( Lam ( Var 0 ))) App ( Global "x" ) ( Lam ( Var 0 )) λ> substRaiseBy [ App ( Global "x" ) ( Var 0 ) ] 2 $ App ( Lam ( App ( Var 1 ) ( Var 0 ))) ( Var 2 ) App ( Lam ( App ( App ( Global "x" ) ( Var 1 )) ( Var 0 ))) ( Var 3 )

As a slight generalization, it can also make sense to put traverseExpD into a type class. That way we can traverse over the subexpressions inside other data types. For instance, if the language uses a separate data type for case alternatives, we might write

data Exp = ... | Case [ Alt ] data Alt = Alt Pat Exp class TraverseExp a where traverseExpD :: Applicative f => ( Depth -> Exp -> f Exp ) -> ( Depth -> a -> f a ) instance TraverseExp a => TraverseExp [ a ] where traverseExpD f d = traverse ( traverseExpD f d ) instance TraverseExp Exp where traverseExpD f d ... traverseExpD f d ( Case xs ) = Case <$> traverseExpD f d xs instance TraverseExp Alt where traverseExpD f d ( Alt x y ) = Alt x <$> traverseExpD f ( d + varsBoundByPat x ) y

Another variation is to track other things besides the number of bound variables. For example we might track the names and types of bound variables for better error messages. And with a type class it is possible to track different aspects of bindings as needed,

class Env env where extend :: VarBinding -> env -> env instance Env Depth where extend _ = ( + 1 ) instance Env [ VarBinding ] where extend = (:) instance Env () where extend _ _ = () traverseExpEnv :: Applicative f => ( env -> Exp -> f Exp ) -> ( env -> Exp -> f Exp ) traverseExpEnv f env ( Lam name x ) = Lam <$> f ( extend name env ) x traverseExpEnv f env ...

Overall, I have found that after writing traverseExpD once, I rarely have to look at all constructors again. I can just handle the default cases by traversing the children.