The core challenge with large-scale, long-running stream aggregations isn’t the computation itself, but the management of state. Stateless transformations running on a framework like Dask are trivial to parallelize and recover; if a worker fails, the scheduler simply re-runs the task on another machine. The moment state enters the picture—for instance, calculating a nunique
over a multi-terabyte dataset—this model breaks down. Storing intermediate sets of unique IDs in worker memory is both a capacity bottleneck and a guarantee of data loss on failure. A full re-computation from source is often operationally infeasible.
Our initial problem was a stream of user interaction events, billions per day, from which we needed to derive hourly unique user counts per feature flag. The naive Dask approach, a dd.read_parquet(...).groupby('feature_flag').user_id.nunique().compute()
, would work on a small scale but failed spectacularly in production. It required shuffling enormous amounts of data and held gigabytes of intermediate sets in memory, leading to worker restarts and cascading failures. The system was fragile.
The concept we landed on was to push state down to the worker level, making it local and persistent. Instead of a massive, distributed shuffle, each worker would be responsible for aggregating a subset of partitions. It would maintain its own state on local disk. This avoids network overhead for state access and leverages fast local SSDs. On failure, a task could be rescheduled, reopen its local state store, and continue where it left off, provided the input data was replayable. This design inherently forgoes strong consistency. We are not building a distributed transactional database. We are building a system that is Basically Available, operates in a Soft state, and is Eventually consistent—a classic BASE system.
The technology selection followed logically. Dask remains the orchestrator due to its flexible task scheduling and Python-native ecosystem. For the local state store, an embedded key-value database was the only viable option. A network-attached store like Redis or a relational database would re-introduce network latency as the primary bottleneck, defeating the purpose of local state. LevelDB, via the plyvel
library, was chosen for its simplicity, high write throughput thanks to its Log-Structured Merge-Tree (LSM) design, and minimal operational overhead. It’s a library, not a server, which is exactly what’s needed for a worker-local store.
The Architecture of Worker-Local State
The implementation hinges on a StateStore
class that runs within the Dask worker process. This class is responsible for managing a dedicated LevelDB instance. Each worker gets its own database directory, ensuring isolation. A critical design decision is that the state itself must be structured to support idempotent operations. Simply updating a value is insufficient, as a task might fail after writing to LevelDB but before acknowledging completion to the Dask scheduler, leading to double processing upon re-execution.
To achieve idempotency, every data batch we process must have a unique, monotonically increasing identifier (e.g., a Kafka offset, a batch timestamp, or a simple integer sequence). Our state in LevelDB will not just be the aggregated value but a tuple: (aggregated_value, last_processed_batch_id)
. The core processing logic then becomes:
- Read the current state
(value, last_id)
for a given aggregation key. - If the incoming batch’s ID is less than or equal to
last_id
, skip processing. - Otherwise, perform the aggregation and write the new state
(new_value, current_batch_id)
back to LevelDB.
This ensures that re-running a task on the same batch of data has no adverse effect. The system will always move forward, eventually converging on the correct result.
flowchart TD subgraph Dask Cluster A[Dask Scheduler] subgraph Worker 1 direction LR W1_Task[Task α] --> W1_DB[(LevelDB Instance 1)] end subgraph Worker 2 direction LR W2_Task[Task β] --> W2_DB[(LevelDB Instance 2)] end subgraph Worker N direction LR WN_Task[Task γ] --> WN_DB[(LevelDB Instance N)] end end B[Input Data Source] -- Partitions --> A A -- Schedules Tasks --> W1_Task A -- Schedules Tasks --> W2_Task A -- Schedules Tasks --> WN_Task W1_Task -- Reads/Writes State --> W1_DB W2_Task -- Reads/Writes State --> W2_DB WN_Task -- Reads/Writes State --> WN_DB C{Final Aggregation} W1_Task -- Partial Result --> C W2_Task -- Partial Result --> C WN_Task -- Partial Result --> C
Implementation: The StateStore and Worker Setup
First, let’s define the StateStore
manager. This class will be instantiated on each worker. A common pitfall here is trying to serialize the StateStore
object itself and send it with the Dask task. This will fail because the underlying LevelDB connection object is not pickleable. Instead, we instantiate it once per worker and access it via a global-like dictionary managed on the worker.
The code needs robust serialization. We use msgpack
for its performance and cross-language compatibility, and blosc
for compression, as the sets of unique IDs can become large.
# state_manager.py
import plyvel
import msgpack
import blosc
import logging
import os
import threading
# Configure basic logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(threadName)s - %(levelname)s - %(message)s'
)
class StateStore:
"""
Manages a worker-local LevelDB instance for stateful operations.
This class is designed to be instantiated once per worker process.
"""
def __init__(self, db_path: str):
self.db_path = db_path
self._db = None
# In a real-world project, a more robust lock per key might be needed
# if multiple threads on a worker could access the same aggregation key.
# For Dask's default threading model, this simple lock is sufficient.
self._lock = threading.Lock()
self._initialize_db()
def _initialize_db(self):
"""Creates the DB directory and opens the LevelDB instance."""
try:
os.makedirs(self.db_path, exist_ok=True)
self._db = plyvel.DB(self.db_path, create_if_missing=True)
logging.info(f"Initialized LevelDB at {self.db_path}")
except Exception as e:
logging.error(f"Failed to initialize LevelDB at {self.db_path}: {e}")
raise
def close(self):
"""Closes the database connection."""
if self._db and not self._db.closed:
self._db.close()
logging.info(f"Closed LevelDB at {self.db_path}")
def _serialize(self, data) -> bytes:
"""Serializes and compresses data."""
packed = msgpack.packb(data, use_bin_type=True)
return blosc.compress(packed, typesize=8, cname='zstd')
def _deserialize(self, data: bytes):
"""Decompresses and deserializes data."""
unpacked = blosc.decompress(data)
return msgpack.unpackb(unpacked, raw=False)
def get_state(self, key: str) -> tuple:
"""
Retrieves state for a given key.
Returns a tuple of (state_data, last_batch_id).
Returns (None, -1) if the key does not exist.
"""
key_bytes = key.encode('utf-8')
raw_value = self._db.get(key_bytes)
if raw_value is None:
return None, -1 # Default state for a new key
try:
return self._deserialize(raw_value)
except Exception as e:
logging.error(f"Deserialization failed for key '{key}': {e}")
# In a production system, this might trigger a quarantine or alert
return None, -1
def update_state(self, key: str, new_state_data, batch_id: int):
"""
Atomically updates the state for a key if the new batch_id is higher.
This is the core of our idempotency mechanism.
"""
with self._lock:
# Re-check inside the lock to handle concurrent updates from threads
_, last_batch_id = self.get_state(key)
if batch_id <= last_batch_id:
logging.warning(
f"Skipping update for key '{key}'. "
f"Incoming batch ID {batch_id} <= stored ID {last_batch_id}."
)
return
key_bytes = key.encode('utf-8')
value_tuple = (new_state_data, batch_id)
serialized_value = self._serialize(value_tuple)
# Use a write batch for atomicity, though for a single put it's less critical.
# It's good practice for more complex transactions.
with self._db.write_batch() as wb:
wb.put(key_bytes, serialized_value)
logging.debug(f"Successfully updated state for key '{key}' to batch ID {batch_id}")
def get_all_states(self) -> dict:
"""
Retrieves all states from the database. Used for final aggregation.
This can be memory-intensive and should only be called at the end.
"""
results = {}
for key_bytes, value_bytes in self._db:
key = key_bytes.decode('utf-8')
results[key] = self._deserialize(value_bytes)
return results
# This function will be run on each Dask worker when it starts up.
def initialize_worker_state(dask_worker):
"""
Initializes the StateStore on a Dask worker and attaches it.
"""
worker_id = dask_worker.id
# Ensure the path is unique per worker and in a non-temporary location
db_path = os.path.join(os.getcwd(), f"dask-worker-space/{worker_id}/leveldb_state")
# A common mistake is to overlook cleanup. This simplistic example doesn't
# handle it, but a production system would need to clear this on restart.
store = StateStore(db_path)
dask_worker.state_store = store
logging.info(f"StateStore attached to worker {worker_id} at {db_path}")
# Return True to indicate success
return True
The Core Aggregation Task
With the state management infrastructure in place, the actual data processing function can be written. This function is what Dask will execute on a partition of data. It needs to be designed to fetch the state store from the worker it’s running on.
# processor.py
import pandas as pd
from dask.distributed import get_worker
import logging
def process_partition(df: pd.DataFrame, batch_id: int) -> dict:
"""
Processes one partition of data (a DataFrame), updating worker-local state.
Args:
df: The pandas DataFrame for the current partition.
batch_id: The unique ID for this batch of data.
Returns:
A dictionary summarizing the updates performed.
"""
try:
worker = get_worker()
state_store = worker.state_store
except Exception as e:
# This will happen if run outside a Dask worker context
logging.error(f"Could not get state_store from worker. Are you running in a Dask cluster? Error: {e}")
raise
# Group data within the partition first
grouped = df.groupby('feature_flag')['user_id'].apply(set)
update_counts = {}
for feature_flag, user_ids_in_batch in grouped.items():
# The aggregation key is the feature_flag
key = str(feature_flag)
# 1. Get current state
current_user_set, last_batch_id = state_store.get_state(key)
# Initialize state if it's the first time seeing this key
if current_user_set is None:
current_user_set = set()
# 2. Check for idempotency
if batch_id <= last_batch_id:
# This partition has already been processed for this key.
logging.info(f"Idempotency check: Skipping batch {batch_id} for key '{key}'")
continue
# 3. Perform the stateful update
# A real-world project might use a more memory-efficient structure
# like a roaring bitmap or HyperLogLog here. A Python set is simple but hungry.
new_user_set = current_user_set.union(user_ids_in_batch)
# 4. Write back the new state
state_store.update_state(key, new_user_set, batch_id)
update_counts[key] = len(new_user_set)
logging.info(f"Processed batch {batch_id} on worker {worker.id}, updates: {len(update_counts)}")
return {"batch_id": batch_id, "worker_id": worker.id, "updates": len(update_counts)}
Orchestrating the Execution
Now we tie it all together. The main script will set up a Dask cluster, preload the initialize_worker_state
function, generate some mock data representing batches from our stream, and then submit the process_partition
tasks.
A critical part of the orchestration is data partitioning. To ensure that all data for a given feature_flag
is processed by the same worker (and thus updates the same LevelDB instance), we should partition the input data by feature_flag
. In Dask DataFrame, this is achieved via set_index('feature_flag')
. This co-locates the data with the state, which is the entire point of this architecture.
# main_runner.py
import dask
import dask.dataframe as dd
from dask.distributed import Client, LocalCluster
import pandas as pd
import numpy as np
import os
import shutil
import time
# Import our custom modules
from state_manager import initialize_worker_state
from processor import process_partition
def generate_mock_data(batch_id, num_rows, num_flags):
"""Generates a mock DataFrame for a single batch."""
flags = [f"flag_{i}" for i in range(num_flags)]
data = {
'feature_flag': np.random.choice(flags, size=num_rows),
'user_id': np.random.randint(0, num_rows // 2, size=num_rows),
'timestamp': pd.to_datetime(time.time() - (10 - batch_id) * 3600, unit='s')
}
return pd.DataFrame(data)
def cleanup_worker_space():
"""Utility to clean up state directories before a run."""
path = "dask-worker-space"
if os.path.exists(path):
print(f"Cleaning up old worker state at ./{path}")
shutil.rmtree(path)
def get_final_results(client):
"""
Collects the final aggregated state from all workers.
"""
def get_local_state(dask_worker):
return dask_worker.state_store.get_all_states()
# Run the function on all workers and gather results
results_per_worker = client.run(get_local_state)
# The state is partitioned by key across workers. We need to merge.
final_aggregates = {}
for worker, worker_data in results_per_worker.items():
for key, (user_set, batch_id) in worker_data.items():
if key not in final_aggregates:
final_aggregates[key] = {'users': set(), 'last_batch_id': -1}
# Merge sets
final_aggregates[key]['users'].update(user_set)
# Take the max batch ID
final_aggregates[key]['last_batch_id'] = max(
final_aggregates[key]['last_batch_id'], batch_id
)
# Return the final counts
return {k: len(v['users']) for k, v in final_aggregates.items()}
if __name__ == "__main__":
cleanup_worker_space()
# Set up a local cluster with 2 workers, each with 2 threads
cluster = LocalCluster(n_workers=2, threads_per_worker=2)
client = Client(cluster)
# Preload the state initialization function on all workers.
# This is a crucial step.
client.run(initialize_worker_state)
# --- Simulate processing 5 batches of data ---
num_batches = 5
futures = []
print(f"Submitting {num_batches} batches for processing...")
for i in range(num_batches):
# Generate data and convert to a Dask DataFrame
pdf = generate_mock_data(batch_id=i, num_rows=10000, num_flags=10)
ddf = dd.from_pandas(pdf, npartitions=4)
# The key to co-locating data with state: partition by the aggregation key.
ddf = ddf.set_index('feature_flag')
# Use map_partitions to run our custom function.
# We pass the batch_id as an argument.
future = ddf.map_partitions(process_partition, batch_id=i, meta=object).compute()
futures.append(future)
print("All batches submitted. Waiting for completion...")
# In a real stream, this would be a continuous loop.
# Here we just wait for the submitted tasks.
client.gather(futures)
print("Processing complete.")
# --- Retrieve Final Results ---
print("\n--- Final Aggregated Counts ---")
final_counts = get_final_results(client)
for flag, count in sorted(final_counts.items()):
print(f"{flag}: {count} unique users")
# --- Simulate a Failure and Re-run ---
# Imagine batch 3 failed and needs to be re-run.
# Due to our idempotency logic, this should not change the final result.
print("\n--- Simulating re-run of batch 3 ---")
pdf_rerun = generate_mock_data(batch_id=3, num_rows=10000, num_flags=10)
ddf_rerun = dd.from_pandas(pdf_rerun, npartitions=4).set_index('feature_flag')
rerun_future = ddf_rerun.map_partitions(process_partition, batch_id=3, meta=object).compute()
client.gather(rerun_future)
print("Re-run complete.")
print("\n--- Final Aggregated Counts After Re-run ---")
final_counts_after_rerun = get_final_results(client)
for flag, count in sorted(final_counts_after_rerun.items()):
print(f"{flag}: {count} unique users")
# Validate that counts are identical
assert final_counts == final_counts_after_rerun
print("\nValidation successful: Results are identical after re-run.")
client.close()
cluster.close()
Limitations and Future Iterations
This architecture solves the immediate problem of volatile in-memory state, but it is not a silver bullet. The most significant limitation is that state is bound to the lifecycle of a Dask worker’s local storage. If a machine suffers a disk failure, its portion of the state is lost. A production-grade system would require a background process to periodically snapshot the LevelDB directories to a durable object store like S3. Recovery would then involve a more complex procedure of restoring that state to a new worker before it begins processing tasks.
Secondly, the state itself can grow without bounds. Our use of Python sets for unique counts is simple but memory-inefficient. For high-cardinality keys, this can still exhaust a worker’s memory during the union
operation. A more advanced implementation should use probabilistic data structures like HyperLogLog, which offer a trade-off between accuracy and a fixed, small memory footprint. This would change the StateStore
and process_partition
logic to handle serialized HLL objects.
Finally, the process of garbage collecting old state (e.g., hourly aggregations from previous days) is not addressed. A separate, scheduled Dask job would be needed to scan the LevelDB instances across all workers and delete keys corresponding to expired time windows. This adds another layer of operational complexity to the system. The current design is best suited for aggregations where state is either perpetually relevant or managed within finite, rolling windows that are periodically cleared in their entirety.