Spark SQL window function with complex condition

25,668

Solution 1

Spark >= 3.2

Recent Spark releases provide native support for session windows in both batch and structured streaming queries (see SPARK-10816 and its sub-tasks, especially SPARK-34893).

The official documentation provides nice usage example.

Spark < 3.2

Here is the trick. Import a bunch of functions:

import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions.{coalesce, datediff, lag, lit, min, sum}

Define windows:

val userWindow = Window.partitionBy("user_name").orderBy("login_date")
val userSessionWindow = Window.partitionBy("user_name", "session")

Find the points where new sessions starts:

val newSession =  (coalesce(
  datediff($"login_date", lag($"login_date", 1).over(userWindow)),
  lit(0)
) > 5).cast("bigint")

val sessionized = df.withColumn("session", sum(newSession).over(userWindow))

Find the earliest date per session:

val result = sessionized
  .withColumn("became_active", min($"login_date").over(userSessionWindow))
  .drop("session")

With dataset defined as:

val df = Seq(
  ("SirChillingtonIV", "2012-01-04"), ("Booooooo99900098", "2012-01-04"),
  ("Booooooo99900098", "2012-01-06"), ("OprahWinfreyJr", "2012-01-10"), 
  ("SirChillingtonIV", "2012-01-11"), ("SirChillingtonIV", "2012-01-14"),
  ("SirChillingtonIV", "2012-08-11")
).toDF("user_name", "login_date")

The result is:

+----------------+----------+-------------+
|       user_name|login_date|became_active|
+----------------+----------+-------------+
|  OprahWinfreyJr|2012-01-10|   2012-01-10|
|SirChillingtonIV|2012-01-04|   2012-01-04| <- The first session for user
|SirChillingtonIV|2012-01-11|   2012-01-11| <- The second session for user
|SirChillingtonIV|2012-01-14|   2012-01-11| 
|SirChillingtonIV|2012-08-11|   2012-08-11| <- The third session for user
|Booooooo99900098|2012-01-04|   2012-01-04|
|Booooooo99900098|2012-01-06|   2012-01-04|
+----------------+----------+-------------+

Solution 2

Refactoring the other answer to work with Pyspark

In Pyspark you can do like below.

create data frame

df = sqlContext.createDataFrame(
[
("SirChillingtonIV", "2012-01-04"), 
("Booooooo99900098", "2012-01-04"), 
("Booooooo99900098", "2012-01-06"), 
("OprahWinfreyJr", "2012-01-10"), 
("SirChillingtonIV", "2012-01-11"), 
("SirChillingtonIV", "2012-01-14"), 
("SirChillingtonIV", "2012-08-11")
], 
("user_name", "login_date"))

The above code creates a data frame like below

+----------------+----------+
|       user_name|login_date|
+----------------+----------+
|SirChillingtonIV|2012-01-04|
|Booooooo99900098|2012-01-04|
|Booooooo99900098|2012-01-06|
|  OprahWinfreyJr|2012-01-10|
|SirChillingtonIV|2012-01-11|
|SirChillingtonIV|2012-01-14|
|SirChillingtonIV|2012-08-11|
+----------------+----------+

Now we want to first find out the difference between login_date is more than 5 days.

For this do like below.

Necessary imports

from pyspark.sql import functions as f
from pyspark.sql import Window


# defining window partitions  
login_window = Window.partitionBy("user_name").orderBy("login_date")
session_window = Window.partitionBy("user_name", "session")

session_df = df.withColumn("session", f.sum((f.coalesce(f.datediff("login_date", f.lag("login_date", 1).over(login_window)), f.lit(0)) > 5).cast("int")).over(login_window))

When we run the above line of code if the date_diff is NULL then the coalesce function will replace NULL to 0.

+----------------+----------+-------+
|       user_name|login_date|session|
+----------------+----------+-------+
|  OprahWinfreyJr|2012-01-10|      0|
|SirChillingtonIV|2012-01-04|      0|
|SirChillingtonIV|2012-01-11|      1|
|SirChillingtonIV|2012-01-14|      1|
|SirChillingtonIV|2012-08-11|      2|
|Booooooo99900098|2012-01-04|      0|
|Booooooo99900098|2012-01-06|      0|
+----------------+----------+-------+


