Spark dataframe transform multiple rows to column

24,478

Solution 1

Lets start with example data:

df = sqlContext.createDataFrame([
    ("a", 1, "m1"), ("a", 1, "m2"), ("a", 2, "m3"),
    ("a", 3, "m4"), ("b", 4, "m1"), ("b", 1, "m2"),
    ("b", 2, "m3"), ("c", 3, "m1"), ("c", 4, "m3"),
    ("c", 5, "m4"), ("d", 6, "m1"), ("d", 1, "m2"),
    ("d", 2, "m3"), ("d", 3, "m4"), ("d", 4, "m5"),
    ("e", 4, "m1"), ("e", 5, "m2"), ("e", 1, "m3"),
    ("e", 1, "m4"), ("e", 1, "m5")], 
    ("a", "cnt", "major"))

Please note that I've changed count to cnt. Count is a reserved keyword in most of the SQL dialects and it is not a good choice for a column name.

There are at least two ways to reshape this data:

  • aggregating over DataFrame

    from pyspark.sql.functions import col, when, max
    
    majors = sorted(df.select("major")
        .distinct()
        .map(lambda row: row[0])
        .collect())
    
    cols = [when(col("major") == m, col("cnt")).otherwise(None).alias(m) 
        for m in  majors]
    maxs = [max(col(m)).alias(m) for m in majors]
    
    reshaped1 = (df
        .select(col("a"), *cols)
        .groupBy("a")
        .agg(*maxs)
        .na.fill(0))
    
    reshaped1.show()
    
    ## +---+---+---+---+---+---+
    ## |  a| m1| m2| m3| m4| m5|
    ## +---+---+---+---+---+---+
    ## |  a|  1|  1|  2|  3|  0|
    ## |  b|  4|  1|  2|  0|  0|
    ## |  c|  3|  0|  4|  5|  0|
    ## |  d|  6|  1|  2|  3|  4|
    ## |  e|  4|  5|  1|  1|  1|
    ## +---+---+---+---+---+---+
    
  • groupBy over RDD

    from pyspark.sql import Row
    
    grouped = (df
        .map(lambda row: (row.a, (row.major, row.cnt)))
        .groupByKey())
    
    def make_row(kv):
        k, vs = kv
        tmp = dict(list(vs) + [("a", k)])
        return Row(**{k: tmp.get(k, 0) for k in ["a"] + majors})
    
    reshaped2 = sqlContext.createDataFrame(grouped.map(make_row))
    
    reshaped2.show()
    
    ## +---+---+---+---+---+---+
    ## |  a| m1| m2| m3| m4| m5|
    ## +---+---+---+---+---+---+
    ## |  a|  1|  1|  2|  3|  0|
    ## |  e|  4|  5|  1|  1|  1|
    ## |  c|  3|  0|  4|  5|  0|
    ## |  b|  4|  1|  2|  0|  0|
    ## |  d|  6|  1|  2|  3|  4|
    ## +---+---+---+---+---+---+
    

Solution 2

Using zero323's dataframe,

df = sqlContext.createDataFrame([
("a", 1, "m1"), ("a", 1, "m2"), ("a", 2, "m3"),
("a", 3, "m4"), ("b", 4, "m1"), ("b", 1, "m2"),
("b", 2, "m3"), ("c", 3, "m1"), ("c", 4, "m3"),
("c", 5, "m4"), ("d", 6, "m1"), ("d", 1, "m2"),
("d", 2, "m3"), ("d", 3, "m4"), ("d", 4, "m5"),
("e", 4, "m1"), ("e", 5, "m2"), ("e", 1, "m3"),
("e", 1, "m4"), ("e", 1, "m5")], 
("a", "cnt", "major"))

you could also use

reshaped_df = df.groupby('a').pivot('major').max('cnt').fillna(0)
Share:
24,478
resec
Author by

resec

just another coder for now

Updated on January 06, 2020

Comments

  • resec
    resec over 4 years

    I am a novice to spark, and I want to transform below source dataframe (load from JSON file):

    +--+-----+-----+
    |A |count|major|
    +--+-----+-----+
    | a|    1|   m1|
    | a|    1|   m2|
    | a|    2|   m3|
    | a|    3|   m4|
    | b|    4|   m1|
    | b|    1|   m2|
    | b|    2|   m3|
    | c|    3|   m1|
    | c|    4|   m3|
    | c|    5|   m4|
    | d|    6|   m1|
    | d|    1|   m2|
    | d|    2|   m3|
    | d|    3|   m4|
    | d|    4|   m5|
    | e|    4|   m1|
    | e|    5|   m2|
    | e|    1|   m3|
    | e|    1|   m4|
    | e|    1|   m5|
    +--+-----+-----+
    

    Into below result dataframe:

    +--+--+--+--+--+--+
    |A |m1|m2|m3|m4|m5|
    +--+--+--+--+--+--+
    | a| 1| 1| 2| 3| 0|
    | b| 4| 2| 1| 0| 0|
    | c| 3| 0| 4| 5| 0|
    | d| 6| 1| 2| 3| 4|
    | e| 4| 5| 1| 1| 1|
    +--+--+--+--+--+--+
    

    Here is the Transformation Rule:

    1. The result dataframe is consisted with A + (n major columns) where the major columns names are specified by:

      sorted(src_df.map(lambda x: x[2]).distinct().collect())
      
    2. The result dataframe contains m rows where the values for A column are provided by:

      sorted(src_df.map(lambda x: x[0]).distinct().collect())
      
    3. The value for each major column in result dataframe is the value from source dataframe on the corresponding A and major (e.g. the count in Row 1 in source dataframe is mapped to the box where A is a and column m1)

    4. The combinations of A and major in source dataframe has no duplication (please consider it a primary key on the two columns in SQL)