How to retrieve all columns using pyspark collect_list functions

16,081

Solution 1

Use struct to combine the columns before calling groupBy

suppose you have a dataframe

df = spark.createDataFrame(sc.parallelize([(0,1,2),(0,4,5),(1,7,8),(1,8,7)])).toDF("a","b","c")

df = df.select("a", f.struct(["b","c"]).alias("newcol"))
df.show()
+---+------+
|  a|newcol|
+---+------+
|  0| [1,2]|
|  0| [4,5]|
|  1| [7,8]|
|  1| [8,7]|
+---+------+
df = df.groupBy("a").agg(f.collect_list("newcol").alias("collected_col"))
df.show()
+---+--------------+
|  a| collected_col|
+---+--------------+
|  0|[[1,2], [4,5]]|
|  1|[[7,8], [8,7]]|
+---+--------------+

Aggregation operation can be done only on single columns.

After aggregation, You can collect the result and iterate over it to separate the combined columns generate the index dict. or you can write a udf to separate the combined columns.

from pyspark.sql.types import *
def foo(x):
    x1 = [y[0] for y in x]
    x2 = [y[1] for y in x]
    return(x1,x2)

st = StructType([StructField("b", ArrayType(LongType())), StructField("c", ArrayType(LongType()))])
udf_foo = udf(foo, st)
df = df.withColumn("ncol", 
                  udf_foo("collected_col")).select("a",
                  col("ncol").getItem("b").alias("b"), 
                  col("ncol").getItem("c").alias("c"))
df.show()

+---+------+------+
|  a|     b|     c|
+---+------+------+
|  0|[1, 4]|[2, 5]|
|  1|[7, 8]|[8, 7]|
+---+------+------+

Solution 2

Actually we can do it in pyspark 2.2 .

First we need create a constant column ("Temp"), groupBy with that column ("Temp") and apply agg by pass iterable *exprs in which expression of collect_list exits.

Below is the code:

import pyspark.sql.functions as ftions
import functools as ftools

def groupColumnData(df, columns):
      df = df.withColumn("Temp", ftions.lit(1))
      exprs = [ftions.collect_list(colName) for colName in columns]
      df = df.groupby('Temp').agg(*exprs)
      df = df.drop("Temp")
      df = df.toDF(*columns)
      return df

Input Data:

df.show()
+---+---+---+
|  a|  b|  c|
+---+---+---+
|  0|  1|  2|
|  0|  4|  5|
|  1|  7|  8|
|  1|  8|  7|
+---+---+---+

Output Data:

df.show()

    +------------+------------+------------+
    |           a|           b|           c|
    +------------+------------+------------+
    |[0, 0, 1, 1]|[1, 4, 7, 8]|[2, 5, 8, 7]|
    +------------+------------+------------+

Solution 3

in spark 2.4.4 and python 3.7 (I guess its also relevant for previous spark and python version) --
My suggestion is a based on pauli's answer,
instead of creating the struct and then using the agg function, create the struct inside collect_list:

df = spark.createDataFrame([(0,1,2),(0,4,5),(1,7,8),(1,8,7)]).toDF("a","b","c")
df.groupBy("a").agg(collect_list(struct(["b","c"])).alias("res")).show()

result :

+---+-----------------+
|  a|res              |
+---+-----------------+
|  0|[[1, 2], [4, 5]] |
|  1|[[7, 8], [8, 7]] |
+---+-----------------+
Share:
16,081
Python Learner
Author by

Python Learner

Updated on June 09, 2022

Comments

  • Python Learner
    Python Learner almost 2 years

    I have a pyspark 2.0.1. I'm trying to groupby my data frame & retrieve the value for all the fields from my data frame. I found that

    z=data1.groupby('country').agg(F.collect_list('names')) 
    

    will give me values for country & names attribute & for names attribute it will give column header as collect_list(names). But for my job I have dataframe with around 15 columns & I will run a loop & will change the groupby field each time inside loop & need the output for all of the remaining fields.Can you please suggest me how to do it using collect_list() or any other pyspark functions?

    I tried this code too

    from pyspark.sql import functions as F 
    fieldnames=data1.schema.names 
    names1= list() 
    for item in names: 
       if item != 'names': 
         names1.append(item) 
     z=data1.groupby('names').agg(F.collect_list(names1)) 
     z.show() 
    

    but got error message

    Py4JError: An error occurred while calling z:org.apache.spark.sql.functions.collect_list. Trace: py4j.Py4JException: Method collect_list([class java.util.ArrayList]) does not exist