How to flatten a struct in a Spark dataframe?

107,152

Solution 1

Here is function that is doing what you want and that can deal with multiple nested columns containing columns with same name:

import pyspark.sql.functions as F

def flatten_df(nested_df):
    flat_cols = [c[0] for c in nested_df.dtypes if c[1][:6] != 'struct']
    nested_cols = [c[0] for c in nested_df.dtypes if c[1][:6] == 'struct']

    flat_df = nested_df.select(flat_cols +
                               [F.col(nc+'.'+c).alias(nc+'_'+c)
                                for nc in nested_cols
                                for c in nested_df.select(nc+'.*').columns])
    return flat_df

Before:

root
 |-- x: string (nullable = true)
 |-- y: string (nullable = true)
 |-- foo: struct (nullable = true)
 |    |-- a: float (nullable = true)
 |    |-- b: float (nullable = true)
 |    |-- c: integer (nullable = true)
 |-- bar: struct (nullable = true)
 |    |-- a: float (nullable = true)
 |    |-- b: float (nullable = true)
 |    |-- c: integer (nullable = true)

After:

root
 |-- x: string (nullable = true)
 |-- y: string (nullable = true)
 |-- foo_a: float (nullable = true)
 |-- foo_b: float (nullable = true)
 |-- foo_c: integer (nullable = true)
 |-- bar_a: float (nullable = true)
 |-- bar_b: float (nullable = true)
 |-- bar_c: integer (nullable = true)

Solution 2

For Spark 2.4.5,

while,df.select(df.col("data.*")) will give you org.apache.spark.sql.AnalysisException: No such struct field * in exception

this will work:-

df.select($"data.*")

Solution 3

This flatten_df version flattens the dataframe at every layer level, using a stack to avoid recursive calls:

from pyspark.sql.functions import col


def flatten_df(nested_df):
    stack = [((), nested_df)]
    columns = []

    while len(stack) > 0:
        parents, df = stack.pop()

        flat_cols = [
            col(".".join(parents + (c[0],))).alias("_".join(parents + (c[0],)))
            for c in df.dtypes
            if c[1][:6] != "struct"
        ]

        nested_cols = [
            c[0]
            for c in df.dtypes
            if c[1][:6] == "struct"
        ]

        columns.extend(flat_cols)

        for nested_col in nested_cols:
            projected_df = df.select(nested_col + ".*")
            stack.append((parents + (nested_col,), projected_df))

    return nested_df.select(columns)

Example:

from pyspark.sql.types import StringType, StructField, StructType


schema = StructType([
    StructField("some", StringType()),

    StructField("nested", StructType([
        StructField("nestedchild1", StringType()),
        StructField("nestedchild2", StringType())
    ])),

    StructField("renested", StructType([
        StructField("nested", StructType([
            StructField("nestedchild1", StringType()),
            StructField("nestedchild2", StringType())
        ]))
    ]))
])

data = [
    {
        "some": "value1",
        "nested": {
            "nestedchild1": "value2",
            "nestedchild2": "value3",
        },
        "renested": {
            "nested": {
                "nestedchild1": "value4",
                "nestedchild2": "value5",
            }
        }
    }
]

df = spark.createDataFrame(data, schema)
flat_df = flatten_df(df)
print(flat_df.collect())

Prints:

[Row(some=u'value1', renested_nested_nestedchild1=u'value4', renested_nested_nestedchild2=u'value5', nested_nestedchild1=u'value2', nested_nestedchild2=u'value3')]

Solution 4

I generalized the solution from stecos a bit more so the flattening can be done on more than two struct layers deep:

def flatten_df(nested_df, layers):
    flat_cols = []
    nested_cols = []
    flat_df = []

    flat_cols.append([c[0] for c in nested_df.dtypes if c[1][:6] != 'struct'])
    nested_cols.append([c[0] for c in nested_df.dtypes if c[1][:6] == 'struct'])

    flat_df.append(nested_df.select(flat_cols[0] +
                               [col(nc+'.'+c).alias(nc+'_'+c)
                                for nc in nested_cols[0]
                                for c in nested_df.select(nc+'.*').columns])
                  )
    for i in range(1, layers):
        print (flat_cols[i-1])
        flat_cols.append([c[0] for c in flat_df[i-1].dtypes if c[1][:6] != 'struct'])
        nested_cols.append([c[0] for c in flat_df[i-1].dtypes if c[1][:6] == 'struct'])

        flat_df.append(flat_df[i-1].select(flat_cols[i] +
                                [col(nc+'.'+c).alias(nc+'_'+c)
                                    for nc in nested_cols[i]
                                    for c in flat_df[i-1].select(nc+'.*').columns])
        )

    return flat_df[-1]

just call with:

my_flattened_df = flatten_df(my_df_having_nested_structs, 3)

(second parameter is the level of layers to be flattened, in my case it's 3)

Solution 5

PySpark solution to flatten nested df with both struct and array types with any level of depth. This is improved on this: https://stackoverflow.com/a/56533459/7131019

from pyspark.sql.types import *
from pyspark.sql import functions as f

def flatten_structs(nested_df):
    stack = [((), nested_df)]
    columns = []

    while len(stack) > 0:
        
        parents, df = stack.pop()
        
        array_cols = [
            c[0]
            for c in df.dtypes
            if c[1][:5] == "array"
        ]
        
        flat_cols = [
            f.col(".".join(parents + (c[0],))).alias("_".join(parents + (c[0],)))
            for c in df.dtypes
            if c[1][:6] != "struct"
        ]

        nested_cols = [
            c[0]
            for c in df.dtypes
            if c[1][:6] == "struct"
        ]
        
        columns.extend(flat_cols)

        for nested_col in nested_cols:
            projected_df = df.select(nested_col + ".*")
            stack.append((parents + (nested_col,), projected_df))
        
    return nested_df.select(columns)

def flatten_array_struct_df(df):
    
    array_cols = [
            c[0]
            for c in df.dtypes
            if c[1][:5] == "array"
        ]
    
    while len(array_cols) > 0:
        
        for array_col in array_cols:
            
            cols_to_select = [x for x in df.columns if x != array_col ]
            
            df = df.withColumn(array_col, f.explode(f.col(array_col)))
            
        df = flatten_structs(df)
        
        array_cols = [
            c[0]
            for c in df.dtypes
            if c[1][:5] == "array"
        ]
    return df

flat_df = flatten_array_struct_df(df)
Share:
107,152
djWann
Author by

djWann

Updated on July 08, 2022

Comments

  • djWann
    djWann almost 2 years

    I have a dataframe with the following structure:

     |-- data: struct (nullable = true)
     |    |-- id: long (nullable = true)
     |    |-- keyNote: struct (nullable = true)
     |    |    |-- key: string (nullable = true)
     |    |    |-- note: string (nullable = true)
     |    |-- details: map (nullable = true)
     |    |    |-- key: string
     |    |    |-- value: string (valueContainsNull = true)
    

    How it is possible to flatten the structure and create a new dataframe:

         |-- id: long (nullable = true)
         |-- keyNote: struct (nullable = true)
         |    |-- key: string (nullable = true)
         |    |-- note: string (nullable = true)
         |-- details: map (nullable = true)
         |    |-- key: string
         |    |-- value: string (valueContainsNull = true)
    

    Is there something like explode, but for structs?