Refactoring Haskell: A Case Study

Posted on 12 February 2019

Many people claim that refactoring Haskell is a joy. I’ve certainly found this to be the case, but what does that mean in practice? I thought it might be useful to demonstrate by refactoring some of my own code.

The code we’re looking at today is an implementation of Tarjan’s Strongly Connected Components algorithm used to determine whether a given 2-SAT problem is satisfiable or not, and was written to complete an online course that is now offered in a different form. I’ve written about Tarjan’s algorithm previously and it can be used to determine the satisfiability of a 2-SAT problem by checking if any SCC contains both a variable and its negation. If it does, we have a contradiction and the problem is unsatisfiable, otherwise the problem is satisfiable.

This code isn’t particularly elegant or easy to follow, and it’s lousy with mutable state. Despite these drawbacks, it is still relatively straightforward to refactor.

If you’d like to follow along, I have the code (and some test data) available at this gist with each revision representing a refactoring step.

The initial version of the code is as follows:

Initial 2SAT.hs {-# LANGUAGE LambdaCase #-} import qualified Data.Graph as G import qualified Data.Map.Strict as M import qualified Data.Set as S import qualified Data.Array as A import qualified Prelude as P import Prelude hiding (lookup) (lookup) import Control.Monad.ST import Data.STRef import Control.Monad (forM_, when) (forM_, when) import Data.Maybe (isJust, isNothing, fromJust) (isJust, isNothing, fromJust) tarjan :: Int -> G.Graph -> Maybe [ S.Set Int ] = runST $ do tarjan n graphrunST index <- newSTRef 0 newSTRef <- newSTRef [] stacknewSTRef [] <- newSTRef S.empty stackSetnewSTRef S.empty <- newSTRef M.empty indicesnewSTRef M.empty <- newSTRef M.empty lowlinksnewSTRef M.empty <- newSTRef ( Just []) outputnewSTRef ([]) $ \v -> do forM_ (G.vertices graph)\v <- M.lookup v <$> readSTRef indices vIndexM.lookup vreadSTRef indices $ when (isNothing vIndex) index stack stackSet indices lowlinks output strongConnect n v graphstack stackSet indices lowlinks output readSTRef output strongConnect :: Int -> Int -> G.Graph -> STRef s Int -> STRef s [ Int ] s [ -> STRef s ( S.Set Int ) s ( -> STRef s ( M.Map Int Int ) s ( -> STRef s ( M.Map Int Int ) s ( -> STRef s ( Maybe [ S.Set Int ]) s (]) -> ST s () s () index stack stackSet indices lowlinks output = do strongConnect n v graphstack stackSet indices lowlinks output i <- readSTRef index readSTRef insert v i indices insert v i lowlinks index ( + 1 ) modifySTRef' push stack stackSet v A.! v) $ \w -> lookup w indices >>= \ case forM_ (graphv)\ww indices Nothing -> do index stack stackSet indices lowlinks output strongConnect n w graphstack stackSet indices lowlinks output <- fromJust <$> lookup v lowlinks vLowLinkfromJustv lowlinks <- fromJust <$> lookup w lowlinks wLowLinkfromJustw lowlinks min vLowLink wLowLink) lowlinks insert v (vLowLink wLowLink) lowlinks Just wIndex -> do wIndex <- S.member w <$> readSTRef stackSet wOnStackS.member wreadSTRef stackSet $ do when wOnStack <- fromJust <$> lookup v lowlinks vLowLinkfromJustv lowlinks min vLowLink wIndex) lowlinks insert v (vLowLink wIndex) lowlinks <- fromJust <$> lookup v lowlinks vLowLinkfromJustv lowlinks <- fromJust <$> lookup v indices vIndexfromJustv indices == vIndex) $ do when (vLowLinkvIndex) <- addSCC n v S.empty stack stackSet sccaddSCC n v S.empty stack stackSet $ \sccs -> ( : ) <$> scc <*> sccs modifySTRef' output\sccssccsccs where lookup value hashMap = M.lookup value <$> readSTRef hashMap value hashMapM.lookup valuereadSTRef hashMap = modifySTRef' hashMap (M.insert key value) insert key value hashMapmodifySTRef' hashMap (M.insert key value) addSCC :: Int -> Int -> S.Set Int -> STRef s [ Int ] -> STRef s ( S.Set Int ) -> ST s ( Maybe ( S.Set Int )) s [s (s ()) = pop stack stackSet >>= \w -> if ((other n w) `S.member` scc) then return Nothing else addSCC n v scc stack stackSetpop stack stackSet\w((other n w)scc) let scc' = S.insert w scc scc'S.insert w scc in if w == v then return ( Just scc') else addSCC n v scc' stack stackSet scc')addSCC n v scc' stack stackSet push :: STRef s [ Int ] -> STRef s ( S.Set Int ) -> Int -> ST s () s [s (s () = do push stack stackSet e : ) modifySTRef' stack (e modifySTRef' stackSet (S.insert e) pop :: STRef s [ Int ] -> STRef s ( S.Set Int ) -> ST s Int s [s ( = do pop stack stackSet e <- head <$> readSTRef stack readSTRef stack tail modifySTRef' stack modifySTRef' stackSet (S.delete e) return e = subtract denormalise = ( + ) normalise = 2 * n - v other n v = [(other n u, v), (other n v, u)] clauses n [u,v][(other n u, v), (other n v, u)] checkSat :: String -> IO Bool = do checkSat name p <- map ( map P.read . words ) . lines <$> readFile name P.readname let pNo = head $ head p pNo = map ( map (normalise pNo)) $ tail p pn(normalise pNo)) = G.buildG ( 0 , 2 * pNo) $ concatMap (clauses pNo) pn pGraphG.buildG (pNo)(clauses pNo) pn return $ ( Nothing /= ) $ tarjan pNo pGraph tarjan pNo pGraph

I’ve included 2SAT-specific functionality for completeness, but I’ll only be changing the tarjan function and the functions it depends on ( strongConnect , addSCC , push , and pop ).

The first change is using more suitable data structures. Tarjan’s algorithm is only linear in the size of the graph when operations, such as checking if w is on the stack and looking up indices, happen in constant time (O(1)). I’m currently using Data.Map and Data.Set which are both implemented with trees and are O(log n) in these operations. A better choice would be Data.Vector.Mutable from the vector package, which does have constant-time operations.

This refactoring mostly consists of initialising vectors with a known length and replacing calls to lookup and insert with calls to read and write .

2SAT.hs using vector {-# LANGUAGE LambdaCase #-} import qualified Data.Graph as G import qualified Data.Array as A import qualified Prelude as P import Prelude hiding (lookup, read, replicate) (lookup, read, replicate) import Control.Monad.ST import Data.STRef import Control.Monad (forM_, when) (forM_, when) import Data.Maybe (isJust, isNothing, fromJust) (isJust, isNothing, fromJust) import Data.Vector.Mutable ( STVector , read, replicate, write) , read, replicate, write) tarjan :: Int -> G.Graph -> Maybe [[ Int ]] [[]] = runST $ do tarjan n graphrunST index <- newSTRef 0 newSTRef <- newSTRef [] stacknewSTRef [] <- replicate size False stackSetsize <- replicate size Nothing indicessize <- replicate size Nothing lowlinkssize <- newSTRef ( Just []) outputnewSTRef ([]) $ \v -> do forM_ (G.vertices graph)\v <- read indices v vIndexindices v $ when (isNothing vIndex) index stack stackSet indices lowlinks output strongConnect n v graphstack stackSet indices lowlinks output readSTRef output where = snd (A.bounds graph) + 1 size(A.bounds graph) strongConnect :: Int -> Int -> G.Graph -> STRef s Int -> STRef s [ Int ] s [ -> STVector s Bool -> STVector s ( Maybe Int ) s ( -> STVector s ( Maybe Int ) s ( -> STRef s ( Maybe [[ Int ]]) s ([[]]) -> ST s () s () index stack stackSet indices lowlinks output = do strongConnect n v graphstack stackSet indices lowlinks output i <- readSTRef index readSTRef Just i) write indices v (i) Just i) write lowlinks v (i) index ( + 1 ) modifySTRef' push stack stackSet v A.! v) $ \w -> read indices w >>= \ case forM_ (graphv)\windices w Nothing -> do index stack stackSet indices lowlinks output strongConnect n w graphstack stackSet indices lowlinks output <- fromJust <$> read lowlinks v vLowLinkfromJustlowlinks v <- fromJust <$> read lowlinks w wLowLinkfromJustlowlinks w Just ( min vLowLink wLowLink)) write lowlinks v (vLowLink wLowLink)) Just wIndex -> do wIndex <- read stackSet w wOnStackstackSet w $ do when wOnStack <- fromJust <$> read lowlinks v vLowLinkfromJustlowlinks v Just ( min vLowLink wIndex)) write lowlinks v (vLowLink wIndex)) <- fromJust <$> read lowlinks v vLowLinkfromJustlowlinks v <- fromJust <$> read indices v vIndexfromJustindices v == vIndex) $ do when (vLowLinkvIndex) <- addSCC n v [] stack stackSet sccaddSCC n v [] stack stackSet $ \sccs -> ( : ) <$> scc <*> sccs modifySTRef' output\sccssccsccs addSCC :: Int -> Int -> [ Int ] -> STRef s [ Int ] -> STVector s Bool -> ST s ( Maybe [ Int ]) s [s (]) = pop stack stackSet >>= \w -> if ((other n w) `elem` scc) then return Nothing else addSCC n v scc stack stackSetpop stack stackSet\w((other n w)scc) let scc' = w : scc scc'scc in if w == v then return ( Just scc') else addSCC n v scc' stack stackSet scc')addSCC n v scc' stack stackSet push :: STRef s [ Int ] -> STVector s Bool -> Int -> ST s () s [s () = do push stack stackSet e : ) modifySTRef' stack (e True write stackSet e pop :: STRef s [ Int ] -> STVector s Bool -> ST s Int s [ = do pop stack stackSet e <- head <$> readSTRef stack readSTRef stack tail modifySTRef' stack False write stackSet e return e = subtract denormalise = ( + ) normalise = 2 * n - v other n v = [(other n u, v), (other n v, u)] clauses n [u,v][(other n u, v), (other n v, u)] checkSat :: String -> IO Bool = do checkSat name p <- map ( map P.read . words ) . lines <$> readFile name P.readname let pNo = head $ head p pNo = map ( map (normalise pNo)) $ tail p pn(normalise pNo)) = G.buildG ( 0 , 2 * pNo) $ concatMap (clauses pNo) pn pGraphG.buildG (pNo)(clauses pNo) pn return $ ( Nothing /= ) $ tarjan pNo pGraph tarjan pNo pGraph

I didn’t notice a significant difference in speed on my inputs, but it’s good to know that the algorithm has been implemented with the correct asymptotics now!

Sidenote: A Vector of Bool s can be much more compactly represented as a sequence of 0s and 1s, which are just machine words. For implementations of this in Haskell, see the bv or bv-little packages. Using these could be another possible refactoring.

Looking at the code again, I notice some repetition of the form

x <- fromJust <$> lookup vectorX i fromJustvectorX i y <- fromJust <$> lookup vectorY j fromJustvectorY j Just (operation x y)) write vectorZ k ((operation x y))

and with the judicious use of (=<<) and (<*>) this can instead be

=<< (operation <$> lookup vectorX i <*> lookup vectorY j) write vectorZ k(operationvectorX ivectorY j)

There are a couple of other places we could use (<*>) :

2SAT.hs using (<*>) {-# LANGUAGE LambdaCase #-} import qualified Data.Graph as G import qualified Data.Array as A import qualified Prelude as P import Prelude hiding (lookup, read, replicate) (lookup, read, replicate) import Control.Monad.ST import Data.STRef import Control.Monad (forM_, when) (forM_, when) import Data.Maybe (isJust, isNothing, fromJust) (isJust, isNothing, fromJust) import Data.Vector.Mutable ( STVector , read, replicate, write) , read, replicate, write) tarjan :: Int -> G.Graph -> Maybe [[ Int ]] [[]] = runST $ do tarjan n graphrunST index <- newSTRef 0 newSTRef <- newSTRef [] stacknewSTRef [] <- replicate size False stackSetsize <- replicate size Nothing indicessize <- replicate size Nothing lowlinkssize <- newSTRef ( Just []) outputnewSTRef ([]) $ \v -> do forM_ (G.vertices graph)\v <- read indices v vIndexindices v $ when (isNothing vIndex) index stack stackSet indices lowlinks output strongConnect n v graphstack stackSet indices lowlinks output readSTRef output where = snd (A.bounds graph) + 1 size(A.bounds graph) strongConnect :: Int -> Int -> G.Graph -> STRef s Int -> STRef s [ Int ] s [ -> STVector s Bool -> STVector s ( Maybe Int ) s ( -> STVector s ( Maybe Int ) s ( -> STRef s ( Maybe [[ Int ]]) s ([[]]) -> ST s () s () index stack stackSet indices lowlinks output = do strongConnect n v graphstack stackSet indices lowlinks output i <- readSTRef index readSTRef Just i) write indices v (i) Just i) write lowlinks v (i) index ( + 1 ) modifySTRef' push stack stackSet v A.! v) $ \w -> read indices w >>= \ case forM_ (graphv)\windices w Nothing -> do index stack stackSet indices lowlinks output strongConnect n w graphstack stackSet indices lowlinks output =<< ( min <$> read lowlinks v <*> read lowlinks w) write lowlinks vlowlinks vlowlinks w) Just {} -> do {} <- read stackSet w wOnStackstackSet w $ do when wOnStack =<< ( min <$> read lowlinks v <*> read indices w) write lowlinks vlowlinks vindices w) <- fromJust <$> read lowlinks v vLowLinkfromJustlowlinks v <- fromJust <$> read indices v vIndexfromJustindices v == vIndex) $ do when (vLowLinkvIndex) <- addSCC n v [] stack stackSet sccaddSCC n v [] stack stackSet $ \sccs -> ( : ) <$> scc <*> sccs modifySTRef' output\sccssccsccs addSCC :: Int -> Int -> [ Int ] -> STRef s [ Int ] -> STVector s Bool -> ST s ( Maybe [ Int ]) s [s (]) = pop stack stackSet >>= \w -> if ((other n w) `elem` scc) then return Nothing else addSCC n v scc stack stackSetpop stack stackSet\w((other n w)scc) let scc' = w : scc scc'scc in if w == v then return ( Just scc') else addSCC n v scc' stack stackSet scc')addSCC n v scc' stack stackSet push :: STRef s [ Int ] -> STVector s Bool -> Int -> ST s () s [s () = do push stack stackSet e : ) modifySTRef' stack (e True write stackSet e pop :: STRef s [ Int ] -> STVector s Bool -> ST s Int s [ = do pop stack stackSet e <- head <$> readSTRef stack readSTRef stack tail modifySTRef' stack False write stackSet e return e = subtract denormalise = ( + ) normalise = 2 * n - v other n v = [(other n u, v), (other n v, u)] clauses n [u,v][(other n u, v), (other n v, u)] checkSat :: String -> IO Bool = do checkSat name p <- map ( map P.read . words ) . lines <$> readFile name P.readname let pNo = head $ head p pNo = map ( map (normalise pNo)) $ tail p pn(normalise pNo)) = G.buildG ( 0 , 2 * pNo) $ concatMap (clauses pNo) pn pGraphG.buildG (pNo)(clauses pNo) pn return $ ( Nothing /= ) $ tarjan pNo pGraph tarjan pNo pGraph

This is much nicer with the applicative combinators.

I would like to clean up that when as well, and for that I’d need a function like

whenM :: Monad m => m Bool -> m () -> m () m ()m ()

which is available in Neil Mitchell’s extra package.

I don’t think it’s worth pulling in that dependency though, so I’ll just copy that definition:

2SAT.hs using whenM {-# LANGUAGE LambdaCase #-} import qualified Data.Graph as G import qualified Data.Array as A import qualified Prelude as P import Prelude hiding (lookup, read, replicate) (lookup, read, replicate) import Control.Monad.ST import Data.STRef import Control.Monad (forM_) (forM_) import Data.Vector.Mutable ( STVector , read, replicate, write) , read, replicate, write) whenM :: Monad m => m Bool -> m () -> m () m ()m () = condM >>= \cond -> if cond then block else return () whenM condM blockcondM\condcondblock() tarjan :: Int -> G.Graph -> Maybe [[ Int ]] [[]] = runST $ do tarjan n graphrunST index <- newSTRef 0 newSTRef <- newSTRef [] stacknewSTRef [] <- replicate size False stackSetsize <- replicate size Nothing indicessize <- replicate size Nothing lowlinkssize <- newSTRef ( Just []) outputnewSTRef ([]) $ \v -> forM_ (G.vertices graph)\v == ) Nothing <$> read indices v) $ whenM ((indices v) index stack stackSet indices lowlinks output strongConnect n v graphstack stackSet indices lowlinks output readSTRef output where = snd (A.bounds graph) + 1 size(A.bounds graph) strongConnect :: Int -> Int -> G.Graph -> STRef s Int -> STRef s [ Int ] s [ -> STVector s Bool -> STVector s ( Maybe Int ) s ( -> STVector s ( Maybe Int ) s ( -> STRef s ( Maybe [[ Int ]]) s ([[]]) -> ST s () s () index stack stackSet indices lowlinks output = do strongConnect n v graphstack stackSet indices lowlinks output i <- readSTRef index readSTRef Just i) write indices v (i) Just i) write lowlinks v (i) index ( + 1 ) modifySTRef' push stack stackSet v A.! v) $ \w -> read indices w >>= \ case forM_ (graphv)\windices w Nothing -> do index stack stackSet indices lowlinks output strongConnect n w graphstack stackSet indices lowlinks output =<< ( min <$> read lowlinks v <*> read lowlinks w) write lowlinks vlowlinks vlowlinks w) Just {} -> whenM ( read stackSet w) $ {}whenM (stackSet w) =<< ( min <$> read lowlinks v <*> read indices w) write lowlinks vlowlinks vindices w) == ) <$> read lowlinks v <*> read indices v) $ do whenM ((lowlinks vindices v) <- addSCC n v [] stack stackSet sccaddSCC n v [] stack stackSet $ \sccs -> ( : ) <$> scc <*> sccs modifySTRef' output\sccssccsccs addSCC :: Int -> Int -> [ Int ] -> STRef s [ Int ] -> STVector s Bool -> ST s ( Maybe [ Int ]) s [s (]) = pop stack stackSet >>= \w -> if ((other n w) `elem` scc) then return Nothing else addSCC n v scc stack stackSetpop stack stackSet\w((other n w)scc) let scc' = w : scc scc'scc in if w == v then return ( Just scc') else addSCC n v scc' stack stackSet scc')addSCC n v scc' stack stackSet push :: STRef s [ Int ] -> STVector s Bool -> Int -> ST s () s [s () = do push stack stackSet e : ) modifySTRef' stack (e True write stackSet e pop :: STRef s [ Int ] -> STVector s Bool -> ST s Int s [ = do pop stack stackSet e <- head <$> readSTRef stack readSTRef stack tail modifySTRef' stack False write stackSet e return e = subtract denormalise = ( + ) normalise = 2 * n - v other n v = [(other n u, v), (other n v, u)] clauses n [u,v][(other n u, v), (other n v, u)] checkSat :: String -> IO Bool = do checkSat name p <- map ( map P.read . words ) . lines <$> readFile name P.readname let pNo = head $ head p pNo = map ( map (normalise pNo)) $ tail p pn(normalise pNo)) = G.buildG ( 0 , 2 * pNo) $ concatMap (clauses pNo) pn pGraphG.buildG (pNo)(clauses pNo) pn return $ ( Nothing /= ) $ tarjan pNo pGraph tarjan pNo pGraph

Now I don’t actually even need when anymore!

Since most of the auxiliary functions aren’t used outside strongConnect , it might make sense to put them under a where clause. This would also make the parameters passed to strongConnect available to these functions. This is one place that the ScopedTypeVariables language extension is necessary, otherwise GHC can’t tell that the s in the type signature of strongConnect is the same s as the one in each type signature under the where clause.

2SAT.hs using where {-# LANGUAGE LambdaCase #-} {-# LANGUAGE ScopedTypeVariables #-} import qualified Data.Graph as G import qualified Data.Array as A import qualified Prelude as P import Prelude hiding (lookup, read, replicate) (lookup, read, replicate) import Control.Monad.ST import Data.STRef import Control.Monad (forM_) (forM_) import Data.Vector.Mutable ( STVector , read, replicate, write) , read, replicate, write) whenM :: Monad m => m Bool -> m () -> m () m ()m () = condM >>= \cond -> if cond then block else return () whenM condM blockcondM\condcondblock() tarjan :: Int -> G.Graph -> Maybe [[ Int ]] [[]] = runST $ do tarjan n graphrunST index <- newSTRef 0 newSTRef <- newSTRef [] stacknewSTRef [] <- replicate size False stackSetsize <- replicate size Nothing indicessize <- replicate size Nothing lowlinkssize <- newSTRef ( Just []) outputnewSTRef ([]) $ \v -> forM_ (G.vertices graph)\v == ) Nothing <$> read indices v) $ whenM ((indices v) index stack stackSet indices lowlinks output strongConnect n v graphstack stackSet indices lowlinks output readSTRef output where = snd (A.bounds graph) + 1 size(A.bounds graph) strongConnect :: forall s . Int -> Int -> G.Graph -> STRef s Int -> STRef s [ Int ] s [ -> STVector s Bool -> STVector s ( Maybe Int ) s ( -> STVector s ( Maybe Int ) s ( -> STRef s ( Maybe [[ Int ]]) s ([[]]) -> ST s () s () index stack stackSet indices lowlinks output = do strongConnect n v graphstack stackSet indices lowlinks output i <- readSTRef index readSTRef Just i) write indices v (i) Just i) write lowlinks v (i) index ( + 1 ) modifySTRef' push v A.! v) $ \w -> read indices w >>= \ case forM_ (graphv)\windices w Nothing -> do index stack stackSet indices lowlinks output strongConnect n w graphstack stackSet indices lowlinks output =<< ( min <$> read lowlinks v <*> read lowlinks w) write lowlinks vlowlinks vlowlinks w) Just {} -> whenM ( read stackSet w) $ {}whenM (stackSet w) =<< ( min <$> read lowlinks v <*> read indices w) write lowlinks vlowlinks vindices w) == ) <$> read lowlinks v <*> read indices v) $ do whenM ((lowlinks vindices v) <- addSCC n v [] sccaddSCC n v [] $ \sccs -> ( : ) <$> scc <*> sccs modifySTRef' output\sccssccsccs where addSCC :: Int -> Int -> [ Int ] -> ST s ( Maybe [ Int ]) s (]) = pop >>= \w -> if ((other n w) `elem` scc) then return Nothing else addSCC n v sccpop\w((other n w)scc) let scc' = w : scc scc'scc in if w == v then return ( Just scc') else addSCC n v scc' scc')addSCC n v scc' push :: Int -> ST s () s () = do push e : ) modifySTRef' stack (e True write stackSet e pop :: ST s Int = do pop e <- head <$> readSTRef stack readSTRef stack tail modifySTRef' stack False write stackSet e return e = subtract denormalise = ( + ) normalise = 2 * n - v other n v = [(other n u, v), (other n v, u)] clauses n [u,v][(other n u, v), (other n v, u)] checkSat :: String -> IO Bool = do checkSat name p <- map ( map P.read . words ) . lines <$> readFile name P.readname let pNo = head $ head p pNo = map ( map (normalise pNo)) $ tail p pn(normalise pNo)) = G.buildG ( 0 , 2 * pNo) $ concatMap (clauses pNo) pn pGraphG.buildG (pNo)(clauses pNo) pn return $ ( Nothing /= ) $ tarjan pNo pGraph tarjan pNo pGraph

I think the logic is clearer now that the auxiliary functions take fewer arguments.

Instead of a large number of implictly related variables, it might be nice to define a single product type containing our entire environment and pass just one value around. With NamedFieldPuns only minimal code changes are required:

2SAT.hs using NamedFieldPuns {-# LANGUAGE LambdaCase #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE ScopedTypeVariables #-} import qualified Data.Graph as G import qualified Data.Array as A import qualified Prelude as P import Prelude hiding (lookup, read, replicate) (lookup, read, replicate) import Control.Monad.ST import Data.STRef import Control.Monad (forM_) (forM_) import Data.Vector.Mutable ( STVector , read, replicate, write) , read, replicate, write) data TarjanEnv s = TarjanEnv { index :: STRef s Int , stack :: STRef s [ Int ] s [ , stackSet :: STVector s Bool , indices :: STVector s ( Maybe Int ) s ( , lowlinks :: STVector s ( Maybe Int ) s ( , output :: STRef s ( Maybe [[ Int ]]) s ([[]]) } whenM :: Monad m => m Bool -> m () -> m () m ()m () = condM >>= \cond -> if cond then block else return () whenM condM blockcondM\condcondblock() tarjan :: Int -> G.Graph -> Maybe [[ Int ]] [[]] = runST $ do tarjan n graphrunST <- TarjanEnv tarjanEnv <$> newSTRef 0 newSTRef <*> newSTRef [] newSTRef [] <*> replicate size False size <*> replicate size Nothing size <*> replicate size Nothing size <*> newSTRef ( Just []) newSTRef ([]) $ \v -> forM_ (G.vertices graph)\v == ) Nothing <$> read (indices tarjanEnv) v) $ whenM (((indices tarjanEnv) v) strongConnect n v graph tarjanEnv readSTRef (output tarjanEnv) where = snd (A.bounds graph) + 1 size(A.bounds graph) strongConnect :: forall s . Int -> Int -> G.Graph -> TarjanEnv s -> ST s () s () @ TarjanEnv { index , stack, stackSet, indices, lowlinks, output } = do strongConnect n v graph tarjanEnv, stack, stackSet, indices, lowlinks, output } i <- readSTRef index readSTRef Just i) write indices v (i) Just i) write lowlinks v (i) index ( + 1 ) modifySTRef' push v A.! v) $ \w -> read indices w >>= \ case forM_ (graphv)\windices w Nothing -> do strongConnect n w graph tarjanEnv =<< ( min <$> read lowlinks v <*> read lowlinks w) write lowlinks vlowlinks vlowlinks w) Just {} -> whenM ( read stackSet w) $ {}whenM (stackSet w) =<< ( min <$> read lowlinks v <*> read indices w) write lowlinks vlowlinks vindices w) == ) <$> read lowlinks v <*> read indices v) $ do whenM ((lowlinks vindices v) <- addSCC n v [] sccaddSCC n v [] $ \sccs -> ( : ) <$> scc <*> sccs modifySTRef' output\sccssccsccs where addSCC :: Int -> Int -> [ Int ] -> ST s ( Maybe [ Int ]) s (]) = pop >>= \w -> if ((other n w) `elem` scc) then return Nothing else addSCC n v sccpop\w((other n w)scc) let scc' = w : scc scc'scc in if w == v then return ( Just scc') else addSCC n v scc' scc')addSCC n v scc' push :: Int -> ST s () s () = do push e : ) modifySTRef' stack (e True write stackSet e pop :: ST s Int = do pop e <- head <$> readSTRef stack readSTRef stack tail modifySTRef' stack False write stackSet e return e = subtract denormalise = ( + ) normalise = 2 * n - v other n v = [(other n u, v), (other n v, u)] clauses n [u,v][(other n u, v), (other n v, u)] checkSat :: String -> IO Bool = do checkSat name p <- map ( map P.read . words ) . lines <$> readFile name P.readname let pNo = head $ head p pNo = map ( map (normalise pNo)) $ tail p pn(normalise pNo)) = G.buildG ( 0 , 2 * pNo) $ concatMap (clauses pNo) pn pGraphG.buildG (pNo)(clauses pNo) pn return $ ( Nothing /= ) $ tarjan pNo pGraph tarjan pNo pGraph

Let’s pause here. Although more refactoring is certainly possible, my last two steps did not reduce the line count and may have in fact made the code harder to understand.

How have we benefited from this refactoring? Aside from the code being shorter and better structured, it’s now easier to make meaningful improvements. For example, this implementation is more inefficient than it needs to be, because it doesn’t short-circuit when it finds that the current problem is unsatisfiable. Instead it works through the rest of the problem, only to throw all that work away. A sophisticated solution to this problem might involve the use of the ExceptT monad transformer to throw an exception and exit early, but there is a simpler approach: we can store an extra boolean variable denoting whether or not the current problem is possibly satisfiable, and only continue working if it is. I’ll call this variable possible , update it in addSCC , and check for it before each call to strongConnect in tarjan . It takes more effort to reformat the code than to make this change:

2SAT.hs with short-circuiting {-# LANGUAGE LambdaCase #-} {-# LANGUAGE NamedFieldPuns #-} {-# LANGUAGE ScopedTypeVariables #-} import qualified Data.Graph as G import qualified Data.Array as A import qualified Prelude as P import Prelude hiding (lookup, read, replicate) (lookup, read, replicate) import Control.Monad.ST import Data.STRef import Control.Monad (forM_) (forM_) import Data.Vector.Mutable ( STVector , read, replicate, write) , read, replicate, write) data TarjanEnv s = TarjanEnv { index :: STRef s Int , stack :: STRef s [ Int ] s [ , stackSet :: STVector s Bool , indices :: STVector s ( Maybe Int ) s ( , lowlinks :: STVector s ( Maybe Int ) s ( , output :: STRef s ( Maybe [[ Int ]]) s ([[]]) , possible :: STRef s Bool } whenM :: Monad m => m Bool -> m () -> m () m ()m () = condM >>= \cond -> if cond then block else return () whenM condM blockcondM\condcondblock() tarjan :: Int -> G.Graph -> Maybe [[ Int ]] [[]] = runST $ do tarjan n graphrunST <- TarjanEnv tarjanEnv <$> newSTRef 0 newSTRef <*> newSTRef [] newSTRef [] <*> replicate size False size <*> replicate size Nothing size <*> replicate size Nothing size <*> newSTRef ( Just []) newSTRef ([]) <*> newSTRef True newSTRef $ \v -> forM_ (G.vertices graph)\v && ) whenM (( <$> (( == ) Nothing <$> read (indices tarjanEnv) v) (((indices tarjanEnv) v) <*> readSTRef (possible tarjanEnv)) $ readSTRef (possible tarjanEnv)) strongConnect n v graph tarjanEnv readSTRef (output tarjanEnv) where = snd (A.bounds graph) + 1 size(A.bounds graph) strongConnect :: forall s . Int -> Int -> G.Graph -> TarjanEnv s -> ST s () s () @ TarjanEnv { index , stack, stackSet, indices, lowlinks, output, possible } = do strongConnect n v graph tarjanEnv, stack, stackSet, indices, lowlinks, output, possible } i <- readSTRef index readSTRef Just i) write indices v (i) Just i) write lowlinks v (i) index ( + 1 ) modifySTRef' push v A.! v) $ \w -> read indices w >>= \ case forM_ (graphv)\windices w Nothing -> do strongConnect n w graph tarjanEnv =<< ( min <$> read lowlinks v <*> read lowlinks w) write lowlinks vlowlinks vlowlinks w) Just {} -> whenM ( read stackSet w) $ {}whenM (stackSet w) =<< ( min <$> read lowlinks v <*> read indices w) write lowlinks vlowlinks vindices w) == ) <$> read lowlinks v <*> read indices v) $ do whenM ((lowlinks vindices v) <- addSCC n v [] sccaddSCC n v [] $ \sccs -> ( : ) <$> scc <*> sccs modifySTRef' output\sccssccsccs where addSCC :: Int -> Int -> [ Int ] -> ST s ( Maybe [ Int ]) s (]) = pop >>= \w -> if ((other n w) `elem` scc) addSCC n v sccpop\w((other n w)scc) then writeSTRef possible False >> return Nothing writeSTRef possible else let scc' = w : scc scc'scc in if w == v then return ( Just scc') else addSCC n v scc' scc')addSCC n v scc' push :: Int -> ST s () s () = do push e : ) modifySTRef' stack (e True write stackSet e pop :: ST s Int = do pop e <- head <$> readSTRef stack readSTRef stack tail modifySTRef' stack False write stackSet e return e = subtract denormalise = ( + ) normalise = 2 * n - v other n v = [(other n u, v), (other n v, u)] clauses n [u,v][(other n u, v), (other n v, u)] checkSat :: String -> IO Bool = do checkSat name p <- map ( map P.read . words ) . lines <$> readFile name P.readname let pNo = head $ head p pNo = map ( map (normalise pNo)) $ tail p pn(normalise pNo)) = G.buildG ( 0 , 2 * pNo) $ concatMap (clauses pNo) pn pGraphG.buildG (pNo)(clauses pNo) pn return $ ( Nothing /= ) $ tarjan pNo pGraph tarjan pNo pGraph

This change does seem to make a significant difference, and it’s good to know we’re not doing useless work.

I think this is a good place to stop, and I hope I’ve been able to demonstrate some of Haskell’s strengths when it comes to refactoring. In my experience, it’s not usually necessary to deeply understand Haskell code in order to attempt a refactoring, especially if it’s backed by well-chosen types and a good test suite. I also find that I’m able to be more daring when writing new code, because bad up-front design is less costly and even the jankiest working code can be gently massaged into something presentable.

Thanks to Joel Burget, Mat Fournier, Robert Klotzner, Tenor, Tom Harding, and Tyler Weir for suggestions and feedback.