Ed’s discrimination package seemed very interesting to me. I was vaguely aware of sorting algorithms not based on comparison but I didn’t realize that you could achieve such impressive asymptotics with them. Radix sort seemed quite simple so I wanted to see how well it would perform.

There’s no point in explaining the algorithm here because I doubt I would do a better job than other people on the internet. Here’s a simple explanation with buckets being decimal digits.

I’ve went through 4 iterations of the algorithm and I’ll present them in order.

The common idea

All three approaches share the same mold. A sort function [Int] -> [Int] that’s the interface. Handles initialization and final result extraction. Function gather that takes the output of the previous step of the algorithm and produces a list-like structure that can be recursively passed to the algorithm again. Function toList that takes the final state of the algorithm and produces an actual list.

Finally, the radix function that does a single step of the sort.

Different approaches use different structures for buckets. The bucket structure consists of two parts, the “bucket holder” and the buckets themselves. All three approaches seemed to perform best with 256 buckets.

The list bucket problem

The algorithm requires that we’re able to insert numbers into buckets and then retrieve them from oldest to newest. This is a basic queue with one important difference. We first do all the inserting, then do the iteration. This means that a simple list is good enough for insertions. We just need to reverse at the end.

This reverse is what I wanted to optimize away.

Plain list solution

The first iteration just reversed the list. I’ve used mutable vectors in the ST monad to be my “bucket holders”.

Here’s the code.

module Main where import qualified Data.List import Control.Monad import Control.Monad.ST import Data.Vector.Mutable ( STVector ) import qualified Data.Vector.Mutable as Vec import Data.Foldable ( foldlM ) import Data.Bits import System.TimeIt import System.Random totalBits , numBuckets , bucketBits :: Int totalBits = 64 numBuckets = 2 ^ bucketBits bucketBits = 8 bucketList :: [ Int ] bucketList = reverse [ 0 .. numBuckets - 1 ] sort :: [ Int ] -> [ Int ] sort original = runST $ do initial <- Vec . new numBuckets reset initial Vec . write initial 0 original --put the list in bucket 0 so we can --express the algorithm as repeated iteration mapM_ ( radix initial ) [ 0 .. ( totalBits ` div ` bucketBits ) - 1 ] toList initial gather :: STVector s [ a ] -> ST s [ a ] gather vec = fmap ( reverse . concat ) $! mapM ( Vec . read vec ) bucketList toList :: STVector s [ a ] -> ST s [ a ] toList = gather reset :: STVector s [ a ] -> ST s () reset vec = Vec . set vec [] radix :: STVector s [ Int ] -> Int -> ST s () radix vec offset = do ll <- gather vec reset vec mapM_ ins ll where ins x = do let bts = shiftR x ( offset * bucketBits ) .&. ( numBuckets - 1 ) l <- Vec . read vec bts Vec . write vec bts ( x : l ) main :: IO () main = do gen <- newStdGen let list = take 1000000 $ randoms gen print $ sum list print "Radix" timeIt $ print $ sum $ sort list print "Standard" timeIt $ print $ sum $ Data . List . sort list

This first iteration already showed great results, beating the default sorting algorithm. I’ll discuss that later.

I used sum when testing to force the lists. Here are the times from one measurement ( -O2 flags)

-6508836477411096561 "Radix" -6508836477411096561 CPU time: 2.61s "Standard" -6508836477411096561 CPU time: 3.72s

Now to tackle the actual reversal problem. I thought that since I’m already in the ST monad, why not try implementing my own specialized linked lists. (I’ve tried just swapping the list for a sequence before this, it performed worse)

Custom linked lists

My implementation is nothing to write home about. It’s just a node with a STRef that points to the next element.

