This article is about when you want to aggregate some data by a key within the data, like a sql group by + aggregate function, but you want the whole row of data. It’s easy to do it the right way, but Spark provides lots of wrong ways. I’m going to go over the problem, and the right solution, then cover some ways that didn’t work out and cover why.

This article assumes that you understand how Spark lays out data in datasets and partitions, and that partition skewing is bad. I’ve included links in the various sections to resources that explain the issues in more depth.

Even though I do understand the above, let’s be clear: it took experiments with every method that didn’t work to realise that they weren’t doing what I expected, and some pretty focused reading of docs, source code, and Jacek Laskowski’s indispensable The Internals of Apache Spark to find a solution that does work as I expect.

The problem: user level

Consider data like these, but imagine millions of rows spread over thousands of dates:

Key Date Numeric Text

-------- ------------ --------- -----------

ham 2019-01-01 3 Yah

cheese 2018-12-31 4 Woo

fish 2019-01-02 5 Hah

grain 2019-01-01 6 Community

grain 2019-01-02 7 Community

ham 2019-01-04 3 jamón

And what you want is latest (or earliest, or any criterion relative to the set of rows) entry for each key, like so:

Key Date Numeric Text

-------- ------------ --------- -----------

cheese 2018-12-31 4 Woo

fish 2019-01-02 5 Hah

grain 2019-01-02 7 Community

ham 2019-01-04 3 jamón

The problem: Spark level

The problem with doing this for a very large dataset in Spark is that grouping by key requires a shuffle, which (a) is the enemy of Spark performance (see also)(b) expands the amount of data that needs to be held (because shuffle data is generally bigger than input data), which tends to make tuning your job for your cluster parameters (or vice versa) much more important. With big shuffles, you can have slow applications with tasks that fail repeatedly and need to be retried.

So, given this problem, what you want to do is shuffle the minimum amount of data. The way to do this is to reduce the amount of data going into the shuffle. In the next section, I’ll talk about how.

As an aside, if you can perform this kind of task incrementally, you can do so faster and with less latency; but sometimes you want to do this as a batch, either because you’re recovering from data loss, you’re ensuring that your stream processing worked (or recovering from it losing some records), or you just don’t want to operate stream infrastructure (and you don’t need low latency).

The solution: Aggregators

Aggregators (and UDAFs, their untyped cousins) are the solution here because they allow Spark to partially perform the aggregation as it maps over the data getting ready to shuffle it (“the map side”) (see code), so that the data actually going into the shuffle is already reduced; then on the receiving side where data are actually grouped by key (by the action of the groupByKey), the same reduce operation can happen again.

Another part of what makes this work well is that if you’re selecting a fixed number of records per key (e.g. the n latest), you will also remove partition skew, which again makes the reduce side operations much more reliable.

Aggregator: example code

In this specific example, I’ve chosen to focus on aggregating whole rows. Aggregators and UDAFs can be used to also aggregate part of the data in various ways; and again the more you do cut down on the width of your data going into a shuffle the faster it will be.

See this notebook and the API docs for more examples

Full example of using Aggregator

Using real data, this took 1.2 hours over 1 billion rows.

The solution to a different problem: Aggregate Functions, UDAFs, and Sketches

UDAFs are the untyped equivalent to Aggregators, and I won’t say much more about them except that if you’re using a custom function that extracts exactly what you need, instead you’re going to get functionality very similar to sql groupby: you can get the original columns in the groupby clause, and then the rest of the columns in your results are aggregates. It’s as fast as an Aggregator, for the same reasons, including that you narrow your data going into the shuffle.

The code looks a little bit like this:

foods.groupBy('key). agg(max("date"), sum("numeric")).show()

Aggregate functions are simply built in (as above), and UDAFs are used in the same way.

Sketches are probabilistic (i.e. not fully accurate) but fast ways of producing certain types of results. Spark has limited support for sketches, but you can read more at Apache Data Sketches and ZetaSketches.

Non-Solution: groupByKey + reduceGroups

For some reason, this takes forever, and doesn’t do the map-side aggregation you’d expect:



in.

// version 2

in.groupByKey(_.key).reduceGroups((a: Food, b: Food) => Seq(a,b).maxBy(_.date)).map(_._2)

// version 3 - as above but with rdd

in.rdd.keyBy(_.key). // version 1in. groupByKey (_.key). reduceGroups ((a: Food, b: Food) => Seq(a,b).maxBy(_.date)).rdd.values// version 2in.groupByKey(_.key).reduceGroups((a: Food, b: Food) => Seq(a,b).maxBy(_.date)).map(_._2)// version 3 - as above but with rddin.rdd.keyBy(_.key). reduceByKey ((a: Food, b: Food) => Seq(a,b).maxBy(_.date)).values.toDS

I haven’t reproduced the query plan diagrams for any of these solutions, largely because none of them look distinctively crazy. Unfortunately, I also have retained the statistics from failed runs.

Using real data, over 1 billion rows, version 1 took 4.4 hours; version 2 took 4.9 hours, and version 3 failed after 4.9 hours.

Non-Solution: mapPartitions + groupByKey + reduceGroups

This ought to work. Maybe it can even be made to work. The idea is to do the map-side aggregation oneself before the grouping and reducing. This is what I tried, and it didn’t work for me. I suspect that I would have needed to not accumulate the whole map first before returning the iterator (maybe yield options, then flatMap the Option away)?



val latestRows = HashMap.empty[String, Food]

i.forEach((r: Food) => {

val latestFood = Seq(

latestRows.get(r.key), Some(r)).flatMap(x=>x).maxBy(_.date)

latestRows.put(r.key, latestFood)

}

latestRows.iterator

}).groupByKey(_.key).reduceGroups((a: Food, b: Food) => Seq(a,b).maxBy(_.date)).rdd.values in. mapPartitions ((i: Iterator[Food]) => {val latestRows = HashMap.empty[String, Food]i.forEach((r: Food) => {val latestFood = Seq(latestRows.get(r.key), Some(r)).flatMap(x=>x).maxBy(_.date)latestRows.put(r.key, latestFood)latestRows.iterator}).groupByKey(_.key).reduceGroups((a: Food, b: Food) => Seq(a,b).maxBy(_.date)).rdd.values

No timing result, as this fell over almost immediately.

Non-Solution: combineByKey

This one is kind of disappointing, because it has all the same elements as Aggregator , it just didn’t work well. I tried variants with salting the keys and such in order to reduce skew, but no luck. Fell over after 7.2 hours.