Decision Trees Are Free Monads Over the Reader Functor

Clay Thomas

Motivation

Free f a , the free monad over a given functor f is often described as "trees which branch in the shape of f and are leaf-labeled by a ". What exactly does this mean? Well, the definition of Free is:

data Free f a = Pure a | Free (f ( Free f a))

So if we have a functor data Pair a = Pair a a , we indeed have Free Pair a representing ordinary leaf labeled, nonempty binary trees.

Now, if you know about decision trees, "leaf labeled" should catch your ear. (Binary) decision trees are trees where each node represents a yes/no feature about some observation. The leafs of the tree are labeled with distributions. To predict, you descend the tree according to an observation, and the distribution you reach at the bottom is your best guess at the distribution for your new observation.

So if we want to fit a monad Free f a to the task of a decision tree, it is clear that a should represent the distribution you are predicting. For the sake of simplicity, we will simply set a to Bool and will just guess a yes or no, as opposed to providing percentages.

Our choice of f is a little less clear. We need to read in some information from an observation, say of type r , and descend into the next level of the tree. Well, speaking of reading information, what about the reader functor (->) r ! If we try f = (->) r , we get

data Free ( -> r) Bool = Pure Bool | Free (r -> Free ( -> r) Bool )

This looks like exactly what we want! We fix a row type r , and when provided new rows we can traverse through internal Free nodes until we reach a leaf Pure node. So, we define

-- | A model with row type `r` and class label type `c` type TreeM r c = Free (( -> ) r) c

Recursion Combinators

The observation presented here is hardly earth-shattering, but it does come with some advantages other than "hey, a connection!". But adapting some recursive combinators (which typically act on the Fix data type) we can separate the recursive logic of our program from the actual computation. The goal of recursion combinators is to standardize patterns in recursion and make their implementations cleaner.

Free f a is strikingly similar to the type Fix f , the fixed point of the functor f . Recall

newtype Fix f = Fix { unFix :: f ( Fix f) }

Essentially, Free allows us to stop our infinite, f -branching tree early and return a Pure value of type a . In many applications, the real result of this is that the functors f that you use with Fix are more complicated than those you use with Free because with Fix you need to embed the notion of returning data into your base functor. (For example, we could model Free f a itself using Fix with a base functor data Br f a r = Either a (f r) deriving Functor , and get Free f a === Fix (Br f a) .)

For Fix , catamorphisms and anamorphisms are very useful recursion combinators, which respectively collapse and grow a recursive structure:

-- | Tear down a recursive structure, i.e. an element of Fix f. -- First we recursively tear down each of the subtrees in the -- first level of the fixed point. Then we collapse the -- last layer using the algebra directly. cata :: Functor f => (f a -> a) -- ^ A algebra to collapse a container to a value -> Fix f -- ^ A recursive tree of containers to start with -> a cata alg fix = alg . fmap (cata alg) . unFix $ fix -- | Build up a recursive structure, i.e. an element of Fix f. -- We first expand our seed by one step. -- Then we map over the resultant container, -- recursively expanding each value along the way. ana :: Functor f => (a -> f a) -- ^ A function to expand an a -> a -- ^ A seed value to start off with -> Fix f ana grow seed = Fix . fmap (ana grow) . grow $ seed

It is pretty easy to extend cata to work on Free instead of Fix , we just need to add a case for Pure :

cataF :: Functor f => (f a -> a) -> Free f a -> a cataF alg ( Free u) = alg . fmap (cataF alg) $ u cataF _ ( Pure a) = a

It is somewhat harder to logically extend ana. Indeed, the exact same code that worked for Fix works for free, it just always builds infinite trees and never allows Pure . Thus the input function has to allow for the possibility of a Pure value:

anaF :: Functor f => (a -> Either (f a) b) -> a -> Free f b anaF grow seed = case grow seed of Left u -> Free . fmap (anaF grow) $ u Right b -> Pure b

The Hard Work

We start with a preamble to equip some convenient language extensions and import the needed libraries. Then we define our table data type and some simplified accessor functions.

What follows is complete and valid Haskell can be run on a modern GHC. You need only the code below this point, along with our definitions of cataF and anaF above. You can also snag the code here.

