How to sort array of struct type in Spark DataFrame by particular field?

19,299

Solution 1

If you have complex object it is much better to use statically typed Dataset.

case class Result(a: Int, b: Int, c: Int, c2: (java.sql.Date, Int))

val joined = first.join(second.withColumnRenamed("c", "c2"), Seq("a", "b"))
joined.as[Result]
  .groupByKey(_.a)
  .mapGroups((key, xs) => (key, xs.map(_.c2).toSeq.sortBy(_._2)))
  .show(false)

// +---+----------------------------------+            
// |_1 |_2                                |
// +---+----------------------------------+
// |1  |[[2018-01-01,20], [2018-01-02,30]]|
// |2  |[[2018-01-02,50], [2018-01-01,60]]|
// +---+----------------------------------+

In simple cases it is also possible to udf, but leads to inefficient and fragile code in general and quickly goes out of control, when complexity of objects grows.

Solution 2

According to the Hive Wiki:

sort_array(Array<T>) : Sorts the input array in ascending order according to the natural ordering of the array elements and returns it (as of version 0.9.0).

This means that the array will be sorted lexicographically which holds true even with complex data types.

Alternatively, you can create a UDF to sort it (and witness performance degradation) based on the second element:

val sortUdf = udf { (xs: Seq[Row]) => xs.sortBy(_.getAs[Int](1) )
                                        .map{ case Row(x:java.sql.Date, y: Int) => (x,y) }}

first.join(second.withColumnRenamed("c", "c2"), Seq("a", "b"))
     .groupBy("a")
     .agg(sortUdf(collect_list("c2")))
     .show(false)

//+---+----------------------------------+
//|a  |UDF(collect_list(c2, 0, 0))       |
//+---+----------------------------------+
//|1  |[[2018-01-01,20], [2018-01-02,30]]|
//|2  |[[2018-01-02,50], [2018-01-01,60]]|
//+---+----------------------------------+

Solution 3

For Spark 3+, you can pass a custom comparator function to array_sort:

The comparator will take two arguments representing two elements of the array. It returns -1, 0, or 1 as the first element is less than, equal to, or greater than the second element. If the comparator function returns other values (including null), the function will fail and raise an error.

val df = first
  .join(second.withColumnRenamed("c", "c2"), Seq("a", "b"))
  .groupBy("a")
  .agg(collect_list("c2").alias("list"))

val df2 = df.withColumn(
  "list",
  expr(
    "array_sort(list, (left, right) -> case when left._2 < right._2 then -1 when left._2 > right._2 then 1 else 0 end)"
  )
)

df2.show(false)
//+---+------------------------------------+
//|a  |list                                |
//+---+------------------------------------+
//|1  |[[2018-01-01, 20], [2018-01-02, 30]]|
//|2  |[[2018-01-02, 50], [2018-01-01, 60]]|
//+---+------------------------------------+

Where _2 is the name of the struct field you wan to use for sorting

Share:
19,299
addmeaning
Author by

addmeaning

Java/Scala developer passionate about Apache Spark, Kafka, Cassandra.

Updated on July 01, 2022

Comments

  • addmeaning
    addmeaning almost 2 years

    Given following code:

    import java.sql.Date
    import org.apache.spark.sql.SparkSession
    import org.apache.spark.sql.functions._
    
    object SortQuestion extends App{
    
      val spark = SparkSession.builder().appName("local").master("local[*]").getOrCreate()
      import spark.implicits._
      case class ABC(a: Int, b: Int, c: Int)
    
      val first = Seq(
        ABC(1, 2, 3),
        ABC(1, 3, 4),
        ABC(2, 4, 5),
        ABC(2, 5, 6)
      ).toDF("a", "b", "c")
    
      val second = Seq(
        (1, 2, (Date.valueOf("2018-01-02"), 30)),
        (1, 3, (Date.valueOf("2018-01-01"), 20)),
        (2, 4, (Date.valueOf("2018-01-02"), 50)),
        (2, 5, (Date.valueOf("2018-01-01"), 60))
      ).toDF("a", "b", "c")
    
      first.join(second.withColumnRenamed("c", "c2"), Seq("a", "b")).groupBy("a").agg(sort_array(collect_list("c2")))
        .show(false)
    
    }
    

    Spark produces following result:

    +---+----------------------------------+
    |a  |sort_array(collect_list(c2), true)|
    +---+----------------------------------+
    |1  |[[2018-01-01,20], [2018-01-02,30]]|
    |2  |[[2018-01-01,60], [2018-01-02,50]]|
    +---+----------------------------------+
    

    This implies that Spark is sorting an array by date (since it is the first field), but I want to instruct Spark to sort by specific field from that nested struct.

    I know I can reshape array to (value, date) but it seems inconvenient, I want a general solution (imagine I have a big nested struct, 5 layers deep, and I want to sort that structure by particular column).

    Is there a way to do that? Am I missing something?

  • angelcervera
    angelcervera over 3 years
    The question is clear about the array type. sort_array does not work with struct. You will get the error: "sort_array does not support sorting array of type struct" So your comment about sort_array is not valid.