Apache Spark – aggregateByKey

January 18, 2021

Let us say we have an input RDD[(K,V)] and and we want to group the values for K , process them and output RDD[(K,U)].

In one of the earlier post we have seen how reduceByKey is more efficient operation than groupByKey and then reducing the result.

We cannot use reduceByKey in this case because reduceByKey does not allow you to change the the type of the value you are outputting. In such cases “aggregateByKey” comes handy.

aggregateByKey” gives us the ability to process the values for a key in each partition to produce a different type of result. In the second step, Spark will merge all such values for the keys in different partition to transform the input RDD with values of different datatype.

There are multiple variants of “aggregateByKey“, I will demonstrate use of

aggregateByKey[U](zeroValue: U)(seqOp: (U, V) ⇒ U, combOp: (U, U) ⇒ U)(implicit arg0: ClassTag[U]): RDD[(K, U)]

Presume we have tweets from different users and we want to bring together all the tweets for a user. We also need to make sure each tweet occurs only once in the output produced.

val part1 = List((1, "Tweet-1"), (2, "Tweet-2"),(1, "Tweet-3"), (1, "Tweet-1") ,(2, "Tweet-4") ,(1, "Tweet-5"),(3,"Tweet-6"))
val part2 = List((1, "Tweet-4"), (2, "Tweet-5"),(1, "Tweet-6"),(2, "Tweet-7") ,(3,"Tweet-8") ,(4,"Tweet-9"))
val part3 = List((3, "Tweet-6") ,(4, "Tweet-7") ,(4,"Tweet-1") ,(2,"Tweet-10"))

val inputdata = part1 ::: part2 ::: part3
val input = sc.parallelize(inputdata,3)
//input RDD is of type RDD[(Int,String)]

In this case our Input RDD will be of type (Int,String) representing (userid,tweet). The output RDD shall be of type (Int,Set[String]) based on the scenario mentioned above.

We can compute this by groupByKey() and then running a map operation on the grouped data to produce Set of unique tweets per user. However, in this case all the tweets for a user would need to made available on one partition. This would cause data shuffling much before we actually start merging the tweets together in a Set. Wouldn’t it be nice to have duplicates removed while working with partitions itself ? If you create Set for each key in that partition, it will help us do this. (This is seqOp, further explained below). This would help reduce amount of data that needs to be shuffled at the very end.

aggregateByKey” as shown below help us combine the values for a key in a given partition and form an interim Set[String]. This is represented by a seqOp or sequenceOperation. We need to provide a neutral Zero value to start the seqOp. In our case it will EMPTY Set[String].

Hence our seqOp should append the tweet to the the Set available for that key in that partition. seqOp takes two values , first is the existing Set[String] for the key under process. If within a partition a key is processed for first time then the neutral value which is EMPTY Set is used for the computation. If Spark has seen that key before while working with this partition then it will use available Set and not the empty one.

(setOfTweets,tweet)=> setOfTweets + tweet

When all the partitions are worked upon, the values produced by seqOp step are merged together to compute the final output for a Key. Please refer the function definition of combOp above. It takes in two values of type U ( in our case Set[String]) and outputs U. We will implement this by combOp to merge two Sets together.

(setOfTweets1,setOfTweets2) => setOfTweets1 ++ setOfTweets2

Putting all of this together it will look like below

val output = input.aggregateByKey(Set.empty[String])((setOfTweets,tweet)=> setOfTweets + tweet ,
(setOfTweets1,setOfTweets2) => setOfTweets1 ++ setOfTweets2)

Including the diagrammatic representation of the transformation to help you understand and imagine how this is worked upon by Spark.

The overall Spark application that illustrates the “aggregateByKey” is included for your reference.

val part1 = List((1, "Tweet-1"), (2, "Tweet-2"),(1, "Tweet-3"), (1, "Tweet-1") ,(2, "Tweet-4") ,(1, "Tweet-5"),(3,"Tweet-6"))
val part2 = List((1, "Tweet-4"), (2, "Tweet-5"),(1, "Tweet-6"),(2, "Tweet-7") ,(3,"Tweet-8") ,(4,"Tweet-9"))
val part3 = List((3, "Tweet-6") ,(4, "Tweet-7") ,(4,"Tweet-1") ,(2,"Tweet-10"))

val inputdata = part1 ::: part2 ::: part3

val input = sc.parallelize(inputdata,3)

val output = input.aggregateByKey(Set.empty[String])((setOfTweets,tweet)=> setOfTweets + tweet , (setOfTweets1,setOfTweets2) => setOfTweets1 ++ setOfTweets2)

output.collect.foreach(println)

Output :
(3,Set(Tweet-6, Tweet-8))
(4,Set(Tweet-9, Tweet-7, Tweet-1))
(1,Set(Tweet-5, Tweet-6, Tweet-4, Tweet-1, Tweet-3))
(2,Set(Tweet-5, Tweet-2, Tweet-10, Tweet-7, Tweet-4))

Number of Records in an RDD partition

November 9, 2020

I was recently answering a question on stackoverflow.com about how data is partitioned when it is smaller than number of partitions itself. I ended up writing a simple code snippet to see how many records ended up in each partition when number if elements in an RDD are less than number of partitions specified.

Thought of putting down together couple of lines code that is useful if someone is looking for a way to count number of records in a partition.

This could be also useful to see if your partitions are skewed. Sometimes you run into OutOfMemory Error as one partition is too big when compared to other partition. In such cases, it is usually the case where lot of elements share the similar hash-key and thus end up on same partition.

For example if the key is null then all such elements would end up having same hash code and same partition.

Here is the code that would help you find number of records in a partition.

Below we have 4 elements in RDD and number of partitions are 8.


val rdd = sc.parallelize(List(1,2,3,4),8) 
rdd.mapPartitionsWithIndex((x,y) => { 
   println(s"partitions $x has ${y.length} records");y
}).collect.foreach(println)


GroupBy and count using Spark DataFrame

June 18, 2019

Here we are trying to group by keys and run a count against them.

val datardd = sc.parallelize(Seq(“a”->1,”b”->1,”a”->1,”c”->1))

val mydf = datardd.toDF

mydf.groupBy($”name”).agg(“count” -> “count”).
withColumnRenamed(“count(count)”,”noofoccurrences”).
orderBy($”noofoccurrences”.desc).show

name noofoccurrences
a 2
b 1
c 1