How to select the first row of each group?

151,518

Solution 1

Window functions:

Something like this should do the trick:

import org.apache.spark.sql.functions.{row_number, max, broadcast}
import org.apache.spark.sql.expressions.Window

val df = sc.parallelize(Seq(
  (0,"cat26",30.9), (0,"cat13",22.1), (0,"cat95",19.6), (0,"cat105",1.3),
  (1,"cat67",28.5), (1,"cat4",26.8), (1,"cat13",12.6), (1,"cat23",5.3),
  (2,"cat56",39.6), (2,"cat40",29.7), (2,"cat187",27.9), (2,"cat68",9.8),
  (3,"cat8",35.6))).toDF("Hour", "Category", "TotalValue")

val w = Window.partitionBy($"hour").orderBy($"TotalValue".desc)

val dfTop = df.withColumn("rn", row_number.over(w)).where($"rn" === 1).drop("rn")

dfTop.show
// +----+--------+----------+
// |Hour|Category|TotalValue|
// +----+--------+----------+
// |   0|   cat26|      30.9|
// |   1|   cat67|      28.5|
// |   2|   cat56|      39.6|
// |   3|    cat8|      35.6|
// +----+--------+----------+

This method will be inefficient in case of significant data skew. This problem is tracked by SPARK-34775 and might be resolved in the future (SPARK-37099).

Plain SQL aggregation followed by join:

Alternatively you can join with aggregated data frame:

val dfMax = df.groupBy($"hour".as("max_hour")).agg(max($"TotalValue").as("max_value"))

val dfTopByJoin = df.join(broadcast(dfMax),
    ($"hour" === $"max_hour") && ($"TotalValue" === $"max_value"))
  .drop("max_hour")
  .drop("max_value")

dfTopByJoin.show

// +----+--------+----------+
// |Hour|Category|TotalValue|
// +----+--------+----------+
// |   0|   cat26|      30.9|
// |   1|   cat67|      28.5|
// |   2|   cat56|      39.6|
// |   3|    cat8|      35.6|
// +----+--------+----------+

It will keep duplicate values (if there is more than one category per hour with the same total value). You can remove these as follows:

dfTopByJoin
  .groupBy($"hour")
  .agg(
    first("category").alias("category"),
    first("TotalValue").alias("TotalValue"))

Using ordering over structs:

Neat, although not very well tested, trick which doesn't require joins or window functions:

val dfTop = df.select($"Hour", struct($"TotalValue", $"Category").alias("vs"))
  .groupBy($"hour")
  .agg(max("vs").alias("vs"))
  .select($"Hour", $"vs.Category", $"vs.TotalValue")

dfTop.show
// +----+--------+----------+
// |Hour|Category|TotalValue|
// +----+--------+----------+
// |   0|   cat26|      30.9|
// |   1|   cat67|      28.5|
// |   2|   cat56|      39.6|
// |   3|    cat8|      35.6|
// +----+--------+----------+

With DataSet API (Spark 1.6+, 2.0+):

Spark 1.6:

case class Record(Hour: Integer, Category: String, TotalValue: Double)

df.as[Record]
  .groupBy($"hour")
  .reduce((x, y) => if (x.TotalValue > y.TotalValue) x else y)
  .show

// +---+--------------+
// | _1|            _2|
// +---+--------------+
// |[0]|[0,cat26,30.9]|
// |[1]|[1,cat67,28.5]|
// |[2]|[2,cat56,39.6]|
// |[3]| [3,cat8,35.6]|
// +---+--------------+

Spark 2.0 or later:

df.as[Record]
  .groupByKey(_.Hour)
  .reduceGroups((x, y) => if (x.TotalValue > y.TotalValue) x else y)

The last two methods can leverage map side combine and don't require full shuffle so most of the time should exhibit a better performance compared to window functions and joins. These cane be also used with Structured Streaming in completed output mode.

Don't use:

df.orderBy(...).groupBy(...).agg(first(...), ...)

It may seem to work (especially in the local mode) but it is unreliable (see SPARK-16207, credits to Tzach Zohar for linking relevant JIRA issue, and SPARK-30335).

The same note applies to

df.orderBy(...).dropDuplicates(...)

which internally uses equivalent execution plan.

Solution 2

For Spark 2.0.2 with grouping by multiple columns:

import org.apache.spark.sql.functions.row_number
import org.apache.spark.sql.expressions.Window

val w = Window.partitionBy($"col1", $"col2", $"col3").orderBy($"timestamp".desc)

