A Gentle Intro to UDAFs In Apache Spark

As part of any data analysis workflow, doing some sort of aggregation across groups, or columns is common. In an analysis of basketball data you might want to know what the mean points per 100 possession per team is, or you may want to find the player with the highest true shooting percentage, or even make up your own statistic and rank all rookies according to that. Doing any of this work requires some aggregation. In Spark SQL, there are many API's that allow us to aggregate data but not all of the built it methods are adequate for our needs. Fortunately, in these cases we can define our own aggregation functions called User-defined aggregate functions. When looking for resources to share on UDAF's and the motivation for using them, I found most of them to be somewhat beginner hostile. I decided to try to best and write a gentle intro for someone looking to write their own UDAF.

When Do I Need A UDAF vs a UDF

Most people acquainted with Spark SQL have written a User Defined Function. A UDF looks something like this:

As arguments, it takes columns and then return columns with the applied transformations.

image We wouldn't be able to write a SUM with a UDF, because it requires looking at more than one value at a time. For this we need some kind of aggregation. Spark SQL has a few built in aggregate functions like sum.

image

We might want to define our own aggregate function, a user-defined aggregate function that does some summary over a column. I may want to define a harmonic mean for example. The harmonic mean has the following defintion:

image

Defining this as a plain function is Scala is not too hard:

Here is an example of how finding the harmonic mean of 3 numbers would work:

image

However, there is a bit more overhead to get this to work with a DataFrame. Getting this to work with a DataFrame requires digging a bit into some of the Spark implementation details versus a higher level API. Fortunately, since the code is well written it is easy enough to make sense out of.

The User Defined Aggregate Aggregation Function Class

In order to define a custom UDAF, we have to create a new class that extends the User Defined Aggregate Function class in Spark. Doing so is not too tedious:

From here it's not as simple as we would expect. We need to override a number of the methods in the original class. I'll simply share my implementation and walk through each part. I've implemented

inputSchema

Defining the input schema should be recognizable. There are many cases in Spark where we use this syntax. In this case, I am just expecting a column that is a double.

bufferSchema

The buffer schema can be thought of as the Schema for your aggregation plus whatever the values of the column you inputted. For example, in the simple function where there is a counter:

Now, we can quickly imagine that iterating over a billion row dataset would require a large value (Long) so we simply store that data as a long. You'll notice the type is a MutableAggregationBuffer. You can look at the definition and implementation to understand more about the specifics, but suffice it to say it is an efficient / type safe way of storing aggregated data.

dataType

This is the data type we would like to return.

deterministic

Deterministic just means that for the same input the function will return the same output. Unless there is some randomness in your function it'll always be deterministic.

initialize

This is where you initialize the values you are using. In our case we need count a counter (we will start at zero) and a place holder for our aggregation (We'll also start this at zero).

update

Update is simply the method you will use to update your aggregation buffer given a new column. Do you want to add to it, divide, multiply? In our case we want to add, 1 divided by whatever our input row is.

merge

Merge is how you would merge two objects that share a schema. In this case we just need to add because of the definition of our function.

evaluate

Finally, what's the final value you need? In our case we need to total number of iterations, divided by our sum of fractions. And that's all!

Let's see if this works:

Sales
315.0
952.0

That's It

Hopefully that provided some clarity on UDAFs. They can be pretty intimidating but they are one of the more powerful parts of the Spark. It's a worthwhile exercise to try to re-implement some of your favorite aggregations in Spark, or take some of the Excel aggregation functions and try to implement them in Spark.

Bonus

One thing I forgot to mention is that you can use more than one column in a User defined aggregation function. The apply method takes a List of Columns, not just one column.

We can then define a UDAF that uses multiple columns like a product sum:

image

I've implemented it here. Hopefully, it's easy to reason about after reading this post.