A comprehensive guide to windowing functions in PySpark for data science

Window functions are incredibly useful. Within a single query, you can find out things which may have otherwise been tricky. In this article, I will cover all of the key window functions in Pyspark.

First off, we need to define our dataframe – you can get the data to play along here.

df = spark.read.format("csv").option("header", "true").load("sales.csv")
df.show()

Now, we have a dataframe which holds a whole bunch of information about the orders placed with our business. The first example we’re going to work through is a simple count. Here, I want to return a dataframe with three columns: City, Customer Name and a new column, which is the total count of customer id’s that have placed an order within that city.

So, we define our window and call it ‘w’. Here we simply say ‘give me the answer, per city’. We can now use this definition in our select statement – we’re going to count customer ID over our window (i.e. per city) and we’re going to alias the column (give it a new name) as simply ‘count’.

from pyspark.sql.window import Window
import pyspark.sql.functions as func
from pyspark.sql.functions import *
#count the total number of customers per city
w = Window.partitionBy("City")
df.select("City","Customer Name", count("Customer ID").over(w).alias("count")).show()

The output of our analysis is as below. As you can see, the count is the same for every record assocated to a city. In Tyler, we have 5 orders and so count is always equal to 5. Note though, the customer is actually the same on 4 of the 5 occasions – this is because we did not count distinct, so we did not disregard duplicate customer IDs in our analysis.

+-------------+----------------+-----+
|         City|   Customer Name|count|
+-------------+----------------+-----+
|        Tyler|    Carol Darley|    5|
|        Tyler|    Carol Darley|    5|
|        Tyler|    Carol Darley|    5|
|        Tyler|    Carol Darley|    5|
|        Tyler|     Lisa Hazard|    5|
|Bowling Green| Katharine Harms|   10|
|Bowling Green|      Nora Preis|   10|
|Bowling Green|      Brian Derr|   10|
|Bowling Green|    Neoma Murray|   10|

We can do the same with sum, I’ve shown the code below. This shows the total sales value per city, in a new column.

df.select("City","Customer Name", "Sales", sum("Sales").over(w).alias("City_Sales")).show()
+-------------+----------------+-------+------------------+
|         City|   Customer Name|  Sales|        City_Sales|
+-------------+----------------+-------+------------------+
|        Tyler|    Carol Darley|  2.688|           347.206|
|        Tyler|    Carol Darley| 27.816|           347.206|
|        Tyler|    Carol Darley| 82.524|           347.206|
|        Tyler|    Carol Darley|182.994|           347.206|
|        Tyler|     Lisa Hazard| 51.184|           347.206|
|Bowling Green| Katharine Harms| 140.81|          2077.375|
|Bowling Green|      Nora Preis| 264.32|          2077.375|
|Bowling Green|      Brian Derr|     71|          2077.375|
|Bowling Green|    Neoma Murray|  5.553|          2077.375|
|Bowling Green|   Maurice Satty| 26.982|          2077.375|

With Min and Max, I am going to do something a little bit different – I am going to show you how you can create a new column, with the values stored in it. Here, I have used the .withColumn function, to define a new column, which will be populated with the value from our window function.

w = Window.partitionBy("City")
df = df.withColumn('Min', min("Sales").over(w))
df = df.withColumn('Max', max("Sales").over(w))
df.select("City","Customer Name", "Sales",'Min', 'Max').show()
+-------------+----------------+-------+-------+-------+
|         City|   Customer Name|  Sales|    Min|    Max|
+-------------+----------------+-------+-------+-------+
|Bowling Green| Katharine Harms| 140.81| 139.96|899.982|
|Bowling Green|      Nora Preis| 264.32| 139.96|899.982|
|Bowling Green|      Brian Derr|     71| 139.96|899.982|
|Bowling Green|    Neoma Murray|  5.553| 139.96|899.982|
|Bowling Green|   Maurice Satty| 26.982| 139.96|899.982|
|Bowling Green|   Maurice Satty|  6.912| 139.96|899.982|
|Bowling Green|   Maurice Satty|435.504| 139.96|899.982|
|Bowling Green|    Bill Stewart|899.982| 139.96|899.982|
|Bowling Green|    Bill Stewart| 86.352| 139.96|899.982|
|Bowling Green|    Bill Stewart| 139.96| 139.96|899.982|
|      Edmonds|Zuschuss Carroll|  11.52|  11.52|  81.98|
|      Edmonds|Zuschuss Carroll|1298.55|  11.52|  81.98|
|      Edmonds|Zuschuss Carroll| 213.92|  11.52|  81.98|
|      Edmonds|Zuschuss Carroll|  25.78|  11.52|  81.98|
|      Edmonds|   Bruce Stewart|   19.6|  11.52|  81.98|

The final aggregate windowing function we will use is average. Here, we are comparing the transaction value versus the average sale value for that city.

w = Window.partitionBy("City")
df = df.select("City","Customer Name", "Sales", func.round(avg("Sales").over(w)).alias("City_Sales"))
df = df.withColumn('diff_to_avg', func.round(df.Sales - df.City_Sales))
df.show()
+-------------+----------------+-------+----------+-----------+
|         City|   Customer Name|  Sales|City_Sales|diff_to_avg|
+-------------+----------------+-------+----------+-----------+
|        Tyler|    Carol Darley|  2.688|      69.0|      -66.0|
|        Tyler|    Carol Darley| 27.816|      69.0|      -41.0|
|        Tyler|    Carol Darley| 82.524|      69.0|       14.0|
|        Tyler|    Carol Darley|182.994|      69.0|      114.0|
|        Tyler|     Lisa Hazard| 51.184|      69.0|      -18.0|
|Bowling Green| Katharine Harms| 140.81|     208.0|      -67.0|
|Bowling Green|      Nora Preis| 264.32|     208.0|       56.0|
|Bowling Green|      Brian Derr|     71|     208.0|     -137.0|
|Bowling Green|    Neoma Murray|  5.553|     208.0|     -202.0|
|Bowling Green|   Maurice Satty| 26.982|     208.0|     -181.0|
|Bowling Green|   Maurice Satty|  6.912|     208.0|     -201.0|

