Adding new Columns based on aggregation on existing column in Spark DataFrame using scala

10,144

Solution 1

groupBy col1 and aggregate to get count and max. Then you can join it back with original dataframe to get your desired result

val df2 = df1.groupBy("col1").agg(count() as col3, max("col2") as col4) 

val df3 = df1.join(df2, "col1")

Solution 2

spark df has property called withColumn You can add as many derived columns as you want. But the column is not added to existing DF instead it create a new DF with added column.

e.g. Adding a static date to the data

val myFormattedData = myData.withColumn("batchdate",addBatchDate(myData("batchdate")))
val addBatchDate = udf { (BatchDate: String) => "20160101" }

Solution 3

To add col3 you can use withcolumn + when/otherwise :

val df2 = df.withColumn("col3",when($"col2" > 1, 1).otherwise(0))

To add col4 the groupBy/max + join already mentionned should do the job :

val df3 = df2.join(df.groupBy("col1").max("col2"), "col1")

Solution 4

To achieve this without a join, you need to use count and max as window functions. This requires creating a window using Window and telling count and max t operate over this window.

from pyspark.sql import Window, functions as fn

df = sc.parallelize([
    {'col1': 'a', 'col2': 1},
    {'col1': 'a', 'col2': 2},
    {'col1': 'b', 'col2': 1},
    {'col1': 'c', 'col2': 1},
    {'col1': 'd', 'col2': 1},
    {'col1': 'd', 'col2': 2}
]).toDF()

col1_window = Window.partitionBy('col1')
df = df.withColumn('col3', fn.when(fn.count('col1').over(col1_window) > 1, 1).otherwise(0))
df = df.withColumn('col4', fn.max('col2').over(col1_window))
df.orderBy(['col1', 'col2']).show()
Share:
10,144
John
Author by

John

Updated on June 04, 2022

Comments

  • John
    John almost 2 years

    I have a DataFrame like below. I need to create a new column based on existing columns.

    col1 col2
    a      1
    a      2
    b      1
    c      1
    d      1
    d      2
    

    Output Data Frame look like this

    col1  col2 col3 col4
    a      1   1      2
    a      2   1      2
    b      1   0      1
    c      1   0      1
    d      1   1      2
    d      2   1      2
    

    The logic I have used to find col3 is if count of col1 > 1 and col4 is max value of col2.

    I am familiar with how to do it in sql . But it's hard to find solution with dataframe DSL. Any help would be appreciated. Thanks

  • John
    John almost 8 years
    ,+1 for join and group concept. Just for clarification col3 is not the sum of col2. It is count of col2. if col2 >1 it should be 1 otherwise it should be zero.Without join is there any way?. When i am using join in huge data I am facing memory error. Thanks
  • Shrikant Prabhu
    Shrikant Prabhu over 6 years
    Yes I would also like to know the solution without a join