Update: Friday, August 23, 2013

This post is from 2011, but has seen a lot of traffic lately and drawn some comments that the solution given here is incomplete or “doesn’t work”. This post is definitely incomplete, but the solution does work. For the most up-to-date code, please see my paper Stackless Scala with Free Monads as well as the source code for scalaz.Free.

Consider a simple reader monad:

case class IntReader[A](run: Int => A) { def map[B](f: A => B): IntReader[B] = IntReader(i => f(run(i))) def flatMap[B](f: A => IntReader[B]): IntReader[B] = IntReader(i => f(run(i)).run(i)) }

Now say we have a chain of flatMaps of arbitrary length. Let’s say 100,000. Let’s mock that up using a list:

List.range(0, 100000).foldLeft( IntReader(List(_)))( (a, e) => a.flatMap(xs => IntReader(_ => e :: xs)))

This is going to result in a single function that crashes with a StackOverflowError. The reason why is that flatMap makes nested calls to apply on its argument without being in tail position. So the call stack repeatedly shows a call to apply in an anonymous function.

CPS Transformation

The classical way of avoiding the call stack in this situation is to transform the program to continuation-passing style (CPS). The CPS-transformed version of our reader monad looks like this:

trait IntReader[A] { def apply[R](k: A => R, i: Int): R def map[B](f: A => B): IntReader[B] = new IntReader[B] { def apply[R](k: B => R, i: Int): R = IntReader.this(a => k(f(a)), i) } def flatMap[B](f: A => IntReader[B]): IntReader[B] = new IntReader[B] { def apply[R](k: B => R, i: Int): R = IntReader.this(a => k(f(a)(b => b, i)), i) } }

Instead of returning A directly from apply , we take a continuation k that receives the A . You can see how this would be an improvement. The calls to apply are now all in tail position. And the calls to the continuation k at every point is also in tail position. Unfortunately, Scala has very limited tail call elimination, which is able to eliminate a tail call only if it’s a recursive call to the current method. But note that apply above actually calls a different apply method: that of the containing module. And since k is different from the apply method of the function it’s called from, the call to the continuation cannot be eliminated either. So if we traverse our list with this CPS-transformed reader monad, we will still get a StackOverflowError.

Trampolining

What we must do is exchange stack for heap. The idea is simple. Instead of making a tail call, we return a data structure representing what to do next.

sealed trait Trampoline[+A] { def run: A = this match { case Done(a) => a case More(t) => t().run } } case class Done[A](a: A) extends Trampoline[A] case class More[A](a: () => Trampoline[A]) extends Trampoline[A]

Note that the run method is tail recursive. We can now use our trampoline to turn mutual recursion into tail recursion (thanks, Rich):

def even(n: Int): Trampoline[Boolean] = { if (n == 0) Done(true) else More(() => odd(n - 1)) } def odd(n: Int): Trampoline[Boolean] = { if (n == 0) Done(false) else More(() => even(n - 1)) }

No matter how deep the mutual recursion, calling either of these methods simply returns a Trampoline that we can unwind tail-recursively with run :

scala> val b = odd(100000001).run b: Boolean = true

Trampolines of Trampolines

Now let’s say we wanted to transform a binary recursion in the same way. For example, the (terribly inefficient) recursive function to find the nth Fibonacci number:

def fib(n: Int): Int = if (n < 2) n else fib(n - 1) + fib(n - 2)

There’s a bit of a problem here. If we change this to use our Trampoline , the result of fib will be Trampoline[Int] . So then how do we add two trampolines together? One way is to simply call run :

def fib(n: Int): Trampoline[Int] = if (n < 2) Done(n) else More(() => fib(n - 1).run + fib(n - 2).run)

But this defeats the purpose! The call to run is not in a tail position here, and so we’re back to getting stack overflows.

Another idea is to make Trampoline a monad, by just adding a flatMap method to it. Then we can just say:

for { x <- fib(n - 1) y <- fib(n - 2) } yield x + y

But there is no way of implementing flatMap without calling run .

def flatMap[B](f: A => Trampoline[B]): Trampoline[B] = More(() => f(this.run).run)

Delimited Continuations

The solution, as ever so often with continuations, is found in delimited control. We bake monadicity into the Trampoline data type, with an additional case. Again, instead of making a call to the continuation, we return a data structure representing what we’re doing currently together with what to do with the result:

