Unit Testing with PySpark

By David Illes, Vice President at Morgan Stanley and Big Data Teaching Fellow at Cambridge Spark

I’m a big fan of testing in general, but especially unit testing. Don’t get me wrong, I don’t particularly enjoy writing tests, but having a proper testing suite is one of the fundamental building blocks that differentiate hacking from software engineering. Sort of like sending your application to the gym, if you do it right, it might not be a pleasant experience, but you’ll reap the benefits continuously. At work we are especially big fans of the testing pyramid, and having dozens of unit tests give us the support that we need to deliver high quality software with rapid delivery to production.

Most recently I had the pleasure of working on a project for one of Cambridge Sparks’ project-partners, which heavily relied on PySpark, and I was faced with the question of how to write effective unit tests for my PySpark jobs. To do this, we must start right at the beginning — how we structure our code.

When people start out writing PySpark jobs (especially Data Scientists) they tend to create one massive function that looks something like this:

def my_application_main():

#Code here to create my spark contexts



#Code here to fetch initial tables/rdds



#Data Science magic code...

#...

#...

#...



#Persist/send results

;

The minute you start to write your tests for a code like this, you instantly realise that it’s impossible, because the context creation is coupled to the job, and your local environment that runs your unit tests will not be able to create it. The solution is to start breaking up your program to many smaller units and functions that we can test individually and not couple ourselves to the real spark context, but we will be able to create a testing spark context instead, that can work in our local machine.

A better application would be structured like this:



def my_logic_step_1(my_rdd, my_dataframe):

#some processing

;



def my_logic_step2(my_dataframe):

#some processing

;



def persist_results(my_dataframe):

#persist back results

;



def logic_main(ctx):

#get tables, rdds in interest from the context

data = my_logic_step_1(x, y)

data = my_logic_step_2(data)

persist_results(data)



def my_application_main():

ctx = ...#create spark context as you see fit

logic_main(ctx)



The key idea here is to have small functions that get the rdds and dataframes they work on as inputs, so they are easily testable, have a function that composes these modules as the main logic for the job (logic_main in this example) but still expect the spark context to be provided to it as an input. This will enable us to write an integration test for the entire job as well, and have a separate main method that does nothing else but creates our contexts and passes it in to our logic. This is the only piece of our code that will not be covered by unit tests, and if we did our job right, it should not be longer than 10–20 lines.

Once we split our job to testable units, we can create a local spark session for our testing purposes. We would also recommend changing the logging level for py4j to get rid of noise in the logs irrelevant to our testing.



import logging

from pyspark.sql import SparkSession def suppress_py4j_logging():

logger = logging.getLogger(‘py4j’)

logger.setLevel(logging.WARN) def create_testing_pyspark_session():

return (SparkSession.builder

.master(‘local[2]’)

.appName(‘my-local-testing-pyspark-context’)

.enableHiveSupport()

.getOrCreate())



Assuming we use ‘unittest’, it’s really easy to create a pyspark testing base class for our test suites:

import unittest

import logging

from pyspark.sql import SparkSession



@classmethod

def suppress_py4j_logging(cls):

logger = logging.getLogger(‘py4j’)

logger.setLevel(logging.WARN) class PySparkTest(unittest.TestCase):def suppress_py4j_logging(cls):logger = logging.getLogger(‘py4j’)logger.setLevel(logging.WARN) @classmethod

def create_testing_pyspark_session(cls):

return (SparkSession.builder

.master(‘local[2]’)

.appName(‘my-local-testing-pyspark-context’)

.enableHiveSupport()

.getOrCreate())



@classmethod

def setUpClass(cls):

cls.suppress_py4j_logging()

cls.spark = cls.create_testing_pyspark_session() def create_testing_pyspark_session(cls):return (SparkSession.builder.master(‘local[2]’).appName(‘my-local-testing-pyspark-context’).enableHiveSupport().getOrCreate())def setUpClass(cls):cls.suppress_py4j_logging()cls.spark = cls.create_testing_pyspark_session() @classmethod

def tearDownClass(cls):

cls.spark.stop() def tearDownClass(cls):cls.spark.stop()

