Apache Spark – aggregateByKey

January 18, 2021

Let us say we have an input RDD[(K,V)] and and we want to group the values for K , process them and output RDD[(K,U)].

In one of the earlier post we have seen how reduceByKey is more efficient operation than groupByKey and then reducing the result.

We cannot use reduceByKey in this case because reduceByKey does not allow you to change the the type of the value you are outputting. In such cases “aggregateByKey” comes handy.

aggregateByKey” gives us the ability to process the values for a key in each partition to produce a different type of result. In the second step, Spark will merge all such values for the keys in different partition to transform the input RDD with values of different datatype.

There are multiple variants of “aggregateByKey“, I will demonstrate use of

aggregateByKey[U](zeroValue: U)(seqOp: (U, V) ⇒ U, combOp: (U, U) ⇒ U)(implicit arg0: ClassTag[U]): RDD[(K, U)]

Presume we have tweets from different users and we want to bring together all the tweets for a user. We also need to make sure each tweet occurs only once in the output produced.

val part1 = List((1, "Tweet-1"), (2, "Tweet-2"),(1, "Tweet-3"), (1, "Tweet-1") ,(2, "Tweet-4") ,(1, "Tweet-5"),(3,"Tweet-6"))
val part2 = List((1, "Tweet-4"), (2, "Tweet-5"),(1, "Tweet-6"),(2, "Tweet-7") ,(3,"Tweet-8") ,(4,"Tweet-9"))
val part3 = List((3, "Tweet-6") ,(4, "Tweet-7") ,(4,"Tweet-1") ,(2,"Tweet-10"))

val inputdata = part1 ::: part2 ::: part3
val input = sc.parallelize(inputdata,3)
//input RDD is of type RDD[(Int,String)]

In this case our Input RDD will be of type (Int,String) representing (userid,tweet). The output RDD shall be of type (Int,Set[String]) based on the scenario mentioned above.

We can compute this by groupByKey() and then running a map operation on the grouped data to produce Set of unique tweets per user. However, in this case all the tweets for a user would need to made available on one partition. This would cause data shuffling much before we actually start merging the tweets together in a Set. Wouldn’t it be nice to have duplicates removed while working with partitions itself ? If you create Set for each key in that partition, it will help us do this. (This is seqOp, further explained below). This would help reduce amount of data that needs to be shuffled at the very end.

aggregateByKey” as shown below help us combine the values for a key in a given partition and form an interim Set[String]. This is represented by a seqOp or sequenceOperation. We need to provide a neutral Zero value to start the seqOp. In our case it will EMPTY Set[String].

Hence our seqOp should append the tweet to the the Set available for that key in that partition. seqOp takes two values , first is the existing Set[String] for the key under process. If within a partition a key is processed for first time then the neutral value which is EMPTY Set is used for the computation. If Spark has seen that key before while working with this partition then it will use available Set and not the empty one.

(setOfTweets,tweet)=> setOfTweets + tweet

When all the partitions are worked upon, the values produced by seqOp step are merged together to compute the final output for a Key. Please refer the function definition of combOp above. It takes in two values of type U ( in our case Set[String]) and outputs U. We will implement this by combOp to merge two Sets together.

(setOfTweets1,setOfTweets2) => setOfTweets1 ++ setOfTweets2

Putting all of this together it will look like below

val output = input.aggregateByKey(Set.empty[String])((setOfTweets,tweet)=> setOfTweets + tweet ,
(setOfTweets1,setOfTweets2) => setOfTweets1 ++ setOfTweets2)

Including the diagrammatic representation of the transformation to help you understand and imagine how this is worked upon by Spark.

The overall Spark application that illustrates the “aggregateByKey” is included for your reference.

val part1 = List((1, "Tweet-1"), (2, "Tweet-2"),(1, "Tweet-3"), (1, "Tweet-1") ,(2, "Tweet-4") ,(1, "Tweet-5"),(3,"Tweet-6"))
val part2 = List((1, "Tweet-4"), (2, "Tweet-5"),(1, "Tweet-6"),(2, "Tweet-7") ,(3,"Tweet-8") ,(4,"Tweet-9"))
val part3 = List((3, "Tweet-6") ,(4, "Tweet-7") ,(4,"Tweet-1") ,(2,"Tweet-10"))

val inputdata = part1 ::: part2 ::: part3

val input = sc.parallelize(inputdata,3)

val output = input.aggregateByKey(Set.empty[String])((setOfTweets,tweet)=> setOfTweets + tweet , (setOfTweets1,setOfTweets2) => setOfTweets1 ++ setOfTweets2)

output.collect.foreach(println)

Output :
(3,Set(Tweet-6, Tweet-8))
(4,Set(Tweet-9, Tweet-7, Tweet-1))
(1,Set(Tweet-5, Tweet-6, Tweet-4, Tweet-1, Tweet-3))
(2,Set(Tweet-5, Tweet-2, Tweet-10, Tweet-7, Tweet-4))