val refined_df = df.withColumn("rn", row_number.over(w)).where($"rn" === 1).drop("rn")

Solution 3

This is a exact same of zero323's answer but in SQL query way.

Assuming that dataframe is created and registered as

df.createOrReplaceTempView("table")
//+----+--------+----------+
//|Hour|Category|TotalValue|
//+----+--------+----------+
//|0   |cat26   |30.9      |
//|0   |cat13   |22.1      |
//|0   |cat95   |19.6      |
//|0   |cat105  |1.3       |
//|1   |cat67   |28.5      |
//|1   |cat4    |26.8      |
//|1   |cat13   |12.6      |
//|1   |cat23   |5.3       |
//|2   |cat56   |39.6      |
//|2   |cat40   |29.7      |
//|2   |cat187  |27.9      |
//|2   |cat68   |9.8       |
//|3   |cat8    |35.6      |
//+----+--------+----------+

Window function :

sqlContext.sql("select Hour, Category, TotalValue from (select *, row_number() OVER (PARTITION BY Hour ORDER BY TotalValue DESC) as rn  FROM table) tmp where rn = 1").show(false)
//+----+--------+----------+
//|Hour|Category|TotalValue|
//+----+--------+----------+
//|1   |cat67   |28.5      |
//|3   |cat8    |35.6      |
//|2   |cat56   |39.6      |
//|0   |cat26   |30.9      |
//+----+--------+----------+

Plain SQL aggregation followed by join:

sqlContext.sql("select Hour, first(Category) as Category, first(TotalValue) as TotalValue from " +
  "(select Hour, Category, TotalValue from table tmp1 " +
  "join " +
  "(select Hour as max_hour, max(TotalValue) as max_value from table group by Hour) tmp2 " +
  "on " +
  "tmp1.Hour = tmp2.max_hour and tmp1.TotalValue = tmp2.max_value) tmp3 " +
  "group by tmp3.Hour")
  .show(false)
//+----+--------+----------+
//|Hour|Category|TotalValue|
//+----+--------+----------+
//|1   |cat67   |28.5      |
//|3   |cat8    |35.6      |
//|2   |cat56   |39.6      |
//|0   |cat26   |30.9      |
//+----+--------+----------+

Using ordering over structs:

sqlContext.sql("select Hour, vs.Category, vs.TotalValue from (select Hour, max(struct(TotalValue, Category)) as vs from table group by Hour)").show(false)
//+----+--------+----------+
//|Hour|Category|TotalValue|
//+----+--------+----------+
//|1   |cat67   |28.5      |
//|3   |cat8    |35.6      |
//|2   |cat56   |39.6      |
//|0   |cat26   |30.9      |
//+----+--------+----------+

DataSets way and don't dos are same as in original answer

Solution 4

You can use max_by() function from Spark 3.0 !

https://spark.apache.org/docs/3.0.0-preview/api/sql/index.html#max_by

val df = sc.parallelize(Seq(
  (0,"cat26",30.9), (0,"cat13",22.1), (0,"cat95",19.6), (0,"cat105",1.3),
  (1,"cat67",28.5), (1,"cat4",26.8), (1,"cat13",12.6), (1,"cat23",5.3),
  (2,"cat56",39.6), (2,"cat40",29.7), (2,"cat187",27.9), (2,"cat68",9.8),
  (3,"cat8",35.6))).toDF("Hour", "Category", "TotalValue")

// Register the DataFrame as a SQL temporary view
df.createOrReplaceTempView("table")

// Using SQL
val result = spark.sql("select Hour, max_by(Category, TotalValue) AS Category, max(TotalValue) as TotalValue FROM table group by Hour order by Hour")

// or Using DataFrame API
val result = df.groupBy("Hour").
  agg(expr("max_by(Category, TotalValue)").as("Category"), max("TotalValue").as("TotalValue")).
  sort("Hour")

+----+--------+----------+
|Hour|Category|TotalValue|
+----+--------+----------+
|   0|   cat26|      30.9|
|   1|   cat67|      28.5|
|   2|   cat56|      39.6|
|   3|    cat8|      35.6|
+----+--------+----------+

Solution 5

The pattern is group by keys => do something to each group e.g. reduce => return to dataframe

I thought the Dataframe abstraction is a bit cumbersome in this case so I used RDD functionality

 val rdd: RDD[Row] = originalDf
  .rdd
  .groupBy(row => row.getAs[String]("grouping_row"))
  .map(iterableTuple => {
    iterableTuple._2.reduce(reduceFunction)
  })

