I have been practicing coding tests again. One nice exercise I saw the other day was a sample from a HackerRank test: find the number of prime numbers less than N. Or tell how many primes are below a given limit. You can easily change the problem to pick the list of the said primes.

It is such a classic problem that I decided to dedicate myself to write a nice solution, because I couldn’t find a really good “functional style” implementation in Scala, strictly with immutable collections, and using `foldLeft` if possible. And there is also an important detail about this problem that is worth talking about. The most well known solution is not actually the “best” one, and the difference isn’t even that that big. So it really is a subject we should all be studying a little bit more carefully.

This article discusses four different solutions to the problem. First is the most usual solution, the so-called *unfaithful sieve.* Then we will show the most “cult” solution to this problem, the sieve of Eratosthenes. These both can be written recursively, but the first can actually be implemented with a fold. We will then show a quite different implementation of the sieve, using a map to store known composite numbers, that I came up with myself, even though I doubt is new. We finish by looking at implementations of the two algorithms using mutable collections.

## The unfaithful sieve

I probably learned about the sieve of Eratosthenes when I was a kid, but I had never played with it after I started to code. That’s probably why I never really questioned the usual algorithm that people will give you if you ask how to find primes, and that many call the sieve of Eratosthenes. The article The Genuine Sieve of Eratosthenes by Melissa O’Neill discusses the mistake, and she even raises the hypothesis it came from the famous Sussman and Alberson book, SICP (which I never really read).

The *unfaithful sieve* is probably more well-known than the genuine because it is very straight-forward and intuitive. It works by looping through all numbers from 2 to the desired *N*. Then for each number you perform an explicit primality test, and the test depends on the list of primes that have already been found. If the current number is a prime, then you add it to the list and move to the next number.

The following is an implementation of the algorithm in Scala, using a `foldLeft` to iterate over the numbers and update a (immutable) list of primes. There is also an auxiliary `isPrime` function, that explicitly performs the test from the given list and target number.

import scala.collection.immutable.SortedSet import scala.compat.Platform object FoldPrimes extends App { def unfaithfulSieve(limit: Int) = { def primes = (SortedSet[Int]() /: (2 until limit)) { (primes, n) => primes ++ (if (isPrime(primes, n)) Some(n) else None) } def isPrime(primes: SortedSet[Int], i: Int): Boolean = primes.takeWhile(_ >= math.sqrt(i).toInt).forall(i % _ != 0) primes.size } val input = 10000000 val startTime = Platform.currentTime val nPrimes = unfaithfulSieve(input) println(s"pi($input)=$nPrimes [total ${Platform.currentTime - startTime} ms]") }

I think getting `foldLeft` was one of the most important things I learned when I started to get into functional programming, so it’s probably worthy to speak a little bit more about this. What happens is that we have a list of numbers we will iterate over, and we could know what all of the number would be at the start of the iteration. Then we iterate over them, and update the state of some variable in the process. This state is a function of the previous state. A fold can be implemented with a recursive function, what kind of hides away the fact you are iterating over a list. It can also be implemented with a for that updates some mutable variable in its body.

Next here is the recursive version of the previous algorithm. Notice how we employ the same `isPrime` function, and the same expression to update the list of primes conditional to the test. The only difference is that now that expression goes into an argument to the recursive function, and we now have to perform all the logic from the iteration control using a new parameter. I prefer the fold version, because the logic is trivial, and cannot be optimized. Recursive functions really should be used only when necessary.

import scala.annotation.tailrec import scala.collection.immutable.SortedSet import scala.compat.Platform object RecPrimes extends App { def unfaithfulSieve(limit: Int) = { @tailrec def rec(primes: SortedSet[Int], n: Int): Int = if (n >= limit) primes.size else rec(primes ++ (if (isPrime(primes, n)) Some(n) else None), n + 1) def isPrime(primes: SortedSet[Int], i: Int): Boolean = primes.takeWhile(_ <= math.sqrt(i).toInt).forall(i % _ != 0) rec(SortedSet(), 2) } val input = 1000000 val startTime = Platform.currentTime val nPrimes = unfaithfulSieve(input) println(s"pi($input)=$nPrimes [total ${Platform.currentTime - startTime} ms]") }

## The Sieve of Eratosthenes

