MapReduce Implementation for Union-Find

October 23, 2010

A fun problem I had to solve a while back deals with finding all connected components in a very large graph to build a list of disjoint sets. A simple serial solution might involve picking a node at random, running a BFS from it, partitioning out the visited nodes, and repeating until no nodes remain. Unfortunately, that’s not feasible for a graph with billions of nodes, as it requires holding the entire graph in memory. Luckily, I had a hadoop cluster at my disposal, so once I could formulate a solution in MapReduce, I’d be ready to go.

Before diving into the solution, here’s a better breakdown of the problem. Imagine a graph where connected components form equivalence classes, and in the universe of nodes there are a lot of different classes. In my case, we’re dealing with more than 10 billion nodes, which form yield more than one billion disjoint sets. The size of individual components can range from one node to hundreds of thousands. Some clusters within the larger components might be strongly connected, with many paths from any two nodes, while others could have a sparse collection of edges. The process needed to account for all of these variations and eventually yield the simple disjoint sets.

The Inputs:

A long list of lists. Each list represents nodes that we know are connected. Initially, each list may contain only one or two nodes.

The Algorithm:

1) Finding intersections between sets using set representatives and merge those that overlap

2) Determine which sets are isolated (their intersection with every other set is empty), and isolate them.

3) Repeat 1 and 2 until are sets are isolated.

A naive approach for growing the frontier of nodes could involve emitting every pair of nodes from a set as a way to connect nodes by way of a third party. However, this would be very inefficient in later stages of the algorithm, as the number of emitted pairs from a set of size n would grow to O(n2). A more efficient implementation is to pick a consistent representative, R, from each set (I use the smallest), and emit <R, {Nodes}>. To keep the graph bi-directional, also emit the inverse. That is, for E in Nodes emit <E, {R}>. This step serves two purposes. First, it finds intersections between sets, and establishes connections between their representatives. Second, it produces unions as representatives for sets are improved through those intersections. The Reduce phase simply passes the Key through unaltered and shrinks the values down to only distinct elements.

Job 2 builds directly off the output from Job 1. The goal is to determine if a representative for a set of nodes is the only representative for each node. The mapper reads a key, and it’s list of representatives. If there is only one representative, then it emits <R, {K}> (a constituent swap). If there are many representatives, then it passes the pair through unaltered (ambiguous representatives or THE representative). Think of it as each node reporting back to its representative when there’s only one. The reducer then looks for keys where each value appears twice. Once for the representative pass through, and once for each constituent swap. Those are flagged as disjoint sets. Any constituents with ambiguous representatives won’t appear in that list, which means there are more iterations to perform.

For simplicity, the output key from the second job is either DISJOINT or OPEN and the value is a list of all nodes that appeared in a reduce call. This feeds the next iteration, which ignores any DISJOINT records. Once the algorithm completes, an final job is run to collect only the DISJOINT records from each iteration’s output.

Because the frontier of each set grows out from all nodes simultaneously, the maximum number of iterations O(log n) where n is the longest path between any two nodes in the set. Also, graphs with mostly small components benefit from early isolation of disjoint sets, since they are removed from inputs to later iterations.

Finally, here’s the code:

