Improving performance when calculating percentiles in Spark

Performance is a major concern when you’re working in a distributed environment with a massive amount of data. I’ve discussed Spark performance in quite a lot of detail before here and here.

Today I am going to talk specifically about percentiles. Because, calculating percentiles over large distributed datasets is a mammoth task. You will likely find that a map reduce or spark job will take absolutely ages to calculate the measures. This is because, the data is distributed across many nodes and a significant amount of data shuffle has to occur in order to calculate the percentile values.

Before we get into that, the approx functions available do not work on floating point numbers. So, we can cast them as integers, to remove everything after the decimal point.

#select the fields we need, limited to given social catid's
df1 =

There are ways to get around this. PySpark offers two methods approxQuantile and approx_percentile. Approx quantile takes in three arguments: the field you want to calculate the percentile of; the percentile you want to calculate (e.g. 0.25 for 25th percentile) and the error you are willing to accept (between 0 and 1 (where 0 is the exact value)). This returns a list, which is not all that helpful; so in most cases, the approx_percentile is preferable.

df.approxQuantile('fieldname', [quantile to calculate], accepted_error)
df.approxQuantile('usage', [0.5], 0.25)

To use approx percentile, we have a number of options. First, we could simply use the function by running SQL on our dataframes. This is a super useful way to be able to implement this function, expecially if you’re writing back to a database table.

spark.sql("SELECT percentile_approx(x, 0.5) FROM df")

We could also run the below. Which would give us a new column called ‘%10’ which would be the result of the percentile calculation. This again has it’s advantages – keeping all your code in a single language, rather than mix and matching between Spark functions & SQL.

df \
    .groupby('grp') \
    .agg(round(sqlfunc.expr('percentile_approx(fieldname, 0.10)'),2).alias('%10'))

These simple changes to your script will result in a huge improvement in performance – but there will be a slight dip in absolute accuracy. Unless you’re chasing accuracy to an extreme level (for example in financial trading or medical use-cases), this will probably be sufficient.