how to use spark intersection() by key or filter() with two RDD?

12,849

Solution 1

This can be achieved in different ways

1. broadcast variable in filter() - needs scalability improvement

val data1 = sc.parallelize(Seq(("a", 1), ("a", 2), ("b", 2), ("b", 3), ("c", 1)))
val data2 = sc.parallelize(Seq(("a", 3), ("b", 5)))

// broadcast data2 key list to use in filter method, which runs in executor nodes
val bcast = sc.broadcast(data2.map(_._1).collect())

val result = data1.filter(r => bcast.value.contains(r._1))


println(result.collect().toList)
//Output
List((a,1), (a,2), (b,2), (b,3))

2. cogroup (similar to group by key)

val data1 = sc.parallelize(Seq(("a", 1), ("a", 2), ("b", 2), ("b", 3), ("c", 1)))
val data2 = sc.parallelize(Seq(("a", 3), ("b", 5)))

val cogroupRdd: RDD[(String, (Iterable[Int], Iterable[Int]))] = data1.cogroup(data2)
/* List(
  (a, (CompactBuffer(1, 2), CompactBuffer(3))),
  (b, (CompactBuffer(2, 3), CompactBuffer(5))),
  (c, (CompactBuffer(1), CompactBuffer()))
) */

//Now filter keys which have two non empty CompactBuffer. You can do that with 
//filter(row => row._2._1.nonEmpty && row._2._2.nonEmpty) also. 
val filterRdd = cogroupRdd.filter { case (k, (v1, v2)) => v1.nonEmpty && v2.nonEmpty } 
/* List(
  (a, (CompactBuffer(1, 2), CompactBuffer(3))),
  (b, (CompactBuffer(2, 3), CompactBuffer(5)))
) */

//As we care about first data only, lets pick first compact buffer only 
// by doing v1.map(val1 => (k, val1))
val result = filterRdd.flatMap { case (k, (v1, v2)) => v1.map(val1 => (k, val1)) }
//List((a, 1), (a, 2), (b, 2), (b, 3))

3. Using inner join

val resultRdd = data1.join(data2).map(r => (r._1, r._2._1)).distinct()
//List((b,2), (b,3), (a,2), (a,1)) 

Here data1.join(data2) holds pairs with common keys (inner join)

//List((a,(1,3)), (a,(2,3)), (b,(2,5)), (b,(2,1)), (b,(3,5)), (b,(3,1)))

Solution 2

For your problem, I think cogroup() is better suited. The intersection() method will consider both keys and values in your data, and will result in an empty rdd.

The function cogroup() groups the values of both rdd's by key and gives us (key, vals1, vals2), where vals1 and vals2 contain the values of data1 and data2 respectively, for each key. Note that if a certain key is not shared in both datasets, one of vals1 or vals2 will be returned as an empty Seq, hence we'll first have to filter out these tuples to arrive at the intersection of the two rdd's.

Next, we'll grab vals1 - which contains the values from data1 for the common keys - and convert it to format (key, Array). Lastly we use flatMapValues() to unpack the result into the format of (key, value).

val result = (data1.cogroup(data2)
  .filter{case (k, (vals1, vals2)) => vals1.nonEmpty && vals2.nonEmpty }
  .map{case (k, (vals1, vals2)) => (k, vals1.toArray)}
  .flatMapValues(identity[Array[Int]]))

result.collect()
// Array[(String, Int)] = Array((a,1), (a,2), (b,2), (b,3))
Share:
12,849
S.Kang
Author by

S.Kang

Updated on June 19, 2022

Comments

  • S.Kang
    S.Kang almost 2 years

    I want to use intersection() by key or filter() in spark.

    But I really don't know how to use intersection() by key.

    So I tried to use filter(), but it's not worked.

    example - here is two RDD:

    data1 //RDD[(String, Int)] = Array(("a", 1), ("a", 2), ("b", 2), ("b", 3), ("c", 1))
    data2 //RDD[(String, Int)] = Array(("a", 3), ("b", 5))
    
    val data3 = data2.map{_._1}
    
    data1.filter{_._1 == data3}.collect //Array[(String, Int] = Array()
    

    I want to get a (key, value) pair with the same key as data1 based on the key that data2 has.

    Array(("a", 1), ("a", 2), ("b", 2), ("b", 3)) is the result I want.

    Is there a method to solve this problem using intersection() by key or filter()?