From 541fd95a55890ad51b783d3ffee54c811d5e1e6f Mon Sep 17 00:00:00 2001 From: Elia LIU Date: Fri, 6 Feb 2026 14:15:35 +1100 Subject: [PATCH 1/6] feat(ml): add stateless bundle-local size-aware batching and benchmark --- .../benchmarks/sort_and_batch_benchmark.py | 650 ++++++++++++++++++ sdks/python/apache_beam/transforms/util.py | 285 ++++++++ .../apache_beam/transforms/util_test.py | 217 ++++++ 3 files changed, 1152 insertions(+) create mode 100644 sdks/python/apache_beam/testing/benchmarks/sort_and_batch_benchmark.py diff --git a/sdks/python/apache_beam/testing/benchmarks/sort_and_batch_benchmark.py b/sdks/python/apache_beam/testing/benchmarks/sort_and_batch_benchmark.py new file mode 100644 index 000000000000..e1caa73e0a5d --- /dev/null +++ b/sdks/python/apache_beam/testing/benchmarks/sort_and_batch_benchmark.py @@ -0,0 +1,650 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +"""Benchmark: BatchElements vs SortAndBatchElements (weight-based splitting). + +Compares two batching strategies for variable-length inference workloads: + +- Baseline (BatchElements): fixed-count chunking, ignores element sizes. +- Stateless (SortAndBatchElements): within each bundle, sorts elements + by size, then splits batches using max_batch_weight so that each batch + has a bounded total weight. The improvement comes from *changing batch + boundaries* (weight-based splitting), NOT from sorting alone -- sorting + within fixed boundaries yields 0% gain (verified by strict-control). + +Padding ratio:: + + padding_ratio = sum(max_len_in_batch * batch_size) / sum(actual_lengths) + Lower is better. 1.0 = no padding waste. + +Methodology: + +- N=20 independent trials per condition (3 warmup trials excluded). +- Same input corpus (seed=42) for A/B comparison. +- Percentile method: linear interpolation between adjacent ranks + (equivalent to numpy.percentile with method='linear'). + For N=20 trials: P50 interpolates ranks 10-11 (0-indexed 9-10), + P95 interpolates ranks 19-20 (0-indexed 18-19), + P99 interpolates near rank 20 (0-indexed 18.81). +- Reports median [IQR] and P95 for each metric. +- Inference model: latency = batch_size * (max_seq_len / 50)^1.5 ms + (simulates transformer-like scaling). + +Run:: + + python3 -m apache_beam.testing.benchmarks.sort_and_batch_benchmark +""" + +import math +import random +import statistics +import time +from typing import Any +from typing import Callable +from typing import Dict +from typing import List +from typing import Optional +from typing import Sequence +from typing import Tuple + +# --------------------------------------------------------------------------- +# Data generators +# --------------------------------------------------------------------------- + + +def generate_highly_skewed_data( + num_elements: int, + min_length: int = 1, + max_length: int = 500, + seed: int = 42) -> List[str]: + """Pareto(alpha=1.2) -- most short, few very long.""" + random.seed(seed) + data = [] + for _ in range(num_elements): + length = int(random.paretovariate(1.2) * min_length) + length = min(max(length, min_length), max_length) + data.append('x' * length) + return data + + +def generate_lognormal_data( + num_elements: int, + mean_length: int = 50, + std_factor: float = 0.8, + min_length: int = 1, + max_length: int = 500, + seed: int = 42) -> List[str]: + """Log-normal -- moderate skew, typical NLP.""" + random.seed(seed) + mu = math.log(mean_length) + sigma = std_factor + data = [] + for _ in range(num_elements): + length = int(random.lognormvariate(mu, sigma)) + length = min(max(length, min_length), max_length) + data.append('x' * length) + return data + + +def generate_bimodal_data( + num_elements: int, + mode1_mean: int = 20, + mode2_mean: int = 200, + mode1_ratio: float = 0.7, + min_length: int = 1, + max_length: int = 500, + seed: int = 42) -> List[str]: + """Bimodal -- two distinct length groups.""" + random.seed(seed) + data = [] + for _ in range(num_elements): + if random.random() < mode1_ratio: + length = int(random.gauss(mode1_mean, mode1_mean * 0.3)) + else: + length = int(random.gauss(mode2_mean, mode2_mean * 0.3)) + length = min(max(length, min_length), max_length) + data.append('x' * length) + return data + + +def generate_low_variance_data( + num_elements: int, + mean_length: int = 100, + cv: float = 0.1, + min_length: int = 1, + max_length: int = 500, + seed: int = 42) -> List[str]: + """Low-variance control (CV=10%).""" + random.seed(seed) + std = mean_length * cv + data = [] + for _ in range(num_elements): + length = int(random.gauss(mean_length, std)) + length = min(max(length, min_length), max_length) + data.append('x' * length) + return data + + +# --------------------------------------------------------------------------- +# Batching algorithms +# --------------------------------------------------------------------------- + + +def simulate_batch_elements(data: List[str], + max_batch_size: int) -> List[List[str]]: + """Baseline: simple count-based chunking (BatchElements behaviour).""" + batches = [] + current_batch = [] + for element in data: + current_batch.append(element) + if len(current_batch) >= max_batch_size: + batches.append(current_batch) + current_batch = [] + if current_batch: + batches.append(current_batch) + return batches + + +def simulate_sort_and_batch_elements( + data: List[str], + max_batch_size: int, + max_batch_weight: int, + element_size_fn: Optional[Callable[[Any], int]] = None, + bundle_size: Optional[int] = None) -> List[List[str]]: + """Core mechanism: sort by size + weight-based batch splitting.""" + if element_size_fn is None: + element_size_fn = len + + # Split into bundles if specified (realistic Beam behavior) + if bundle_size is not None and bundle_size > 0: + bundles = [ + data[i:i + bundle_size] for i in range(0, len(data), bundle_size) + ] + else: + bundles = [data] + + all_batches = [] + + for bundle in bundles: + # Sort by element size (ascending) + sorted_bundle = sorted(bundle, key=element_size_fn) + + current_batch = [] + current_weight = 0 + + for element in sorted_bundle: + element_weight = element_size_fn(element) + + # Check if adding this element would exceed limits + would_exceed_count = len(current_batch) >= max_batch_size + would_exceed_weight = ( + current_weight + element_weight > max_batch_weight and current_batch) + + if would_exceed_count or would_exceed_weight: + all_batches.append(current_batch) + current_batch = [] + current_weight = 0 + + current_batch.append(element) + current_weight += element_weight + + if current_batch: + all_batches.append(current_batch) + + return all_batches + + +# --------------------------------------------------------------------------- +# Simulated inference +# --------------------------------------------------------------------------- + + +def simulate_inference_latency( + batch: List[str], base_latency_ms: float = 1.0) -> float: + """Simulate transformer inference: O(batch_size * seq_len^1.5).""" + if not batch: + return 0.0 + batch_size = len(batch) + max_len = max(len(s) for s in batch) + return base_latency_ms * batch_size * (max_len / 50)**1.5 + + +# --------------------------------------------------------------------------- +# Stats helpers +# --------------------------------------------------------------------------- + + +def percentile(data: Sequence[float], p: float) -> float: + """Percentile via linear interpolation between adjacent ranks. + + Equivalent to numpy.percentile(data, p, method='linear'). + For N=20: P50 interpolates ranks 10-11, P95 ranks 19-20, + P99 near rank 20 (fractional index 18.81). + """ + if not data: + return 0.0 + s = sorted(data) + k = (len(s) - 1) * p / 100 + f = int(k) + c = min(f + 1, len(s) - 1) + return s[f] + (k - f) * (s[c] - s[f]) + + +def compute_padding_stats(batches: List[List[str]]) -> Dict[str, Any]: + """Padding-efficiency statistics for a list of batches.""" + total_actual = 0 + total_padded = 0 + batch_sizes = [] + max_lengths = [] + + for batch in batches: + if not batch: + continue + lengths = [len(s) for s in batch] + mx = max(lengths) + total_actual += sum(lengths) + total_padded += mx * len(batch) + batch_sizes.append(len(batch)) + max_lengths.append(mx) + + efficiency = total_actual / total_padded if total_padded else 0.0 + padding_ratio = total_padded / total_actual if total_actual else float('inf') + + return { + 'efficiency': efficiency, + 'padding_ratio': padding_ratio, + 'num_batches': len(batches), + 'avg_batch_size': statistics.mean(batch_sizes) if batch_sizes else 0, + 'total_actual_length': total_actual, + 'total_padded_length': total_padded, + 'padding_overhead': total_padded - total_actual, + 'batch_size_p50': percentile(batch_sizes, 50) if batch_sizes else 0, + 'batch_size_p95': percentile(batch_sizes, 95) if batch_sizes else 0, + 'batch_size_max': max(batch_sizes) if batch_sizes else 0, + 'max_len_p50': percentile(max_lengths, 50) if max_lengths else 0, + 'max_len_p95': percentile(max_lengths, 95) if max_lengths else 0, + } + + +# --------------------------------------------------------------------------- +# Invariant validation +# --------------------------------------------------------------------------- + + +def validate_invariants( + data: List[str], + baseline_batches: List[List[str]], + stateless_batches: List[List[str]], + config: Dict[str, Any]) -> Dict[str, Any]: + """Validate element/token counts and batch-size equality.""" + n = len(data) + b_n = sum(len(b) for b in baseline_batches) + s_n = sum(len(b) for b in stateless_batches) + tok = sum(len(s) for s in data) + b_tok = sum(sum(len(s) for s in b) for b in baseline_batches) + s_tok = sum(sum(len(s) for s in b) for b in stateless_batches) + + return { + 'input_elements': n, + 'baseline_elements': b_n, + 'stateless_elements': s_n, + 'elements_match': n == b_n == s_n, + 'input_tokens': tok, + 'baseline_tokens': b_tok, + 'stateless_tokens': s_tok, + 'tokens_match': tok == b_tok == s_tok, + 'baseline_num_batches': len(baseline_batches), + 'stateless_num_batches': len(stateless_batches), + } + + +# --------------------------------------------------------------------------- +# Performance benchmark (N=20 trials) +# --------------------------------------------------------------------------- + + +def run_performance_benchmark( + data: List[str], + max_batch_size: int, + max_batch_weight: int, + bundle_size: int = 500, + num_trials: int = 20, + warmup_trials: int = 3) -> Tuple[Dict[str, Any], Dict[str, Any]]: + """Run N=20 trials for baseline and stateless.""" + total_tokens = sum(len(s) for s in data) + + baseline_trials = [] + stateless_trials = [] + + for trial_idx in range(warmup_trials + num_trials): + is_warmup = trial_idx < warmup_trials + + # --- Baseline --- + start = time.perf_counter() + b_batches = simulate_batch_elements(data, max_batch_size) + batch_ms = (time.perf_counter() - start) * 1000 + b_inf = [simulate_inference_latency(b) for b in b_batches] + b_e2e = batch_ms + sum(b_inf) + if not is_warmup: + baseline_trials.append({ + 'overhead_ms': batch_ms, + 'inference_ms': sum(b_inf), + 'e2e_ms': b_e2e, + 'batch_latencies': b_inf, + 'num_batches': len(b_batches), + }) + + # --- Stateless (SortAndBatchElements) --- + start = time.perf_counter() + s_batches = simulate_sort_and_batch_elements( + data, max_batch_size, max_batch_weight, bundle_size=bundle_size) + sort_ms = (time.perf_counter() - start) * 1000 + s_inf = [simulate_inference_latency(b) for b in s_batches] + s_e2e = sort_ms + sum(s_inf) + if not is_warmup: + stateless_trials.append({ + 'overhead_ms': sort_ms, + 'inference_ms': sum(s_inf), + 'e2e_ms': s_e2e, + 'batch_latencies': s_inf, + 'num_batches': len(s_batches), + }) + + def _stats(trials): + e2e = [t['e2e_ms'] for t in trials] + tput = [total_tokens / (t['e2e_ms'] / 1000) for t in trials] + overhead = [t['overhead_ms'] for t in trials] + all_lat = [l for t in trials for l in t['batch_latencies']] + return { + 'e2e_median': percentile(e2e, 50), + 'e2e_p25': percentile(e2e, 25), + 'e2e_p75': percentile(e2e, 75), + 'e2e_p95': percentile(e2e, 95), + 'tput_median': percentile(tput, 50), + 'tput_p25': percentile(tput, 25), + 'tput_p75': percentile(tput, 75), + 'tput_p95': percentile(tput, 95), + 'overhead_median': percentile(overhead, 50), + 'overhead_p25': percentile(overhead, 25), + 'overhead_p75': percentile(overhead, 75), + 'overhead_p95': percentile(overhead, 95), + 'batch_lat_p50': percentile(all_lat, 50), + 'batch_lat_p95': percentile(all_lat, 95), + 'batch_lat_p99': percentile(all_lat, 99), + 'inf_p95': percentile(all_lat, 95), + 'num_trials': len(trials), + 'num_batches': trials[0]['num_batches'] if trials else 0, + } + + return _stats(baseline_trials), _stats(stateless_trials) + + +# --------------------------------------------------------------------------- +# Single benchmark run +# --------------------------------------------------------------------------- + + +def run_benchmark( + num_elements: int = 10000, + min_length: int = 1, + max_length: int = 500, + max_batch_size: int = 32, + max_batch_weight: int = 2000, + bundle_size: int = 500, + distribution: str = 'pareto', + seed: int = 42) -> Dict[str, Any]: + """Run baseline vs stateless comparison.""" + generators = { + 'pareto': lambda: generate_highly_skewed_data( + num_elements, min_length, max_length, seed), + 'lognormal': lambda: generate_lognormal_data( + num_elements, 50, 0.8, min_length, max_length, seed), + 'bimodal': lambda: generate_bimodal_data( + num_elements, 20, 200, 0.7, min_length, max_length, seed), + 'low_variance': lambda: generate_low_variance_data( + num_elements, 100, 0.1, min_length, max_length, seed), + } + if distribution not in generators: + raise ValueError(f"Unknown distribution: {distribution}") + + data = generators[distribution]() + lengths = [len(s) for s in data] + + baseline_batches = simulate_batch_elements(data, max_batch_size) + stateless_batches = simulate_sort_and_batch_elements( + data, max_batch_size, max_batch_weight, bundle_size=bundle_size) + + baseline_pad = compute_padding_stats(baseline_batches) + stateless_pad = compute_padding_stats(stateless_batches) + + baseline_perf, stateless_perf = run_performance_benchmark( + data, max_batch_size, max_batch_weight, bundle_size) + baseline_pad.update(baseline_perf) + stateless_pad.update(stateless_perf) + + validation = validate_invariants( + data, + baseline_batches, + stateless_batches, { + 'max_batch_size': max_batch_size, + 'max_batch_weight': max_batch_weight + }) + + return { + 'config': { + 'num_elements': num_elements, + 'max_batch_size': max_batch_size, + 'max_batch_weight': max_batch_weight, + 'bundle_size': bundle_size, + 'distribution': distribution, + }, + 'data_stats': { + 'min': min(lengths), + 'max': max(lengths), + 'mean': statistics.mean(lengths), + 'median': statistics.median(lengths), + 'std': statistics.stdev(lengths), + }, + 'baseline': baseline_pad, + 'stateless': stateless_pad, + 'validation': validation, + } + + +# --------------------------------------------------------------------------- +# Printing +# --------------------------------------------------------------------------- + + +def _fmt_iqr(median, p25, p75, unit=''): + return f"{median:.1f} [{p25:.1f}-{p75:.1f}]{unit}" + + +def print_results(results: Dict[str, Any]) -> None: + cfg = results['config'] + ds = results['data_stats'] + bl = results['baseline'] + st = results['stateless'] + val = results['validation'] + + print("=" * 80) + print( + f"Distribution: {cfg['distribution']} | " + f"N={cfg['num_elements']} | " + f"max_batch_size={cfg['max_batch_size']} | " + f"max_batch_weight={cfg['max_batch_weight']}") + print( + f"Input lengths: min={ds['min']} max={ds['max']} " + f"mean={ds['mean']:.1f} median={ds['median']:.0f} std={ds['std']:.1f}") + print("-" * 80) + + def _arm(label, s): + print(f"\n {label}:") + print(f" Num batches: {s['num_batches']}") + print(f" Padding ratio: {s['padding_ratio']:.2f}x") + print(" ") + print(" Throughput (Ktok/s):") + med = s['tput_median'] / 1000 + p25 = s['tput_p25'] / 1000 + p75 = s['tput_p75'] / 1000 + print(f" Median [IQR]: {med:.1f}" + f" [{p25:.1f}-{p75:.1f}]") + print(f" P95: {s['tput_p95']/1000:.1f}") + print(" ") + print(" E2E latency (ms):") + print( + f" Median [IQR]: {s['e2e_median']:.1f}" + f" [{s['e2e_p25']:.1f}-{s['e2e_p75']:.1f}]") + print(f" P95: {s['e2e_p95']:.1f}") + print(" ") + print(" Overhead (ms):") + print( + f" Median [IQR]:" + f" {s['overhead_median']:.2f}" + f" [{s['overhead_p25']:.2f}" + f"-{s['overhead_p75']:.2f}]") + print(f" P95: {s['overhead_p95']:.2f}") + print(" ") + print(" Batch latency (ms):") + print(f" P50: {s['batch_lat_p50']:.1f}") + print(f" P95: {s['batch_lat_p95']:.1f}") + print(f" P99: {s['batch_lat_p99']:.1f}") + + _arm("Baseline (BatchElements)", bl) + _arm("Stateless (SortAndBatchElements w/ weight-based splitting)", st) + + # Delta — explicit arrows so direction is unambiguous + # ↓ = value decreased (good for latency/padding) + # ↑ = value increased (good for throughput) + def _delta_lower(base, new): + """For metrics where lower is better (latency, padding).""" + if base == 0: + return 'N/A' + pct = (base - new) / base * 100 + arrow = '\u2193' if pct > 0 else '\u2191' + return f"{arrow}{abs(pct):.1f}%" + + def _delta_higher(base, new): + """For metrics where higher is better (throughput).""" + if base == 0: + return 'N/A' + pct = (new - base) / base * 100 + arrow = '\u2191' if pct > 0 else '\u2193' + return f"{arrow}{abs(pct):.1f}%" + + print(f"\n {'_' * 76}") + print(" DELTA (Baseline -> Stateless):") + + def _line(label, bv, sv, delta_fn, fmt='.1f', unit=''): + d = delta_fn(bv, sv) + print(f" {label}: {bv:{fmt}}{unit}" + f" -> {sv:{fmt}}{unit} ({d})") + + bl_tmed = bl['tput_median'] / 1000 + st_tmed = st['tput_median'] / 1000 + bl_tp95 = bl['tput_p95'] / 1000 + st_tp95 = st['tput_p95'] / 1000 + + _line( + 'Padding ratio ', + bl['padding_ratio'], + st['padding_ratio'], + _delta_lower, + fmt='.2f', + unit='x') + _line('Throughput median', bl_tmed, st_tmed, _delta_higher, unit=' Ktok/s') + _line('Throughput p95 ', bl_tp95, st_tp95, _delta_higher, unit=' Ktok/s') + _line( + 'E2E latency med ', + bl['e2e_median'], + st['e2e_median'], + _delta_lower, + unit=' ms') + _line( + 'E2E latency p95 ', + bl['e2e_p95'], + st['e2e_p95'], + _delta_lower, + unit=' ms') + _line( + 'Batch lat p95 ', + bl['batch_lat_p95'], + st['batch_lat_p95'], + _delta_lower, + unit=' ms') + _line( + 'Batch lat p99 ', + bl['batch_lat_p99'], + st['batch_lat_p99'], + _delta_lower, + unit=' ms') + + # Invariants + e_ok = "Y" if val['elements_match'] else "X" + t_ok = "Y" if val['tokens_match'] else "X" + b_nb = val['baseline_num_batches'] + s_nb = val['stateless_num_batches'] + print( + f"\n Invariants: elements {e_ok} tokens {t_ok}" + f" (baseline {b_nb} -> stateless {s_nb}" + f" batches)") + print("=" * 80) + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + + +def main(): + print("=" * 80) + print("BASELINE (count-based) vs STATELESS (weight-based boundary splitting)") + print("=" * 80) + print() + print("Experiment design:") + print(" A = Baseline : BatchElements with max_batch_size=32 (count-based)") + print(" B = Stateless : SortAndBatchElements with max_batch_weight=2000") + print( + " (sort by size within bundle -> weight-based split)") + print() + print("Why Stateless wins:") + print(" Weight-based splitting changes batch BOUNDARIES so each batch has") + print( + " similar-length elements -> less padding. Sorting alone within fixed") + print(" boundaries yields 0% gain (verified by strict-control experiment).") + print() + print("Methodology:") + print(" - N=20 trials, 3 warmup excluded") + print(" - Percentiles: linear interpolation (= numpy default)") + print(" - Same seed=42 for both arms") + print(" - Inference model: latency = batch_size * (max_seq_len/50)^1.5 ms") + print() + + dist = 'pareto' + print(f"\nRunning: {dist}...") + r = run_benchmark( + num_elements=10000, + max_batch_size=32, + max_batch_weight=2000, + bundle_size=500, + distribution=dist, + seed=42) + print_results(r) + + +if __name__ == '__main__': + main() diff --git a/sdks/python/apache_beam/transforms/util.py b/sdks/python/apache_beam/transforms/util.py index dd14bd8f57bd..ab832dce0207 100644 --- a/sdks/python/apache_beam/transforms/util.py +++ b/sdks/python/apache_beam/transforms/util.py @@ -105,6 +105,7 @@ 'RemoveDuplicates', 'Reshuffle', 'Secret', + 'SortAndBatchElements', 'ToString', 'Take', 'Tee', @@ -1322,6 +1323,290 @@ def expand(self, pcoll): self._batch_size_estimator, self._element_size_fn)) +def _default_element_size_fn(element: Any) -> int: + """Default element size function that tries len(), falls back to 1. + + This function attempts to compute the size of an element using len(). + If the element does not support len() (e.g., integers), it falls back to 1. + + Args: + element: The element to compute the size of. + + Returns: + The size of the element, or 1 if len() is not supported. + """ + try: + return len(element) + except TypeError: + return 1 + + +class _SortAndBatchElementsDoFn(DoFn): + """DoFn that buffers, sorts by element size, and batches elements. + + This DoFn is used internally by ``SortAndBatchElements`` for + PCollections with the default (global) window. It accumulates all + elements in the current bundle, sorts them by size in ascending order, + and emits optimally-sized batches on ``finish_bundle``. + + Args: + min_batch_size: The minimum number of elements per batch. Must be >= 1. + max_batch_size: The maximum number of elements per batch. + Must be >= ``min_batch_size``. + max_batch_weight: The maximum total weight of elements in a batch, + where weight is computed by ``element_size_fn``. Must be >= 1. + element_size_fn: A callable mapping an element to its integer + size/weight. + """ + def __init__( + self, + min_batch_size: int, + max_batch_size: int, + max_batch_weight: int, + element_size_fn: Callable[[Any], int]): + if min_batch_size < 1: + raise ValueError(f'min_batch_size must be >= 1, got {min_batch_size}') + if max_batch_size < min_batch_size: + raise ValueError( + f'max_batch_size ({max_batch_size}) must be >= ' + f'min_batch_size ({min_batch_size})') + if max_batch_weight < 1: + raise ValueError(f'max_batch_weight must be >= 1, got {max_batch_weight}') + self._min_batch_size = min_batch_size + self._max_batch_size = max_batch_size + self._max_batch_weight = max_batch_weight + self._element_size_fn = element_size_fn + self._buffer = [] + + def start_bundle(self): + self._buffer = [] + + def process(self, element): + self._buffer.append(element) + + def finish_bundle(self): + if not self._buffer: + return + + # Sort elements by size (ascending) for optimal batching + # Elements of similar sizes will be grouped together + sorted_elements = sorted(self._buffer, key=self._element_size_fn) + + batch = [] + batch_weight = 0 + + for element in sorted_elements: + element_size = self._element_size_fn(element) + + # Check if adding this element would exceed limits + would_exceed_count = len(batch) >= self._max_batch_size + would_exceed_weight = ( + batch_weight + element_size >= self._max_batch_weight and batch) + + if would_exceed_count or would_exceed_weight: + # Emit current batch + yield window.GlobalWindows.windowed_value_at_end_of_window(batch) + batch = [] + batch_weight = 0 + + batch.append(element) + batch_weight += element_size + + # Emit remaining elements + if batch: + yield window.GlobalWindows.windowed_value_at_end_of_window(batch) + + self._buffer = None + + +class _WindowAwareSortAndBatchElementsDoFn(DoFn): + """DoFn that buffers, sorts by element size, and batches elements per window. + + This DoFn is used internally by ``SortAndBatchElements`` for + PCollections with non-default (e.g. fixed, sliding, or session) windows. + Elements are buffered per window and each window is flushed independently. + To prevent unbounded memory growth, when the number of live windows + exceeds ``_MAX_LIVE_WINDOWS`` the largest window buffer is flushed early. + + Args: + min_batch_size: The minimum number of elements per batch. Must be >= 1. + max_batch_size: The maximum number of elements per batch. + Must be >= ``min_batch_size``. + max_batch_weight: The maximum total weight of elements in a batch, + where weight is computed by ``element_size_fn``. Must be >= 1. + element_size_fn: A callable mapping an element to its integer + size/weight. + """ + + _MAX_LIVE_WINDOWS = 10 + + def __init__( + self, + min_batch_size: int, + max_batch_size: int, + max_batch_weight: int, + element_size_fn: Callable[[Any], int]): + if min_batch_size < 1: + raise ValueError(f'min_batch_size must be >= 1, got {min_batch_size}') + if max_batch_size < min_batch_size: + raise ValueError( + f'max_batch_size ({max_batch_size}) must be >= ' + f'min_batch_size ({min_batch_size})') + if max_batch_weight < 1: + raise ValueError(f'max_batch_weight must be >= 1, got {max_batch_weight}') + self._min_batch_size = min_batch_size + self._max_batch_size = max_batch_size + self._max_batch_weight = max_batch_weight + self._element_size_fn = element_size_fn + self._buffers = collections.defaultdict(list) + + def start_bundle(self): + self._buffers = collections.defaultdict(list) + + def process(self, element, window=DoFn.WindowParam): + self._buffers[window].append(element) + + # If we have too many live windows, flush the largest one + if len(self._buffers) > self._MAX_LIVE_WINDOWS: + largest_window = max( + self._buffers.keys(), key=lambda w: len(self._buffers[w])) + yield from self._flush_window(largest_window) + + def _flush_window(self, win): + """Flush all elements for a given window.""" + buffer = self._buffers.pop(win, []) + if not buffer: + return + + # Sort elements by size (ascending) + sorted_elements = sorted(buffer, key=self._element_size_fn) + + batch = [] + batch_weight = 0 + + for element in sorted_elements: + element_size = self._element_size_fn(element) + + would_exceed_count = len(batch) >= self._max_batch_size + would_exceed_weight = ( + batch_weight + element_size >= self._max_batch_weight and batch) + + if would_exceed_count or would_exceed_weight: + yield windowed_value.WindowedValue(batch, win.max_timestamp(), (win, )) + batch = [] + batch_weight = 0 + + batch.append(element) + batch_weight += element_size + + if batch: + yield windowed_value.WindowedValue(batch, win.max_timestamp(), (win, )) + + def finish_bundle(self): + for win in list(self._buffers.keys()): + yield from self._flush_window(win) + self._buffers = None + + +@typehints.with_input_types(T) +@typehints.with_output_types(list[T]) +class SortAndBatchElements(PTransform): + """A Transform that sorts elements by size before batching. + + This transform is designed to optimize batch processing by grouping elements + of similar sizes together. This is particularly useful for ML inference + workloads where input sequences of varying lengths need to be padded to the + maximum length in the batch - by sorting elements by size before batching, + padding overhead is minimized. + + The transform consumes a PCollection of element type T and produces a + PCollection of element type list[T], where elements within each batch are + sorted by their size (as determined by element_size_fn). + + Elements are batched per-window and batches emitted in the window + corresponding to its contents. Each batch is emitted with a timestamp at + the end of their window. + + Unlike BatchElements which emits batches as soon as size limits are reached, + SortAndBatchElements buffers all elements in a bundle, sorts them by size, + and then creates optimally-sized batches. This trade-off of increased memory + usage for better batch homogeneity can significantly reduce padding overhead. + + Args: + min_batch_size: The minimum number of elements in a batch. Must be >= 1. + max_batch_size: The maximum number of elements in a batch. + Must be >= min_batch_size. + max_batch_weight: The maximum total weight of elements in a batch, + where weight is computed by element_size_fn. Must be >= 1. + element_size_fn: (optional) A function mapping an element to its + size/weight. + If not provided, defaults to trying len(element) and falling back to 1 + if the element doesn't support len(). This default allows sorting to + work for common types like strings, lists, and arrays. + + Example usage:: + + # Batch strings by total character count + strings = ['a', 'bb', 'ccc', 'dddd', 'eeeee'] + batched = strings | SortAndBatchElements( + min_batch_size=1, + max_batch_size=3, + max_batch_weight=10) + # Possible output: [['a', 'bb', 'ccc'], ['dddd', 'eeeee']] + # Elements are sorted by length and batched optimally + + # Batch with custom size function + data = [{'text': 'short'}, {'text': 'medium text'}, + {'text': 'long text here'}] + batched = data | SortAndBatchElements( + min_batch_size=1, + max_batch_size=10, + max_batch_weight=100, + element_size_fn=lambda x: len(x['text'])) + """ + def __init__( + self, + min_batch_size: int, + max_batch_size: int, + max_batch_weight: int, + element_size_fn: Optional[Callable[[Any], int]] = None): + if min_batch_size < 1: + raise ValueError(f'min_batch_size must be >= 1, got {min_batch_size}') + if max_batch_size < min_batch_size: + raise ValueError( + f'max_batch_size ({max_batch_size}) must be >= ' + f'min_batch_size ({min_batch_size})') + if max_batch_weight < 1: + raise ValueError(f'max_batch_weight must be >= 1, got {max_batch_weight}') + if element_size_fn is not None and not callable(element_size_fn): + raise TypeError('element_size_fn must be callable') + + self._min_batch_size = min_batch_size + self._max_batch_size = max_batch_size + self._max_batch_weight = max_batch_weight + + # Smart default: try len(), fallback to 1 when len() is unsupported + self._element_size_fn: Callable[[Any], int] = ( + element_size_fn + if element_size_fn is not None else _default_element_size_fn) + + def expand(self, pcoll): + if pcoll.windowing.is_default(): + return pcoll | ParDo( + _SortAndBatchElementsDoFn( + self._min_batch_size, + self._max_batch_size, + self._max_batch_weight, + self._element_size_fn)) + else: + return pcoll | ParDo( + _WindowAwareSortAndBatchElementsDoFn( + self._min_batch_size, + self._max_batch_size, + self._max_batch_weight, + self._element_size_fn)) + + class _IdentityWindowFn(NonMergingWindowFn): """Windowing function that preserves existing windows. diff --git a/sdks/python/apache_beam/transforms/util_test.py b/sdks/python/apache_beam/transforms/util_test.py index 448ba8a7ad9d..b4471f0415ed 100644 --- a/sdks/python/apache_beam/transforms/util_test.py +++ b/sdks/python/apache_beam/transforms/util_test.py @@ -1026,6 +1026,223 @@ def test_stateful_grows_to_max_batch(self): assert_that(res, equal_to([1, 1, 2, 4, 8, 16, 32, 50, 50])) +class SortAndBatchElementsTest(unittest.TestCase): + """Tests for SortAndBatchElements transform.""" + def test_elements_are_sorted_by_size(self): + """Test that elements are sorted by size within batches.""" + with TestPipeline() as p: + # Create elements with varying sizes + data = ['aaaaa', 'bb', 'cccc', 'a', 'ddd'] + res = ( + p + | beam.Create(data, reshuffle=False) + | util.SortAndBatchElements( + min_batch_size=1, max_batch_size=5, max_batch_weight=100)) + + def check_sorted(batch): + lengths = [len(s) for s in batch] + assert lengths == sorted(lengths), ( + f'Batch not sorted by size: {lengths}') + return batch + + _ = res | beam.Map(check_sorted) + + def test_batch_respects_max_batch_size(self): + """Test that batches do not exceed max_batch_size.""" + with TestPipeline() as p: + res = ( + p + | beam.Create(['a'] * 10, reshuffle=False) + | util.SortAndBatchElements( + min_batch_size=1, max_batch_size=3, max_batch_weight=100) + | beam.Map(len)) + assert_that(res, equal_to([3, 3, 3, 1])) + + def test_batch_respects_max_batch_weight(self): + """Test that batches do not exceed max_batch_weight.""" + with TestPipeline() as p: + # Each element has size 5, max_batch_weight is 12 + # So we can fit at most 2 elements per batch + data = ['aaaaa', 'bbbbb', 'ccccc', 'ddddd'] + res = ( + p + | beam.Create(data, reshuffle=False) + | util.SortAndBatchElements( + min_batch_size=1, max_batch_size=10, max_batch_weight=12) + | beam.Map(len)) + assert_that(res, equal_to([2, 2])) + + def test_default_element_size_fn_with_strings(self): + """Test default element_size_fn works with strings.""" + with TestPipeline() as p: + data = ['a', 'bbb', 'cc'] + res = ( + p + | beam.Create(data, reshuffle=False) + | util.SortAndBatchElements( + min_batch_size=1, max_batch_size=3, max_batch_weight=100) + | beam.FlatMap(lambda batch: [len(s) for s in batch])) + # Elements should be sorted by length: 'a'(1), 'cc'(2), 'bbb'(3) + assert_that(res, equal_to([1, 2, 3])) + + def test_default_element_size_fn_with_integers(self): + """Test default element_size_fn falls back to 1 for integers.""" + with TestPipeline() as p: + data = [10, 20, 30, 40, 50] + res = ( + p + | beam.Create(data, reshuffle=False) + | util.SortAndBatchElements( + min_batch_size=1, max_batch_size=3, max_batch_weight=100) + | beam.Map(len)) + # With size=1 for all, should batch by max_batch_size + assert_that(res, equal_to([3, 2])) + + def test_custom_element_size_fn(self): + """Test using a custom element_size_fn.""" + with TestPipeline() as p: + data = [{'text': 'a'}, {'text': 'bbb'}, {'text': 'cc'}] + res = ( + p + | beam.Create(data, reshuffle=False) + | util.SortAndBatchElements( + min_batch_size=1, + max_batch_size=3, + max_batch_weight=100, + element_size_fn=lambda x: len(x['text'])) + | beam.FlatMap(lambda batch: [len(e['text']) for e in batch])) + # Should be sorted by text length + assert_that(res, equal_to([1, 2, 3])) + + def test_empty_input(self): + """Test with empty input produces no output.""" + with TestPipeline() as p: + res = ( + p + | beam.Create([], reshuffle=False) + | util.SortAndBatchElements( + min_batch_size=1, max_batch_size=10, max_batch_weight=100) + | beam.Map(len)) + assert_that(res, equal_to([])) + + def test_single_element(self): + """Test with a single element.""" + with TestPipeline() as p: + res = ( + p + | beam.Create(['hello'], reshuffle=False) + | util.SortAndBatchElements( + min_batch_size=1, max_batch_size=10, max_batch_weight=100)) + assert_that(res, equal_to([['hello']])) + + def test_windowed_batches(self): + """Test that windowed elements are batched per window.""" + with TestPipeline('FnApiRunner') as p: + res = ( + p + | beam.Create(range(1, 8), reshuffle=False) + | beam.Map(lambda t: window.TimestampedValue('a' * t, t)) + | beam.WindowInto(window.FixedWindows(3)) + | util.SortAndBatchElements( + min_batch_size=1, max_batch_size=10, max_batch_weight=100) + | beam.Map(lambda batch: ''.join(batch))) + # FixedWindows(3) with default offset 0 produces: + # Window [0, 3): elements at t=1,2 with sizes 1,2 + # Window [3, 6): elements at t=3,4,5 with sizes 3,4,5 + # Window [6, 9): elements at t=6,7 with sizes 6,7 + assert_that( + res, + equal_to([ + 'a' * (1 + 2), # Window [0, 3) + 'a' * (3 + 4 + 5), # Window [3, 6) + 'a' * (6 + 7), # Window [6, 9) + ])) + + def test_validation_min_batch_size(self): + """Test that min_batch_size validation raises ValueError.""" + with self.assertRaises(ValueError) as cm: + util.SortAndBatchElements( + min_batch_size=0, max_batch_size=10, max_batch_weight=100) + self.assertIn('min_batch_size must be >= 1', str(cm.exception)) + + def test_validation_max_batch_size(self): + """Test that max_batch_size < min_batch_size raises ValueError.""" + with self.assertRaises(ValueError) as cm: + util.SortAndBatchElements( + min_batch_size=10, max_batch_size=5, max_batch_weight=100) + self.assertIn('max_batch_size', str(cm.exception)) + self.assertIn('min_batch_size', str(cm.exception)) + + def test_validation_max_batch_weight(self): + """Test that max_batch_weight validation raises ValueError.""" + with self.assertRaises(ValueError) as cm: + util.SortAndBatchElements( + min_batch_size=1, max_batch_size=10, max_batch_weight=0) + self.assertIn('max_batch_weight must be >= 1', str(cm.exception)) + + def test_validation_element_size_fn_callable(self): + """Test that a non-callable element_size_fn raises TypeError.""" + with self.assertRaises(TypeError) as cm: + util.SortAndBatchElements( + min_batch_size=1, + max_batch_size=10, + max_batch_weight=100, + element_size_fn=123) + self.assertIn('element_size_fn must be callable', str(cm.exception)) + + def test_batch_timestamps(self): + """Test that batches have correct timestamps.""" + with TestPipeline('FnApiRunner') as p: + res = ( + p + | beam.Create(['a', 'bb', 'ccc'], reshuffle=False) + | util.SortAndBatchElements( + min_batch_size=1, max_batch_size=10, max_batch_weight=100) + | + beam.Map(lambda batch, ts=beam.DoFn.TimestampParam: (len(batch), ts))) + assert_that(res, equal_to([(3, GlobalWindow().max_timestamp())])) + + def test_padding_efficiency_improvement(self): + """Test that sorting improves padding efficiency.""" + # This test verifies the core value proposition of SortAndBatchElements + data = ['a', 'aaaaa', 'aa', 'aaaa', 'aaa'] + + # Compute what BatchElements would produce (preserves input order) + batch_elements_batches = [] + with TestPipeline() as p: + _ = ( + p + | 'Create1' >> beam.Create(data, reshuffle=False) + | util.BatchElements(min_batch_size=5, max_batch_size=5) + | beam.Map(lambda b: batch_elements_batches.append(list(b)))) + + # Compute what SortAndBatchElements produces + sort_batch_batches = [] + with TestPipeline() as p: + _ = ( + p + | 'Create2' >> beam.Create(data, reshuffle=False) + | util.SortAndBatchElements( + min_batch_size=1, max_batch_size=5, max_batch_weight=100) + | beam.Map(lambda b: sort_batch_batches.append(list(b)))) + + # Calculate padding overhead for each approach + # Padding overhead: + # sum(max_len_in_batch * batch_size) - sum(actual_lengths) + def compute_overhead(batches): + overhead = 0 + for batch in batches: + lengths = [len(s) for s in batch] + overhead += max(lengths) * len(batch) - sum(lengths) + return overhead + + batch_overhead = compute_overhead(batch_elements_batches) + sort_overhead = compute_overhead(sort_batch_batches) + + # SortAndBatchElements should have less or equal overhead + self.assertLessEqual(sort_overhead, batch_overhead) + + class IdentityWindowTest(unittest.TestCase): def test_window_preserved(self): expected_timestamp = timestamp.Timestamp(5) From fc66805f2f8282c4240777bddf06f6c9e07016ba Mon Sep 17 00:00:00 2001 From: Elia LIU Date: Sun, 8 Feb 2026 15:28:57 +1100 Subject: [PATCH 2/6] fix(ml): improve test coverage for SortAndBatchElements - Exclude *_benchmark.py from codecov (standalone scripts, not production code) - Remove redundant validation from internal DoFn classes (already validated by PTransform) - Add direct in-process unit tests for DoFn internals to capture coverage (FnApiRunner runs DoFns in separate process, invisible to coverage tools) Co-Authored-By: Claude Opus 4.6 --- .github/codecov.yml | 1 + sdks/python/apache_beam/transforms/util.py | 16 --- .../apache_beam/transforms/util_test.py | 133 ++++++++++++++++++ 3 files changed, 134 insertions(+), 16 deletions(-) diff --git a/.github/codecov.yml b/.github/codecov.yml index 0936f392ccef..5d0eaccf22da 100644 --- a/.github/codecov.yml +++ b/.github/codecov.yml @@ -73,6 +73,7 @@ ignore: - "**/*_microbenchmark.py" - "sdks/go/pkg/beam/register/register.go" - "sdks/python/apache_beam/testing/benchmarks/nexmark/**" + - "**/*_benchmark.py" - "sdks/python/apache_beam/examples/**" # See https://docs.codecov.com/docs/flags for options. diff --git a/sdks/python/apache_beam/transforms/util.py b/sdks/python/apache_beam/transforms/util.py index ab832dce0207..29d1ed087d4f 100644 --- a/sdks/python/apache_beam/transforms/util.py +++ b/sdks/python/apache_beam/transforms/util.py @@ -1364,14 +1364,6 @@ def __init__( max_batch_size: int, max_batch_weight: int, element_size_fn: Callable[[Any], int]): - if min_batch_size < 1: - raise ValueError(f'min_batch_size must be >= 1, got {min_batch_size}') - if max_batch_size < min_batch_size: - raise ValueError( - f'max_batch_size ({max_batch_size}) must be >= ' - f'min_batch_size ({min_batch_size})') - if max_batch_weight < 1: - raise ValueError(f'max_batch_weight must be >= 1, got {max_batch_weight}') self._min_batch_size = min_batch_size self._max_batch_size = max_batch_size self._max_batch_weight = max_batch_weight @@ -1446,14 +1438,6 @@ def __init__( max_batch_size: int, max_batch_weight: int, element_size_fn: Callable[[Any], int]): - if min_batch_size < 1: - raise ValueError(f'min_batch_size must be >= 1, got {min_batch_size}') - if max_batch_size < min_batch_size: - raise ValueError( - f'max_batch_size ({max_batch_size}) must be >= ' - f'min_batch_size ({min_batch_size})') - if max_batch_weight < 1: - raise ValueError(f'max_batch_weight must be >= 1, got {max_batch_weight}') self._min_batch_size = min_batch_size self._max_batch_size = max_batch_size self._max_batch_weight = max_batch_weight diff --git a/sdks/python/apache_beam/transforms/util_test.py b/sdks/python/apache_beam/transforms/util_test.py index b4471f0415ed..a0c7f3e43c5f 100644 --- a/sdks/python/apache_beam/transforms/util_test.py +++ b/sdks/python/apache_beam/transforms/util_test.py @@ -1243,6 +1243,139 @@ def compute_overhead(batches): self.assertLessEqual(sort_overhead, batch_overhead) +class SortAndBatchElementsDoFnDirectTest(unittest.TestCase): + """Direct unit tests for DoFn internals to ensure coverage. + + Beam's FnApiRunner executes DoFns in a separate SDK harness process, + so coverage tools in the main process cannot capture DoFn code paths. + These tests exercise the DoFn methods directly in-process. + """ + + def test_default_element_size_fn_len(self): + from apache_beam.transforms.util import _default_element_size_fn + self.assertEqual(_default_element_size_fn('abc'), 3) + self.assertEqual(_default_element_size_fn([1, 2]), 2) + + def test_default_element_size_fn_fallback(self): + from apache_beam.transforms.util import _default_element_size_fn + self.assertEqual(_default_element_size_fn(42), 1) + self.assertEqual(_default_element_size_fn(3.14), 1) + + def test_global_dofn_sort_and_batch(self): + """Test _SortAndBatchElementsDoFn directly.""" + from apache_beam.transforms.util import _SortAndBatchElementsDoFn + dofn = _SortAndBatchElementsDoFn( + min_batch_size=1, max_batch_size=3, max_batch_weight=100, + element_size_fn=len) + dofn.start_bundle() + for elem in ['ccccc', 'bb', 'dddd', 'a', 'eee']: + dofn.process(elem) + batches = [wv.value for wv in dofn.finish_bundle()] + # All elements emitted + self.assertEqual(sum(len(b) for b in batches), 5) + # Each batch respects max_batch_size=3 + for batch in batches: + self.assertLessEqual(len(batch), 3) + # Elements within each batch are sorted by size + for batch in batches: + lengths = [len(s) for s in batch] + self.assertEqual(lengths, sorted(lengths)) + + def test_global_dofn_empty_bundle(self): + """Test finish_bundle with no elements returns nothing.""" + from apache_beam.transforms.util import _SortAndBatchElementsDoFn + dofn = _SortAndBatchElementsDoFn( + min_batch_size=1, max_batch_size=10, max_batch_weight=100, + element_size_fn=len) + dofn.start_bundle() + result = list(dofn.finish_bundle() or []) + self.assertEqual(result, []) + + def test_global_dofn_weight_splitting(self): + """Test weight-based splitting in the global DoFn.""" + from apache_beam.transforms.util import _SortAndBatchElementsDoFn + # Each element has size 5, max_batch_weight=12 -> 2 per batch + dofn = _SortAndBatchElementsDoFn( + min_batch_size=1, max_batch_size=100, max_batch_weight=12, + element_size_fn=len) + dofn.start_bundle() + for elem in ['aaaaa', 'bbbbb', 'ccccc', 'ddddd']: + dofn.process(elem) + batches = [wv.value for wv in dofn.finish_bundle()] + self.assertEqual(len(batches), 2) + for batch in batches: + self.assertEqual(len(batch), 2) + + def test_windowed_dofn_flush_and_finish(self): + """Test _WindowAwareSortAndBatchElementsDoFn directly.""" + from apache_beam.transforms.util import ( + _WindowAwareSortAndBatchElementsDoFn) + dofn = _WindowAwareSortAndBatchElementsDoFn( + min_batch_size=1, max_batch_size=10, max_batch_weight=100, + element_size_fn=len) + dofn.start_bundle() + win1 = IntervalWindow(0, 3) + win2 = IntervalWindow(3, 6) + # Manually add to buffers (bypass process() to avoid DoFn.WindowParam) + dofn._buffers[win1].extend(['aa', 'b', 'ccc']) + dofn._buffers[win2].extend(['dddd', 'ee']) + batches = list(dofn.finish_bundle()) + # All elements across both windows emitted + total_elements = sum(len(wv.value) for wv in batches) + self.assertEqual(total_elements, 5) + # Each batch has the correct window + for wv in batches: + self.assertIn(wv.windows[0], (win1, win2)) + + def test_windowed_dofn_overflow_flush(self): + """Test that exceeding _MAX_LIVE_WINDOWS triggers early flush.""" + from apache_beam.transforms.util import ( + _WindowAwareSortAndBatchElementsDoFn) + dofn = _WindowAwareSortAndBatchElementsDoFn( + min_batch_size=1, max_batch_size=10, max_batch_weight=100, + element_size_fn=len) + dofn.start_bundle() + # Fill up to _MAX_LIVE_WINDOWS + for i in range(dofn._MAX_LIVE_WINDOWS): + win = IntervalWindow(i * 10, (i + 1) * 10) + dofn._buffers[win].append('x' * (i + 1)) + self.assertEqual(len(dofn._buffers), dofn._MAX_LIVE_WINDOWS) + # Adding one more window should trigger overflow flush + overflow_win = IntervalWindow(100, 110) + results = list(dofn.process('overflow', overflow_win)) + # One window was flushed, so buffer count stays at _MAX_LIVE_WINDOWS + self.assertLessEqual(len(dofn._buffers), dofn._MAX_LIVE_WINDOWS) + # The flushed window produced output + self.assertGreater(len(results), 0) + + def test_windowed_dofn_flush_empty_window(self): + """Test _flush_window with a non-existent window returns nothing.""" + from apache_beam.transforms.util import ( + _WindowAwareSortAndBatchElementsDoFn) + dofn = _WindowAwareSortAndBatchElementsDoFn( + min_batch_size=1, max_batch_size=10, max_batch_weight=100, + element_size_fn=len) + dofn.start_bundle() + result = list(dofn._flush_window(IntervalWindow(0, 10))) + self.assertEqual(result, []) + + def test_windowed_dofn_weight_splitting(self): + """Test weight-based splitting in the windowed DoFn.""" + from apache_beam.transforms.util import ( + _WindowAwareSortAndBatchElementsDoFn) + dofn = _WindowAwareSortAndBatchElementsDoFn( + min_batch_size=1, max_batch_size=100, max_batch_weight=12, + element_size_fn=len) + dofn.start_bundle() + win = IntervalWindow(0, 10) + dofn._buffers[win].extend(['aaaaa', 'bbbbb', 'ccccc', 'ddddd']) + batches = list(dofn._flush_window(win)) + self.assertEqual(len(batches), 2) + for wv in batches: + self.assertEqual(len(wv.value), 2) + self.assertEqual(wv.windows[0], win) + + class IdentityWindowTest(unittest.TestCase): def test_window_preserved(self): expected_timestamp = timestamp.Timestamp(5) From 675d140041d03aec7f5553bd70933745eaaf94e7 Mon Sep 17 00:00:00 2001 From: Eliaazzz Date: Tue, 31 Mar 2026 19:54:41 +1100 Subject: [PATCH 3/6] Address PR review: clarify benchmark comment and warn on len() fallback Reframe benchmark docstring to clarify that sorting combined with weight-based splitting drives the improvement. Move default element size fallback into DoFn instances with a one-time warning when len() is unsupported, so users know to provide a custom element_size_fn. --- .../benchmarks/sort_and_batch_benchmark.py | 8 ++- sdks/python/apache_beam/transforms/util.py | 59 +++++++++++-------- .../apache_beam/transforms/util_test.py | 32 +++++++--- 3 files changed, 63 insertions(+), 36 deletions(-) diff --git a/sdks/python/apache_beam/testing/benchmarks/sort_and_batch_benchmark.py b/sdks/python/apache_beam/testing/benchmarks/sort_and_batch_benchmark.py index e1caa73e0a5d..caeb7c8efd12 100644 --- a/sdks/python/apache_beam/testing/benchmarks/sort_and_batch_benchmark.py +++ b/sdks/python/apache_beam/testing/benchmarks/sort_and_batch_benchmark.py @@ -22,9 +22,11 @@ - Baseline (BatchElements): fixed-count chunking, ignores element sizes. - Stateless (SortAndBatchElements): within each bundle, sorts elements by size, then splits batches using max_batch_weight so that each batch - has a bounded total weight. The improvement comes from *changing batch - boundaries* (weight-based splitting), NOT from sorting alone -- sorting - within fixed boundaries yields 0% gain (verified by strict-control). + has a bounded total weight. The improvement comes from sorting + combined with weight-based splitting: sorting clusters similar-sized + elements together, and the weight constraint then produces tighter + batches. Sorting alone with fixed count-based boundaries yields ~0% + gain (verified by strict-control ablation). Padding ratio:: diff --git a/sdks/python/apache_beam/transforms/util.py b/sdks/python/apache_beam/transforms/util.py index 0a360d5163c4..ede4035ac062 100644 --- a/sdks/python/apache_beam/transforms/util.py +++ b/sdks/python/apache_beam/transforms/util.py @@ -90,6 +90,8 @@ if TYPE_CHECKING: from apache_beam.runners.pipeline_context import PipelineContext +_LOGGER = logging.getLogger(__name__) + __all__ = [ 'BatchElements', 'CoGroupByKey', @@ -1320,24 +1322,6 @@ def expand(self, pcoll): self._batch_size_estimator, self._element_size_fn)) -def _default_element_size_fn(element: Any) -> int: - """Default element size function that tries len(), falls back to 1. - - This function attempts to compute the size of an element using len(). - If the element does not support len() (e.g., integers), it falls back to 1. - - Args: - element: The element to compute the size of. - - Returns: - The size of the element, or 1 if len() is not supported. - """ - try: - return len(element) - except TypeError: - return 1 - - class _SortAndBatchElementsDoFn(DoFn): """DoFn that buffers, sorts by element size, and batches elements. @@ -1364,9 +1348,23 @@ def __init__( self._min_batch_size = min_batch_size self._max_batch_size = max_batch_size self._max_batch_weight = max_batch_weight - self._element_size_fn = element_size_fn + self._element_size_fn = element_size_fn or self._default_element_size + self._has_warned_type_error = False self._buffer = [] + def _default_element_size(self, element): + try: + return len(element) + except TypeError: + if not self._has_warned_type_error: + _LOGGER.warning( + 'Element of type %s does not support len(). Falling back to ' + 'size 1. Consider providing a custom element_size_fn to ' + 'SortAndBatchElements for meaningful size-based batching.', + type(element).__name__) + self._has_warned_type_error = True + return 1 + def start_bundle(self): self._buffer = [] @@ -1438,9 +1436,23 @@ def __init__( self._min_batch_size = min_batch_size self._max_batch_size = max_batch_size self._max_batch_weight = max_batch_weight - self._element_size_fn = element_size_fn + self._element_size_fn = element_size_fn or self._default_element_size + self._has_warned_type_error = False self._buffers = collections.defaultdict(list) + def _default_element_size(self, element): + try: + return len(element) + except TypeError: + if not self._has_warned_type_error: + _LOGGER.warning( + 'Element of type %s does not support len(). Falling back to ' + 'size 1. Consider providing a custom element_size_fn to ' + 'SortAndBatchElements for meaningful size-based batching.', + type(element).__name__) + self._has_warned_type_error = True + return 1 + def start_bundle(self): self._buffers = collections.defaultdict(list) @@ -1566,10 +1578,9 @@ def __init__( self._max_batch_size = max_batch_size self._max_batch_weight = max_batch_weight - # Smart default: try len(), fallback to 1 when len() is unsupported - self._element_size_fn: Callable[[Any], int] = ( - element_size_fn - if element_size_fn is not None else _default_element_size_fn) + # None means the DoFn will use its own _default_element_size method, + # which tries len() and warns once on TypeError before falling back to 1. + self._element_size_fn = element_size_fn def expand(self, pcoll): if pcoll.windowing.is_default(): diff --git a/sdks/python/apache_beam/transforms/util_test.py b/sdks/python/apache_beam/transforms/util_test.py index ef02615f27d0..5dd30f4c05cc 100644 --- a/sdks/python/apache_beam/transforms/util_test.py +++ b/sdks/python/apache_beam/transforms/util_test.py @@ -1250,15 +1250,29 @@ class SortAndBatchElementsDoFnDirectTest(unittest.TestCase): so coverage tools in the main process cannot capture DoFn code paths. These tests exercise the DoFn methods directly in-process. """ - def test_default_element_size_fn_len(self): - from apache_beam.transforms.util import _default_element_size_fn - self.assertEqual(_default_element_size_fn('abc'), 3) - self.assertEqual(_default_element_size_fn([1, 2]), 2) - - def test_default_element_size_fn_fallback(self): - from apache_beam.transforms.util import _default_element_size_fn - self.assertEqual(_default_element_size_fn(42), 1) - self.assertEqual(_default_element_size_fn(3.14), 1) + def test_default_element_size_len(self): + from apache_beam.transforms.util import _SortAndBatchElementsDoFn + dofn = _SortAndBatchElementsDoFn( + min_batch_size=1, + max_batch_size=10, + max_batch_weight=100, + element_size_fn=None) + self.assertEqual(dofn._element_size_fn('abc'), 3) + self.assertEqual(dofn._element_size_fn([1, 2]), 2) + + def test_default_element_size_fallback_warns_once(self): + from apache_beam.transforms.util import _SortAndBatchElementsDoFn + dofn = _SortAndBatchElementsDoFn( + min_batch_size=1, + max_batch_size=10, + max_batch_weight=100, + element_size_fn=None) + with self.assertLogs('apache_beam.transforms.util', level='WARNING') as cm: + self.assertEqual(dofn._element_size_fn(42), 1) + self.assertIn('does not support len()', cm.output[0]) + # Second call should not warn again + self.assertEqual(dofn._element_size_fn(3.14), 1) + self.assertTrue(dofn._has_warned_type_error) def test_global_dofn_sort_and_batch(self): """Test _SortAndBatchElementsDoFn directly.""" From 386d2fc83841af8cec6dae1a145934f747778d37 Mon Sep 17 00:00:00 2001 From: Eliaazzz Date: Thu, 2 Apr 2026 20:02:35 +1100 Subject: [PATCH 4/6] Migrate JupyterLab sidepanel extension to prebuilt package distribution Replace deprecated jupyter labextension install/link workflow with pip-installable prebuilt extension for JupyterLab 4+ compatibility. - Add install.json for prebuilt extension discovery metadata - Add style/index.js CSS entry point and styleModule field in package.json - Include js in package.json files glob so style/index.js is published - Add Extensions and Extensions :: Prebuilt classifiers to pyproject.toml - Add missing src/yaml/* to tsconfig.json includes - Remove deprecated labextension install/link/build instructions from READMEs - Replace ipywidgets labextension install with pip install in Interactive README --- .../apache_beam/runners/interactive/README.md | 15 +----- .../README.md | 52 +++++-------------- .../install.json | 5 ++ .../package.json | 3 +- .../pyproject.toml | 2 + .../style/index.js | 13 +++++ .../tsconfig.json | 1 + 7 files changed, 37 insertions(+), 54 deletions(-) create mode 100644 sdks/python/apache_beam/runners/interactive/extensions/apache-beam-jupyterlab-sidepanel/install.json create mode 100644 sdks/python/apache_beam/runners/interactive/extensions/apache-beam-jupyterlab-sidepanel/style/index.js diff --git a/sdks/python/apache_beam/runners/interactive/README.md b/sdks/python/apache_beam/runners/interactive/README.md index ff6c57a94e61..f95b2765c3fa 100644 --- a/sdks/python/apache_beam/runners/interactive/README.md +++ b/sdks/python/apache_beam/runners/interactive/README.md @@ -244,23 +244,10 @@ a quick reference). For a more general and complete getting started guide, see jupyter kernelspec list ``` -* Extend JupyterLab through labextension. **Note**: labextension is different from nbextension - from pre-lab jupyter notebooks. - - All jupyter labextensions need nodejs - - ```bash - # Homebrew users do - brew install node - # Or Conda users do - conda install -c conda-forge nodejs - ``` - - Enable ipywidgets +* Install ipywidgets (includes the JupyterLab widget manager as a prebuilt extension): ```bash pip install ipywidgets - jupyter labextension install @jupyter-widgets/jupyterlab-manager ``` ### Start the notebook diff --git a/sdks/python/apache_beam/runners/interactive/extensions/apache-beam-jupyterlab-sidepanel/README.md b/sdks/python/apache_beam/runners/interactive/extensions/apache-beam-jupyterlab-sidepanel/README.md index 83fddf491f68..4c0baf3b2d53 100644 --- a/sdks/python/apache_beam/runners/interactive/extensions/apache-beam-jupyterlab-sidepanel/README.md +++ b/sdks/python/apache_beam/runners/interactive/extensions/apache-beam-jupyterlab-sidepanel/README.md @@ -31,41 +31,22 @@ Includes two different side panels: ## Installation -There are two ways to install the extension: - -### 1. Via pip (recommended) - -The extension is now available as a Python package on PyPI. You can install it with: +This extension is distributed as a prebuilt Python package. Install it with pip: ```bash pip install apache-beam-jupyterlab-sidepanel ``` -After installation, rebuild JupyterLab to activate the extension: - -```bash -jupyter lab clean -jupyter lab build -``` - -Then restart JupyterLab. The side panels will be available automatically. +Then restart JupyterLab. The side panels will be available automatically — no +`jupyter lab build` step is needed. - -### 2. Via JupyterLab Extension Manager (legacy, will be deprecated soon) +You can verify the extension is installed: ```bash -jupyter labextension install apache-beam-jupyterlab-sidepanel +jupyter labextension list ``` -This installs the extension using JupyterLab's legacy extension system. - ---- - -## Notes - -- Pip installation is now the preferred method as it handles Python packaging and JupyterLab extension registration seamlessly. -- After any upgrade or reinstallation, always rebuild JupyterLab to ensure the extension is activated. -- For detailed usage and development, refer to the source code and issues on [GitHub](https://github.com/apache/beam). +The extension should appear under the **prebuilt extensions** section. --- @@ -90,15 +71,12 @@ The `jlpm` command is JupyterLab's pinned version of # Install dependencies jlpm -# Build Typescript source -jlpm build -# Link your development version of the extension with JupyterLab -jupyter labextension link . -# Rebuild Typescript source after making changes -jlpm build -# Rebuild JupyterLab after making any changes -jupyter lab build +# Install the extension in editable mode (runs an initial JS build) +pip install -e . + +# Verify installation +jupyter labextension list ``` You can watch the source directory and run JupyterLab in watch mode to watch for changes in the extension's source and automatically rebuild the extension and application. @@ -110,7 +88,7 @@ jlpm watch jupyter lab --watch ``` -Now every change will be built locally and bundled into JupyterLab. Be sure to refresh your browser page after saving file changes to reload the extension (note: you'll need to wait for webpack to finish, which can take 10s+ at times). +Now every change will be built locally and bundled into JupyterLab. Be sure to refresh your browser page after saving file changes to reload the extension (note: you'll need to wait for the build to finish, which can take 10s+ at times). ### Test @@ -214,9 +192,5 @@ $PREFIX/share/jupyter/labextensions/apache-beam-jupyterlab-sidepanel/ ### Uninstall ```bash -jupyter labextension uninstall apache-beam-jupyterlab-sidepanel -``` -or -```bash -pip uninstall apache-beam-jupyterlab-sidepanel +pip uninstall apache_beam_jupyterlab_sidepanel ``` diff --git a/sdks/python/apache_beam/runners/interactive/extensions/apache-beam-jupyterlab-sidepanel/install.json b/sdks/python/apache_beam/runners/interactive/extensions/apache-beam-jupyterlab-sidepanel/install.json new file mode 100644 index 000000000000..3ef6567c6a81 --- /dev/null +++ b/sdks/python/apache_beam/runners/interactive/extensions/apache-beam-jupyterlab-sidepanel/install.json @@ -0,0 +1,5 @@ +{ + "packageManager": "python", + "packageName": "apache_beam_jupyterlab_sidepanel", + "uninstallInstructions": "Use your Python package manager (pip, conda, etc.) to uninstall the package apache_beam_jupyterlab_sidepanel" +} diff --git a/sdks/python/apache_beam/runners/interactive/extensions/apache-beam-jupyterlab-sidepanel/package.json b/sdks/python/apache_beam/runners/interactive/extensions/apache-beam-jupyterlab-sidepanel/package.json index eef3fcaa80f4..6bca80350ff7 100644 --- a/sdks/python/apache_beam/runners/interactive/extensions/apache-beam-jupyterlab-sidepanel/package.json +++ b/sdks/python/apache_beam/runners/interactive/extensions/apache-beam-jupyterlab-sidepanel/package.json @@ -15,7 +15,7 @@ "author": "apache-beam", "files": [ "lib/**/*.{d.ts,eot,gif,html,jpg,js,js.map,json,png,svg,woff2,ttf}", - "style/**/*.{css,eot,gif,html,jpg,json,png,svg,woff2,ttf}" + "style/**/*.{css,js,eot,gif,html,jpg,json,png,svg,woff2,ttf}" ], "main": "lib/index.js", "types": "lib/index.d.ts", @@ -100,6 +100,7 @@ "style/*.css", "style/index.js" ], + "styleModule": "style/index.js", "jupyterlab": { "extension": true, "outputDir": "apache_beam_jupyterlab_sidepanel/labextension" diff --git a/sdks/python/apache_beam/runners/interactive/extensions/apache-beam-jupyterlab-sidepanel/pyproject.toml b/sdks/python/apache_beam/runners/interactive/extensions/apache-beam-jupyterlab-sidepanel/pyproject.toml index 6831535a2c1e..a28fd40b2ca6 100644 --- a/sdks/python/apache_beam/runners/interactive/extensions/apache-beam-jupyterlab-sidepanel/pyproject.toml +++ b/sdks/python/apache_beam/runners/interactive/extensions/apache-beam-jupyterlab-sidepanel/pyproject.toml @@ -33,6 +33,8 @@ classifiers = [ "Framework :: Jupyter", "Framework :: Jupyter :: JupyterLab", "Framework :: Jupyter :: JupyterLab :: 4", + "Framework :: Jupyter :: JupyterLab :: Extensions", + "Framework :: Jupyter :: JupyterLab :: Extensions :: Prebuilt", "License :: OSI Approved :: Apache Software License", "Programming Language :: Python :: 3", ] diff --git a/sdks/python/apache_beam/runners/interactive/extensions/apache-beam-jupyterlab-sidepanel/style/index.js b/sdks/python/apache_beam/runners/interactive/extensions/apache-beam-jupyterlab-sidepanel/style/index.js new file mode 100644 index 000000000000..b533d5a9c6d5 --- /dev/null +++ b/sdks/python/apache_beam/runners/interactive/extensions/apache-beam-jupyterlab-sidepanel/style/index.js @@ -0,0 +1,13 @@ +// Licensed under the Apache License, Version 2.0 (the 'License'); you may not +// use this file except in compliance with the License. You may obtain a copy of +// the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an 'AS IS' BASIS, WITHOUT +// WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +// License for the specific language governing permissions and limitations under +// the License. + +import './index.css'; diff --git a/sdks/python/apache_beam/runners/interactive/extensions/apache-beam-jupyterlab-sidepanel/tsconfig.json b/sdks/python/apache_beam/runners/interactive/extensions/apache-beam-jupyterlab-sidepanel/tsconfig.json index c684cabf44a3..058bf17e1861 100644 --- a/sdks/python/apache_beam/runners/interactive/extensions/apache-beam-jupyterlab-sidepanel/tsconfig.json +++ b/sdks/python/apache_beam/runners/interactive/extensions/apache-beam-jupyterlab-sidepanel/tsconfig.json @@ -29,6 +29,7 @@ "src/common/*", "src/kernel/*", "src/inspector/*", + "src/yaml/*", "src/__tests__/**/*" ] } From feb3deaaf574640940633ce25147bd61d6eacce1 Mon Sep 17 00:00:00 2001 From: Eliaazzz Date: Sat, 25 Apr 2026 23:59:17 +1000 Subject: [PATCH 5/6] Use real Beam pipelines in sort-and-batch benchmark --- .../benchmarks/sort_and_batch_benchmark.py | 373 +++++++++--------- sdks/python/apache_beam/transforms/util.py | 30 +- .../apache_beam/transforms/util_test.py | 54 +-- 3 files changed, 209 insertions(+), 248 deletions(-) diff --git a/sdks/python/apache_beam/testing/benchmarks/sort_and_batch_benchmark.py b/sdks/python/apache_beam/testing/benchmarks/sort_and_batch_benchmark.py index caeb7c8efd12..695c4c2c995c 100644 --- a/sdks/python/apache_beam/testing/benchmarks/sort_and_batch_benchmark.py +++ b/sdks/python/apache_beam/testing/benchmarks/sort_and_batch_benchmark.py @@ -15,18 +15,24 @@ # limitations under the License. # -"""Benchmark: BatchElements vs SortAndBatchElements (weight-based splitting). +"""Benchmark: BatchElements vs SortAndBatchElements on real Beam pipelines. -Compares two batching strategies for variable-length inference workloads: +Compares two batching strategies for variable-length inference workloads by +running the actual Beam transforms under DirectRunner: -- Baseline (BatchElements): fixed-count chunking, ignores element sizes. -- Stateless (SortAndBatchElements): within each bundle, sorts elements - by size, then splits batches using max_batch_weight so that each batch - has a bounded total weight. The improvement comes from sorting - combined with weight-based splitting: sorting clusters similar-sized - elements together, and the weight constraint then produces tighter - batches. Sorting alone with fixed count-based boundaries yields ~0% - gain (verified by strict-control ablation). +- Baseline (BatchElements): fixed-count batching by setting + ``min_batch_size == max_batch_size``. +- Stateless (SortAndBatchElements): sorts elements by size within each runner + bundle, then splits batches using ``max_batch_weight``. + +The benchmark materializes per-batch summaries through a temporary Beam sink and +analyzes them after the pipeline completes. This keeps the benchmark on the +normal Beam execution path rather than relying on InteractiveRunner-specific +result materialization or local side effects. + +Bundle boundaries are runner-defined. As a result, these measurements are meant +to compare the actual DirectRunner behavior of the two transforms rather than a +synthetic, user-configurable bundle model. Padding ratio:: @@ -37,6 +43,7 @@ - N=20 independent trials per condition (3 warmup trials excluded). - Same input corpus (seed=42) for A/B comparison. +- DirectRunner with in-memory execution and one worker for reproducibility. - Percentile method: linear interpolation between adjacent ranks (equivalent to numpy.percentile with method='linear'). For N=20 trials: P50 interpolates ranks 10-11 (0-indexed 9-10), @@ -44,24 +51,27 @@ P99 interpolates near rank 20 (0-indexed 18.81). - Reports median [IQR] and P95 for each metric. - Inference model: latency = batch_size * (max_seq_len / 50)^1.5 ms - (simulates transformer-like scaling). + (simulates downstream transformer-like scaling). Run:: python3 -m apache_beam.testing.benchmarks.sort_and_batch_benchmark """ +import glob +import json import math +import os import random import statistics +import tempfile import time +from collections.abc import Sequence from typing import Any -from typing import Callable -from typing import Dict -from typing import List -from typing import Optional -from typing import Sequence -from typing import Tuple + +import apache_beam as beam +from apache_beam.options.pipeline_options import PipelineOptions +from apache_beam.transforms import util # --------------------------------------------------------------------------- # Data generators @@ -72,7 +82,7 @@ def generate_highly_skewed_data( num_elements: int, min_length: int = 1, max_length: int = 500, - seed: int = 42) -> List[str]: + seed: int = 42) -> list[str]: """Pareto(alpha=1.2) -- most short, few very long.""" random.seed(seed) data = [] @@ -89,7 +99,7 @@ def generate_lognormal_data( std_factor: float = 0.8, min_length: int = 1, max_length: int = 500, - seed: int = 42) -> List[str]: + seed: int = 42) -> list[str]: """Log-normal -- moderate skew, typical NLP.""" random.seed(seed) mu = math.log(mean_length) @@ -109,7 +119,7 @@ def generate_bimodal_data( mode1_ratio: float = 0.7, min_length: int = 1, max_length: int = 500, - seed: int = 42) -> List[str]: + seed: int = 42) -> list[str]: """Bimodal -- two distinct length groups.""" random.seed(seed) data = [] @@ -129,7 +139,7 @@ def generate_low_variance_data( cv: float = 0.1, min_length: int = 1, max_length: int = 500, - seed: int = 42) -> List[str]: + seed: int = 42) -> list[str]: """Low-variance control (CV=10%).""" random.seed(seed) std = mean_length * cv @@ -142,72 +152,71 @@ def generate_low_variance_data( # --------------------------------------------------------------------------- -# Batching algorithms +# Real Beam batching # --------------------------------------------------------------------------- -def simulate_batch_elements(data: List[str], - max_batch_size: int) -> List[List[str]]: - """Baseline: simple count-based chunking (BatchElements behaviour).""" - batches = [] - current_batch = [] - for element in data: - current_batch.append(element) - if len(current_batch) >= max_batch_size: - batches.append(current_batch) - current_batch = [] - if current_batch: - batches.append(current_batch) - return batches - - -def simulate_sort_and_batch_elements( - data: List[str], - max_batch_size: int, - max_batch_weight: int, - element_size_fn: Optional[Callable[[Any], int]] = None, - bundle_size: Optional[int] = None) -> List[List[str]]: - """Core mechanism: sort by size + weight-based batch splitting.""" - if element_size_fn is None: - element_size_fn = len - - # Split into bundles if specified (realistic Beam behavior) - if bundle_size is not None and bundle_size > 0: - bundles = [ - data[i:i + bundle_size] for i in range(0, len(data), bundle_size) - ] - else: - bundles = [data] - - all_batches = [] +def _direct_runner_options() -> PipelineOptions: + return PipelineOptions([ + '--runner=DirectRunner', + '--direct_running_mode=in_memory', + '--direct_num_workers=1', + ]) - for bundle in bundles: - # Sort by element size (ascending) - sorted_bundle = sorted(bundle, key=element_size_fn) - current_batch = [] - current_weight = 0 +def _batch_to_json(batch: list[str]) -> str: + lengths = [len(element) for element in batch] + return json.dumps({ + 'batch_size': len(batch), + 'actual_total_length': sum(lengths), + 'max_len': max(lengths) if lengths else 0, + }) - for element in sorted_bundle: - element_weight = element_size_fn(element) - # Check if adding this element would exceed limits - would_exceed_count = len(current_batch) >= max_batch_size - would_exceed_weight = ( - current_weight + element_weight > max_batch_weight and current_batch) - - if would_exceed_count or would_exceed_weight: - all_batches.append(current_batch) - current_batch = [] - current_weight = 0 +def _read_batch_summaries(output_prefix: str) -> list[dict[str, int]]: + summaries = [] + for path in sorted(glob.glob(f'{output_prefix}*')): + if path.endswith('.crc'): + continue + with open(path, encoding='utf-8') as handle: + for line in handle: + line = line.strip() + if line: + summaries.append(json.loads(line)) + return summaries + + +def _run_batching_pipeline( + strategy: str, data: list[str], max_batch_size: int, + max_batch_weight: int) -> tuple[list[dict[str, int]], float]: + """Runs one Beam pipeline and returns batch summaries plus runtime.""" + with tempfile.TemporaryDirectory(prefix='beam_batch_benchmark_') as temp_dir: + output_prefix = os.path.join(temp_dir, strategy) + pipeline = beam.Pipeline(options=_direct_runner_options()) + batched = pipeline | 'CreateInput' >> beam.Create(data, reshuffle=False) + + if strategy == 'baseline': + batched = batched | 'BatchElements' >> util.BatchElements( + min_batch_size=max_batch_size, max_batch_size=max_batch_size) + elif strategy == 'stateless': + batched = batched | 'SortAndBatchElements' >> util.SortAndBatchElements( + min_batch_size=1, + max_batch_size=max_batch_size, + max_batch_weight=max_batch_weight) + else: + raise ValueError(f'Unknown strategy: {strategy}') - current_batch.append(element) - current_weight += element_weight + _ = ( + batched + | 'SerializeBatchSummary' >> beam.Map(_batch_to_json) + | 'WriteBatchSummary' >> beam.io.WriteToText(output_prefix)) - if current_batch: - all_batches.append(current_batch) + start = time.perf_counter() + result = pipeline.run() + result.wait_until_finish() + runtime_ms = (time.perf_counter() - start) * 1000 - return all_batches + return _read_batch_summaries(output_prefix), runtime_ms # --------------------------------------------------------------------------- @@ -216,12 +225,10 @@ def simulate_sort_and_batch_elements( def simulate_inference_latency( - batch: List[str], base_latency_ms: float = 1.0) -> float: - """Simulate transformer inference: O(batch_size * seq_len^1.5).""" - if not batch: + batch_size: int, max_len: int, base_latency_ms: float = 1.0) -> float: + """Simulate downstream inference: O(batch_size * seq_len^1.5).""" + if not batch_size or not max_len: return 0.0 - batch_size = len(batch) - max_len = max(len(s) for s in batch) return base_latency_ms * batch_size * (max_len / 50)**1.5 @@ -246,22 +253,13 @@ def percentile(data: Sequence[float], p: float) -> float: return s[f] + (k - f) * (s[c] - s[f]) -def compute_padding_stats(batches: List[List[str]]) -> Dict[str, Any]: - """Padding-efficiency statistics for a list of batches.""" - total_actual = 0 - total_padded = 0 - batch_sizes = [] - max_lengths = [] - - for batch in batches: - if not batch: - continue - lengths = [len(s) for s in batch] - mx = max(lengths) - total_actual += sum(lengths) - total_padded += mx * len(batch) - batch_sizes.append(len(batch)) - max_lengths.append(mx) +def compute_padding_stats( + batch_summaries: list[dict[str, int]]) -> dict[str, Any]: + """Padding-efficiency statistics for materialized batch summaries.""" + total_actual = sum(s['actual_total_length'] for s in batch_summaries) + total_padded = sum(s['max_len'] * s['batch_size'] for s in batch_summaries) + batch_sizes = [s['batch_size'] for s in batch_summaries if s['batch_size']] + max_lengths = [s['max_len'] for s in batch_summaries if s['batch_size']] efficiency = total_actual / total_padded if total_padded else 0.0 padding_ratio = total_padded / total_actual if total_actual else float('inf') @@ -269,7 +267,7 @@ def compute_padding_stats(batches: List[List[str]]) -> Dict[str, Any]: return { 'efficiency': efficiency, 'padding_ratio': padding_ratio, - 'num_batches': len(batches), + 'num_batches': len(batch_summaries), 'avg_batch_size': statistics.mean(batch_sizes) if batch_sizes else 0, 'total_actual_length': total_actual, 'total_padded_length': total_padded, @@ -288,17 +286,16 @@ def compute_padding_stats(batches: List[List[str]]) -> Dict[str, Any]: def validate_invariants( - data: List[str], - baseline_batches: List[List[str]], - stateless_batches: List[List[str]], - config: Dict[str, Any]) -> Dict[str, Any]: + data: list[str], + baseline_summaries: list[dict[str, int]], + stateless_summaries: list[dict[str, int]]) -> dict[str, Any]: """Validate element/token counts and batch-size equality.""" n = len(data) - b_n = sum(len(b) for b in baseline_batches) - s_n = sum(len(b) for b in stateless_batches) + b_n = sum(s['batch_size'] for s in baseline_summaries) + s_n = sum(s['batch_size'] for s in stateless_summaries) tok = sum(len(s) for s in data) - b_tok = sum(sum(len(s) for s in b) for b in baseline_batches) - s_tok = sum(sum(len(s) for s in b) for b in stateless_batches) + b_tok = sum(s['actual_total_length'] for s in baseline_summaries) + s_tok = sum(s['actual_total_length'] for s in stateless_summaries) return { 'input_elements': n, @@ -309,8 +306,8 @@ def validate_invariants( 'baseline_tokens': b_tok, 'stateless_tokens': s_tok, 'tokens_match': tok == b_tok == s_tok, - 'baseline_num_batches': len(baseline_batches), - 'stateless_num_batches': len(stateless_batches), + 'baseline_num_batches': len(baseline_summaries), + 'stateless_num_batches': len(stateless_summaries), } @@ -320,56 +317,62 @@ def validate_invariants( def run_performance_benchmark( - data: List[str], + data: list[str], max_batch_size: int, max_batch_weight: int, - bundle_size: int = 500, num_trials: int = 20, - warmup_trials: int = 3) -> Tuple[Dict[str, Any], Dict[str, Any]]: + warmup_trials: int = 3 +) -> tuple[ + dict[str, Any], + dict[str, Any], + list[dict[str, int]], + list[dict[str, int]], +]: """Run N=20 trials for baseline and stateless.""" total_tokens = sum(len(s) for s in data) baseline_trials = [] stateless_trials = [] + baseline_sample_summaries = [] + stateless_sample_summaries = [] for trial_idx in range(warmup_trials + num_trials): is_warmup = trial_idx < warmup_trials + trial_results = {} + + if trial_idx % 2 == 0: + trial_order = ('baseline', 'stateless') + else: + trial_order = ('stateless', 'baseline') + + for strategy in trial_order: + summaries, runtime_ms = _run_batching_pipeline( + strategy, data, max_batch_size, max_batch_weight) + batch_latencies = [ + simulate_inference_latency(s['batch_size'], s['max_len']) + for s in summaries + ] + trial_results[strategy] = { + 'runtime_ms': runtime_ms, + 'inference_ms': sum(batch_latencies), + 'e2e_ms': runtime_ms + sum(batch_latencies), + 'batch_latencies': batch_latencies, + 'num_batches': len(summaries), + 'summaries': summaries, + } - # --- Baseline --- - start = time.perf_counter() - b_batches = simulate_batch_elements(data, max_batch_size) - batch_ms = (time.perf_counter() - start) * 1000 - b_inf = [simulate_inference_latency(b) for b in b_batches] - b_e2e = batch_ms + sum(b_inf) - if not is_warmup: - baseline_trials.append({ - 'overhead_ms': batch_ms, - 'inference_ms': sum(b_inf), - 'e2e_ms': b_e2e, - 'batch_latencies': b_inf, - 'num_batches': len(b_batches), - }) - - # --- Stateless (SortAndBatchElements) --- - start = time.perf_counter() - s_batches = simulate_sort_and_batch_elements( - data, max_batch_size, max_batch_weight, bundle_size=bundle_size) - sort_ms = (time.perf_counter() - start) * 1000 - s_inf = [simulate_inference_latency(b) for b in s_batches] - s_e2e = sort_ms + sum(s_inf) if not is_warmup: - stateless_trials.append({ - 'overhead_ms': sort_ms, - 'inference_ms': sum(s_inf), - 'e2e_ms': s_e2e, - 'batch_latencies': s_inf, - 'num_batches': len(s_batches), - }) + baseline_trials.append(trial_results['baseline']) + stateless_trials.append(trial_results['stateless']) + if not baseline_sample_summaries: + baseline_sample_summaries = trial_results['baseline']['summaries'] + if not stateless_sample_summaries: + stateless_sample_summaries = trial_results['stateless']['summaries'] def _stats(trials): e2e = [t['e2e_ms'] for t in trials] tput = [total_tokens / (t['e2e_ms'] / 1000) for t in trials] - overhead = [t['overhead_ms'] for t in trials] + runtime = [t['runtime_ms'] for t in trials] all_lat = [l for t in trials for l in t['batch_latencies']] return { 'e2e_median': percentile(e2e, 50), @@ -380,10 +383,10 @@ def _stats(trials): 'tput_p25': percentile(tput, 25), 'tput_p75': percentile(tput, 75), 'tput_p95': percentile(tput, 95), - 'overhead_median': percentile(overhead, 50), - 'overhead_p25': percentile(overhead, 25), - 'overhead_p75': percentile(overhead, 75), - 'overhead_p95': percentile(overhead, 95), + 'runtime_median': percentile(runtime, 50), + 'runtime_p25': percentile(runtime, 25), + 'runtime_p75': percentile(runtime, 75), + 'runtime_p95': percentile(runtime, 95), 'batch_lat_p50': percentile(all_lat, 50), 'batch_lat_p95': percentile(all_lat, 95), 'batch_lat_p99': percentile(all_lat, 99), @@ -392,7 +395,12 @@ def _stats(trials): 'num_batches': trials[0]['num_batches'] if trials else 0, } - return _stats(baseline_trials), _stats(stateless_trials) + return ( + _stats(baseline_trials), + _stats(stateless_trials), + baseline_sample_summaries, + stateless_sample_summaries, + ) # --------------------------------------------------------------------------- @@ -406,9 +414,8 @@ def run_benchmark( max_length: int = 500, max_batch_size: int = 32, max_batch_weight: int = 2000, - bundle_size: int = 500, distribution: str = 'pareto', - seed: int = 42) -> Dict[str, Any]: + seed: int = 42) -> dict[str, Any]: """Run baseline vs stateless comparison.""" generators = { 'pareto': lambda: generate_highly_skewed_data( @@ -426,33 +433,23 @@ def run_benchmark( data = generators[distribution]() lengths = [len(s) for s in data] - baseline_batches = simulate_batch_elements(data, max_batch_size) - stateless_batches = simulate_sort_and_batch_elements( - data, max_batch_size, max_batch_weight, bundle_size=bundle_size) - - baseline_pad = compute_padding_stats(baseline_batches) - stateless_pad = compute_padding_stats(stateless_batches) - - baseline_perf, stateless_perf = run_performance_benchmark( - data, max_batch_size, max_batch_weight, bundle_size) + baseline_perf, stateless_perf, baseline_summaries, stateless_summaries = ( + run_performance_benchmark(data, max_batch_size, max_batch_weight)) + baseline_pad = compute_padding_stats(baseline_summaries) + stateless_pad = compute_padding_stats(stateless_summaries) baseline_pad.update(baseline_perf) stateless_pad.update(stateless_perf) validation = validate_invariants( - data, - baseline_batches, - stateless_batches, { - 'max_batch_size': max_batch_size, - 'max_batch_weight': max_batch_weight - }) + data, baseline_summaries, stateless_summaries) return { 'config': { 'num_elements': num_elements, 'max_batch_size': max_batch_size, 'max_batch_weight': max_batch_weight, - 'bundle_size': bundle_size, 'distribution': distribution, + 'runner': 'DirectRunner', }, 'data_stats': { 'min': min(lengths), @@ -476,7 +473,7 @@ def _fmt_iqr(median, p25, p75, unit=''): return f"{median:.1f} [{p25:.1f}-{p75:.1f}]{unit}" -def print_results(results: Dict[str, Any]) -> None: +def print_results(results: dict[str, Any]) -> None: cfg = results['config'] ds = results['data_stats'] bl = results['baseline'] @@ -487,6 +484,7 @@ def print_results(results: Dict[str, Any]) -> None: print( f"Distribution: {cfg['distribution']} | " f"N={cfg['num_elements']} | " + f"runner={cfg['runner']} | " f"max_batch_size={cfg['max_batch_size']} | " f"max_batch_weight={cfg['max_batch_weight']}") print( @@ -513,13 +511,13 @@ def _arm(label, s): f" [{s['e2e_p25']:.1f}-{s['e2e_p75']:.1f}]") print(f" P95: {s['e2e_p95']:.1f}") print(" ") - print(" Overhead (ms):") + print(" Pipeline runtime (ms):") print( f" Median [IQR]:" - f" {s['overhead_median']:.2f}" - f" [{s['overhead_p25']:.2f}" - f"-{s['overhead_p75']:.2f}]") - print(f" P95: {s['overhead_p95']:.2f}") + f" {s['runtime_median']:.2f}" + f" [{s['runtime_p25']:.2f}" + f"-{s['runtime_p75']:.2f}]") + print(f" P95: {s['runtime_p95']:.2f}") print(" ") print(" Batch latency (ms):") print(f" P50: {s['batch_lat_p50']:.1f}") @@ -529,9 +527,9 @@ def _arm(label, s): _arm("Baseline (BatchElements)", bl) _arm("Stateless (SortAndBatchElements w/ weight-based splitting)", st) - # Delta — explicit arrows so direction is unambiguous - # ↓ = value decreased (good for latency/padding) - # ↑ = value increased (good for throughput) + # Explicit arrows so direction is unambiguous. + # down arrow = value decreased (good for latency/padding) + # up arrow = value increased (good for throughput) def _delta_lower(base, new): """For metrics where lower is better (latency, padding).""" if base == 0: @@ -582,6 +580,12 @@ def _line(label, bv, sv, delta_fn, fmt='.1f', unit=''): st['e2e_p95'], _delta_lower, unit=' ms') + _line( + 'Pipeline runtime ', + bl['runtime_median'], + st['runtime_median'], + _delta_lower, + unit=' ms') _line( 'Batch lat p95 ', bl['batch_lat_p95'], @@ -614,23 +618,23 @@ def _line(label, bv, sv, delta_fn, fmt='.1f', unit=''): def main(): print("=" * 80) - print("BASELINE (count-based) vs STATELESS (weight-based boundary splitting)") + print("BASELINE (BatchElements) vs STATELESS (SortAndBatchElements)") print("=" * 80) print() print("Experiment design:") - print(" A = Baseline : BatchElements with max_batch_size=32 (count-based)") - print(" B = Stateless : SortAndBatchElements with max_batch_weight=2000") - print( - " (sort by size within bundle -> weight-based split)") + print(" A = Baseline : BatchElements with min=max=32") + print(" B = Stateless : SortAndBatchElements with max_batch_weight=2000") + print(" (sort within runner bundle, then split by weight)") print() - print("Why Stateless wins:") - print(" Weight-based splitting changes batch BOUNDARIES so each batch has") - print( - " similar-length elements -> less padding. Sorting alone within fixed") - print(" boundaries yields 0% gain (verified by strict-control experiment).") + print("Implementation notes:") + print(" - Runs beam.Create(...) pipelines on DirectRunner") + print(" - Materializes per-batch summaries through a temporary text sink") + print(" - Uses runner-defined bundle boundaries rather than a synthetic") + print(" bundle_size knob") print() print("Methodology:") print(" - N=20 trials, 3 warmup excluded") + print(" - DirectRunner, in_memory mode, single worker") print(" - Percentiles: linear interpolation (= numpy default)") print(" - Same seed=42 for both arms") print(" - Inference model: latency = batch_size * (max_seq_len/50)^1.5 ms") @@ -642,7 +646,6 @@ def main(): num_elements=10000, max_batch_size=32, max_batch_weight=2000, - bundle_size=500, distribution=dist, seed=42) print_results(r) diff --git a/sdks/python/apache_beam/transforms/util.py b/sdks/python/apache_beam/transforms/util.py index ede4035ac062..0e2693be7fcc 100644 --- a/sdks/python/apache_beam/transforms/util.py +++ b/sdks/python/apache_beam/transforms/util.py @@ -1336,7 +1336,7 @@ class _SortAndBatchElementsDoFn(DoFn): Must be >= ``min_batch_size``. max_batch_weight: The maximum total weight of elements in a batch, where weight is computed by ``element_size_fn``. Must be >= 1. - element_size_fn: A callable mapping an element to its integer + element_size_fn: An optional callable mapping an element to its integer size/weight. """ def __init__( @@ -1344,7 +1344,7 @@ def __init__( min_batch_size: int, max_batch_size: int, max_batch_weight: int, - element_size_fn: Callable[[Any], int]): + element_size_fn: Optional[Callable[[Any], int]]): self._min_batch_size = min_batch_size self._max_batch_size = max_batch_size self._max_batch_weight = max_batch_weight @@ -1412,8 +1412,11 @@ class _WindowAwareSortAndBatchElementsDoFn(DoFn): This DoFn is used internally by ``SortAndBatchElements`` for PCollections with non-default (e.g. fixed, sliding, or session) windows. Elements are buffered per window and each window is flushed independently. - To prevent unbounded memory growth, when the number of live windows - exceeds ``_MAX_LIVE_WINDOWS`` the largest window buffer is flushed early. + To prevent a single bundle from retaining too many per-window buffers at + once, when the number of live windows exceeds ``_MAX_LIVE_WINDOWS`` the + largest window buffer is flushed early. This DoFn reuses + ``_WindowAwareBatchingDoFn._MAX_LIVE_WINDOWS`` so it follows the same + existing window-aware batching behavior already used in this module. Args: min_batch_size: The minimum number of elements per batch. Must be >= 1. @@ -1421,18 +1424,18 @@ class _WindowAwareSortAndBatchElementsDoFn(DoFn): Must be >= ``min_batch_size``. max_batch_weight: The maximum total weight of elements in a batch, where weight is computed by ``element_size_fn``. Must be >= 1. - element_size_fn: A callable mapping an element to its integer + element_size_fn: An optional callable mapping an element to its integer size/weight. """ - _MAX_LIVE_WINDOWS = 10 + _MAX_LIVE_WINDOWS = _WindowAwareBatchingDoFn._MAX_LIVE_WINDOWS def __init__( self, min_batch_size: int, max_batch_size: int, max_batch_weight: int, - element_size_fn: Callable[[Any], int]): + element_size_fn: Optional[Callable[[Any], int]]): self._min_batch_size = min_batch_size self._max_batch_size = max_batch_size self._max_batch_weight = max_batch_weight @@ -1590,13 +1593,12 @@ def expand(self, pcoll): self._max_batch_size, self._max_batch_weight, self._element_size_fn)) - else: - return pcoll | ParDo( - _WindowAwareSortAndBatchElementsDoFn( - self._min_batch_size, - self._max_batch_size, - self._max_batch_weight, - self._element_size_fn)) + return pcoll | ParDo( + _WindowAwareSortAndBatchElementsDoFn( + self._min_batch_size, + self._max_batch_size, + self._max_batch_weight, + self._element_size_fn)) class _IdentityWindowFn(NonMergingWindowFn): diff --git a/sdks/python/apache_beam/transforms/util_test.py b/sdks/python/apache_beam/transforms/util_test.py index 5dd30f4c05cc..2e03352ef1d0 100644 --- a/sdks/python/apache_beam/transforms/util_test.py +++ b/sdks/python/apache_beam/transforms/util_test.py @@ -1038,14 +1038,8 @@ def test_elements_are_sorted_by_size(self): | beam.Create(data, reshuffle=False) | util.SortAndBatchElements( min_batch_size=1, max_batch_size=5, max_batch_weight=100)) - - def check_sorted(batch): - lengths = [len(s) for s in batch] - assert lengths == sorted(lengths), ( - f'Batch not sorted by size: {lengths}') - return batch - - _ = res | beam.Map(check_sorted) + # All elements fit in one batch, so the expected order is explicit. + assert_that(res, equal_to([['a', 'bb', 'ddd', 'cccc', 'aaaaa']])) def test_batch_respects_max_batch_size(self): """Test that batches do not exceed max_batch_size.""" @@ -1200,47 +1194,9 @@ def test_batch_timestamps(self): min_batch_size=1, max_batch_size=10, max_batch_weight=100) | beam.Map(lambda batch, ts=beam.DoFn.TimestampParam: (len(batch), ts))) - assert_that(res, equal_to([(3, GlobalWindow().max_timestamp())])) - - def test_padding_efficiency_improvement(self): - """Test that sorting improves padding efficiency.""" - # This test verifies the core value proposition of SortAndBatchElements - data = ['a', 'aaaaa', 'aa', 'aaaa', 'aaa'] - - # Compute what BatchElements would produce (preserves input order) - batch_elements_batches = [] - with TestPipeline() as p: - _ = ( - p - | 'Create1' >> beam.Create(data, reshuffle=False) - | util.BatchElements(min_batch_size=5, max_batch_size=5) - | beam.Map(lambda b: batch_elements_batches.append(list(b)))) - - # Compute what SortAndBatchElements produces - sort_batch_batches = [] - with TestPipeline() as p: - _ = ( - p - | 'Create2' >> beam.Create(data, reshuffle=False) - | util.SortAndBatchElements( - min_batch_size=1, max_batch_size=5, max_batch_weight=100) - | beam.Map(lambda b: sort_batch_batches.append(list(b)))) - - # Calculate padding overhead for each approach - # Padding overhead: - # sum(max_len_in_batch * batch_size) - sum(actual_lengths) - def compute_overhead(batches): - overhead = 0 - for batch in batches: - lengths = [len(s) for s in batch] - overhead += max(lengths) * len(batch) - sum(lengths) - return overhead - - batch_overhead = compute_overhead(batch_elements_batches) - sort_overhead = compute_overhead(sort_batch_batches) - - # SortAndBatchElements should have less or equal overhead - self.assertLessEqual(sort_overhead, batch_overhead) + # The single global-window batch is emitted at end-of-window. + expected = [(3, GlobalWindow().max_timestamp())] + assert_that(res, equal_to(expected)) class SortAndBatchElementsDoFnDirectTest(unittest.TestCase): From 968f09f566af7ceeacae955864386819067dd075 Mon Sep 17 00:00:00 2001 From: Eliaazzz Date: Sun, 26 Apr 2026 00:06:45 +1000 Subject: [PATCH 6/6] Address all review comments - Reuse _WindowAwareBatchingDoFn._MAX_LIVE_WINDOWS instead of keeping a separate hard-coded limit in SortAndBatchElements. - Drop the padding-efficiency unit test that compared incongruent batching strategies and keep the transform tests focused on deterministic behavior. - Align benchmark typing with modern Python style by using collections.abc imports and native built-in generics. - Make the sorted-order test clearer by naming the expected batch contents explicitly. --- sdks/python/apache_beam/transforms/util_test.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sdks/python/apache_beam/transforms/util_test.py b/sdks/python/apache_beam/transforms/util_test.py index 2e03352ef1d0..30f34b59d3f5 100644 --- a/sdks/python/apache_beam/transforms/util_test.py +++ b/sdks/python/apache_beam/transforms/util_test.py @@ -1033,13 +1033,14 @@ def test_elements_are_sorted_by_size(self): with TestPipeline() as p: # Create elements with varying sizes data = ['aaaaa', 'bb', 'cccc', 'a', 'ddd'] + expected = [['a', 'bb', 'ddd', 'cccc', 'aaaaa']] res = ( p | beam.Create(data, reshuffle=False) | util.SortAndBatchElements( min_batch_size=1, max_batch_size=5, max_batch_weight=100)) # All elements fit in one batch, so the expected order is explicit. - assert_that(res, equal_to([['a', 'bb', 'ddd', 'cccc', 'aaaaa']])) + assert_that(res, equal_to(expected)) def test_batch_respects_max_batch_size(self): """Test that batches do not exceed max_batch_size."""