How to use correlation in Spark with Dataframes?

16,840

There is no method that can be used directly to achieve what you want. Python wrappers for the method implemented in SPARK-19636 are present in pyspark.ml.stat:

from pyspark.ml.stat import Correlation

Correlation.corr(df_cat, "features")

but this method is used to compute correlation matrix for a single Vector column.

You could:

  • Assemble features and fail_mode_meas using VectorAssembler and apply pyspark.ml.stat.Correlation afterwards, but it will compute a number of obsolete values.
  • Expand vector column and use pyspark.sql.functions.corr but it will be expensive for large number of columns and add significant overhead when used with Python udf.
Share:
16,840
y.selivonchyk
Author by

y.selivonchyk

Deep learning enthusiast.

Updated on June 08, 2022

Comments

  • y.selivonchyk
    y.selivonchyk almost 2 years

    Spark 2.2.0 adds correlation support for data-frames. More information about that can be found in the pull request.

    MLlib New algorithms in DataFrame-based API:

    SPARK-19636: Correlation in DataFrame-based API (Scala/Java/Python)

    Yet, it is entirely unclear how to use this change or what have changed comparing to previous version.

    I expected something like:

    df_num = spark.read.parquet('/dataframe')
    df_cat.printSchema()
    df_cat.show()
    df_num.corr(col1='features', col2='fail_mode_meas')
    
    root
     |-- features: vector (nullable = true)
     |-- fail_mode_meas: double (nullable = true)
    
    
    +--------------------+--------------+
    |            features|fail_mode_meas|
    +--------------------+--------------+
    |[0.0,0.5,0.0,0.0,...|          22.7|
    |[0.9,0.0,0.7,0.0,...|           0.1|
    |[0.0,5.1,1.0,0.0,...|           2.0|
    |[0.0,0.0,0.0,0.0,...|           3.1|
    |[0.1,0.0,0.0,1.7,...|           0.0|
    ...
    
    pyspark.sql.utils.IllegalArgumentException: 'requirement failed: Currently correlation calculation for columns with dataType org.apach
    e.spark.ml.linalg.VectorUDT not supported.'
    

    Can someone explain how to take advantage of the new Spark 2.2.0 feature for correlation in dataframes?