How to get the number of elements in partition?

23,648

Solution 1

The following gives you a new RDD with elements that are the sizes of each partition:

rdd.mapPartitions(iter => Array(iter.size).iterator, true) 

Solution 2

PySpark:

num_partitions = 20000
a = sc.parallelize(range(int(1e6)), num_partitions)
l = a.glom().map(len).collect()  # get length of each partition
print(min(l), max(l), sum(l)/len(l), len(l))  # check if skewed

Spark/scala:

val numPartitions = 20000
val a = sc.parallelize(0 until 1e6.toInt, numPartitions )
val l = a.glom().map(_.length).collect()  # get length of each partition
print(l.min, l.max, l.sum/l.length, l.length)  # check if skewed

The same is possible for a dataframe, not just for an RDD. Just add DF.rdd.glom... into the code above.

Notice that glom() converts elements of each partition into a list, so it's memory-intensive. A less memory-intensive version (pyspark version only):

import statistics 

def get_table_partition_distribution(table_name: str):

    def get_partition_len (iterator):
        yield sum(1 for _ in iterator)

    l = spark.table(table_name).rdd.mapPartitions(get_partition_len, True).collect()  # get length of each partition
    num_partitions = len(l)
    min_count = min(l)
    max_count = max(l)
    avg_count = sum(l)/num_partitions
    stddev = statistics.stdev(l)
    print(f"{table_name} each of {num_partitions} partition's counts: min={min_count:,} avg±stddev={avg_count:,.1f} ±{stddev:,.1f} max={max_count:,}")


get_table_partition_distribution('someTable')

outputs something like

someTable each of 1445 partition's counts: min=1,201,201 avg±stddev=1,202,811.6 ±21,783.4 max=2,030,137

Solution 3

I know I'm little late here, but I have another approach to get number of elements in a partition by leveraging spark's inbuilt function. It works for spark version above 2.1.

Explanation: We are going to create a sample dataframe (df), get the partition id, do a group by on partition id, and count each record.

Pyspark:

>>> from pyspark.sql.functions import spark_partition_id, count as _count
>>> df = spark.sql("set -v").unionAll(spark.sql("set -v")).repartition(4)
>>> df.rdd.getNumPartitions()
4
>>> df.withColumn("partition_id", spark_partition_id()).groupBy("partition_id").agg(_count("key")).orderBy("partition_id").show()
+------------+----------+
|partition_id|count(key)|
+------------+----------+
|           0|        48|
|           1|        44|
|           2|        32|
|           3|        48|
+------------+----------+

Scala:

scala> val df = spark.sql("set -v").unionAll(spark.sql("set -v")).repartition(4)
df: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [key: string, value: string ... 1 more field]

scala> df.rdd.getNumPartitions
res0: Int = 4

scala> df.withColumn("partition_id", spark_partition_id()).groupBy("partition_id").agg(count("key")).orderBy("partition_id").show()
+------------+----------+
|partition_id|count(key)|
+------------+----------+
|           0|        48|
|           1|        44|
|           2|        32|
|           3|        48|
+------------+----------+

Solution 4

pzecevic's answer works, but conceptually there's no need to construct an array and then convert it to an iterator. I would just construct the iterator directly and then get the counts with a collect call.

rdd.mapPartitions(iter => Iterator(iter.size), true).collect()

P.S. Not sure if his answer is actually doing more work since Iterator.apply will likely convert its arguments into an array.

Share:
23,648
Geo
Author by

Geo

Updated on July 12, 2022

Comments

  • Geo
    Geo almost 2 years

    Is there any way to get the number of elements in a spark RDD partition, given the partition ID? Without scanning the entire partition.

    Something like this:

    Rdd.partitions().get(index).size()
    

    Except I don't see such an API for spark. Any ideas? workarounds?

    Thanks

  • Geo
    Geo about 9 years
    Thank you! From what I understand iter.size iterates through the entire partition to get its size (correct me if I'm wrong here). Is there any way to get the partition size without iterating through it?
  • Jacek Laskowski
    Jacek Laskowski over 8 years
    It's correct - there's no way to know the size until the iteration is queried directly since it's more effective memory-wise where data is fetched on demand not all at once (that could not fit into available memory).