How to group by multiple columns and collect in list in PySpark?

12,312

Solution 1

I finally found a solution, it is not the best way but I can continue working...

from pyspark.sql.functions import udf
from pyspark.sql.functions import *

def example(lista):
    d = [[] for x in range(len(lista))]
    for index, elem in enumerate(lista):
      d[index] = elem.split("@")
    return d
example_udf = udf(example, LongType())

a = [[u'PNR1',u'TKT1',u'TEST',u'a2',u'a3'],[u'PNR1',u'TKT1',u'TEST',u'a5',u'a6'],[u'PNR1',u'TKT1',u'TEST',u'a8',u'a9']]

rdd= sc.parallelize (a)

df = rdd.toDF(["col1","col2","col3","col4","col5"])

df2=df.withColumn('col6', concat(col('col4'),lit('@'),col('col5'))).drop(col("col4")).drop(col("col5")).groupBy([col("col1"),col("col2"),col("col3")]).agg(collect_set(col("col6")).alias("col6"))

df2.map(lambda x: (x[0],x[1],x[2],example(x[3]))).collect()

And it gives:

[(u'PNR1', u'TKT1', u'TEST', [[u'a2', u'a3'], [u'a5', u'a6'], [u'a8', u'a9']])]

Hope this solution can help to someone else.

Thanks for all your answers.

Solution 2

This might do your job (or give you some ideas to proceed further)...

One idea is to convert your col4 to a primitive data type, i.e. a string:

from pyspark.sql.functions import collect_list
import pandas as pd

a = [[u'PNR1',u'TKT1',u'TEST',u'a2',u'a3'],[u'PNR1',u'TKT1',u'TEST',u'a5',u'a6'],[u'PNR1',u'TKT1',u'TEST',u'a8',u'a9']]
rdd = sc.parallelize(a)

df = rdd.map(lambda x: (x[0],x[1],x[2], '(' + ' '.join(str(e) for e in x[3:]) + ')')).toDF(["col1","col2","col3","col4"])

df.groupBy("col1","col2","col3").agg(collect_list("col4")).toPandas().values.tolist()[0]
#[u'PNR1', u'TKT1', u'TEST', [u'(a2 a3)', u'(a5 a6)', u'(a8 a9)']]

UPDATE (after your own answer):

I really thought the point I had reached above was enough to further adapt it according to your needs, plus that I didn't have time at the moment to do it myself; so, here it is (after modifying my df definition to get rid of the parentheses, it is just a matter of a single list comprehension):

df = rdd.map(lambda x: (x[0],x[1],x[2], ' '.join(str(e) for e in x[3:]))).toDF(["col1","col2","col3","col4"])

# temp list:
ff = df.groupBy("col1","col2","col3").agg(collect_list("col4")).toPandas().values.tolist()[0]
ff
# [u'PNR1', u'TKT1', u'TEST', [u'a2 a3', u'a5 a6', u'a8 a9']]

# final list of lists:
ll = ff[:-1] + [[x.split(' ') for x in ff[-1]]]
ll

which gives your initially requested result:

[u'PNR1', u'TKT1', u'TEST', [[u'a2', u'a3'], [u'a5', u'a6'], [u'a8', u'a9']]]  # requested output

This approach has certain advantages compared with the one provided in your own answer:

  • It avoids Pyspark UDFs, which are known to be slow
  • All the processing is done in the final (and hopefully much smaller) aggregated data, instead of adding and removing columns and performing map functions and UDFs in the initial (presumably much bigger) data
Share:
12,312
Carlos Lopez Sobrino
Author by

Carlos Lopez Sobrino

Updated on June 08, 2022

Comments

  • Carlos Lopez Sobrino
    Carlos Lopez Sobrino almost 2 years

    Here is my problem: I've got this RDD:

    a = [[u'PNR1',u'TKT1',u'TEST',u'a2',u'a3'],[u'PNR1',u'TKT1',u'TEST',u'a5',u'a6'],[u'PNR1',u'TKT1',u'TEST',u'a8',u'a9']]
    
    rdd= sc.parallelize (a)
    

    Then I try :

    rdd.map(lambda x: (x[0],x[1],x[2], list(x[3:])))
    
    .toDF(["col1","col2","col3","col4"])
    
    .groupBy("col1","col2","col3")
    
    .agg(collect_list("col4")).show
    

    Finally I should find this:

    [col1,col2,col3,col4]=[u'PNR1',u'TKT1',u'TEST',[[u'a2',u'a3'][u'a5',u'a6'][u'a8',u'a9']]]
    

    But the problem is that I can't collect a list.

    If anyone can help me I will appreciate it

  • Carlos Lopez Sobrino
    Carlos Lopez Sobrino over 6 years
    Actually I need a list of lists in col4, in your answer I've in string type (a2 a3) for example, and I need [[a2,a3],[a5,a6],[a8,a9]]
  • desertnaut
    desertnaut over 6 years
    @CarlosLopezSobrino isn't the updated answer exactly what you asked for?