Which we can extend to write our test cases and leverage the local spark session:

from operator import add class SimpleTest(PySparkTest): def test_basic(self):

test_rdd = self.spark.sparkContext.parallelize([‘cat dog mouse’,’cat cat dog’], 2)

results = test_rdd.flatMap(lambda line: line.split()).map(lambda word: (word, 1)).reduceByKey(add).collect()

expected_results = [(‘cat’, 3), (‘dog’, 2), (‘mouse’, 1)]

self.assertEqual(set(results), set(expected_results))

So far so good, we have a job that’s broken up to many testable units, we know how to create a local spark session to be used while running our unit tests, and we know how to test simple methods that work on RDDs. One crucial thing that is missing is how to test methods operating on DataFrames.

Creating and testing Spark DataFrames seems problematic, because these dataframes usually originate from an underlying Hive table (or some other big data store), and asserting equality of 2 Spark DataFrames is not as trivial is it should be. My argument is, that unit tests should focus on a single functionality / edge case, therefore manually crafting a small (meaning fits into the memory of our python process) dataset that can demonstrate that functionality should be easy, and in fact desirable to loading huge amounts of random production data.

Luckily we already have a technology that we know and love, and let’s us do just that. This is Python’s Pandas framework. Turns out it’s quite easy to prepare our input data with the help of Pandas. (either from CSV or creating the data in line — this post is not intended to cover the Pandas library, I suggest you take a look at their excellent documentation: https://pandas.pydata.org/pandas-docs/stable/ ). And not just that, it is also easy to convert our results back to Pandas dataframes and gain access to the entirety of carefully crafted testing functions that pandas have to offer.

Let’s take a quick look how that works in practice. Let’s assume we want to test this function that filters rows with a ‘year’ value smaller than 2000:

def my_spark_function(df):

return df.filter(col(‘year’) >= 2000)

In this case we can prepare our data by hand with the help of pandas, run the function, and assert the results against another dataframe prepared by us:

def test_data_frame(self):

import pandas as pd

# Create the test data, with larger examples this can come from a CSV file

# and we can use pd.read_csv(…)

data_pandas = pd.DataFrame({‘make’:[‘Jaguar’, ‘MG’, ‘MINI’, ‘Rover’, ‘Lotus’],

‘registration’:[‘AB98ABCD’,’BC99BCDF’,’CD00CDE’,’DE01DEF’,’EF02EFG’],

‘year’:[1998,1999,2000,2001,2002]}) # Turn the data into a Spark DataFrame, self.spark comes from our PySparkTest base class

data_spark = self.spark.createDataFrame(data_pandas) # Invoke the unit we’d like to test

results_spark = my_spark_function(data_spark)

# Turn the results back to Pandas

results_pandas = results_spark.toPandas() # Our expected results crafted by hand, again, this could come from a CSV

# in case of a bigger example

expected_results = pd.DataFrame({‘make’:[‘Rover’, ‘Lotus’, ‘MINI’],

‘registration’:[‘DE01DEF’,’EF02EFG’, ‘CD00CDE’],

‘year’:[2001,2002, 2000]}) # Assert that the 2 results are the same. We’ll cover this function in a bit

assert_frame_equal_with_sort(results_pandas, expected_results, [‘registration’])

The only tricky bit here is the assert_frame_equal_with_sort, which relies on the Pandas built-in assert_frame_equal method, but since the ordering in Spark is not guaranteed (due to the distributed nature) we usually need to order our dataframes before comparison. The method is implemented like this:

from pandas.testing import assert_frame_equal def assert_frame_equal_with_sort(results, expected, keycolumns):

results_sorted = results.sort_values(by=keycolumns).reset_index(drop=True)

expected_sorted = expected.sort_values(by=keycolumns).reset_index(drop=True)

assert_frame_equal(results_sorted, expected_sorted)

And that’s really it for the basics, which in my opinion follows the 80/20 rule quite nicely, and can be applied to write unit tests for almost all of your PySpark needs. In this small post we have touched on structuring PySpark applications, setting up a local spark session for our unit testing, getting rid of logging noise for our tests, unit test functions operating on simple RDDs, and unit tests operating on Spark DataFrames.