The proper implementation of the sieve of Eratosthenes does require a recursive function. Or at least what I consider to be the “proper” implementation… The idea is that you already start with this table with all numbers from 2 to N. Then you pick the first “available” number, which is 2, and that means it is a prime. Then you “cross-out” every multiple of 2 from the table, you go crossing out 4, 6, 8, etc… Then you pick the next available number, 3, so this is our next prime. Then you cross out 6, 9, 12, etc. Then 4 is crossed out, so it is a composite, so ignore it. Then you find the next available number, the next prime, which is 5, then cross out… And so you go until the end of the list, although there will be a moment you will see all remaining numbers will be crossed out, so you can stop already.

Here is my “proper” implementation of the algorithm. The variable `rem` is a list of the remaining numbers from the table, which were not crossed-out. We then have an explicit function that calculates the new table by “crossing-out” the multiples from the given number. Notice how we are actually keeping the whole table in memory in this implementation. And the way it works is that the recursive function always picks the first available number in the table, cross out its multiples, and increment the count of primes. We don’t even need to keep a list of the primes like in the other algorithm, they are irrelevant once we crossed out its multiples.

import scala.annotation.tailrec import scala.collection.immutable.SortedSet import scala.compat.Platform object EratRecPrimes extends App { def sieveOfEratosthenes(n: Int) = { @tailrec def recSieve(rem: SortedSet[Int], nPrimes: Int = 0): Int = if (rem.head > math.sqrt(n)) nPrimes + rem.size else recSieve(crossOut(rem.tail, rem.head), nPrimes + 1) def crossOut(nums: SortedSet[Int], p: Int) = nums -- (p * p to nums.last by p).iterator recSieve(SortedSet(2 until n: _*)) } val input = 2000000 val startTime = Platform.currentTime val nPrimes = sieveOfEratosthenes(input) println(s"pi($input)=$nPrimes [total ${Platform.currentTime - startTime} ms]") }

Notice how we really need a recursive function in this case, because we are iterating over the primes, and we could not calculate it ahead of time. This recursive function could only be replace with an imperative “while” loop, but not a proper for or fold.

Another important detail from this implementation is that we stop once we pick a prime that is larger than the square root of the limit we are interested in. Once we get there, we can know for sure the remaining numbers are all primes, so we can terminate the iterations algorithm early.

## The lazy Eratosthenes

The genuine sieve of Eratosthenes is said to be more efficient than the unfaithful, for reasons explained in O’Neill’s article. So I was a bit disappointed when I found out my implementation did not beat the performance form the alternative algorithm! In my tests it tended to be twice as slow. I don’t know where I could make a good optimization to make it faster, so we could observe this expected smaller complexity.

One thing I tried to do was to generate lazy streams for the multiples, and `diff` them from the stream of all integers, but that did not quite work, I got stack overflow problems that I couldn’t see how to avoid. Scala streams can be tricky sometimes…

So my next attempt to enhance the algorithm was something more drastic. I decided to modify the sieve so that I still kept the proposal of crossing out composites instead of performing primality tests, but without eagerly generating the lists of multiples to be removed from a stored table. This way we get back to iterating over all the numbers, but keeping a dictionary of composites above n instead of the list of primes.

So the idea now is that we move over all numbers, but our primality test is just checking out whether the number belongs to this list of known composites we are generating. If it is there, we pick the number and also its prime factors, which are stored along it, and then we generate all the next multiples from those factors, and store the new composites back at the list along with each generating factor. If the number is already there you store the new known factor. The number can be discarded from the list of composites once it is visited. And then if the current number is a prime, we just store the first multiple on the list (2 * n).

Here is the code:

import scala.compat.Platform object EratDictFoldPrimes extends App { def sieveOfEratosthenes(limit: Int) = { val result = (3 until limit).foldLeft((Map(4 -> Set(2)), 1)) { (state, n) => val (composites, nPrimes) = state val isPrime = !composites.contains(n) def newComposites = if (isPrime) composites + (2 * n -> Set(n)) else (composites /: composites(n)) { (cc, p) => cc + (n + p -> (cc.getOrElse(n + p, Set()) + p)) } - n def newPrimes = if (isPrime) nPrimes + 1 else nPrimes (newComposites, newPrimes) } result._2 } val input = 1000000 val startTime = Platform.currentTime val nPrimes = sieveOfEratosthenes(input) println(s"pi($input)=$nPrimes [total ${Platform.currentTime - startTime} ms]") }