{-# LANGUAGE ViewPatterns #-} module LinkedListSpecial where import Prelude hiding ( mapM_ ) import Control.Monad.ST import Data.STRef import Data.Foldable ( mapM_ , foldlM , forM_ , foldl' ) import System.TimeIt import qualified Data.DList as DList data LLN s a = Stub ( STRef s ( Maybe ( LLN s a ))) | LLN a ( STRef s ( Maybe ( LLN s a ))) getRef :: LLN s a -> STRef s ( Maybe ( LLN s a )) getRef ( Stub ref ) = ref getRef ( LLN _ ref ) = ref emptyNode :: ST s ( LLN s a ) emptyNode = fmap Stub ( newSTRef Nothing ) makeNode :: a -> ST s ( LLN s a ) makeNode x = fmap ( LLN x ) $! newSTRef Nothing append :: LLN s a -> a -> ST s ( LLN s a ) append ( getRef -> ref ) x = do new <- makeNode x writeSTRef ref ( Just new ) return new iter :: ( a -> ST s () ) -> LLN s a -> ST s () iter f ( Stub ref ) = do next <- readSTRef ref mapM_ ( iter f ) next iter f ( LLN x ref ) = do f x next <- readSTRef ref mapM_ ( iter f ) next iterAll :: ( a -> ST s () ) -> [ LLN s a ] -> ST s () iterAll f = mapM_ ( iter f ) fromList :: [ a ] -> ST s ( LLN s a , LLN s a ) fromList xs = do f <- emptyNode l <- foldlM append f xs return ( f , l ) collect :: LLN s a -> ST s [ a ] collect ( Stub ref ) = do next <- readSTRef ref case next of Nothing -> return [] Just n -> collect n collect ( LLN x ref ) = do next <- readSTRef ref case next of Nothing -> return [ x ] Just n -> do xs <- collect n return $! x : xs collectAll :: [ LLN s a ] -> ST s [ a ] collectAll = fmap concat . mapM collect

I have to say, the preliminary tests didn’t show great results. Generating a linked list and iterating through it consistently performed worse than making a normal list, reversing, then iterating.

Apparently it’s because GHC is optimized with the expectation that boxed references only get updated once and when you invalidate that expectation you pay the price. Or something. I think…

In any case, here’s the second iteration.

module Main where import qualified Data.List import Control.Monad import Control.Monad.ST import Data.Vector.Mutable ( STVector ) import qualified Data.Vector.Mutable as Vec import Data.Foldable ( foldlM ) import Data.Bits import System.TimeIt import System.Random import LinkedListSpecial ( LLN ) import qualified LinkedListSpecial as LL totalBits , numBuckets , bucketBits :: Int totalBits = 64 numBuckets = 2 ^ bucketBits bucketBits = 8 bucketList :: [ Int ] bucketList = [ 0 .. numBuckets - 1 ] sort :: [ Int ] -> [ Int ] sort original = runST $ do initial <- Vec . new numBuckets reset initial LL . fromList original >>= Vec . write initial 0 --put the list in bucket 0 so we can --express the algorithm as repeated iteration mapM_ ( radix initial ) [ 0 .. ( totalBits ` div ` bucketBits ) - 1 ] toList initial gather :: STVector s ( LLN s a , LLN s a ) -> ST s [ LLN s a ] gather vec = mapM ( fmap fst . Vec . read vec ) bucketList toList :: STVector s ( LLN s a , LLN s a ) -> ST s [ a ] toList vec = do lls <- gather vec LL . collectAll lls reset :: STVector s ( LLN s a , LLN s a ) -> ST s () reset vec = forM_ [ 0 .. Vec . length vec - 1 ] $ \ i -> do node <- LL . emptyNode Vec . write vec i ( node , node ) radix :: STVector s ( LLN s Int , LLN s Int ) -> Int -> ST s () radix vec offset = do ll <- gather vec reset vec LL . iterAll ins ll where ins x = do let bts = shiftR x ( offset * bucketBits ) .&. ( numBuckets - 1 ) ( f , l ) <- Vec . read vec bts newL <- LL . append l x Vec . write vec bts ( f , newL )

And the results

-4051746686150878325 "Radix" -4051746686150878325 CPU time: 5.78s "Standard" -4051746686150878325 CPU time: 3.88s

sadface

LUCKILY, I was informed that there exists a version of a list that’s optimized to appending. A difference list. The concept is simple. You actually suspend the modifications as functions that you compose together any way you want and then just do them all in order when you want to finally produce a list. I used a package that provided them.

DLists

Right to the implementation.

module Main where import Prelude hiding ( mapM_ ) import qualified Data.List import Control.Monad hiding ( mapM_ ) import Control.Applicative import Control.Monad.ST import Data.Vector.Mutable ( STVector ) import qualified Data.Vector.Mutable as Vec import Data.Foldable ( foldlM , mapM_ ) import Data.Bits import System.TimeIt import System.Random import Data.DList ( DList ) import qualified Data.DList as DList totalBits , numBuckets , bucketBits :: Int totalBits = 64 numBuckets = 2 ^ bucketBits bucketBits = 8 bucketList :: [ Int ] bucketList = [ 0 .. numBuckets - 1 ] sort :: [ Int ] -> [ Int ] sort original = runST $ do initial <- Vec . new numBuckets reset initial Vec . unsafeWrite initial 0 ( DList . fromList original ) --put the list in bucket 0 so we can --express the algorithm as repeated iteration mapM_ ( radix initial ) [ 0 .. ( totalBits ` div ` bucketBits ) - 1 ] toList initial gather :: STVector s ( DList a ) -> ST s ( DList a ) gather vec = DList . concat <$> mapM ( Vec . unsafeRead vec ) bucketList toList :: STVector s ( DList a ) -> ST s [ a ] toList vec = DList . toList <$> gather vec reset :: STVector s ( DList a ) -> ST s () reset vec = Vec . set vec DList . empty radix :: STVector s ( DList Int ) -> Int -> ST s () radix vec offset = do ll <- gather vec reset vec mapM_ ins ll where ins x = do let bts = shiftR x ( offset * bucketBits ) .&. ( numBuckets - 1 ) l <- Vec . unsafeRead vec bts Vec . unsafeWrite vec bts $ l ` DList . snoc ` x

And the, pretty amazing, results

6741553814578555192 "Radix" 6741553814578555192 CPU time: 1.91s "Standard" 6741553814578555192 CPU time: 3.72s

Twice as fast! Nice!

Finally, I’ve tried ditching the mutable vectors and going fully immutable with IntMap s. Spoiler alert: It’s another sadface unfortunately.

Maximum immutability

module Main where import qualified Data.List import Prelude hiding ( mapM_ ) import Data.Foldable ( foldl' ) import Data.Bits import System.TimeIt import System.Random import Data.DList ( DList ) import qualified Data.DList as DList import Data.IntMap ( IntMap ) import qualified Data.IntMap as Map totalBits , numBuckets , bucketBits :: Int totalBits = 64 numBuckets = 2 ^ bucketBits bucketBits = 8 bucketList :: [ Int ] bucketList = [ 0 .. numBuckets - 1 ] sort :: [ Int ] -> [ Int ] sort original = toList $ foldl' radix start [ 0 .. ( totalBits ` div ` bucketBits ) - 1 ] where start = Map . insert 0 ( DList . fromList original ) initial --put the list in bucket 0 so we can --express the algorithm as repeated iteration gather :: IntMap ( DList a ) -> DList a gather m = DList . concat $! map ( m Map .! ) bucketList toList :: IntMap ( DList a ) -> [ a ] toList m = DList . toList $! gather m initial :: IntMap ( DList a ) initial = foldl' ( \ m i -> Map . insert i DList . empty m ) Map . empty bucketList radix :: IntMap ( DList Int ) -> Int -> IntMap ( DList Int ) radix m offset = foldl' ins initial list where list = gather m ins m' x = Map . adjust (` DList . snoc ` x ) bts m' where bts = shiftR x ( offset * bucketBits ) .&. ( numBuckets - 1 )

It’s a definite winner in terms of conciseness and it really did feel like a breath of fresh air when my functions finally started returning things instead of units. It’s not very performant though. Still faster than my linked lists though.

1830962452316129604 "Radix" 1830962452316129604 CPU time: 3.97s "Standard" 1830962452316129604 CPU time: 3.80s

Comparing to Data.List.sort

At first glance it might seem like the default sort function is pretty bad but there are a couple of very important tradeoffs being made here. Firstly, my sorting only works for Int s. Though it could be extended to work for anything with a Bits instance, it’s still much less general than being able to sort anything thats in Ord .

Secondly, it has a much larger memory footprint. There’s a lot of allocation happening. Sorting a million numbers allocated around 50 megabytes (don’t quote me on the number).

Thirdly, it’s not lazy. No matter if you want only the first 10 numbers from the list or all of them it takes the same ammount of time. This isn’t true for Data.List.sort which is nice and lazy.

All in all I’m still pretty happy that I managed to write something that outperforms the default implementation. I’ll probably come back to this subject later and see if we can implement some other algorithms or improve the above implementations.