# add became_active column by finding the `min login_date` for each window partitionBy `user_name` and `session` created in above step
final_df = session_df.withColumn("became_active", f.min("login_date").over(session_window)).drop("session")

+----------------+----------+-------------+
|       user_name|login_date|became_active|
+----------------+----------+-------------+
|  OprahWinfreyJr|2012-01-10|   2012-01-10|
|SirChillingtonIV|2012-01-04|   2012-01-04|
|SirChillingtonIV|2012-01-11|   2012-01-11|
|SirChillingtonIV|2012-01-14|   2012-01-11|
|SirChillingtonIV|2012-08-11|   2012-08-11|
|Booooooo99900098|2012-01-04|   2012-01-04|
|Booooooo99900098|2012-01-06|   2012-01-04|
+----------------+----------+-------------+
Share:
25,668

Related videos on Youtube

user4601931
Author by

user4601931

Updated on July 05, 2022

Comments

  • user4601931
    user4601931 almost 2 years

    This is probably easiest to explain through example. Suppose I have a DataFrame of user logins to a website, for instance:

    scala> df.show(5)
    +----------------+----------+
    |       user_name|login_date|
    +----------------+----------+
    |SirChillingtonIV|2012-01-04|
    |Booooooo99900098|2012-01-04|
    |Booooooo99900098|2012-01-06|
    |  OprahWinfreyJr|2012-01-10|
    |SirChillingtonIV|2012-01-11|
    +----------------+----------+
    only showing top 5 rows
    

    I would like to add to this a column indicating when they became an active user on the site. But there is one caveat: there is a time period during which a user is considered active, and after this period, if they log in again, their became_active date resets. Suppose this period is 5 days. Then the desired table derived from the above table would be something like this:

    +----------------+----------+-------------+
    |       user_name|login_date|became_active|
    +----------------+----------+-------------+
    |SirChillingtonIV|2012-01-04|   2012-01-04|
    |Booooooo99900098|2012-01-04|   2012-01-04|
    |Booooooo99900098|2012-01-06|   2012-01-04|
    |  OprahWinfreyJr|2012-01-10|   2012-01-10|
    |SirChillingtonIV|2012-01-11|   2012-01-11|
    +----------------+----------+-------------+
    

    So, in particular, SirChillingtonIV's became_active date was reset because their second login came after the active period expired, but Booooooo99900098's became_active date was not reset the second time he/she logged in, because it fell within the active period.

    My initial thought was to use window functions with lag, and then using the lagged values to fill the became_active column; for instance, something starting roughly like:

    import org.apache.spark.sql.expressions.Window
    import org.apache.spark.sql.functions._
    
    val window = Window.partitionBy("user_name").orderBy("login_date")
    val df2 = df.withColumn("tmp", lag("login_date", 1).over(window))
    

    Then, the rule to fill in the became_active date would be, if tmp is null (i.e., if it's the first ever login) or if login_date - tmp >= 5 then became_active = login_date; otherwise, go to the next most recent value in tmp and apply the same rule. This suggests a recursive approach, which I'm having trouble imagining a way to implement.

    My questions: Is this a viable approach, and if so, how can I "go back" and look at earlier values of tmp until I find one where I stop? I can't, to my knowledge, iterate through values of a Spark SQL Column. Is there another way to achieve this result?

  • Sanchit Grover
    Sanchit Grover about 6 years
    I know it has been a long time, but can you help me understand the coalesce part of the solution??
  • zero323
    zero323 about 6 years
    @SanchitGrover If datediff($"login_date", lag($"login_date", 1).over(userWindow)) evaluates to null (first row in the frame) get 0.
  • Sanchit Grover
    Sanchit Grover about 6 years
    Then how this val sessionized = df.withColumn("session", sum(newSession).over(userWindow)) is increasing the count?
  • zero323
    zero323 about 6 years
    It is a cumulative sum of values in set {0, 1}.