In my tests this implementation did attain a better performance, but just comparable to the unfaithful one.

## The mutable Eratosthenes

My last experiment was to forget about what the Functional Programming Patrol might say, and implement the algorithms using mutable collections. They do look a lot like algorithms that might benefit from mutable structures, since we are updating stuff all the time and in continuous iterations… This is also a nice demonstration of the ability of Scala to become an imperative language when you need it to.

So here is first the mutable version of the unfaithful sieve.

import scala.collection.mutable import scala.compat.Platform object ForMutaPrimes extends App { def unfaithfulSieve(limit: Int) = { val primes = mutable.ArrayBuffer[Int](2) def isPrime(i: Int): Boolean = primes.takeWhile(_ <= math.sqrt(i).toInt).forall(i % _ != 0) for (n <- 3 until limit) if (isPrime(n)) primes += n primes.size } val input = 2000000 val startTime = Platform.currentTime val nPrimes = unfaithfulSieve(input) println(s"pi($input)=$nPrimes [total ${Platform.currentTime - startTime} ms]") }

It becomes just a loop over n that updates the list of primes if the primality test passes. And the test is still the same pure function from before.

The mutable lazy Erathostenes has a similar for, and we now created an explicit primality test. The difference from the other code is that the test is performed on the structure containing the known composites, and instead of updating the list of primes we update this dictionary of sets, what requires a bit of work, performed by the new `update` function.

import scala.collection.mutable import scala.compat.Platform object EratMutaDictPrimes extends App { def sieveOfEratosthenes(limit: Int) = { val composites = mutable.HashMap[Int, mutable.Set[Int]]() var nPrimes = 0 def isPrime(n: Int) = !(composites contains n) def update(k: Int, v: Int) = if (composites contains k) composites(k) += v else composites(k) = mutable.Set(v) for (n <- 2 until limit) if (isPrime(n)) { nPrimes += 1 update(2 * n, n) } else for (p <- composites.remove(n).get) update(n + p, p) nPrimes } val input = 4000000 val startTime = Platform.currentTime val nPrimes = sieveOfEratosthenes(input) println(s"pi($input)=$nPrimes [total ${Platform.currentTime - startTime} ms]") }

Now, for reasons beyond my understanding, this mutable implementation of the unfaithful sieve actually attained a performance almost 3 times worse than the recursive immutable version. Even worse than the first genuine implementation! As for the mutable lazy Eratosthenes, it reached the best overall performance, running in approximately half the time from the immutable recursive unfaithful implementation.

Conclusion

I don’t know right now what can be made to enhance any of these implementations, but it all goes to show that it is not trivial to say when one implementation will surpass the performance from the other one when you, for instance, use mutable structures in imperative code instead of immutable ones in more FP-friendly code. I would love to figure out what is happening here in detail, and how to make better implementations of all these algorithms in Scala.

Using the most adequate data structures is apparently quite critical when solving the primes problem, and the differences from the two algorithms are so subtle that other implementation details might become more relevant. And there are also lots of other smart optimizations that can be done to the code that we didn’t talk about here. I suggest this and this Stackoverflow threads for more information.

I suppose final advice to solving this problem in coding tests is to stick to the unfaithful one, the actual Eratosthenes sieve seems to require some more thinking in order to get to really good code… But be aware of what you are doing, that is only the “real world” algorithm, and not the one you really would like to be implementing! It’s kind of like Smoothsort, it is such a nice algorithm, but you should not try to implement it in a coding test if someone just asks you to write whatever sorting algorithms you like. Go with the easy-but-still-good one.

WayneHave you tried an approach like:

“`

def sift (n: Int) : Int = {

def sift2 (a: Vector[Int], b: Vector[Int]) : Vector[Int] = {

if ((a.length > 0) && (a.last * a.last >= b.last)) a ++ b

else sift2 (a :+ b.head, b.tail.filterNot (_ % b.head == 0))

}

sift2 (Vector (), Vector (2 to n: _*)).length

}

“`