case class Cont[A, B](a: Trampoline[A], f: A => Trampoline[B]) extends Trampoline[B]

Note that the arguments to this constructor are exactly the arguments to flatMap . The idea is that we can now implement map and flatMap like this:

def map[B](f: A => B): Trampoline[B] = Cont(this, a => More(() => Done(f(a)))) def flatMap[B](f: A => Trampoline[B]): Trampoline[B] = Cont(this, f)

The implementation of run becomes a tad more complicated now:

def run: A = { var cur: Trampoline[_] = this var stack: List[Any => Trampoline[A]] = List() var result: Option[A] = None while (result.isEmpty) { cur match { case Done(a) => stack match { case Nil => result = Some(a.asInstanceOf[A]) case c :: cs => { cur = c(a) stack = cs } } case More(t) => cur = t() case Cont(a, f) => { cur = a stack = f.asInstanceOf[Any => Trampoline[A]] :: stack } } } result.get }

We’re essentially breaking out of Scala here and dropping into a Java-level loop. Firstly, a Cont has two parts: an intermediate computation whose type is not known, and a continuation for which we only know the return type. Secondly, we have to keep our own stack of continuations as we descend into the monadic binds. So we must cast, just as if we were working in a language without generics. Don’t worry, the continuation type matches the intermediate computation by construction. Lastly, I’m using my own while loop instead of relying on Scala to translate tail recursion into a loop for me.

It’s possible to write this code with better types, using existentials and heterogeneous lists (left as an exercise for the hardened type-level programmer). But this is pretty self-contained, and we can be confident that it’s well typed without Scala’s help. It’s also possible to use the Delimited Continuations compiler plugin (also an exercise for the reader) to hide the casts, but that plugin makes these exact same casts.

Here is the whole trampoline code again:

sealed trait Trampoline[A] { def map[B](f: A => B): Trampoline[B] = flatMap(a => More(() => Done(f(a)))) def flatMap[B](f: A => Trampoline[B]): Trampoline[B] = Cont(this, f) def run: A = { var cur: Trampoline[_] = this var stack: List[Any => Trampoline[A]] = List() var result: Option[A] = None while (result.isEmpty) { cur match { case Done(a) => stack match { case Nil => result = Some(a.asInstanceOf[A]) case c :: cs => { cur = c(a) stack = cs } } case More(t) => cur = t() case Cont(a, f) => { cur = a stack = f.asInstanceOf[Any => Trampoline[A]] :: stack } } } result.get } } case class Done[A](a: A) extends Trampoline[A] case class More[A](a: () => Trampoline[A]) extends Trampoline[A] case class Cont[A, B](a: Trampoline[A], f: A => Trampoline[B]) extends Trampoline[B]

Now we can write binary-recursive Fibonacci function that uses constant stack:

def fib(n: Int): Trampoline[Int] = if (n < 2) Done(n) else for { x <- fib(n - 1) y <- fib(n - 2) } yield (x + y)

Even with millions of recursive calls, we don’t overflow the stack:

scala> fib(40).run res23: Int = 102334155

Trampolining other monads

We’re now ready to come back to our original problem, which was tail call elimination in arbitrary monads. Remember that original reader monad?

As long as there exists a monad transformer version of the monad in question, we can transform our Trampoline monad, resulting in a new tail-recursive monad. For example, type IntReader[A] = Int => A is a monad, but IntReaderT[M[_], A] = Int => M[A] is also a monad, for any monad M , including Trampoline .

To illustrate this, I will use the Kleisli monad transformer from scalaz. Here, Kleisli[M, A, B] is isomorphic to the type A => M[B] , and so our trampolined IntReader[A] will be written Kleisli[Trampoline, Int, A] .

This lets us traverse, in our reader monad, a list with millions of elements:

val x = List.range(0, 1000000). foldLeft[Kleisli[Trampoline, Int, List[Int]]]( kleisli(i => Done(List(i))))( (a, e) => kleisli(r => for { x <- More(() => a.apply(r)) y <- Done(e + r :: x) } yield y))

And it’s all in constant stack:

scala> x(0).run res28: List[Int] = List(999999, 999998, 999997, 999996, 999995, ...

This can be employed with any monad transformer, e.g. StateT , WriterT , Iteratees, etc.