val productDf = sqlContext.createDataFrame(rdd, originalDf.schema)
Share:
151,518

Related videos on Youtube

Rami
Author by

Rami

I am a Data Scientist with experience in Machine Learning, Big Data Analytics and Computer Vision. I am currently working on advanced Machine Learning technologies for Mobile Operators Subscribers Analytics to optimise advertisement, recommendation and costumer care services. Co-founder and organiser of the Lifelogging@Dublin Meetup group, I am an active long-term Lifelogger with particular interest in human behaviour tracking and analytics.

Updated on July 26, 2022

Comments

  • Rami
    Rami almost 2 years

    I have a DataFrame generated as follow:

    df.groupBy($"Hour", $"Category")
      .agg(sum($"value") as "TotalValue")
      .sort($"Hour".asc, $"TotalValue".desc))
    

    The results look like:

    +----+--------+----------+
    |Hour|Category|TotalValue|
    +----+--------+----------+
    |   0|   cat26|      30.9|
    |   0|   cat13|      22.1|
    |   0|   cat95|      19.6|
    |   0|  cat105|       1.3|
    |   1|   cat67|      28.5|
    |   1|    cat4|      26.8|
    |   1|   cat13|      12.6|
    |   1|   cat23|       5.3|
    |   2|   cat56|      39.6|
    |   2|   cat40|      29.7|
    |   2|  cat187|      27.9|
    |   2|   cat68|       9.8|
    |   3|    cat8|      35.6|
    | ...|    ....|      ....|
    +----+--------+----------+
    

    As you can see, the DataFrame is ordered by Hour in an increasing order, then by TotalValue in a descending order.

    I would like to select the top row of each group, i.e.

    • from the group of Hour==0 select (0,cat26,30.9)
    • from the group of Hour==1 select (1,cat67,28.5)
    • from the group of Hour==2 select (2,cat56,39.6)
    • and so on

    So the desired output would be:

    +----+--------+----------+
    |Hour|Category|TotalValue|
    +----+--------+----------+
    |   0|   cat26|      30.9|
    |   1|   cat67|      28.5|
    |   2|   cat56|      39.6|
    |   3|    cat8|      35.6|
    | ...|     ...|       ...|
    +----+--------+----------+
    

    It might be handy to be able to select the top N rows of each group as well.

    Any help is highly appreciated.

  • Adam Szałucha
    Adam Szałucha over 6 years
    It looks like since spark 1.6 it is row_number() instead of rowNumber
  • Ignacio Alorre
    Ignacio Alorre over 6 years
    About the Don't use df.orderBy(...).gropBy(...). Under what circumstances can we rely on orderBy(...)? or if we can not be sure if orderBy() is going to give the correct result, what alternatives do we have?
  • Thomas
    Thomas about 6 years
    I might be overlooking something, but in general it is recommended to avoid groupByKey, instead reduceByKey should be used. Also, you'll be saving one line.
  • soote
    soote almost 6 years
    @Thomas avoiding groupBy/groupByKey is just when dealing with RDDs, you'll notice that the Dataset api doesn't even have a reduceByKey function.
  • Alper t. Turker
    Alper t. Turker almost 6 years
  • Alper t. Turker
    Alper t. Turker almost 6 years
    But it shuffles everything first. It is hardly an improvement (maybe not worse than window functions, depending on the data).
  • elghoto
    elghoto almost 6 years
    you have a group first place, that will triggers a shuffle. It's not worse than window function because in a window function it's going to evaluate the window for each single row in the dataframe.
  • Brendan
    Brendan almost 4 years
    Update: both of those spark issues have been resolved, first should be good to use
  • Abuw
    Abuw about 3 years
    @Brendan from which version does it good to use?
  • Brendan
    Brendan about 3 years
    I’m not sure how to exactly read Sparks issue tracker to confirm which version they fixed this in, but following the links above shows both issues resolved. One had the affected version listed as 1.6.1 and the other as 3.1.0.
  • Eyal
    Eyal almost 3 years
    This code is more or less contained in Apache DataFu's dedupWithOrder method
  • hiryu
    hiryu over 2 years
    if you read both tickets attentively, you'll notice that while they are marked as "resolved", they didn't change any underlying behavior for first(). They only document that the output is non-deterministic. As such, first() without a window function remains unreliable to use...
  • Paulo Moreira
    Paulo Moreira over 2 years
    works very well for me, thanks a lot!