Cumulative sum in Spark

12,504

To get the cumulative sum using the DataFrame API you should use the rowsBetween window method. In Spark 2.1 and newer create the window as follows:

val w = Window.partitionBy($"product_id", $"ack")
  .orderBy($"date_time")
  .rowsBetween(Window.unboundedPreceding, Window.currentRow)

This will tell Spark to use the values from the beginning of the partition until the current row. Using older versions of Spark, use rowsBetween(Long.MinValue, 0) for the same effect.

To use the window, use the same method as before:

val newDf = inputDF.withColumn("val_sum", sum($"val1").over(w))
  .withColumn("val2_sum", sum($"val2").over(w))
Share:
12,504
lucy
Author by

lucy

Updated on July 24, 2022

Comments

  • lucy
    lucy almost 2 years

    I want to do cumulative sum in Spark. Here is the register table (input):

    +---------------+-------------------+----+----+----+
    |     product_id|          date_time| ack|val1|val2|
    +---------------+-------------------+----+----+----+
    |4008607333T.upf|2017-12-13:02:27:01|3-46|  53|  52|
    |4008607333T.upf|2017-12-13:02:27:03|3-47|  53|  52|
    |4008607333T.upf|2017-12-13:02:27:08|3-46|  53|  52|
    |4008607333T.upf|2017-12-13:02:28:01|3-47|  53|  52|
    |4008607333T.upf|2017-12-13:02:28:07|3-46|  15|   1|
    +---------------+-------------------+----+----+----+
    

    Hive query:

    select *, SUM(val1) over ( Partition by product_id, ack order by date_time rows between unbounded preceding and current row ) val1_sum, SUM(val2) over ( Partition by product_id, ack order by date_time rows between unbounded preceding and current row ) val2_sum from test
    

    Output:

    +---------------+-------------------+----+----+----+-------+--------+
    |     product_id|          date_time| ack|val1|val2|val_sum|val2_sum|
    +---------------+-------------------+----+----+----+-------+--------+
    |4008607333T.upf|2017-12-13:02:27:01|3-46|  53|  52|     53|      52|
    |4008607333T.upf|2017-12-13:02:27:08|3-46|  53|  52|    106|     104|
    |4008607333T.upf|2017-12-13:02:28:07|3-46|  15|   1|    121|     105|
    |4008607333T.upf|2017-12-13:02:27:03|3-47|  53|  52|     53|      52|
    |4008607333T.upf|2017-12-13:02:28:01|3-47|  53|  52|    106|     104|
    +---------------+-------------------+----+----+----+-------+--------+
    

    Using Spark logic, I am getting same above output:

    import org.apache.spark.sql.expressions.Window
    val w = Window.partitionBy('product_id, 'ack).orderBy('date_time)
    import org.apache.spark.sql.functions._
    
    val newDf = inputDF.withColumn("val_sum", sum('val1) over w).withColumn("val2_sum", sum('val2) over w)
    newDf.show
    

    However, when I try this logic on spark cluster val_sum value will be half of the cumulative sum and something time it is different. I don't know why it is happening on spark cluster. Is it due to partitions?

    How I can do cumulative sum of a column on a spark cluster?