How to return a "Tuple type" in a UDF in PySpark?

37,944

Solution 1

There is no such thing as a TupleType in Spark. Product types are represented as structs with fields of specific type. For example if you want to return an array of pairs (integer, string) you can use schema like this:

from pyspark.sql.types import *

schema = ArrayType(StructType([
    StructField("char", StringType(), False),
    StructField("count", IntegerType(), False)
]))

Example usage:

from pyspark.sql.functions import udf
from collections import Counter

char_count_udf = udf(
    lambda s: Counter(s).most_common(),
    schema
)

df = sc.parallelize([(1, "foo"), (2, "bar")]).toDF(["id", "value"])

df.select("*", char_count_udf(df["value"])).show(2, False)

## +---+-----+-------------------------+
## |id |value|PythonUDF#<lambda>(value)|
## +---+-----+-------------------------+
## |1  |foo  |[[o,2], [f,1]]           |
## |2  |bar  |[[r,1], [a,1], [b,1]]    |
## +---+-----+-------------------------+

Solution 2

Stackoverflow keeps directing me to this question, so I guess I'll add some info here.

Returning simple types from UDF:

from pyspark.sql.types import *
from pyspark.sql import functions as F

def get_df():
  d = [(0.0, 0.0), (0.0, 3.0), (1.0, 6.0), (1.0, 9.0)]
  df = sqlContext.createDataFrame(d, ['x', 'y'])
  return df

df = get_df()
df.show()

# +---+---+
# |  x|  y|
# +---+---+
# |0.0|0.0|
# |0.0|3.0|
# |1.0|6.0|
# |1.0|9.0|
# +---+---+

func = udf(lambda x: str(x), StringType())
df = df.withColumn('y_str', func('y'))

func = udf(lambda x: int(x), IntegerType())
df = df.withColumn('y_int', func('y'))

df.show()

# +---+---+-----+-----+
# |  x|  y|y_str|y_int|
# +---+---+-----+-----+
# |0.0|0.0|  0.0|    0|
# |0.0|3.0|  3.0|    3|
# |1.0|6.0|  6.0|    6|
# |1.0|9.0|  9.0|    9|
# +---+---+-----+-----+

df.printSchema()

# root
#  |-- x: double (nullable = true)
#  |-- y: double (nullable = true)
#  |-- y_str: string (nullable = true)
#  |-- y_int: integer (nullable = true)

When integers are not enough:

df = get_df()

func = udf(lambda x: [0]*int(x), ArrayType(IntegerType()))
df = df.withColumn('list', func('y'))

func = udf(lambda x: {float(y): str(y) for y in range(int(x))}, 
           MapType(FloatType(), StringType()))
df = df.withColumn('map', func('y'))

df.show()
# +---+---+--------------------+--------------------+
# |  x|  y|                list|                 map|
# +---+---+--------------------+--------------------+
# |0.0|0.0|                  []|               Map()|
# |0.0|3.0|           [0, 0, 0]|Map(2.0 -> 2, 0.0...|
# |1.0|6.0|  [0, 0, 0, 0, 0, 0]|Map(0.0 -> 0, 5.0...|
# |1.0|9.0|[0, 0, 0, 0, 0, 0...|Map(0.0 -> 0, 5.0...|
# +---+---+--------------------+--------------------+

df.printSchema()
# root
#  |-- x: double (nullable = true)
#  |-- y: double (nullable = true)
#  |-- list: array (nullable = true)
#  |    |-- element: integer (containsNull = true)
#  |-- map: map (nullable = true)
#  |    |-- key: float
#  |    |-- value: string (valueContainsNull = true)

Returning complex datatypes from UDF:

df = get_df()
df = df.groupBy('x').agg(F.collect_list('y').alias('y[]'))
df.show()

# +---+----------+
# |  x|       y[]|
# +---+----------+
# |0.0|[0.0, 3.0]|
# |1.0|[9.0, 6.0]|
# +---+----------+

schema = StructType([
    StructField("min", FloatType(), True),
    StructField("size", IntegerType(), True),
    StructField("edges",  ArrayType(FloatType()), True),
    StructField("val_to_index",  MapType(FloatType(), IntegerType()), True)
    # StructField('insanity', StructType([StructField("min_", FloatType(), True), StructField("size_", IntegerType(), True)]))

])

def func(values):
  mn = min(values)
  size = len(values)
  lst = sorted(values)[::-1]
  val_to_index = {x: i for i, x in enumerate(values)}
  return (mn, size, lst, val_to_index)

func = udf(func, schema)
dff = df.select('*', func('y[]').alias('complex_type'))
dff.show(10, False)

# +---+----------+------------------------------------------------------+
# |x  |y[]       |complex_type                                          |
# +---+----------+------------------------------------------------------+
# |0.0|[0.0, 3.0]|[0.0,2,WrappedArray(3.0, 0.0),Map(0.0 -> 0, 3.0 -> 1)]|
# |1.0|[6.0, 9.0]|[6.0,2,WrappedArray(9.0, 6.0),Map(9.0 -> 1, 6.0 -> 0)]|
# +---+----------+------------------------------------------------------+

dff.printSchema()

# +---+----------+------------------------------------------------------+
# |x  |y[]       |complex_type                                          |
# +---+----------+------------------------------------------------------+
# |0.0|[0.0, 3.0]|[0.0,2,WrappedArray(3.0, 0.0),Map(0.0 -> 0, 3.0 -> 1)]|
# |1.0|[6.0, 9.0]|[6.0,2,WrappedArray(9.0, 6.0),Map(9.0 -> 1, 6.0 -> 0)]|
# +---+----------+------------------------------------------------------+

Passing multiple arguments to a UDF:

df = get_df()
func = udf(lambda arr: arr[0]*arr[1],FloatType())
df = df.withColumn('x*y', func(F.array('x', 'y')))

    # +---+---+---+
    # |  x|  y|x*y|
    # +---+---+---+
    # |0.0|0.0|0.0|
    # |0.0|3.0|0.0|
    # |1.0|6.0|6.0|
    # |1.0|9.0|9.0|
    # +---+---+---+

The code is purely for demo purposes, all above transformation are available in Spark code and would yield much better performance. As @zero323 in the comment above, UDFs should generally be avoided in pyspark; returning complex types should make you think about simplifying your logic.

Share:
37,944
kamalbanga
Author by

kamalbanga

Data Scientist at Myntra

Updated on July 18, 2022

Comments

  • kamalbanga
    kamalbanga almost 2 years

    All the data types in pyspark.sql.types are:

    __all__ = [
        "DataType", "NullType", "StringType", "BinaryType", "BooleanType", "DateType",
        "TimestampType", "DecimalType", "DoubleType", "FloatType", "ByteType", "IntegerType",
        "LongType", "ShortType", "ArrayType", "MapType", "StructField", "StructType"]
    

    I have to write a UDF (in pyspark) which returns an array of tuples. What do I give the second argument to it which is the return type of the udf method? It would be something on the lines of ArrayType(TupleType())...