I have a PySpark dataframe with a column "group". I also have feature columns and a label column. I want to split the dataframe for each group and then train a model and end up with a dictionary where the keys are the "group" names and the values are the trained models.

This question essentially give an answer to this problem. This method is inefficient.

The obvious problem here is that it requires a full data scan for each level, so it is an expensive operation.

The answer is old and I am hoping there have been improvements in PySpark since then. For my use case I have 10k groups, with heavy skew in the data sizes. The largest group can have 1 Billion records and the smallest group can have 1 record.

Edit: As suggested here is a small reproducible example.

df = sc.createDataFrame( [ ('A', 1, 0, True), ('A', 3, 0, False), ('B', 2, 2, True), ('B', 3, 3, True), ('B', 5, 2, False) ], ('group', 'feature_1', 'feature_2', 'label') )

I can split the data as suggested in the above link:

from itertools import chain from pyspark.sql.functions import col groups = chain(*df.select("group").distinct().collect()) df_by_group = {group: train_model(df.where(col("group").eqNullSafe(group))) for group in groups}