If there’s one thing compiler writers spill a lot of ink over, it’s tree traversals. Tree traversals are infuriatingly simple conceptually, but can be a lot of boilerplate to actually write. This post covers a couple of tricks that I’ve found useful recently using tools from lens.

Let’s suppose we have a simply lambda calculus with some primitive integer operations.

{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DeriveTraversable #-}
{-# LANGUAGE RankNTypes #-}
module LensesForTreeTraversals where
import Control.Lens
import Data.Monoid
import Data.Functor.Foldable hiding (fold)
import Data.Foldable
import Data.Fix hiding (cata)

type Name = String

data Type = IntegerType | FunType Type Type

data Term = 
    Var Name
    | Lam Name Type Term
    | App Term Term
    | Plus Term Term
    | Constant Integer

We’d like to write a simple constant-folding pass over this AST. The algorithm is extremely simple, in prose:

Recursively transform all the subterms, and if the resulting node is then the sum of two constants, replace it with a constant equal to their sum.

We would really like to reach this level of clarity in our code. The subtlety, of course, is how we talk about “all the subterms”, and that will be the theme of this post.

A naive implementation of the constant folding function is as follows:

-- | Do the local part of the constant folding transformation.
cf :: Term -> Term
cf = \case
    Plus (Constant i1) (Constant i2) -> Constant (i1 + i2)
    x -> x

constantFold :: Term -> Term
-- Do the local transformation after recursively calling ourselves 
-- on the subterms, if any
constantFold t = cf $ case t of 
    Plus t1 t2 -> Plus (constantFold t1) (constantFold t2)
    Lam n ty t -> Lam n ty $ constantFold t
    App t1 t2 -> App (constantFold t1) (constantFold t2)
    x -> x

Tediously, we have to explicitly pull out each subterm and call the function on it recursively, which is not only boilerplate, but error-prone (nothing will tell us if we’ve missed a subterm!).


Now, we want to do something for every subterm, which sounds like the sort of thing we should be able to do with traverse. But it’s a pain to make Traversable work with our AST. In particular Traversable expects your type to have kind * -> *, i.e. to be able to “contain” values of any type. This isn’t really right for us: a term is like a monomorphic container of its subterms.

Enter Traversals from lens. A traversal is very closely related to Traversable.traverse:

-- traverse :: (Traversable t, Applicative f) => (a -> f b) -> t a -> f (t b)
-- type Traversal s t a b = forall f. Applicative f => (a -> f b) -> s -> f t

A Traversal is a traverse-like function, but it can be more specific if that’s appropriate. In particular, we can define a Traversal that only traverses the subterms of a term.

(Maybe there’s a clever way to write traversals, but I do it the stupid way: take the effectful function and the value, and apply it to the subparts.)

-- Traversal' a b = Traversal a a b b, useful if you're not doing clever stuff
termSubterms :: Traversal' Term Term
termSubterms f = \case
    Lam n ty t -> Lam n ty <$> f t 
    App t1 t2 -> App <$> f t1 <*> f t2
    Plus t1 t2 -> Plus <$> f t1 <*> f t2
    -- Terms without subterms. Note that you should *not* put 'f x' as the 
    -- RHS: that would say that the term was its own subterm!
    x -> pure x

What does termSubterms do? It’s like traverse: you give it a function that does stuff to subterms, and it will give you one that does that to all the subterms of a particular term. It’s also usable with a lot of lens functions, for example, it’s a Setter so you can write to all the targets of a Traversal.

Now we can write our constant folder more cleanly:

constantFold2 :: Term -> Term
-- 'over' applies a function to all targets of an optic, we can use it with a 
-- 'Traversal' to apply the function to all of the things which it traverses 
-- (the subterms)
constantFold2 t = cf $ over termSubterms constantFold2 t

I think this is pretty close to our original English specification!

We can abstract this a little bit further to factor our the local part of the transformation. Then we get a nice little function transformOf which is already defined for us in Control.Lens.Plated.

-- The real version is a bit more general than this
-- transformOf :: Traversal' a a -> (a -> a) -> (a -> a)
-- transformOf l f = go where go = f . over l go

constantFold3 :: Term -> Term
constantFold3 = transformOf termSubterms cf

Folds and mixing Traversals

As a bonus, we can also do folds. For example, let’s count the number of term nodes in a term.

countTerms :: Term -> Sum Integer
countTerms t = 
    -- The number of terms is 1...
    Sum 1 
    -- ... plus the number of terms in all the subterms
    <> foldMapOf termSubterms countTerms t

Note that we don’t need to do any case analysis since the algorithm happens to be completely generic over the different kinds of term.

This problem also gives us the opportunity to show off another very useful feature of Traversals: we can mix and match various different Traversals of the same type. In particular, we haven’t counted the types within our terms: if we care about the number of AST nodes then we should probably count them too!

We need a couple more traversals: one to get the types within a term, and one to get the types within a type.

termSubtypes :: Traversal' Term Type
termSubtypes f = \case
    Lam n ty t -> Lam n <$> f ty <*> pure t 
    x -> pure x

typeSubtypes :: Traversal' Type Type
typeSubtypes f = \case
    FunType ty1 ty2 -> FunType <$> f ty1 <*> f ty2
    x -> pure x

Now we can say how to count the nodes in a type, and how to count the nodes in a term. Again, this is completely generic and doesn’t need to do any case analysis.

countTypeNodes :: Type -> Sum Integer
countTypeNodes t = 
    -- The number of nodes is 1...
    Sum 1 
    -- ... plus the number of nodes in all the subtypes
    <> foldMapOf typeSubtypes countTypeNodes t

countTermNodes :: Term -> Sum Integer
countTermNodes t = 
    -- The number of nodes is 1...
    Sum 1 
    -- ... plus the number of nodes in all the subterms 
    <> foldMapOf termSubterms countTermNodes t
    -- ... plus the number of nodes in all the subtypes
    <> foldMapOf termSubtypes countTypeNodes t

There are a bunch of other nice tools lens, although as ever that depends on your willingness to wade through it.

Addendum: why not recursion-schemes?

You can achieve some of these goals withe recursion-schemes. The style there is to define your type as a fixpoint of a “one-level” functor.

data Term2F a = 
    Var2F Name
    | Lam2F Name Type a
    | App2F a a
    | Plus2F a a
    | Constant2F Integer
    deriving (Functor, Foldable, Traversable)
type Term2 = Fix Term2F

Now, recursion schemes is all about folds, so we can go ahead and “fold” (cata) our term into another term.

cf' :: Term2F Term2 -> Term2F Term2
cf' = \case
    Plus2F (Fix (Constant2F i1)) (Fix (Constant2F i2)) -> Constant2F (i1 + i2)
    x -> x
constantFold4 :: Term2 -> Term2
constantFold4 = cata (embed . cf') 

Also, Term2F is of the right shape for Traversable, so we can use traverse as well!

This is fine as far as it goes, but there are two major problems:

  1. recursion-schemes does badly with mutually recursive types. If this is a problem for you, you’ll realise pretty quickly.
  2. recursion-schemes is good at dealing with the sub-parts of the same type, but not those of different types.

For example, let’s try and write countTermNodes.

countTermNodes2 :: Term2 -> Sum Integer
countTermNodes2 = cata f where
   f = \case
       Lam2F _ ty tc -> Sum 1 <> countTypeNodes ty <> tc
       x -> Sum 1 <> fold x

The normal case is similarly concise, but there’s no way to handle the types generically. So we have to do the case for Lam manually, and we would have to do this for every term with a type in, if we had others.

RESPONSES: This post by Oleg Grenrus tackles the mutually recursive traversals problem, and this comment by Chris Penner gives an alternative way of doing the recursion.