Next, we will look at row number, which simply, gives a row number on an ordered set of data. Below, we have ordered by Sales, descending. Notice that, the row numbers reset for each city, as we have partitioned by City in our window definition.

w = Window.partitionBy("City").orderBy(desc('Sales'))
df.select("City","Customer Name", row_number().over(w).alias("row")).show()
+-------------+-----------------+---+
|         City|    Customer Name|row|
+-------------+-----------------+---+
|        Tyler|     Carol Darley|  1|
|        Tyler|      Lisa Hazard|  2|
|        Tyler|     Carol Darley|  3|
|        Tyler|     Carol Darley|  4|
|        Tyler|     Carol Darley|  5|
|Bowling Green|     Bill Stewart|  1|
|Bowling Green|     Bill Stewart|  2|
|Bowling Green|       Brian Derr|  3|

We could also use the Rank function, which achieves the same:

w = Window.partitionBy("City").orderBy(desc('Sales'))
df.select("City","Customer Name", rank().over(w).alias("rank")).show()
+-------------+-----------------+----+
|         City|    Customer Name|rank|
+-------------+-----------------+----+
|        Tyler|     Carol Darley|   1|
|        Tyler|      Lisa Hazard|   2|
|        Tyler|     Carol Darley|   3|
|        Tyler|     Carol Darley|   4|
|        Tyler|     Carol Darley|   5|
|Bowling Green|     Bill Stewart|   1|
|Bowling Green|     Bill Stewart|   2|
|Bowling Green|       Brian Derr|   3|

We can also utilise Dense Rank. This provides the same result as rank, but, when two values are equal, rank will leave a gap in your results; while Dense Rank will not.

w = Window.partitionBy("City").orderBy(desc('Sales'))
df.select("City","Customer Name", dense_rank().over(w).alias("rank")).show()
+-------------+-----------------+----+
|         City|    Customer Name|rank|
+-------------+-----------------+----+
|        Tyler|     Carol Darley|   1|
|        Tyler|      Lisa Hazard|   2|
|        Tyler|     Carol Darley|   3|
|        Tyler|     Carol Darley|   4|
|        Tyler|     Carol Darley|   5|
|Bowling Green|     Bill Stewart|   1|
|Bowling Green|     Bill Stewart|   2|
|Bowling Green|       Brian Derr|   3|

Often, in timeseries data, you want to compare the current record with the preceding or subsequent record; for this, we can use lead (to get the next value) or lag (to get the previous value).

Here, we require a bit of data prep – as you can’t order by date, as it currently is presented in the dataset. So, I’ve created a new column ‘dt’ and formatted our original date field as a timestamp. We then use the lead function in a very similar way to the functions above. Notice the 1 in lead(col(‘sales’),1). This means, take the 1 leading value (the one value after this value.

from pyspark.sql.functions import to_timestamp
df = df.withColumnRenamed('Order Date', 'Date')
df = df.withColumn('dt', to_timestamp(df.Date, 'MM/dd/yyyy'))
w = Window.partitionBy("dt").orderBy(asc('dt'))
df = df.select("dt", "Sales", lead(col("Sales"),1).over(w).alias("Next_Row_sales"))
df.show()
+-------------------+--------+--------------+
|                 dt|   Sales|Next_Row_sales|
+-------------------+--------+--------------+
|2014-08-04 00:00:00| 1089.75|        447.84|
|2014-08-04 00:00:00|  447.84|          16.4|
|2014-08-04 00:00:00|    16.4|        399.96|
|2014-08-04 00:00:00|  399.96|        Black"|
|2014-08-04 00:00:00|  Black"|        13.184|
|2014-08-04 00:00:00|  13.184|        101.96|
|2014-08-04 00:00:00|  101.96|        259.74|
|2014-08-04 00:00:00|  259.74|        255.42|
|2014-08-04 00:00:00|  255.42|          null|
|2015-04-26 00:00:00| 831.936|         97.04|
|2015-04-26 00:00:00|   97.04|        72.784|
|2015-04-26 00:00:00|  72.784|       408.422|
|2015-04-26 00:00:00| 408.422|        63.936|
|2015-04-26 00:00:00|  63.936|         59.52|

We can do exactly the same with the lag function:

from pyspark.sql.functions import to_timestamp
df = df.withColumnRenamed('Order Date', 'Date')
df = df.withColumn('dt', to_timestamp(df.Date, 'MM/dd/yyyy'))
w = Window.partitionBy("dt").orderBy(asc('dt'))
df = df.select("dt", "Sales", lag(col("Sales"),1).over(w).alias("Next_Row_sales"))
df.show()
+-------------------+--------+--------------+
|                 dt|   Sales|Next_Row_sales|
+-------------------+--------+--------------+
|2014-08-04 00:00:00| 1089.75|          null|
|2014-08-04 00:00:00|  447.84|       1089.75|
|2014-08-04 00:00:00|    16.4|        447.84|
|2014-08-04 00:00:00|  399.96|          16.4|
|2014-08-04 00:00:00|  Black"|        399.96|
|2014-08-04 00:00:00|  13.184|        Black"|
|2014-08-04 00:00:00|  101.96|        13.184|
|2014-08-04 00:00:00|  259.74|        101.96|
|2014-08-04 00:00:00|  255.42|        259.74|
|2015-04-26 00:00:00| 831.936|          null

Share the Post:

Related Posts