Sitemap

Scalable Cumulative Sum in PySpark: From Window Functions to Partition-Aware Aggregation

7 min readJun 17, 2025

--

While working on a large-scale search analytics pipeline, I encountered a classic but deceptively complex challenge — calculating a cumulative sum. On paper, it sounded simple. In practice, it was anything but.
The dataset in question was huge — over 5 billion records per day. Each record represented a search query with metrics like search_count, click_count, and order_count. My task was to compute the cumulative distribution of search counts to understand which keyword groups contributed to the top search volume — a kind of percentile analysis across billions of rows.
Naturally, I turned to PySpark’s window functions. But very soon, I hit performance walls: memory overflows, long job durations, and excessive shuffling. It became clear that the standard approach wouldn’t cut it at this scale.
This article walks through how I rethought the problem — going from the familiar window functions to a partition-aware cumulative sum strategy that scales seamlessly on distributed data.

Let’s start by understanding why the direct approach breaks down and how to build a more robust solution step by step.

The Direct Window Function Approach — And Why It Breaks at Scale

The most straightforward way to calculate a cumulative sum in PySpark is to use the Window function. Let’s say we have a dataset of search keywords and their associated search_count. We want to compute a cumulative sum sorted in descending order of search_count.

from pyspark.sql import functions as F
from pyspark.sql.window import Window

window_spec = Window.orderBy(F.col("search_count").desc())
df = df.withColumn("cumulative_sum", F.sum("search_count").over(window_spec))

This code is elegant, declarative, and easy to write. And it works perfectly — until it doesn’t.

What’s Really Happening Under the Hood?
When you use a Window function with an orderBy clause across the entire dataset (i.e., no partitionBy), Spark must impose a global ordering. This means it performs a wide shuffle — redistributing all records across the cluster to maintain the required sort order. Here’s what that entails:

  1. Full Dataset Shuffle: Every row is sent to the appropriate partition based on the sort key (search_count in our case). This is extremely expensive, especially when the dataset is massive (billions of rows).
  2. Skewed Partitioning: Since there is no partitionBy, Spark effectively treats the entire dataset as a single window group, funneling it all through a narrow DAG stage that usually ends up in one or few partitions.
  3. Memory Pressure and GC Overhead: Spark attempts to accumulate and sort the entire dataset in memory to apply the cumulative sum logic. This can easily exceed executor memory limits, causing out-of-memory (OOM) errors or triggering excessive garbage collection (GC) cycles that drastically slow down execution.

When I initially tried this direct method on our search dataset (~5 billion records/day), the Spark job:

a. Hung indefinitely at the stage involving the cumulative sum
b. Spilled to disk heavily, despite large executor memory
c. Failed frequently with ExecutorLostFailure and OOM exceptions
d. And even when it did succeed on a small subset, runtime exceeded hours

Why This Approach Doesn’t Scale
In a distributed system like Spark, global operations such as total ordering and full cumulative tracking across all rows violate the principles of distributed parallelism. They force serialization and coordination where you’d ideally want independent computation.
The direct approach treats the dataset monolithically. Spark has no choice but to “collect” and sort it in logical terms — defeating the purpose of distributing the computation.

A Partition-Aware Approach to Scalable Cumulative Sum in Spark

After running into serious scalability issues with the window function approach, I turned to a more distributed-friendly strategy: computing partition-wise cumulative sums and then applying partition-level offsets. This method leverages Spark’s distributed architecture and avoids full data shuffles or global state computation.
Let’s walk through the approach, step-by-step.

Step 1: Range-Based Repartitioning with Sort Order
To distribute data across partitions in a logically ordered manner, we use repartitionByRange() — a powerful alternative to standard repartition().

🔍 What repartitionByRange() Actually Does:

  • Internally, Spark performs a sample of the sort key column (search_count in our case) to determine range boundaries.
  • Each partition receives rows with values falling into specific non-overlapping ranges.
  • Importantly, Spark does not sort rows within each partition by default — you still need sortWithinPartitions() if strict local ordering is required for deterministic output.
num_partitions = 8
ordered_df = input_df.repartitionByRange(num_partitions, F.col("search_count").desc())
.sortWithinPartitions(F.col("search_count").desc())

Step 2: Add Partition ID for Local Grouping
We identify each row’s physical partition using spark_partition_id method.
This column helps us group rows logically by their assigned partition during the next steps.

ordered_df = ordered_df.withColumn("partition_id", F.spark_partition_id())

Step 3: Compute Partial Cumulative Sums Per Partition
Now that the data is range-partitioned and locally ordered, we calculate partial cumulative sums within each partition.
💡 This works because each partition holds a slice of data sorted in descending search_count order — allowing local cumulative sums to be correct within that range.

window_spec = Window.partitionBy("partition_id").orderBy(F.col("search_count").desc())
partial_df = ordered_df.withColumn("partial_cumsum", F.sum("search_count").over(window_spec))

Step 4: Aggregate Partition Totals
We compute the total search_count contributed by each partition. This will help us later apply offsets to downstream partitions.

partition_totals_df = (
ordered_df.groupBy("partition_id")
.agg(F.sum("search_count").alias("partition_total"))
.orderBy("partition_id")
)

