Tail-recursive tree traversal example in Scala

This is a demonstration of a tree-traversal algorithm to calculate the size of a binary tree, in Scala. The two implementations are recursive, as one should try to do in functional programming, but the first implementation is not tail-recursive, and the second is.

If a recursive function is not tail-recursive, each time the function is called you push a frame into the call stack. So you need to push at most “n” frames if there are “n” nested function calls before you start returning. That means you must take care about how deep your function might go, otherwise you can get a stack overflow error. But if your function is tail-recursive, the compiler can run the nested calls without needing to push anything to the call stack, because the value returned by the upper call is exactly the value from the lower call, directly. That way the recursive calls end up working like a “for” or “while” loop.

In this demonstration we first define a binary tree class. I took this from the “Functional Programming in Scala” book. After that we define the two function implementations.

// MyTree.scala file
package fpinscala.datastructures

import scala.annotation.tailrec

sealed trait Tree[+A]
case class Leaf[A](value: A) extends Tree[A]
case class Branch[A](left: Tree[A], right: Tree[A]) extends Tree[A]

object Tree {

  // @tailrec /* This is not tail-recursive!!... */
  def size_bad(t: Tree[Int]): Int =
    t match {
      case Leaf(a) => 1
      case Branch(a, b) => size_bad(a) + size_bad(b) + 1
    }

  def size(t: Tree[Int]): Int = {
    @tailrec
    def inner_size(l: List[Tree[Int]], acc: Int): Int =
      l match {
        case Nil => acc
        case Leaf(v) :: ls => inner_size(ls, acc + 1)
        case Branch(a, b) :: ls => inner_size(a :: b :: ls, acc + 1)
      }
    inner_size(List(t), 0)
  }
}

In the first implementation we start from a root node, and return the sum from sizes from all the child sub-trees, plus one to account for that node. If the node is a leaf, it’s just one, there are no children. And these sizes from the sub-trees are calculated with new recursive calls to the size function.

It is interesting to note that in practice we are going to perform a depth-first traversal during this calculation, with the size from the “left” branches being calculated before the “right” ones.

Now, if we fear we might have a tree that is so deep that it might blow up the call stack if we use this function, we need a tail-recursive implementation. But it is not very easy to modify this function for this. It’s so much easier to write down the “sum of the sizes from each branch”… But now we need keep our own stack of branches to visit in a variable somewhere.

So the way the other implementation works is we have two input arguments. One is an accumulator that we increment at each new call from the function. Each increment corresponds to one of the nodes from the tree. Then the other argument is a list of nodes to visit (or subtrees to traverse).

Each time the function is called, we “pop” the first node in the list. Then if the node is a leaf, in our next call we just look for the next node to be visited. If the node has children, we must push these into this list (or stack), and our next node to be visited will be one of these children.

The traversal goes on until the list is empty, in which case we just return the value from the accumulator.

One interesting detail from this algorithm is that if we used the list as a queue instead of a stack, we would end up with a breadth-first traversal.

So there you have it, we created a tail-recursive tree-traversal, that even has a “@tailrec” annotation to assure us that the compiler will indeed run the function that way. But for that we needed to write the function as a “helper” function inside another one, to let us call it more naturally. An in this version it is not so obvious how everything works.

I’m still not sure what is my opinion about all that, as an FP newbie coming from years of procedural programming. This accumulator that is an argument that is passed at each call and then returned at the end just looks a bit kludgy to me… To increment a mutable variable still looks better! But the pattern matching still looks way cooler than any procedural implementation. And it is interesting that we managed to make it a function that has a single expression, it’s just the “match” with their alternatives, no composite expressions.

And to finish, this is an example program to test these methods we discussed:

import fpinscala.datastructures._

object MyApp extends App {  
  val example = 
    Branch(
      Branch(
        Branch(
          Leaf(1), 
          Leaf(3)),
        Branch(
          Branch(
            Leaf(4),
            Leaf(5)),
          Leaf(7))),
      Branch(
        Branch(
          Branch(
            Leaf(9),
            Branch(
              Leaf(10), 
              Leaf(3))),
          Leaf(0)),
        Leaf(2)))

  println("size example")
  println(example)
  println("sum (baad):"+ Tree.size_bad(example))
  println("sum (good):"+ Tree.size(example))
}

And the output:

> scala MyApp
size example
Branch(Branch(Branch(Leaf(1),Leaf(3)),Branch(Branch(Leaf(4),Leaf(5)),Leaf(7))),Branch(Branch(Branch(Leaf(9),Branch(Leaf(10),Leaf(3))),Leaf(0)),Leaf(2)))
sum (baad):19
sum (good):19
Advertisements

Leave a Reply

Fill in your details below or click an icon to log in:

WordPress.com Logo

You are commenting using your WordPress.com account. Log Out / Change )

Twitter picture

You are commenting using your Twitter account. Log Out / Change )

Facebook photo

You are commenting using your Facebook account. Log Out / Change )

Google+ photo

You are commenting using your Google+ account. Log Out / Change )

Connecting to %s