Efficient Spark Dataframe Transforms

If you are working with Spark, you will most likely have to write transforms on dataframes. Dataframe exposes the obvious method df.withColumn(col_name,col_expression) for adding a column with a specified expression. Now, as we know that the dataframes are immutable in nature, so we are getting a newly created copy of dataframe with our added column (if you look at the source code for method withColumn , you will also see additional checks being performed, like whether the column exists or not. That check is unnecessary in most cases). And, this is very inefficient, especially, if we have to add multiple columns. for example, a wide transform of our dataframe such as pivot transform (Note: There is also a bug on how wide your transformation can be, which is fixed in Spark 2.3.0).

Here is an optimized version of a pivot method. Note that rather than using df.withColumn , we are collecting all column expressions in a mutable ListBuffer and then applying all expressions at once via df.select(colExprs: _*) which is phenomenally fast, especially considering the fact that df.withColumn hangs the driver process even for a transform on a few hundred columns(it causes hung threads and locks, you can see this using jVisualVM), whereas the optimized version can operate on thousands of columns easily.

Optimized Version: /** * Pivots the DataFrame by the pivot column. It is better to specify the distinct values, as otherwise distinct values need to be calculated * * @param groupBy The columns to groupBy * @param pivot The pivot column * @param distinct An Optional Array of distinct values * @param agg the aggregate function to apply. Default="sum" * @param df the df to transpose and return * @param ev the implicit encoder to use * @tparam A The type of pivot column * @return the transposed dataframe */ def doPivotTF [ A ]( groupBy : Seq [ String ], pivot : String , distinct : Option [ Array [ A ]], agg : String = "sum" )( df : Dataset [ Row ])( implicit ev : Encoder [ A ]) : Dataset [ Row ] = { val colsToFilter = ( Seq ( pivot ) ++ groupBy ++ df . schema . filter ( _ . dataType match { case _: NumericType => false case _: Numeric [ _ ] => false case _ => true }). map ( _ . name )). distinct val colsToTranspose = df . columns . filter (! colsToFilter . contains ( _ )). toSeq if ( logger . isDebugEnabled ) { logger . debug ( s"colsToFilter $colsToFilter " ) logger . debug ( s"colsToTranspose $colsToTranspose " ) } val distinctValues = distinct match { case Some ( v ) => v case None => { df . select ( col ( pivot )). map ( row => row . getAs [ A ]( pivot )). distinct (). collect () } } val colExprs = new ListBuffer [ Column ]() colExprs += col ( "*" ) for ( colName <- colsToTranspose ) { for ( index <- distinctValues ) { val colExpr = when ( col ( pivot ) === index , col ( colName )). otherwise ( 0.0 ) val colNameToUse = s" ${ colName } _TN $index " colExprs += colExpr as colNameToUse } } val transposedDF = df . select ( colExprs : _ * ) //Drop all original columns except columns in groupBy val colsToDrop = colsToFilter . filter (! groupBy . contains ( _ )) ++ colsToTranspose val dfBeforeGroupBy = transposedDF . drop ( colsToDrop : _ * ) val finalDF = dfBeforeGroupBy . groupBy ( groupBy . map ( col ) : _ * ). agg ( dfBeforeGroupBy . columns . filter (! groupBy . contains ( _ )). map ( _ -> agg ). toMap ) //Remove spark generated $agg suffixes val finalColNames = finalDF . columns . map ( _ . stripSuffix ( s" $agg (" ). stripSuffix ( ")" )) if ( logger . isDebugEnabled ()) { logger . debug ( s"Final set of colum names $finalColNames " ) } finalDF . toDF ( finalColNames : _ * ) }