DataFrame equality in Apache Spark

60,827

Solution 1

There are some standard ways in the Apache Spark test suites, however most of these involve collecting the data locally and if you want to do equality testing on large DataFrames then that is likely not a suitable solution.

Checking the schema first and then you could do an intersection to df3 and verify that the count of df1,df2 & df3 are all equal (however this only works if there aren't duplicate rows, if there are different duplicates rows this method could still return true).

Another option would be getting the underlying RDDs of both of the DataFrames, mapping to (Row, 1), doing a reduceByKey to count the number of each Row, and then cogrouping the two resulting RDDs and then do a regular aggregate and return false if any of the iterators are not equal.

Solution 2

Scala (see below for PySpark)

The spark-fast-tests library has two methods for making DataFrame comparisons (I'm the creator of the library):

The assertSmallDataFrameEquality method collects DataFrames on the driver node and makes the comparison

def assertSmallDataFrameEquality(actualDF: DataFrame, expectedDF: DataFrame): Unit = {
  if (!actualDF.schema.equals(expectedDF.schema)) {
    throw new DataFrameSchemaMismatch(schemaMismatchMessage(actualDF, expectedDF))
  }
  if (!actualDF.collect().sameElements(expectedDF.collect())) {
    throw new DataFrameContentMismatch(contentMismatchMessage(actualDF, expectedDF))
  }
}

The assertLargeDataFrameEquality method compares DataFrames spread on multiple machines (the code is basically copied from spark-testing-base)

def assertLargeDataFrameEquality(actualDF: DataFrame, expectedDF: DataFrame): Unit = {
  if (!actualDF.schema.equals(expectedDF.schema)) {
    throw new DataFrameSchemaMismatch(schemaMismatchMessage(actualDF, expectedDF))
  }
  try {
    actualDF.rdd.cache
    expectedDF.rdd.cache

    val actualCount = actualDF.rdd.count
    val expectedCount = expectedDF.rdd.count
    if (actualCount != expectedCount) {
      throw new DataFrameContentMismatch(countMismatchMessage(actualCount, expectedCount))
    }

    val expectedIndexValue = zipWithIndex(actualDF.rdd)
    val resultIndexValue = zipWithIndex(expectedDF.rdd)

    val unequalRDD = expectedIndexValue
      .join(resultIndexValue)
      .filter {
        case (idx, (r1, r2)) =>
          !(r1.equals(r2) || RowComparer.areRowsEqual(r1, r2, 0.0))
      }

    val maxUnequalRowsToShow = 10
    assertEmpty(unequalRDD.take(maxUnequalRowsToShow))

  } finally {
    actualDF.rdd.unpersist()
    expectedDF.rdd.unpersist()
  }
}

assertSmallDataFrameEquality is faster for small DataFrame comparisons and I've found it sufficient for my test suites.

PySpark

Here's a simple function that returns true if the DataFrames are equal:

def are_dfs_equal(df1, df2):
    if df1.schema != df2.schema:
        return False
    if df1.collect() != df2.collect():
        return False
    return True

or simplified

def are_dfs_equal(df1, df2): 
    return (df1.schema == df2.schema) and (df1.collect() == df2.collect())

You'll typically perform DataFrame equality comparisons in a test suite and will want a descriptive error message when the comparisons fail (a True / False return value doesn't help much when debugging).

Use the chispa library to access the assert_df_equality method that returns descriptive error messages for test suite workflows.

Solution 3

I don't know about idiomatic, but I think you can get a robust way to compare DataFrames as you describe as follows. (I'm using PySpark for illustration, but the approach carries across languages.)

a = spark.range(5)
b = spark.range(5)

a_prime = a.groupBy(sorted(a.columns)).count()
b_prime = b.groupBy(sorted(b.columns)).count()

assert a_prime.subtract(b_prime).count() == b_prime.subtract(a_prime).count() == 0

This approach correctly handles cases where the DataFrames may have duplicate rows, rows in different orders, and/or columns in different orders.

For example:

a = spark.createDataFrame([('nick', 30), ('bob', 40)], ['name', 'age'])
b = spark.createDataFrame([(40, 'bob'), (30, 'nick')], ['age', 'name'])
c = spark.createDataFrame([('nick', 30), ('bob', 40), ('nick', 30)], ['name', 'age'])

a_prime = a.groupBy(sorted(a.columns)).count()
b_prime = b.groupBy(sorted(b.columns)).count()
c_prime = c.groupBy(sorted(c.columns)).count()

assert a_prime.subtract(b_prime).count() == b_prime.subtract(a_prime).count() == 0
assert a_prime.subtract(c_prime).count() != 0

This approach is quite expensive, but most of the expense is unavoidable given the need to perform a full diff. And this should scale fine as it doesn't require collecting anything locally. If you relax the constraint that the comparison should account for duplicate rows, then you can drop the groupBy() and just do the subtract(), which would probably speed things up notably.

Solution 4

Java:

assert resultDs.union(answerDs).distinct().count() == resultDs.intersect(answerDs).count();