Step 5: Compute Cumulative Offsets for Each Partition
We now determine how much to “offset” each partition’s local cumulative sum by — using a rolling sum over the partition totals.
This gives us an offset for each partition — the total search counts of preceding partitions.

offset_window = Window.orderBy("partition_id").rowsBetween(Window.unboundedPreceding, -1)

partition_offsets_df = partition_totals_df.withColumn(
"partition_offset", F.sum("partition_total").over(offset_window)
)

Step 6: Join Offsets and Finalize Global Cumulative Sum
Now we join the offsets back to each row using the partition_id, and compute the true cumulative sum

final_df = (
partial_df
.join(partition_offsets_df.select("partition_id", "partition_offset"), on="partition_id", how="left")
.fillna({"partition_offset": 0})
.withColumn("cumulative_sum", F.col("partial_cumsum") + F.col("partition_offset"))
.drop("partition_offset", "partial_cumsum", "partition_id")
)

Illustrating Partition-Aware Cumulative Sum with a Sample Dataset

Let’s say you have the following input DataFrame representing search activity. You want to compute the cumulative search count ordered by search_count in descending order. For simplicity, we’ll use 2 partitions here.

+------------+-------------+
| keyword | search_count|
+------------+-------------+
| shoes | 100 |
| bags | 90 |
| socks | 80 |
| watch | 70 |
| perfume | 60 |
| hat | 50 |
| belt | 40 |
| tie | 30 |
+------------+-------------+

🔹 Step 1: Repartition by Range
Let’s assume Spark’s range sampler distributes data like this:

  • Partition 0 (higher values): shoes, bags, socks, watch
  • Partition 1 (lower values): perfume, hat, belt, tie

After repartitioning, we add a column to show partition ID:

Partition 0
+---------+-------------+--------------+
| keyword | search_count| partition_id |
+---------+-------------+--------------+
| shoes | 100 | 0 |
| bags | 90 | 0 |
| socks | 80 | 0 |
| watch | 70 | 0 |

Partition 1
+---------+-------------+--------------+
| keyword | search_count| partition_id |
+---------+-------------+--------------+
| perfume | 60 | 1 |
| hat | 50 | 1 |
| belt | 40 | 1 |
| tie | 30 | 1 |

🔹 Step 2: Partial Cumulative Sum (Within Partition)

Partition 0
+---------+-------------+--------------+----------------+
| keyword | search_count| partition_id | partial_cumsum |
+---------+-------------+--------------+----------------+
| shoes | 100 | 0 | 100 |
| bags | 90 | 0 | 190 |
| socks | 80 | 0 | 270 |
| watch | 70 | 0 | 340 |

Partition 1
+---------+-------------+--------------+----------------+
| perfume | 60 | 1 | 60 |
| hat | 50 | 1 | 110 |
| belt | 40 | 1 | 150 |
| tie | 30 | 1 | 180 |

🔹 Step 3: Compute Partition Totals

Partition Totals:
+--------------+------------------+
| partition_id | partition_total |
+--------------+------------------+
| 0 | 340 |
| 1 | 180 |

🔹 Step 4: Calculate Partition Offsets

Partition Offsets:
+--------------+------------------+-----------------+
| partition_id | partition_total | partition_offset|
+--------------+------------------+-----------------+
| 0 | 340 | null (→ 0) |
| 1 | 180 | 340 |

🔹 Step 5: Final Global Cumulative Sum
We now join partition_offset back and compute the final cumulative sum.This gives us the correct global cumulative sum without ever sorting or windowing the full dataset!

Partition 0
+---------+-------------+----------------+--------------------+
| keyword | search_count| partial_cumsum | cumulative_sum |
+---------+-------------+----------------+--------------------+
| shoes | 100 | 100 | 100 + 0 = 100 |
| bags | 90 | 190 | 190 + 0 = 190 |
| socks | 80 | 270 | 270 + 0 = 270 |
| watch | 70 | 340 | 340 + 0 = 340 |

Partition 1
+---------+-------------+----------------+--------------------+
| perfume | 60 | 60 | 60 + 340 = 400 |
| hat | 50 | 110 | 110 + 340 = 450 |
| belt | 40 | 150 | 150 + 340 = 490 |
| tie | 30 | 180 | 180 + 340 = 520 |

Computing a cumulative sum might appear trivial at first glance — until you’re faced with billions of rows of data, distributed across dozens or hundreds of executors. Traditional window functions, while elegant and concise, quickly become bottlenecks at this scale due to their inherent need for full data shuffling and centralized aggregation.

The approach we explored embraces Spark’s distributed processing philosophy. By breaking the problem down into:

  • local, partition-level computations,
  • computing partition-wise rollups,
  • and finally offsetting partial results to derive global context,

…we effectively eliminate the need for a full shuffle or a global sort. This not only preserves parallelism and partition locality, but also scales linearly as data grows.

If you found the article to be helpful, you can buy me a coffee here:
Buy Me A Coffee.

--

--

VivekR
VivekR

Written by VivekR

Data Engineer, Big Data Enthusiast and Automation using Python

No responses yet