This blog post shows how to write some Spark code with the Java API and run a simple test.

The code snippets in this post are from this GitHub repo.

Project setup

Start by creating a pom.xml file for Maven.

<?xml version="1.0" encoding="UTF-8"?>

<project xmlns="http://maven.apache.org/POM/4.0.0"

xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"

xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">

<modelVersion>4.0.0</modelVersion>



<groupId>mrpowers</groupId>

<artifactId>JavaSpark</artifactId>

<version>1.0-SNAPSHOT</version>



<build>

<plugins>

<plugin>

<groupId>org.apache.maven.plugins</groupId>

<artifactId>maven-compiler-plugin</artifactId>

<version>3.7.0</version>

<configuration>

<source>1.8</source>

<target>1.8</target>

</configuration>

</plugin>

</plugins>

</build>



<dependencies>

<dependency>

<groupId>org.apache.spark</groupId>

<artifactId>spark-sql_2.11</artifactId>

<version>2.4.0</version>

</dependency>

</dependencies>



</project>

This build file adds Spark SQL as a dependency and specifies a Maven version that’ll support some necessary Java language features for creating DataFrames.

Write some code

Let’s create a Transformations class with a myCounter method that returns the number of rows in a DataFrame. myCounter would not ever be useful in a real project, but it’s best to get started with a simple example.

Create the src/main/java/mrpowers/javaspark/Transformations.java file.

package mrpowers.javaspark;



import org.apache.spark.sql.Dataset;

import org.apache.spark.sql.Row;



public class Transformations {



public long myCounter(Dataset<Row> df){

return df.count();

}



}

The Transformations class lives in the mrpowers.javaspark package. Namespacing is important to prevent class name collisions (we don’t want Java to get our Transformations class confused with another library that has a class with the same name).

We import the Dataset and Row classes from Spark so they can be accessed in the myCounter function.

We could have imported all of the Spark SQL code, including Dataset and Row , with a single wildcard import: import org.apache.spark.sql.* Wildcard imports make it harder to identify where classes are defined and it’s generally best to avoid them.

Write a test

Let’s use junit to test the myCounter function.

Add junit as a dependency in the pom.xml file.

<dependency>

<groupId>junit</groupId>

<artifactId>junit</artifactId>

<version>4.13-beta-1</version>

</dependency>

Our test will create a DataFrame with two rows and verify that the myCounter function returns the integer 2 when it’s passed our DataFrame as an input.

Here’s how our test logic at a high level.

// create a DataFrame called df

Transformations transformations = new Transformations();

long result = transformations.myCounter(df);

assertEquals(2, result);

The junit assertEquals function is where we actually make our assertion to verify that the actual output and expected output match.

Let’s take a look at the whole test file in src/test/java/mrpowers/javaspark/TransformationsTest.java .

Brace yourself for some verbose code!

package mrpowers.javaspark;



import org.junit.Test;

import static org.junit.Assert.*;



import java.util.List;

import java.util.ArrayList;

import org.apache.spark.api.java.JavaRDD;

import org.apache.spark.api.java.JavaSparkContext;

import org.apache.spark.sql.Dataset;

import org.apache.spark.sql.Row;

import org.apache.spark.sql.RowFactory;

import org.apache.spark.sql.types.DataTypes;

import org.apache.spark.sql.types.StructField;

import org.apache.spark.sql.types.StructType;



public class TransformationsTest implements SparkSessionTestWrapper {



@Test

public void testMyCounter() {

List<String[]> stringAsList = new ArrayList<>();

stringAsList.add(new String[] { "bar1.1", "bar2.1" });

stringAsList.add(new String[] { "bar1.2", "bar2.2" });



JavaSparkContext sparkContext = new JavaSparkContext(spark.sparkContext());



JavaRDD<Row> rowRDD = sparkContext

.parallelize(stringAsList)

.map((String[] row) -> RowFactory.create(row));



// Create schema

StructType schema = DataTypes

.createStructType(new StructField[] {

DataTypes.createStructField("foe1", DataTypes.StringType, false),

DataTypes.createStructField("foe2", DataTypes.StringType, false)

});



Dataset<Row> df = spark.sqlContext().createDataFrame(rowRDD, schema).toDF();



Transformations transformations = new Transformations();

long result = transformations.myCounter(df);

assertEquals(2, result);

}



}

Let’s create a SparkSessionTestWrapper interface to access the Spark session in our test. The SparkSession is defined in an interface so multiple test files can use the same SparkSession.

package mrpowers.javaspark;



import org.apache.spark.sql.SparkSession;



public interface SparkSessionTestWrapper {



SparkSession spark = SparkSession

.builder()

.appName("Build a DataFrame from Scratch")

.master("local[*]")

.getOrCreate();



}

Run the tests with the mvn test command.

Next steps

This tutorial gives us a great foundation to explore more features that all Java Spark programmers need to master. Here are the next steps:

Building JAR files with Maven (similar to building JAR files with SBT) Chaining custom transformations (we already know how to do this with the Scala API and with PySpark) Making DataFrame comparisons in the test suite with spark-fast-tests Using spark-daria in application code

P.S. This is the first Java code I’ve ever written. Please post a comment or email me if you have any suggestions on how to make this code better.