Here, I’ve explained how to get the first row, minimum, maximum of each group in Spark DataFrame using Spark SQL window functions and Scala example. Though I’ve explained here with Scala, the same method could be used to working with PySpark and Python.

Preparing Data & DataFrame

Before, we start let’s create the DataFrame from a sequence of the data to work with. This DataFrame contains 3 columns “employee_name”, “department” and “salary” and column “department” contains different departments to do grouping.

Will use this Spark DataFrame to select the first row for each group, minimum salary for each group and maximum salary for the group. finally will also see how to get the sum and the average salary for each department group.

val simpleData = Seq(("James","Sales",3000), ("Michael","Sales",4600), ("Robert","Sales",4100), ("Maria","Finance",3000), ("Raman","Finance",3000), ("Scott","Finance",3300), ("Jen","Finance",3900), ("Jeff","Marketing",3000), ("Kumar","Marketing",2000) ) import spark.implicits._ val df = simpleData.toDF("Name","Department","Salary") df.show()

Outputs below table

+-------------+----------+------+ |employee_name|department|salary| +-------------+----------+------+ | James| Sales| 3000| | Michael| Sales| 4600| | Robert| Sales| 4100| | Maria| Finance| 3000| | Raman| Finance| 3000| | Scott| Finance| 3300| | Jen| Finance| 3900| | Jeff| Marketing| 3000| | Kumar| Marketing| 2000| +-------------+----------+------+

Spark DataFrame – Select the first row from a group

We can select the first row from the group using SQL or DataFrame API, in this section, we will see with DataFrame API using a window function row_rumber and partitionBy.

val w2 = Window.partitionBy("department").orderBy(col("salary")) df.withColumn("row",row_number.over(w2)) .where($"row" === 1).drop("row") .show()

On above snippet, first, we are partitioning on department column which groups all same departments into a group and then apply order on salary column. Now, And will use this window with row_number function. This snippet outputs the following.

row_number function returns a sequential number starting from 1 within a window partition group.

+-------------+----------+------+ |employee_name|department|salary| +-------------+----------+------+ | James| Sales| 3000| | Maria| Finance| 3000| | Kumar| Marketing| 2000| +-------------+----------+------+

Retrieve Employee who earns the highest salary

To retrieve the highest salary for each department, will use orderby “salary” in descending order and retrieve the first element.

val w3 = Window.partitionBy("department").orderBy(col("salary").desc) df.withColumn("row",row_number.over(w3)) .where($"row" === 1).drop("row") .show()

Outputs the following

+-------------+----------+------+ |employee_name|department|salary| +-------------+----------+------+ | Michael| Sales| 4600| | Jen| Finance| 3900| | Jeff| Marketing| 3000| +-------------+----------+------+

Select the Highest, Lowest, Average and Total salary for each department group

Here, we will retrieve the Highest, Average, Total and Lowest salary for each group. Below snippet uses partitionBy and row_number along with aggregation functions avg, sum, min, and max.

val w4 = Window.partitionBy("department") val aggDF = df.withColumn("row",row_number.over(w3)) .withColumn("avg", avg(col("salary")).over(w4)) .withColumn("sum", sum(col("salary")).over(w4)) .withColumn("min", min(col("salary")).over(w4)) .withColumn("max", max(col("salary")).over(w4)) .where(col("row")===1).select("department","avg","sum","min","max") .show()

Outputs the following aggregated values for each group.

+----------+------+-----+----+----+ |department| avg| sum| min| max| +----------+------+-----+----+----+ | Sales|3900.0|11700|3000|4600| | Finance|3300.0|13200|3000|3900| | Marketing|2500.0| 5000|2000|3000| +----------+------+-----+----+----+

Complete program for reference

package com.sparkbyexamples.spark.dataframe.functions import org.apache.spark.sql.SparkSession import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions._ object WindowGroupbyFirst extends App { val spark: SparkSession = SparkSession.builder() .master("local[1]") .appName("SparkByExamples.com") .getOrCreate() spark.sparkContext.setLogLevel("ERROR") import spark.implicits._ val simpleData = Seq(("James","Sales",3000), ("Michael","Sales",4600), ("Robert","Sales",4100), ("Maria","Finance",3000), ("Raman","Finance",3000), ("Scott","Finance",3300), ("Jen","Finance",3900), ("Jeff","Marketing",3000), ("Kumar","Marketing",2000) ) val df = simpleData.toDF("employee_name","department","salary") df.show() //Get the first row from a group. val w2 = Window.partitionBy("department").orderBy(col("salary")) df.withColumn("row",row_number.over(w2)) .where($"row" === 1).drop("row") .show() //Retrieve Highest salary val w3 = Window.partitionBy("department").orderBy(col("salary").desc) df.withColumn("row",row_number.over(w3)) .where($"row" === 1).drop("row") .show() //Maximum, Minimum, Average, total salary for each window group val w4 = Window.partitionBy("department") val aggDF = df.withColumn("row",row_number.over(w3)) .withColumn("avg", avg(col("salary")).over(w4)) .withColumn("sum", sum(col("salary")).over(w4)) .withColumn("min", min(col("salary")).over(w4)) .withColumn("max", max(col("salary")).over(w4)) .where(col("row")===1).select("department","avg","sum","min","max") .show() }

Conclusion

In this article, you have learned how to retrieve the first row of each group, minimum, maximum, average and sum for each group in a Spark Dataframe.

Happy Learning !!