PySpark, top for DataFrame

11,901

Solution 1

First let's define a function to generate test data:

import numpy as np

def sample_df(num_records):
    def data():
      np.random.seed(42)
      while True:
          yield int(np.random.normal(100., 80.))

    data_iter = iter(data())
    df = sc.parallelize((
        (i, next(data_iter)) for i in range(int(num_records))
    )).toDF(('index', 'key_col'))

    return df

sample_df(1e3).show(n=5)
+-----+-------+
|index|key_col|
+-----+-------+
|    0|    139|
|    1|     88|
|    2|    151|
|    3|    221|
|    4|     81|
+-----+-------+
only showing top 5 rows


Now, let's propose three different ways to calculate TopK:

from pyspark.sql import Window
from pyspark.sql import functions


def top_df_0(df, key_col, K):
    """
    Using window functions.  Handles ties OK.
    """
    window = Window.orderBy(functions.col(key_col).desc())
    return (df
            .withColumn("rank", functions.rank().over(window))
            .filter(functions.col('rank') <= K)
            .drop('rank'))


def top_df_1(df, key_col, K):
    """
    Using limit(K). Does NOT handle ties appropriately.
    """
    return df.orderBy(functions.col(key_col).desc()).limit(K)


def top_df_2(df, key_col, K):
    """
    Using limit(k) and then filtering.  Handles ties OK."
    """
    num_records = df.count()
    value_at_k_rank = (df
                       .orderBy(functions.col(key_col).desc())
                       .limit(k)
                       .select(functions.min(key_col).alias('min'))
                       .first()['min'])

    return df.filter(df[key_col] >= value_at_k_rank)

The function called top_df_1 is similar to the one you originally implemented. The reason it gives you non-deterministic behavior is because it cannot handle ties nicely. This may be an OK thing to do if you have lots of data and are only interested in an approximate answer for the sake of performance.


Finally, let's benchmark

For benchmarking use a Spark DF with 4 million entries and define a convenience function:

NUM_RECORDS = 4e6
test_df = sample_df(NUM_RECORDS).cache()

def show(func, df, key_col, K):
    func(df, key_col, K).select(
      functions.max(key_col),
      functions.min(key_col),
      functions.count(key_col)
    ).show()


Let's see the verdict:

%timeit show(top_df_0, test_df, "key_col", K=100)
+------------+------------+--------------+
|max(key_col)|min(key_col)|count(key_col)|
+------------+------------+--------------+
|         502|         420|           108|
+------------+------------+--------------+

1 loops, best of 3: 1.62 s per loop


%timeit show(top_df_1, test_df, "key_col", K=100)
+------------+------------+--------------+
|max(key_col)|min(key_col)|count(key_col)|
+------------+------------+--------------+
|         502|         420|           100|
+------------+------------+--------------+

1 loops, best of 3: 252 ms per loop


%timeit show(top_df_2, test_df, "key_col", K=100)
+------------+------------+--------------+
|max(key_col)|min(key_col)|count(key_col)|
+------------+------------+--------------+
|         502|         420|           108|
+------------+------------+--------------+

1 loops, best of 3: 725 ms per loop

(Note that top_df_0 and top_df_2 have 108 entries in the top 100. This is due to the presence of tied entries for the 100th best. The top_df_1 implementation is ignoring the tied entries.).


The bottom line

If you want an exact answer go with top_df_2 (it is about 2x better than top_df_0). If you want another x2 in performance and are OK with an approximate answer go with top_df_1 .

Solution 2

Options:

1) Use pyspark sql row_number within a window function - relevant SO: spark dataframe grouping, sorting, and selecting top rows for a set of columns

2) convert ordered df to rdd and use the top function there (hint: this doesn't appear to actually maintain ordering from my quick test, but YMMV)

Solution 3

You should try with head() instead of limit()

#sample data
df = sc.parallelize([
    ['123', 'b'], ['666', 'a'],
    ['345', 'd'], ['555', 'a'],
    ['456', 'b'], ['444', 'a'],
    ['678', 'd'], ['333', 'a'],
    ['135', 'd'], ['234', 'd'],
    ['987', 'c'], ['987', 'e']
]).toDF(('col1', 'key_col'))

#select top 'n' 'key_col' values from dataframe 'df'
def retrieve_top_n(df, key, n):
    return sqlContext.createDataFrame(df.groupBy(key).count().orderBy('count', ascending=False).head(n)).select(key)

retrieve_top_n(df, 'key_col', 3).show()

Hope this helps!

Share:
11,901
Jing
Author by

Jing

currently working Machine learning and data mining problem.

Updated on June 08, 2022

Comments

  • Jing
    Jing almost 2 years

    What I want to do is given a DataFrame, take top n elements according to some specified column. The top(self, num) in RDD API is exactly what I want. I wonder if there is equivalent API in DataFrame world ?

    My first attempt is the following

    def retrieve_top_n(df, n):
        # assume we want to get most popular n 'key' in DataFrame
        return df.groupBy('key').count().orderBy('count', ascending=False).limit(n).select('key')
    

    However, I've realized that this results in non-deterministic behavior (I don't know the exact reason but I guess limit(n) doesn't guarantee which n to take)