How to select the first row of each group?
Solution 1
Window functions:
Something like this should do the trick:
import org.apache.spark.sql.functions.{row_number, max, broadcast}
import org.apache.spark.sql.expressions.Window
val df = sc.parallelize(Seq(
(0,"cat26",30.9), (0,"cat13",22.1), (0,"cat95",19.6), (0,"cat105",1.3),
(1,"cat67",28.5), (1,"cat4",26.8), (1,"cat13",12.6), (1,"cat23",5.3),
(2,"cat56",39.6), (2,"cat40",29.7), (2,"cat187",27.9), (2,"cat68",9.8),
(3,"cat8",35.6))).toDF("Hour", "Category", "TotalValue")
val w = Window.partitionBy($"hour").orderBy($"TotalValue".desc)
val dfTop = df.withColumn("rn", row_number.over(w)).where($"rn" === 1).drop("rn")
dfTop.show
// +----+--------+----------+
// |Hour|Category|TotalValue|
// +----+--------+----------+
// | 0| cat26| 30.9|
// | 1| cat67| 28.5|
// | 2| cat56| 39.6|
// | 3| cat8| 35.6|
// +----+--------+----------+
This method will be inefficient in case of significant data skew. This problem is tracked by SPARK-34775 and might be resolved in the future (SPARK-37099).
Plain SQL aggregation followed by join
:
Alternatively you can join with aggregated data frame:
val dfMax = df.groupBy($"hour".as("max_hour")).agg(max($"TotalValue").as("max_value"))
val dfTopByJoin = df.join(broadcast(dfMax),
($"hour" === $"max_hour") && ($"TotalValue" === $"max_value"))
.drop("max_hour")
.drop("max_value")
dfTopByJoin.show
// +----+--------+----------+
// |Hour|Category|TotalValue|
// +----+--------+----------+
// | 0| cat26| 30.9|
// | 1| cat67| 28.5|
// | 2| cat56| 39.6|
// | 3| cat8| 35.6|
// +----+--------+----------+
It will keep duplicate values (if there is more than one category per hour with the same total value). You can remove these as follows:
dfTopByJoin
.groupBy($"hour")
.agg(
first("category").alias("category"),
first("TotalValue").alias("TotalValue"))
Using ordering over structs
:
Neat, although not very well tested, trick which doesn't require joins or window functions:
val dfTop = df.select($"Hour", struct($"TotalValue", $"Category").alias("vs"))
.groupBy($"hour")
.agg(max("vs").alias("vs"))
.select($"Hour", $"vs.Category", $"vs.TotalValue")
dfTop.show
// +----+--------+----------+
// |Hour|Category|TotalValue|
// +----+--------+----------+
// | 0| cat26| 30.9|
// | 1| cat67| 28.5|
// | 2| cat56| 39.6|
// | 3| cat8| 35.6|
// +----+--------+----------+
With DataSet API (Spark 1.6+, 2.0+):
Spark 1.6:
case class Record(Hour: Integer, Category: String, TotalValue: Double)
df.as[Record]
.groupBy($"hour")
.reduce((x, y) => if (x.TotalValue > y.TotalValue) x else y)
.show
// +---+--------------+
// | _1| _2|
// +---+--------------+
// |[0]|[0,cat26,30.9]|
// |[1]|[1,cat67,28.5]|
// |[2]|[2,cat56,39.6]|
// |[3]| [3,cat8,35.6]|
// +---+--------------+
Spark 2.0 or later:
df.as[Record]
.groupByKey(_.Hour)
.reduceGroups((x, y) => if (x.TotalValue > y.TotalValue) x else y)
The last two methods can leverage map side combine and don't require full shuffle so most of the time should exhibit a better performance compared to window functions and joins. These cane be also used with Structured Streaming in completed
output mode.
Don't use:
df.orderBy(...).groupBy(...).agg(first(...), ...)
It may seem to work (especially in the local
mode) but it is unreliable (see SPARK-16207, credits to Tzach Zohar for linking relevant JIRA issue, and SPARK-30335).
The same note applies to
df.orderBy(...).dropDuplicates(...)
which internally uses equivalent execution plan.
Solution 2
For Spark 2.0.2 with grouping by multiple columns:
import org.apache.spark.sql.functions.row_number
import org.apache.spark.sql.expressions.Window
val w = Window.partitionBy($"col1", $"col2", $"col3").orderBy($"timestamp".desc)
val refined_df = df.withColumn("rn", row_number.over(w)).where($"rn" === 1).drop("rn")
Solution 3
This is a exact same of zero323's answer but in SQL query way.
Assuming that dataframe is created and registered as
df.createOrReplaceTempView("table")
//+----+--------+----------+
//|Hour|Category|TotalValue|
//+----+--------+----------+
//|0 |cat26 |30.9 |
//|0 |cat13 |22.1 |
//|0 |cat95 |19.6 |
//|0 |cat105 |1.3 |
//|1 |cat67 |28.5 |
//|1 |cat4 |26.8 |
//|1 |cat13 |12.6 |
//|1 |cat23 |5.3 |
//|2 |cat56 |39.6 |
//|2 |cat40 |29.7 |
//|2 |cat187 |27.9 |
//|2 |cat68 |9.8 |
//|3 |cat8 |35.6 |
//+----+--------+----------+
Window function :
sqlContext.sql("select Hour, Category, TotalValue from (select *, row_number() OVER (PARTITION BY Hour ORDER BY TotalValue DESC) as rn FROM table) tmp where rn = 1").show(false)
//+----+--------+----------+
//|Hour|Category|TotalValue|
//+----+--------+----------+
//|1 |cat67 |28.5 |
//|3 |cat8 |35.6 |
//|2 |cat56 |39.6 |
//|0 |cat26 |30.9 |
//+----+--------+----------+
Plain SQL aggregation followed by join:
sqlContext.sql("select Hour, first(Category) as Category, first(TotalValue) as TotalValue from " +
"(select Hour, Category, TotalValue from table tmp1 " +
"join " +
"(select Hour as max_hour, max(TotalValue) as max_value from table group by Hour) tmp2 " +
"on " +
"tmp1.Hour = tmp2.max_hour and tmp1.TotalValue = tmp2.max_value) tmp3 " +
"group by tmp3.Hour")
.show(false)
//+----+--------+----------+
//|Hour|Category|TotalValue|
//+----+--------+----------+
//|1 |cat67 |28.5 |
//|3 |cat8 |35.6 |
//|2 |cat56 |39.6 |
//|0 |cat26 |30.9 |
//+----+--------+----------+
Using ordering over structs:
sqlContext.sql("select Hour, vs.Category, vs.TotalValue from (select Hour, max(struct(TotalValue, Category)) as vs from table group by Hour)").show(false)
//+----+--------+----------+
//|Hour|Category|TotalValue|
//+----+--------+----------+
//|1 |cat67 |28.5 |
//|3 |cat8 |35.6 |
//|2 |cat56 |39.6 |
//|0 |cat26 |30.9 |
//+----+--------+----------+
DataSets way and don't dos are same as in original answer
Solution 4
You can use max_by()
function from Spark 3.0 !
https://spark.apache.org/docs/3.0.0-preview/api/sql/index.html#max_by
val df = sc.parallelize(Seq(
(0,"cat26",30.9), (0,"cat13",22.1), (0,"cat95",19.6), (0,"cat105",1.3),
(1,"cat67",28.5), (1,"cat4",26.8), (1,"cat13",12.6), (1,"cat23",5.3),
(2,"cat56",39.6), (2,"cat40",29.7), (2,"cat187",27.9), (2,"cat68",9.8),
(3,"cat8",35.6))).toDF("Hour", "Category", "TotalValue")
// Register the DataFrame as a SQL temporary view
df.createOrReplaceTempView("table")
// Using SQL
val result = spark.sql("select Hour, max_by(Category, TotalValue) AS Category, max(TotalValue) as TotalValue FROM table group by Hour order by Hour")
// or Using DataFrame API
val result = df.groupBy("Hour").
agg(expr("max_by(Category, TotalValue)").as("Category"), max("TotalValue").as("TotalValue")).
sort("Hour")
+----+--------+----------+
|Hour|Category|TotalValue|
+----+--------+----------+
| 0| cat26| 30.9|
| 1| cat67| 28.5|
| 2| cat56| 39.6|
| 3| cat8| 35.6|
+----+--------+----------+
Solution 5
The pattern is group by keys => do something to each group e.g. reduce => return to dataframe
I thought the Dataframe abstraction is a bit cumbersome in this case so I used RDD functionality
val rdd: RDD[Row] = originalDf
.rdd
.groupBy(row => row.getAs[String]("grouping_row"))
.map(iterableTuple => {
iterableTuple._2.reduce(reduceFunction)
})
val productDf = sqlContext.createDataFrame(rdd, originalDf.schema)
Related videos on Youtube
Rami
I am a Data Scientist with experience in Machine Learning, Big Data Analytics and Computer Vision. I am currently working on advanced Machine Learning technologies for Mobile Operators Subscribers Analytics to optimise advertisement, recommendation and costumer care services. Co-founder and organiser of the Lifelogging@Dublin Meetup group, I am an active long-term Lifelogger with particular interest in human behaviour tracking and analytics.
Updated on July 26, 2022Comments
-
Rami almost 2 years
I have a DataFrame generated as follow:
df.groupBy($"Hour", $"Category") .agg(sum($"value") as "TotalValue") .sort($"Hour".asc, $"TotalValue".desc))
The results look like:
+----+--------+----------+ |Hour|Category|TotalValue| +----+--------+----------+ | 0| cat26| 30.9| | 0| cat13| 22.1| | 0| cat95| 19.6| | 0| cat105| 1.3| | 1| cat67| 28.5| | 1| cat4| 26.8| | 1| cat13| 12.6| | 1| cat23| 5.3| | 2| cat56| 39.6| | 2| cat40| 29.7| | 2| cat187| 27.9| | 2| cat68| 9.8| | 3| cat8| 35.6| | ...| ....| ....| +----+--------+----------+
As you can see, the DataFrame is ordered by
Hour
in an increasing order, then byTotalValue
in a descending order.I would like to select the top row of each group, i.e.
- from the group of Hour==0 select (0,cat26,30.9)
- from the group of Hour==1 select (1,cat67,28.5)
- from the group of Hour==2 select (2,cat56,39.6)
- and so on
So the desired output would be:
+----+--------+----------+ |Hour|Category|TotalValue| +----+--------+----------+ | 0| cat26| 30.9| | 1| cat67| 28.5| | 2| cat56| 39.6| | 3| cat8| 35.6| | ...| ...| ...| +----+--------+----------+
It might be handy to be able to select the top N rows of each group as well.
Any help is highly appreciated.
-
Adam Szałucha over 6 yearsIt looks like since spark 1.6 it is row_number() instead of rowNumber
-
Ignacio Alorre over 6 yearsAbout the Don't use df.orderBy(...).gropBy(...). Under what circumstances can we rely on orderBy(...)? or if we can not be sure if orderBy() is going to give the correct result, what alternatives do we have?
-
Thomas about 6 yearsI might be overlooking something, but in general it is recommended to avoid groupByKey, instead reduceByKey should be used. Also, you'll be saving one line.
-
soote almost 6 years@Thomas avoiding groupBy/groupByKey is just when dealing with RDDs, you'll notice that the Dataset api doesn't even have a reduceByKey function.
-
Alper t. Turker almost 6 years
-
Alper t. Turker almost 6 yearsBut it shuffles everything first. It is hardly an improvement (maybe not worse than window functions, depending on the data).
-
elghoto almost 6 yearsyou have a group first place, that will triggers a shuffle. It's not worse than window function because in a window function it's going to evaluate the window for each single row in the dataframe.
-
Brendan almost 4 yearsUpdate: both of those spark issues have been resolved,
first
should be good to use -
Abuw about 3 years@Brendan from which version does it good to use?
-
Brendan about 3 yearsI’m not sure how to exactly read Sparks issue tracker to confirm which version they fixed this in, but following the links above shows both issues resolved. One had the affected version listed as 1.6.1 and the other as 3.1.0.
-
Eyal almost 3 yearsThis code is more or less contained in Apache DataFu's dedupWithOrder method
-
hiryu over 2 yearsif you read both tickets attentively, you'll notice that while they are marked as "resolved", they didn't change any underlying behavior for
first()
. They only document that the output is non-deterministic. As such,first()
without a window function remains unreliable to use... -
Paulo Moreira over 2 yearsworks very well for me, thanks a lot!