Solution 5

Try doing the following:

df1.except(df2).isEmpty
Share:
60,827
Sim
Author by

Sim

["co-found", "hack", "create", "enjoy", "play", "invest", "partner"].sort(function(){return 0.5-Math.random()})

Updated on July 25, 2022

Comments

  • Sim
    Sim almost 2 years

    Assume df1 and df2 are two DataFrames in Apache Spark, computed using two different mechanisms, e.g., Spark SQL vs. the Scala/Java/Python API.

    Is there an idiomatic way to determine whether the two data frames are equivalent (equal, isomorphic), where equivalence is determined by the data (column names and column values for each row) being identical save for the ordering of rows & columns?

    The motivation for the question is that there are often many ways to compute some big data result, each with its own trade-offs. As one explores these trade-offs, it is important to maintain correctness and hence the need to check for the equivalence/equality on a meaningful test data set.

  • Sim
    Sim almost 9 years
    Using the testing suite is an interesting idea. Collecting the data may be an option for small/medium data sets. What are the standard tools from there?
  • numeral
    numeral about 7 years
    Just a note that this doesn't work with any unorderable data types such as maps, in which case you might have to drop those columns and do them separately.
  • Sim
    Sim almost 7 years
    Looks like a nice library!
  • Sim
    Sim almost 5 years
    That won't work in the case where df2 is larger than df1. Perhaps if you make it symmetric by adding && df2.except(df1).isEmpty...
  • Nick Chammas
    Nick Chammas over 4 years
    Interesting solution, but I believe this does not handle duplicate rows correctly. For example (in Python): a = spark.createDataFrame([(1,), (1,)], schema='id int'); b = spark.createDataFrame([(1,)], schema='id int'); assert a.union(b).distinct().count() == a.intersect(b).count(); The assert succeeds where it should instead fail.
  • J. P
    J. P over 4 years
    In the case of duplicate rows, how about appending an extra column of 'count' (of course by calculating functions.agg or by SQL) and then get the intersect as df3?
  • J. P
    J. P over 4 years
    And how about taking a Union of both the datasets, then groupBy all the columns (of course using Sequence) and take count, and filter count%2. If it is > 0 then return false. Union is faster than intersection and will return an exception if columns are different (correct me, if I am wrong)
  • J. P
    J. P over 4 years
    try { return ds1.union(ds2) .groupBy(columns(ds1, ds1.columns())) .count() .filter("count % 2 > 0") .count() == 0; } catch (Exception e) { return false; } where columns method returns Seq<Columns> or Column[]
  • J. P
    J. P over 4 years
    Can someone confirm if this union solution has a better performance compared to joins solutions provided above? (and also it works with duplicate rows)
  • Clemens Valiente
    Clemens Valiente over 4 years
    even if you compare it each way it's still not correct since duplicate rows in df2 are matched by one row in df1 and vice versa.
  • Holden
    Holden over 4 years
    I don't think that will be any faster, the slow part of intersection is the shuffle which you'll also have with groupBy.
  • zetaprime
    zetaprime about 4 years
    This unfortunately is not correct, if one of the datasets has a distinct row repeated twice you'll have a false positive.
  • jgtrz
    jgtrz almost 4 years
    @Powers, do you know of any similar libraries for pySpark instead of Scala?
  • Powers
    Powers almost 4 years
    @jgtrz - I started building a PySpark version of spark-fast-tests called chispa: github.com/MrPowers/chispa. Need to finish it!
  • runrig
    runrig almost 4 years
    Is this more scalable than the PySpark solution above using collect() ? Especially if you don't need a list of the differences?
  • EnricoM
    EnricoM almost 4 years
    If you mean the df1.collect() != df2.collect() PySpark solution, this is not scalable at all. Both DataFrames are loaded into the driver's memory. The above diff transormation scales with the cluster, meaning if your cluster can handle the DataFrames, it can handle the diff. So the answer then is: yes.
  • jk-kim
    jk-kim over 3 years
    For those of us who stumbles here and implemented collect compare with !actualDF.collect().sameElements(expectedDF.collect()) . Please note that below post and be wary of ridiculousness of sameElements() stackoverflow.com/questions/29008500/…
  • Sim
    Sim about 3 years
    Re: option 2, concat does not work for all column types and md5 can have collisions on big data. Nice addition of Option 4 with exceptAll, which was only added in 2.4.0.
  • dhalfageme
    dhalfageme about 3 years
    I guess the count goes inside and agg() method, otherwise a_prime, b_prime and c_prime are numbers instead of dataframes
  • Nick Chammas
    Nick Chammas about 3 years
    @dhalfageme - No, .count() on a GroupedData object -- which is what .groupBy() returns -- yields a DataFrame. Try it: spark.range(3).groupBy('id').count().show()
  • matkurek
    matkurek over 2 years
    For the Pyspark folks: the function provided takes sorting into account. If you care only about contents repleace second condition with: if df1.orderBy(*df1.columns).collect() !=df2.orderBy(*df2.columns).collect():