I mentioned in a previous article that for performance reasons, you should avoid the use of UDF’s wherever possible. And while that statement still stands, if you absolutely must use a UDF, you should consider a Pandas UDF rather than those that come out of the box with Spark.
The standard UDF’s in Spark operate one-row-at-a-time. This gives us a massive serialization overhead, as each row is serialized as it is transferred from the JVM to the Python process. As of Spark 2.3 we can use Pandas UDFs.
The difference with these UDF’s (Pandas, also known as vectorized UDF’s), is that they batch rows in the spark dataframe into Pandas dataframes (default of 10,000 rows per batch); which leads to much higher performance as we serialize small batches, rather than one at a time. It achieves this by working with Apache Arrow.
There are two types of Pandas UDF:
- A Scalar UDF; which should be used when you’re running a withColumn() command.
- A Grouped Map UDF; which should be used with groupBy or Apply operations.
In this post, we will look at an example of a Scalar UDF, but you can read more on both on the Spark website. Read the Spark documentation here.
An Example Scalar UDF
To be honest, there isn’t a whole lot to talk about in the below. We define a UDF in exactly the same way as we would usually, except instead of udf(function_name…, we rather say pandas_udf(function_name…; and of course we need to import some slightly different modules at the top.
This simple change could impact your script performance significantly. As with all of the performance tips I suggest; it often depends on your environment and the data you have, but there is no harm in trying – if you have to use a UDF, you may see a bit of a performance hit, so trying these kinds of simple method to improve performance has to be worthwhile.
import pandas as pd
from pyspark.sql.functions import col, pandas_udf
from pyspark.sql.types import LongType, FloatType
# Declare the function and create the UDF
value = x*56/31*1000
multiply = pandas_udf(multiply_func, returnType=FloatType())
# make a pandas series with numbers between 1 and 1,000,000
x = pd.Series(list(range(1, 1000000)))
# Create a Spark DataFrame, ‘spark’ from the series
df = spark.createDataFrame(pd.DataFrame(x, columns=[“x”]))
# Execute function