How to plot correlation heatmap when using pyspark+databricks

10,985

I think the point where you get confused is:

matrix.collect()[0]["pearson({})".format(vector_col)].values

Calling .values of a densematrix gives you a list of all values, but what you are actually looking for is a list of list representing correlation matrix.

import matplotlib.pyplot as plt
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.stat import Correlation

columns = ['col1','col2','col3']

myGraph=spark.createDataFrame([(1.3,2.1,3.0),
                               (2.5,4.6,3.1),
                               (6.5,7.2,10.0)],
                              columns)
vector_col = "corr_features"
assembler = VectorAssembler(inputCols=['col1','col2','col3'], 
                            outputCol=vector_col)
myGraph_vector = assembler.transform(myGraph).select(vector_col)
matrix = Correlation.corr(myGraph_vector, vector_col)

Until now it was basically your code. Instead of calling .values you should use .toArray().tolist() to get a list of lists representing the correlation matrix:

matrix = Correlation.corr(myGraph_vector, vector_col).collect()[0][0]
corrmatrix = matrix.toArray().tolist()
print(corrmatrix)

Output:

[[1.0, 0.9582184104641529, 0.9780872729407004], [0.9582184104641529, 1.0, 0.8776695567739841], [0.9780872729407004, 0.8776695567739841, 1.0]]

The advantage of this approach is that you can turn a list of lists easily into a dataframe:

df = spark.createDataFrame(corrmatrix,columns)
df.show()

Output:

+------------------+------------------+------------------+ 
|              col1|              col2|              col3| 
+------------------+------------------+------------------+ 
|               1.0|0.9582184104641529|0.9780872729407004|
|0.9582184104641529|               1.0|0.8776695567739841| 
|0.9780872729407004|0.8776695567739841|               1.0|  
+------------------+------------------+------------------+

To answer your second question. Just one of the many solutions to plot a heatmap (like this or this even better with seaborn).

def plot_corr_matrix(correlations,attr,fig_no):
    fig=plt.figure(fig_no)
    ax=fig.add_subplot(111)
    ax.set_title("Correlation Matrix for Specified Attributes")
    ax.set_xticklabels(['']+attr)
    ax.set_yticklabels(['']+attr)
    cax=ax.matshow(correlations,vmax=1,vmin=-1)
    fig.colorbar(cax)
    plt.show()

plot_corr_matrix(corrmatrix, columns, 234)
Share:
10,985

Related videos on Youtube

Feng Chen
Author by

Feng Chen

Updated on July 24, 2022

Comments

  • Feng Chen
    Feng Chen over 1 year

    I am studying pyspark in databricks. I want to generate a correlation heatmap. Let's say this is my data:

    myGraph=spark.createDataFrame([(1.3,2.1,3.0),
                                   (2.5,4.6,3.1),
                                   (6.5,7.2,10.0)],
                                  ['col1','col2','col3'])
    

    And this is my code:

    import pyspark
    from pyspark.sql import SparkSession
    import matplotlib.pyplot as plt
    import pandas as pd
    import numpy as np
    from ggplot import *
    from pyspark.ml.feature import VectorAssembler
    from pyspark.ml.stat import Correlation
    from pyspark.mllib.stat import Statistics
    
    myGraph=spark.createDataFrame([(1.3,2.1,3.0),
                                   (2.5,4.6,3.1),
                                   (6.5,7.2,10.0)],
                                  ['col1','col2','col3'])
    vector_col = "corr_features"
    assembler = VectorAssembler(inputCols=['col1','col2','col3'], 
                                outputCol=vector_col)
    myGraph_vector = assembler.transform(myGraph).select(vector_col)
    matrix = Correlation.corr(myGraph_vector, vector_col)
    matrix.collect()[0]["pearson({})".format(vector_col)].values
    

    Until here, I can get the correlation matrix. The result looks like:

    enter image description here

    Now my problems are:

    1. How to transfer matrix to data frame? I have tried the methods of How to convert DenseMatrix to spark DataFrame in pyspark? and How to get correlation matrix values pyspark. But it does not work for me.
    2. How to generate a correlation heatmap which looks like:

    enter image description here

    Because I just studied pyspark and databricks. ggplot or matplotlib are both OK for my problem.

  • mwhee
    mwhee over 4 years
    Cronoik - Do the values have to be in INT format? I'm attempting correlation between FLOAT values and getting NaN in the resulting matrix.
  • cronoik
    cronoik over 4 years
    No, that is not necessary. I also have used float format in the example above. Can you please open your own question and show us your code? I will have a look at it.