package chaser.hadoop; import java.io.IOException; import java.util.Arrays; import java.util.HashMap; import java.util.TreeSet; import java.util.UUID; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.conf.Configured; import org.apache.hadoop.fs.Path; import org.apache.hadoop.io.Text; import org.apache.hadoop.io.Writable; import org.apache.hadoop.mapreduce.Job; import org.apache.hadoop.mapreduce.Mapper; import org.apache.hadoop.mapreduce.Reducer; import org.apache.hadoop.util.Tool; import org.apache.hadoop.util.ToolRunner; public class UnionFind extends Configured implements Tool { private enum Counter{ OPEN, DISJOINT } private static final Text DISJOINT = new Text("D"); private static final Text OPEN = new Text("O"); /** * For potentially overlapping sets, elect a representative. * * Emits <R, {Nodes}> and <N, {R}> for each N in Nodes. * Ignores known DISJOINT sets. */ public static class ElectMap extends Mapper<Text, TextArrayWritable, Text, TextArrayWritable> { @Override protected void map(Text key, TextArrayWritable value, Context context) throws IOException,InterruptedException { // If it was a disjoint output from the last iteration, then don't // continue to propogate it. if( key.equals(DISJOINT) ) { context.getCounter(Counter.DISJOINT).increment(1); return; } context.getCounter(Counter.OPEN).increment(1); // Use a tree set so it's easier to find the smallest while uniquifying TreeSet<Text> distinct = new TreeSet<Text>( Arrays.asList(value.get()) ); TextArrayWritable all = new TextArrayWritable( distinct ); Text representative = distinct.pollFirst(); TextArrayWritable representative_val = new TextArrayWritable( representative ); context.write(representative, all); for( Text other : distinct ) context.write(other, representative_val); } } /** * Emits the union of all incoming array writables for a key. */ public static class ElectReduce extends Reducer<Text, TextArrayWritable, Text, TextArrayWritable> { @Override protected void reduce(Text key, Iterable<TextArrayWritable> values, Context context) throws IOException, InterruptedException { TreeSet<Text> union = new TreeSet<Text>(); for( TextArrayWritable value : values ) { union.addAll( Arrays.asList(value.get()) ); } context.write(key, new TextArrayWritable(union) ); } } /** * Performs representative pass throughs or constituent swaps. */ public static class PartitionMap extends Mapper<Text, TextArrayWritable, Text, TextArrayWritable> { @Override protected void map(Text key, TextArrayWritable value, Context context) throws IOException ,InterruptedException { // Constituent Swap if( value.get().length == 1 ) context.write( value.get()[0], new TextArrayWritable(key) ); // Representative pass through else context.write( key, value ); } } /** * Count the number of constituents, and label the set as DISJOINT if each element appears twice. */ public static class PartitionReduce extends Reducer<Text, TextArrayWritable, Text, Text> { @Override protected void reduce(Text key, Iterable<TextArrayWritable> values, Context context) throws IOException ,InterruptedException { HashMap<Text, Integer> counts = new HashMap<Text, Integer>(); // Inject a 1 for the key, so it counts itself twice. counts.put(key, 1); for( TextArrayWritable value : values ) { for( Text text : value.get() ) if( counts.containsKey(text) ) counts.put(text, counts.get(text)+1); else counts.put(text, 1); } // Assume it's DISJOINT until we see an odd man TextArrayWritable value = new TextArrayWritable(counts.keySet()); key = DISJOINT; for( Integer count : counts.values() ) { if( count != 2 ) { key = OPEN; break; } } if( key.equals(DISJOINT) ) context.getCounter(Counter.DISJOINT).increment(1); else context.getCounter(Counter.OPEN).increment(1); context.write(key, value); } } /** * Simple pass that emits tags all incoming records with the OPEN key */ public static class MarkOpenMap extends Mapper<Writable, TextArrayWritable, Text, TextArrayWritable> { @Override protected void map(Writable key, TextArrayWritable value, Context context) throws IOException ,InterruptedException { context.write( OPEN, value ); } } /** * Simple pass that emits all DISJOINT records */ public static class EmitDisjointMap extends Mapper<Text, TextArrayWritable, Text, TextArrayWritable> { @Override protected void map(Text key, TextArrayWritable value, Context context) throws IOException ,InterruptedException { if( key.equals(DISJOINT) ) context.write( key, value ); } } private String makeTempSpace() { String temporary = "/tmp/union_find/" + UUID.randomUUID(); Path temp_path = new Path(temporary); FileSystem fs = temp_path.getFileSystem(getConf()); fs.mkdirs(temp_path); fs.deleteOnExit(temp_path); return temporary; } @Override public int run(String[] args) throws Exception { // Create a temporary work location that gets cleaned up on exit. String temporary = makeTempSpace(); String elect_path = temporary + "/elect."; String partition_path = temporary + "/partition."; int iteration = 0; // This step assumes some prior data setup. Specifically, the input // must be in a sequence file of <K, TextArrayWritable>. // If IO is very important, the job could be optimized away by tacking the // mapper onto the first iteration of the loop below with a ChainMapper. Job setup = new Job(getConf()); setup.setJarByClass(getClass()); setup.setName("Union Find (setup)"); setup.setMapperClass(MarkOpenMap.class); setup.setOutputDir( partition_path + iteration ); setup.setNumReduceTasks(0); setup.setOutputKeyClass(Text.class); setup.setOutputValueClass(TextArrayWritable.class); setup.waitForCompletion(false); while( true ) { Job elect = new Job(new Configuration(getConf())); Job partition = new Job(new Configuration(getConf())); elect.setJarByClass(getClass()); partition.setJarByClass(getClass()); // Stitch together paths // partition.n => elect => elect.(n+1) => partition => partition.(n+1) elect.setInputDir( partition_path + (iteration++) ); elect.setOutputDir( elect_path + iteration ); partition.setInputDir( elect_path + iteration ); partition.setOutputDir( partition_path + iteration ); elect.setName("Union Find (elect ["+iteration+"])" ); elect.setMapperClass(ElectMap.class); elect.setReducerClass(ElectReduce.class); elect.setOutputKeyClass(Text.class); elect.setOutputValueClass(TextArrayWritable.class); partition.setName("Union Find (partition ["+iteration+"])" ); partition.MapperClass(PartitionMap.class); partition.setReducerClass(PartitionReduce.class); partition.setOutputKeyClass(Text.class); partition.setOutputValueClass(TextArrayWritable.class); elect.waitForCompletion(false); if( !elect.isSuccessful() ) throw new RuntimeError(); // All the sets were disjoint. No more work to do. // Otherwise, run partition and repeat. if( elect.getCounters().findCounter(Counter.OPEN).getValue() == 0 ) break; else partition.waitForCompletion(false); } // Collect all the disjoint values. Job emit = new Job(getConf()); emit.setName("Union Find (emit)" ); emit.setMapperClass(EmitDisjointMap.class); emit.setNumReduceTasks(0); emit.setOutputKeyClass(Text.class); emit.setOutputValueClass(TextArrayWritable.class); emit.setInputDir(partition_path + '*'); emit.waitForCompletion(true); return emit.isSuccessful() ? 0 : 1; } public static void main(String[] args) throws Exception { int result = ToolRunner.run(new UnionFind(), args); System.exit(result); } public static class TextArrayWritable extends ArrayWritable { public TextArrayWritable() { super(Text.class); } public TextArrayWritable(Text... elements) { super(elements); } public TextArrayWritable(Collection<Text> elements { super( elements.toArray(new Text[0]); } public Text[] get() { Writable[] writables = super.get(); Text[] texts = new Text[writables.length]; for(int i=0; i<writables.length; ++i) texts[i] = (Text)writables[i]; return texts; } } }