Pyspark Dataframe group by filtering

60,418

Solution 1

First, I'll just prepare toy dataset from given above,

from pyspark.sql.functions import col
import pyspark.sql.functions as fn

df = spark.createDataFrame([[1, 'r1', 1],
 [1, 'r2', 0],
 [1, 'r2', 1],
 [2, 'r1', 1],
 [3, 'r1', 1],
 [3, 'r2', 1],
 [4, 'r1', 0],
 [5, 'r1', 1],
 [5, 'r2', 0],
 [5, 'r1', 1]], schema=['cust_id', 'req', 'req_met'])
df = df.withColumn('req_met', col("req_met").cast(IntegerType()))
df = df.withColumn('cust_id', col("cust_id").cast(IntegerType()))

I do the same thing by group by cust_id and req then count the req_met. After that, I create function to floor those requirement to just 0, 1

def floor_req(r):
    if r >= 1:
        return 1
    else:
        return 0
udf_floor_req = udf(floor_req, IntegerType())
gr = df.groupby(['cust_id', 'req'])
df_grouped = gr.agg(fn.sum(col('req_met')).alias('sum_req_met'))
df_grouped_floor = df_grouped.withColumn('sum_req_met', udf_floor_req('sum_req_met'))

Now, we can check if each customer has met all requirement by counting distinct number of requirement and total number of requirement met.

df_req = df_grouped_floor.groupby('cust_id').agg(fn.sum('sum_req_met').alias('sum_req'), 
                                                 fn.count('req').alias('n_req'))

Finally, you just have to check if two columns are equal:

df_req.filter(df_req['sum_req'] == df_req['n_req'])[['cust_id']].orderBy('cust_id').show()

Solution 2

 select cust_id from  
(select cust_id , MIN(sum_value) as m from 
( select cust_id,req ,sum(req_met) as sum_value from <data_frame> group by cust_id,req )
 temp group by cust_id )temp1 
where m>0 ;

This will give desired result

Share:
60,418
Lijju Mathew
Author by

Lijju Mathew

Updated on July 09, 2022

Comments

  • Lijju Mathew
    Lijju Mathew almost 2 years

    I have a data frame as below

    cust_id   req    req_met
    -------   ---    -------
     1         r1      1
     1         r2      0
     1         r2      1
     2         r1      1
     3         r1      1
     3         r2      1
     4         r1      0
     5         r1      1
     5         r2      0
     5         r1      1
    

    I have to look at customers, see how many requirements they have and see if they have met at least once. There can be multiple records with same customer and requirement, one with met and not met. In the above case my output should be

    cust_id
    -------
      1
      2
      3
    

    What I have done is

    # say initial dataframe is df
    df1 = df\
        .groupby('cust_id')\
        .countdistinct('req')\
        .alias('num_of_req')\
        .sum('req_met')\
        .alias('sum_req_met')
    
    df2 = df1.filter(df1.num_of_req == df1.sum_req_met)
    

    But in few cases it is not getting correct results

    How can this be done ?

  • Lijju Mathew
    Lijju Mathew about 7 years
    Thanks for the solution. I was looking more like data frame