A Case Study of Implementing an Efficient Shuffling Stream/Spliterator in Java

Sorting a Stream instance is straightforward and involves just a single API method call – achieving the opposite is not that easy.

In this article, we’ll see how to shuffle a Stream in Java – eagerly and lazily using Stream Collectors factories, and custom Spliterators.

Eager Shuffle Collector

One of the most pragmatic solutions to the above problem was already described by Heinz in this article.

Mainly, it involves encapsulating a compound operation of collecting a whole stream to a list, Collections#shuffle’ing it, and converting to Stream:

public static <T> Collector<T, ?, Stream<T>> toEagerShuffledStream() { return Collectors.collectingAndThen( toList(), list -> { Collections.shuffle(list); return list.stream(); }); }

This solution will be optimal if we want to process all stream elements in random order, but it can bite back if we want to process only a small subset of them – it’s all because the whole collection gets shuffled in advance even if we request only a single element.

Let’s have a look at a simple benchmark and the results it generated:

@State(Scope.Benchmark) public class RandomSpliteratorBenchmark { private List<String> source; @Param({"1", "10", "100", "1000", "10000", "10000"}) public int limit; @Param({"100000"}) public int size; @Setup(Level.Iteration) public void setUp() { source = IntStream.range(0, size) .boxed() .map(Object::toString) .collect(Collectors.toList()); } @Benchmark public List<String> eager() { return source.stream() .collect(toEagerShuffledStream()) .limit(limit) .collect(Collectors.toList()); }

(limit) Mode Cnt Score Error Units eager 1 thrpt 5 467.796 ± 9.074 ops/s eager 10 thrpt 5 467.694 ± 17.166 ops/s eager 100 thrpt 5 459.765 ± 8.048 ops/s eager 1000 thrpt 5 467.934 ± 43.095 ops/s eager 10000 thrpt 5 449.471 ± 5.549 ops/s eager 100000 thrpt 5 331.111 ± 5.626 ops/s

As we can see, it scales pretty well along with the number of elements consumed from the resulting Stream, too bad the absolute value is not so impressive for relatively low numbers – in such situations, shuffling the whole collection beforehand turns out to be quite wasteful.

Let’s see what we can do about it.

Lazy Shuffle Collector

To spare precious CPU cycles, instead of pre-shuffling the whole collection, we can just fetch the number of elements that match the upstream demand.

To achieve that, we need to implement a custom Spliterator that will allow us to iterate through objects in random order, and then we’ll be able to construct a Stream instance by using a helper method from the StreamSupport class:

public class RandomSpliterator<T> implements Spliterator<T> { // ... public static <T> Collector<T, ?, Stream<T>> toLazyShuffledStream() { return Collectors.collectingAndThen( toList(), list -> StreamSupport.stream( new ShuffledSpliterator<>(list), false)); } }

Implementation Details

We can’t avoid evaluating the whole Stream even if we want to pick a single random element (which means there’s no support for infinite sequences) so it’s perfectly fine to initiate our RandomSpliterator<T> with a List<T>, but there’s a catch…

If a particular List implementation doesn’t support constant-time random access, this solution can turn out to be much slower than the eager approach. To protect ourselves from this scenario, we can perform a simple check when instantiating the Spliterator:

private RandomSpliterator( List<T> source, Supplier<? extends Random> random) { if (source.isEmpty()) { ... } // throw this.source = source instanceof RandomAccess ? source : new ArrayList<>(source); this.random = random.get(); }

Creating a new instance of ArrayList is costly, but negligible in comparison to the cost generated by implementations that don’t provide O(1) random access.

And now we can override the most important method – tryAdvance().

In this case, it’s fairly straightforward – in each iteration, we need to randomly pick and remove a random element from the source collection.

We can not worry about mutating the source since we don’t publish the RandomSpliterator, only a Collector which is based on it:

@Override public boolean tryAdvance(Consumer<? super T> action) { int remaining = source.size(); if (remaining > 0 ) { action.accept(source.remove(random.nextInt(remaining))); return true; } else { return false; } }

Besides this, we need to implement three other methods:

@Override public Spliterator<T> trySplit() { return null; // to indicate that split is not possible } @Override public long estimateSize() { return source.size(); } @Override public int characteristics() { return SIZED; }

And now, we try it and see that it works indeed:

IntStream.range(0, 10).boxed() .collect(toLazyShuffledStream()) .forEach(System.out::println);

And the result:

3

4

8

1

7

6

5

0

2

9

Performance Considerations

In this implementation, we replaced N array element swaps with M lookups/removals, where:

N – the collection size

M – the number of picked items

Generally, a single lookup/removal from ArrayList is a more expensive operation than a single element swap which makes this solution not that scalable but significantly better performing for relatively low M values.

Let’s now see how does this solution compare to the eager approach showcased at the beginning(both calculated for a collection containing 100_000 objects):

(limit) Mode Cnt Score Error Units eager 1 thrpt 5 467.796 ± 9.074 ops/s eager 10 thrpt 5 467.694 ± 17.166 ops/s eager 100 thrpt 5 459.765 ± 8.048 ops/s eager 1000 thrpt 5 467.934 ± 43.095 ops/s eager 10000 thrpt 5 449.471 ± 5.549 ops/s eager 100000 thrpt 5 331.111 ± 5.626 ops/s lazy 1 thrpt 5 1530.763 ± 72.096 ops/s lazy 10 thrpt 5 1462.305 ± 23.860 ops/s lazy 100 thrpt 5 823.212 ± 119.771 ops/s lazy 1000 thrpt 5 166.786 ± 16.306 ops/s lazy 10000 thrpt 5 19.475 ± 4.052 ops/s lazy 100000 thrpt 5 4.097 ± 0.416 ops/s

As we can see, this solution outperforms the former if the number of processed Stream items is relatively low, but as the processed/collection_size ratio increases, the throughput drops drastically.

That’s all because of the additional overhead generated by removing elements from the ArrayList holding remaining objects – each removal requires shifting the internal array by one using a relatively expensive System#arraycopy method.

We can notice a similar pattern for much bigger collections (1_000_000 elements):

(limit) (size) Mode Cnt Score Err Units eager 1 10000000 thrpt 5 0.915 ops/s eager 10 10000000 thrpt 5 0.783 ops/s eager 100 10000000 thrpt 5 0.965 ops/s eager 1000 10000000 thrpt 5 0.936 ops/s eager 10000 10000000 thrpt 5 0.860 ops/s lazy 1 10000000 thrpt 5 4.338 ops/s lazy 10 10000000 thrpt 5 3.149 ops/s lazy 100 10000000 thrpt 5 2.060 ops/s lazy 1000 10000000 thrpt 5 0.370 ops/s lazy 10000 10000000 thrpt 5 0.05 ops/s

…and much smaller ones (128 elements, mind the scale!):

(limit) (size) Mode Cnt Score Error Units eager 2 128 thrpt 5 246439.459 ops/s eager 4 128 thrpt 5 333866.936 ops/s eager 8 128 thrpt 5 340296.188 ops/s eager 16 128 thrpt 5 345533.673 ops/s eager 32 128 thrpt 5 231725.156 ops/s eager 64 128 thrpt 5 314324.265 ops/s eager 128 128 thrpt 5 270451.992 ops/s lazy 2 128 thrpt 5 765989.718 ops/s lazy 4 128 thrpt 5 659421.041 ops/s lazy 8 128 thrpt 5 652685.515 ops/s lazy 16 128 thrpt 5 470346.570 ops/s lazy 32 128 thrpt 5 324174.691 ops/s lazy 64 128 thrpt 5 186472.090 ops/s lazy 128 128 thrpt 5 108105.699 ops/s

But, could we do better than this?

Further Performance Improvements

Unfortunately, the scalability of the existing solution is quite disappointing. Let’s try to improve it, but before we do, we should measure first:

As expected, Arraylist#remove turns out to be one of the hot spots – in other words, CPU spends a noticeable amount of time removing things from an ArrayList.

Why is that? Removal from an ArrayList involves removal of an element from an underlying array. The catch is that arrays in Java can’t be resized – each removal triggers a new smaller array creation:

private void fastRemove(Object[] es, int i) { modCount++; final int newSize; if ((newSize = size - 1) > i) System.arraycopy(es, i + 1, es, i, newSize - i); es[size = newSize] = null; }

What can we do about this? Avoid removing elements from an ArrayList.

In order to do that, we could avoid shrinking the list physically, and shrink it logically by tracking its size separately:

class ImprovedRandomSpliterator<T, LIST extends RandomAccess & List<T>> implements Spliterator<T> { private final Random random; private final List<T> source; private int size; ImprovedRandomSpliterator( LIST source, Supplier<? extends Random> random) { Objects.requireNonNull(source, "source can't be null"); Objects.requireNonNull(random, "random can't be null"); this.source = source; this.random = random.get(); this.size = this.source.size(); }

Luckily, we can avoid concurrency issues since instances of this Spliterator are not supposed to be shared between threads.

And now whenever we try to remove an element, we don’t need to actually create a new shrunken list. Instead, we can decrement our size tracker and ignore the remaining part of the list.

But straight before that, we need to swap the last element with the returned element:

@Override public boolean tryAdvance(Consumer<? super T> action) { if (size > 0) { int nextIdx = random.nextInt(size); int lastIdx = --size; T last = source.get(lastIdx); T elem = source.set(nextIdx, last); action.accept(elem); return true; } else { return false; } }

If we profile it now, we can see that the expensive call is gone:

We’re ready to rerun benchmarks and compare:

(limit) (size) Mode Cnt Score Error Units eager 1 100000 thrpt 5 454.396 ± 11.738 ops/s eager 10 100000 thrpt 5 441.602 ± 40.503 ops/s eager 100 100000 thrpt 5 456.167 ± 11.420 ops/s eager 1000 100000 thrpt 5 443.149 ± 7.590 ops/s eager 10000 100000 thrpt 5 431.375 ± 12.116 ops/s eager 100000 100000 thrpt 5 328.376 ± 4.156 ops/s lazy 1 100000 thrpt 5 1419.514 ± 58.778 ops/s lazy 10 100000 thrpt 5 1336.452 ± 34.525 ops/s lazy 100 100000 thrpt 5 926.438 ± 65.923 ops/s lazy 1000 100000 thrpt 5 165.967 ± 17.135 ops/s lazy 10000 100000 thrpt 5 19.673 ± 0.375 ops/s lazy 100000 100000 thrpt 5 4.002 ± 0.305 ops/s optimized 1 100000 thrpt 5 1478.069 ± 32.923 ops/s optimized 10 100000 thrpt 5 1477.618 ± 72.917 ops/s optimized 100 100000 thrpt 5 1448.584 ± 42.205 ops/s optimized 1000 100000 thrpt 5 1435.818 ± 38.505 ops/s optimized 10000 100000 thrpt 5 1060.88 ± 15.238 ops/s optimized 100000 100000 thrpt 5 332.096 ± 7.071 ops/s

As you can see, we ended up with an implementation which is way more resistant performance-wise to the number of elements we reach for.

Actually, the improved implementation performs slightly better than the Collections#shuffle-based one even in the pessimistic scenario! Our work here is done.

And to put a small cherry on top, notice how we can leverage intersection types to ensure that only appropriate List instances get passed to it!

The Complete Example

…can be also found on GitHub.

import java.util.List; import java.util.Objects; import java.util.Random; import java.util.RandomAccess; import java.util.Spliterator; import java.util.function.Consumer; import java.util.function.Supplier; class ImprovedRandomSpliterator<T, LIST extends RandomAccess & List<T>> implements Spliterator<T> { private final Random random; private final List<T> source; private int size; ImprovedRandomSpliterator( LIST source, Supplier<? extends Random> random) { Objects.requireNonNull(source, "source can't be null"); Objects.requireNonNull(random, "random can't be null"); this.source = source; this.random = random.get(); this.size = this.source.size(); } @Override public boolean tryAdvance(Consumer<? super T> action) { if (size > 0) { int nextIdx = random.nextInt(size); int lastIdx = --size; T last = source.get(lastIdx); T elem = source.set(nextIdx, last); action.accept(elem); return true; } else { return false; } } @Override public Spliterator<T> trySplit() { return null; } @Override public long estimateSize() { return source.size(); } @Override public int characteristics() { return SIZED; } }

package com.pivovarit.stream; import java.util.ArrayList; import java.util.Collections; import java.util.Random; import java.util.stream.Collector; import java.util.stream.Collectors; import java.util.stream.Stream; import java.util.stream.StreamSupport; import static java.util.stream.Collectors.toCollection; public final class RandomCollectors { private RandomCollectors() { } public static <T> Collector<T, ?, Stream<T>> toOptimizedLazyShuffledStream() { return Collectors.collectingAndThen( toCollection(ArrayList::new), list -> StreamSupport.stream( new ImprovedRandomSpliterator<>(list, Random::new), false)); } public static <T> Collector<T, ?, Stream<T>> toLazyShuffledStream() { return Collectors.collectingAndThen( toCollection(ArrayList::new), list -> StreamSupport.stream( new RandomSpliterator<>(list, Random::new), false)); } public static <T> Collector<T, ?, Stream<T>> toEagerShuffledStream() { return Collectors.collectingAndThen( toCollection(ArrayList::new), list -> { Collections.shuffle(list); return list.stream(); }); } }