{-# LANGUAGE RankNTypes , RelaxedPolyRec , DeriveFunctor , TupleSections , ScopedTypeVariables , UndecidableInstances #-} import qualified Data.List as List import qualified Data.MultiSet as Set import qualified Data.Map as Map import Control.Monad.Free -- | value = Bool will suffice for this code, but -- more general Tables are certainly reasonable data Table key value = Table { keys :: [key] -- ^ For iterating and looping purposes , rows :: Set.MultiSet (value, Row key value) -- ^ An unordered collection of Rows associated to labels } deriving ( Show ) type Row key value = Map.Map key value numKeys :: Table k v -> Int numKeys = length . keys numRows :: Table k v -> Int numRows = Set.size . rows --Assume all tables are full getKey :: Ord k => k -> Row k v -> v getKey k row = maybe undefined id (Map.lookup k row) emptyBinTable :: Table key value emptyBinTable = Table [] Set.empty

Now, the learning method of decision trees is (roughly) the following:

If there are no keys in the table, return a model that predicts the most common class label. If there are keys, find the key that best predicts the class label. Split the data into two new tables based on the value of that key. Recursively grow a decision tree for the two new data sets. Put the two new trees together into a model that first predicts based on the best key, then predicts based on the recursive, new trees.

The following code implements several tools we will need:

-- | We score the keys by how many labels the key could get correct. -- This method returns two values: the first is if the model applied a -- positive correlation, the second is if it assumes a negative correlation. scores :: Ord k => k -> Table k Bool -> ( Int , Int ) scores k tab = Set.fold (indicator k) ((, 0 ) 0 ) $ rows tab where indicator :: Ord k => k -> ( Bool , Row k Bool ) -> ( Int , Int ) -> ( Int , Int ) indicator k (label, row) (pos, neg) = case Map.lookup k row of Just a -> (pos + fromEnum (label == a), neg + fromEnum (label /= a)) Nothing -> (pos, neg) -- | Loop over all the keys and find the one that predicts with highest accuracy bestKey :: Ord k => Table k Bool -> k bestKey tab = let bestScore k = let (pos, neg) = scores k tab in (k, max pos neg) maxScores = fmap bestScore (keys tab) (bestKey, _) = List.maximumBy (\(_,s) (_,s') -> s `compare` s') maxScores in bestKey removeKey :: ( Ord k, Ord v) => k -> Table k v -> Table k v removeKey k tab = emptyBinTable { keys = ( List . \\) (keys tab) [k] , rows = Set.map (\(lab,row) -> (lab, Map.delete k row)) (rows tab) } -- | Split a table into two tables based on the value of one key. -- Also remove the key from the new tables. filterOn :: Ord k => k -> Table k Bool -> ( Table k Bool , Table k Bool ) filterOn k tab = let kTrue = getKey k . snd -- ^ predicate to test if key k is true for some row trueRows = Set.filter kTrue (rows tab) falseRows = Set.filter (not . kTrue) (rows tab) in (removeKey k tab{rows = trueRows}, removeKey k tab{rows = falseRows}) -- | Ignore all the rows and just guess a boolean based on class label bestGuess :: Ord k => Table k Bool -> Bool bestGuess tab = let nTrue = Set.fold (\(b,_) accum -> accum + fromEnum b) 0 (rows tab) -- ^ count number of true labels in if 2 * nTrue >= numRows tab then True else False

These are all straightforward things that we would probably implement if we were writing this algorithm without recursion combinators.

Applying our Recursion Combinators

Now that we have some functions to manipulate and extract information out of our tables, we are ready to learn our models and predict with them. We write the function discriminate to fit the type signature of anaF . This function accepts a table and returns one of two things:

If the table contains only class labels, we return a single Bool that represents our best guess of the class label.

If the table still has some keys, we return a mapping. This mapping takes in any row, and returns a split of the data based on the value of bestKey within the row.

Recall that anaF :: Functor f => (a -> Either (f a) b) -> a -> Free f b and that type TreeM r c = Free ((->) r) c . When we apply anaF to discriminate , it has the effect of recursing over the newly created tables, crowing more models until we hit the base case of a table with only class labels.

discriminate :: Ord k => Table k Bool -> Either ( Row k Bool -> Table k Bool ) Bool discriminate tab | numKeys tab == 0 = Right $ bestGuess tab | otherwise = let key = bestKey tab (trueTab, falseTab) = filterOn key tab in Left $ \row -> let bool = getKey key row in if bool then trueTab else falseTab -- | Finally time to use this! type TreeM r c = Free (( -> ) r) c learn :: Ord k => Table k Bool -> TreeM ( Row k Bool ) Bool learn = anaF discriminate --cataF :: Functor f => (f a -> a) -> Free f a -> a predict :: Ord k => Row k Bool -> TreeM ( Row k Bool ) Bool -> Bool predict row model = cataF ( $ row) model

Now we are done! All the work paid off with very, very short definitions for learn and predict . Keep in mind that before learn and predict , nothing we had written involved recursion at all.

Extensions

By changing discriminate , we change our learning method. Using anaF adds some constraints on how we can learn our model, but still allows some freedom. Here are some possible avenues for extending and improving our models:

We could test whether the inference gained at a given node is statistically significant, for example with a chi squared test. If the key is not a significant predictor at that level, we can stop growing our decision tree early and return our best guess at that stage. This would help prevent overfitting.

We could add fields to our table data type and set them when we split at each step of the recursion. For example, we could add a counter that prevents the tree from growing past a certain height (again, this combats overfitting). Alternatively, if we have some prior belief that certain variables work well together as predictors, we may want these variables to be close together in the decision tree. By storing the "splitting key" of the parent node, we could implement this in discriminate .

The rows can be extended to hold non-Boolean data without changing the data type of our model very much. discriminate can grow the branchings with any function from rows to new training tables. If the complexity of each step of inference grows, we just need to make this branching function more complicated.

The current model is totally deterministic. Recently, there has been some work in elegantly and efficiently adding probabilistic programming to Haskell (for example here), but these methods seem mostly suited to learning parametric models. Perhaps recursion combinators (or a similar idea) are one way to elegantly learn nonparametric models whose structure needs to be determined dynamically.

It is not clear how to give the learning method control over multiple levels of the hierarchy. We could add information about the parent nodes, but we cannot go back and change them based on new information. A common learning method for decision trees is to grow out a few levels at once and decide between them, and this would be difficult to simulate in the current framework. With a wealth of recursion combinators out there, one may capture this idea very well.

Appendix: A Printable Interface

It is hard to verify if the above information is correct because we cannot (sensibly) print out functions in Haskell. The following code fixes this by providing a datatype for the branching instead of relying on functions. This provides much less flexibility in how we do branching, but allows us to print things! The following implements the exact same algorithm as above: