diff --git a/sdks/python/apache_beam/io/gcp/bigquery_change_history.py b/sdks/python/apache_beam/io/gcp/bigquery_change_history.py index ba90c4a8963d..dad56d26e499 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery_change_history.py +++ b/sdks/python/apache_beam/io/gcp/bigquery_change_history.py @@ -711,9 +711,13 @@ class _ReadStorageStreamsSDF(beam.DoFn, def __init__( self, batch_arrow_read: bool = True, - change_timestamp_column: str = 'change_timestamp') -> None: + change_timestamp_column: str = 'change_timestamp', + max_split_rounds: int = 1, + emit_raw_batches: bool = False) -> None: self._batch_arrow_read = batch_arrow_read self._change_timestamp_column = change_timestamp_column + self._max_split_rounds = max_split_rounds + self._emit_raw_batches = emit_raw_batches self._storage_client = None def _ensure_client(self) -> None: @@ -730,16 +734,80 @@ def _ensure_client(self) -> None: def setup(self) -> None: self._ensure_client() + def _split_all_streams( + self, stream_names: Tuple[str, ...], + max_split_rounds: int) -> Tuple[str, ...]: + """Split each stream at fraction=0.5 for up to max_split_rounds rounds. + + Each round attempts to split every stream in the current list. A + successful split replaces the original stream with primary + remainder. + A refused split (both fields empty) keeps the original stream intact. + Stops when max_split_rounds is reached or a full round produces zero + new splits. + + BQ's server-side granularity controls how many splits are possible. + Small tables may not split at all; large tables may allow multiple + rounds of doubling. + """ + result = list(stream_names) + no_split = set() + for round_num in range(1, max_split_rounds + 1): + new_result = [] + made_progress = False + for name in result: + if name in no_split: + new_result.append(name) + continue + response = self._storage_client.split_read_stream( + request=bq_storage.types.SplitReadStreamRequest( + name=name, fraction=0.5)) + primary = response.primary_stream.name + remainder = response.remainder_stream.name + if primary and remainder: + new_result.extend([primary, remainder]) + made_progress = True + else: + new_result.append(name) + no_split.add(name) + result = new_result + _LOGGER.info( + '[Read] _split_all_streams round %d/%d: %d streams ' + '(progress=%s)', + round_num, + max_split_rounds, + len(result), + made_progress) + if not made_progress: + break + return tuple(result) + def initial_restriction(self, element: _QueryResult) -> _StreamRestriction: - """Create ReadSession and return _StreamRestriction with stream names.""" + """Create ReadSession and return _StreamRestriction with stream names. + + When max_split_rounds > 0, uses SplitReadStream to subdivide each + stream at fraction=0.5 for up to max_split_rounds rounds, maximizing + parallelism beyond what CreateReadSession provides. + """ self._ensure_client() table_key = bigquery_tools.get_hashable_destination(element.temp_table_ref) session = self._create_read_session(element.temp_table_ref) stream_names = tuple(s.name for s in session.streams) + original_count = len(stream_names) _LOGGER.info( - '[Read] initial_restriction for %s: %d streams', + '[Read] initial_restriction for %s: %d streams from CreateReadSession', table_key, - len(stream_names)) + original_count) + + if self._max_split_rounds > 0: + stream_names = self._split_all_streams( + stream_names, self._max_split_rounds) + _LOGGER.info( + '[Read] initial_restriction for %s: %d -> %d streams ' + 'after SplitReadStream', + table_key, + original_count, + len(stream_names)) + return _StreamRestriction(stream_names, 0, len(stream_names)) def create_tracker( @@ -767,8 +835,7 @@ def process( element: _QueryResult, restriction_tracker=beam.DoFn.RestrictionParam(), watermark_estimator=beam.DoFn.WatermarkEstimatorParam( - _CDCWatermarkEstimatorProvider()) - ) -> Iterable[Dict[str, Any]]: + _CDCWatermarkEstimatorProvider())): self._ensure_client() table_key = bigquery_tools.get_hashable_destination(element.temp_table_ref) @@ -785,7 +852,6 @@ def process( total_streams = len(stream_names) streams_read = 0 - total_rows = 0 _LOGGER.info( '[Read] Reading streams [%d, %d) of %d total for %s', @@ -808,19 +874,27 @@ def process( '[Read] try_claim(%d) succeeded: reading stream %s', i, stream_name) stream_rows = 0 - for row in self._read_stream(stream_name): - ts = row.get(self._change_timestamp_column) - if ts is None: - raise ValueError( - 'Row missing %r column. Row keys: %s' % - (self._change_timestamp_column, list(row.keys()))) - if isinstance(ts, datetime.datetime): - ts = Timestamp.from_utc_datetime(ts) - - yield TimestampedValue(row, ts) - stream_rows += 1 - total_rows += 1 - Metrics.counter('BigQueryChangeHistory', 'rows_emitted').inc(total_rows) + if self._emit_raw_batches: + stream_batches = 0 + for raw_batch in self._read_stream_raw(stream_name): + yield TimestampedValue(raw_batch, element.range_start) + stream_batches += 1 + Metrics.counter('BigQueryChangeHistory', + 'batches_emitted').inc(stream_batches) + else: + for row in self._read_stream(stream_name): + ts = row.get(self._change_timestamp_column) + if ts is None: + raise ValueError( + 'Row missing %r column. Row keys: %s' % + (self._change_timestamp_column, list(row.keys()))) + if isinstance(ts, datetime.datetime): + ts = Timestamp.from_utc_datetime(ts) + + yield TimestampedValue(row, ts) + stream_rows += 1 + Metrics.counter('BigQueryChangeHistory', + 'rows_emitted').inc(stream_rows) streams_read += 1 _LOGGER.info( @@ -838,16 +912,19 @@ def process( _utc(element.range_end), table_key) + # Release the storage client so the gRPC channel doesn't go stale + # between process() calls. _ensure_client() will create a fresh one. + self._storage_client = None + # Emit cleanup signal. Every split that reads at least one stream # reports how many it read. if streams_read > 0: _LOGGER.info( '[Read] Emitting cleanup signal for %s: ' - 'streams_read=%d, total_streams=%d, total_rows=%d', + 'streams_read=%d, total_streams=%d', table_key, streams_read, - total_streams, - total_rows) + total_streams) yield beam.pvalue.TaggedOutput( _CLEANUP_TAG, (table_key, (streams_read, total_streams))) @@ -863,7 +940,7 @@ def _create_read_session(self, table_ref: 'bigquery.TableReference') -> Any: requested_session.data_format = bq_storage.types.DataFormat.ARROW read_options = requested_session.read_options read_options.arrow_serialization_options.buffer_compression = ( - bq_storage.types.ArrowSerializationOptions.CompressionCodec.LZ4_FRAME) + bq_storage.types.ArrowSerializationOptions.CompressionCodec.ZSTD) session = self._storage_client.create_read_session( parent=f'projects/{table_ref.projectId}', @@ -879,7 +956,7 @@ def _read_stream(self, stream_name: str) -> Iterable[Dict[str, Any]]: """Read all rows from a single Storage API stream as dicts. When batch_arrow_read is enabled, converts entire Arrow RecordBatches - at once using to_pydict() instead of calling .as_py() on each cell + at once using to_pylist() instead of calling .as_py() on each cell individually. This is ~1.5x faster for large tables at the cost of ~2x peak memory per batch. """ @@ -925,6 +1002,56 @@ def _read_stream_batch(self, stream_name: str) -> Iterable[Dict[str, Any]]: elapsed, row_count / elapsed if elapsed > 0 else 0) + def _read_stream_raw(self, stream_name: str) -> Iterable[Tuple[bytes, bytes]]: + """Yield raw (schema_bytes, batch_bytes) without decompression. + + Used when emit_raw_batches is enabled to defer decompression and + Arrow-to-Python conversion to a downstream DoFn after reshuffling. + Schema bytes are included in each tuple so each batch is + self-contained and can be decoded independently. + """ + schema_bytes = b'' + batch_count = 0 + t0 = time.time() + for response in self._storage_client.read_rows(stream_name): + if not schema_bytes and response.arrow_schema.serialized_schema: + schema_bytes = bytes(response.arrow_schema.serialized_schema) + batch_bytes = response.arrow_record_batch.serialized_record_batch + if batch_bytes and schema_bytes: + yield (schema_bytes, bytes(batch_bytes)) + batch_count += 1 + elapsed = time.time() - t0 + _LOGGER.info('[Read] raw_read: %d batches in %.2fs', batch_count, elapsed) + + +class _DecompressArrowBatchesFn(beam.DoFn): + """Decompress and convert raw Arrow batches to timestamped row dicts. + + Receives individual (schema_bytes, batch_bytes) tuples after Reshuffle + and converts each batch to individual row dicts with event timestamps + extracted from the change_timestamp column. + """ + def __init__(self, change_timestamp_column: str = 'change_timestamp') -> None: + self._change_timestamp_column = change_timestamp_column + + def process(self, element: Tuple[bytes, bytes]) -> Iterable[Dict[str, Any]]: + schema_bytes, batch_bytes = element + schema = pyarrow.ipc.read_schema(pyarrow.py_buffer(schema_bytes)) + batch = pyarrow.ipc.read_record_batch( + pyarrow.py_buffer(batch_bytes), schema) + + rows = batch.to_pylist() + for row in rows: + ts = row.get(self._change_timestamp_column) + if ts is None: + raise ValueError( + 'Row missing %r column. Row keys: %s' % + (self._change_timestamp_column, list(row.keys()))) + if isinstance(ts, datetime.datetime): + ts = Timestamp.from_utc_datetime(ts) + yield TimestampedValue(row, ts) + Metrics.counter('BigQueryChangeHistory', 'rows_emitted').inc(len(rows)) + # ============================================================================= # Cleanup: _CleanupTempTablesFn @@ -1038,9 +1165,21 @@ class ReadBigQueryChangeHistory(beam.PTransform): on the CHANGES/APPENDS query. Do not include the WHERE keyword. Example: ``'status = "active" AND region = "US"'``. batch_arrow_read: If True (default), convert Arrow RecordBatches in - bulk using to_pydict() instead of per-cell .as_py() calls. + bulk using to_pylist() instead of per-cell .as_py() calls. This is 1.5x faster for large tables at the cost of ~2x peak memory per RecordBatch. Set to False for minimal memory usage. + max_split_rounds: Maximum number of recursive SplitReadStream + rounds. Each round splits every stream at fraction=0.5, + potentially doubling the stream count (if BQ allows). Default + 1 (one round of splitting). Set 0 to disable splitting + entirely. Set higher for very large tables where more + parallelism is needed. + reshuffle_decompress: If True (default), the Read SDF emits raw + compressed Arrow batches instead of decoded rows. The batches + are reshuffled for fan-out and then decoded in a separate DoFn. + This spreads decompression and Arrow-to-Python conversion CPU + across more workers. Set to False to decode rows inline within + the Read SDF. """ def __init__( self, @@ -1057,7 +1196,9 @@ def __init__( change_timestamp_column: str = 'change_timestamp', columns: Optional[List[str]] = None, row_filter: Optional[str] = None, - batch_arrow_read: bool = True) -> None: + batch_arrow_read: bool = True, + max_split_rounds: int = 1, + reshuffle_decompress: bool = True) -> None: super().__init__() if bq_storage is None: raise ImportError( @@ -1091,6 +1232,8 @@ def __init__( self._columns = columns self._row_filter = row_filter self._batch_arrow_read = batch_arrow_read + self._max_split_rounds = max_split_rounds + self._reshuffle_decompress = reshuffle_decompress def expand(self, pbegin: beam.pvalue.PBegin) -> beam.PCollection: project = self._project @@ -1170,16 +1313,37 @@ def expand(self, pbegin: beam.pvalue.PBegin) -> beam.PCollection: row_filter=self._row_filter)) | 'CommitQueryResults' >> beam.Reshuffle()) + emit_raw = self._reshuffle_decompress + + read_sdf = beam.ParDo( + _ReadStorageStreamsSDF( + batch_arrow_read=self._batch_arrow_read, + change_timestamp_column=self._change_timestamp_column, + max_split_rounds=self._max_split_rounds, + emit_raw_batches=emit_raw)) + if emit_raw: + read_sdf = read_sdf.with_output_types(Tuple[bytes, bytes]) + else: + read_sdf = read_sdf.with_output_types(Dict[str, Any]) + read_outputs = ( query_results - | 'ReadStorageStreams' >> beam.ParDo( - _ReadStorageStreamsSDF( - batch_arrow_read=self._batch_arrow_read, - change_timestamp_column=self._change_timestamp_column)). - with_outputs(_CLEANUP_TAG, main='rows')) + | 'ReadStorageStreams' >> read_sdf.with_outputs( + _CLEANUP_TAG, main='rows')) _ = ( read_outputs[_CLEANUP_TAG] | 'CleanupTempTables' >> beam.ParDo(_CleanupTempTablesFn())) - return read_outputs['rows'] + if emit_raw: + # Reshuffle raw Arrow batches for fan-out, then decompress and + # convert to timestamped row dicts in a separate DoFn. + rows = ( + read_outputs['rows'] + | 'ReshuffleForFanout' >> beam.Reshuffle() + | 'DecompressBatches' >> beam.ParDo( + _DecompressArrowBatchesFn( + change_timestamp_column=(self._change_timestamp_column)))) + return rows + else: + return read_outputs['rows'] diff --git a/sdks/python/apache_beam/io/gcp/bigquery_change_history_it_test.py b/sdks/python/apache_beam/io/gcp/bigquery_change_history_it_test.py index 0af5a80fd434..ef41fc393af7 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery_change_history_it_test.py +++ b/sdks/python/apache_beam/io/gcp/bigquery_change_history_it_test.py @@ -490,6 +490,19 @@ def check_rows(actual): assert_that(rows, check_rows) +_EXPECTED_ROWS = [ + { + 'id': 1, 'name': 'alice', 'value': 10.0, 'change_type': 'INSERT' + }, + { + 'id': 2, 'name': 'bob', 'value': 20.0, 'change_type': 'INSERT' + }, + { + 'id': 3, 'name': 'charlie', 'value': 30.0, 'change_type': 'INSERT' + }, +] + + class EndToEndTest(BigQueryChangeHistoryIntegrationBase): """End-to-end test using the public ReadBigQueryChangeHistory API. @@ -545,25 +558,77 @@ def check_rows(actual): for k, v in row.items() if k != 'change_timestamp' } for row in actual], key=lambda r: r['id']) - expected = [ - { - 'id': 1, - 'name': 'alice', - 'value': 10.0, - 'change_type': 'INSERT' - }, - { - 'id': 2, 'name': 'bob', 'value': 20.0, 'change_type': 'INSERT' - }, - { - 'id': 3, - 'name': 'charlie', - 'value': 30.0, - 'change_type': 'INSERT' - }, - ] - assert got == expected, ( - f'Row mismatch:\n got: {got}\n expected: {expected}') + assert got == _EXPECTED_ROWS, ( + f'Row mismatch:\n got: {got}\n expected: ' + f'{_EXPECTED_ROWS}') + + assert_that(rows, check_rows) + + def test_public_api_reads_inserted_row_inline_decompress(self): + """ReadBigQueryChangeHistory with inline decompression (no reshuffle).""" + table_str = f'{self.project}:{self.dataset}.{self.test_table_id}' + start_time = self.insert_time - 120 # 2 min before insert + stop_time = time.time() + 15 + + with beam.Pipeline(argv=self.args) as p: + rows = ( + p + | ReadBigQueryChangeHistory( + table=table_str, + poll_interval_sec=15, + start_time=start_time, + stop_time=stop_time, + change_function='APPENDS', + buffer_sec=10, + project=self.project, + temp_dataset=self.temp_dataset, + location=self.location, + reshuffle_decompress=False)) + + def check_rows(actual): + assert len(actual) == 3, f'Expected 3 rows, got {len(actual)}' + got = sorted([{ + k: v + for k, v in row.items() if k != 'change_timestamp' + } for row in actual], + key=lambda r: r['id']) + assert got == _EXPECTED_ROWS, ( + f'Row mismatch:\n got: {got}\n expected: ' + f'{_EXPECTED_ROWS}') + + assert_that(rows, check_rows) + + def test_public_api_reads_inserted_row_with_split_streams(self): + """ReadBigQueryChangeHistory with max_split_rounds=1.""" + table_str = f'{self.project}:{self.dataset}.{self.test_table_id}' + start_time = self.insert_time - 120 # 2 min before insert + stop_time = time.time() + 15 + + with beam.Pipeline(argv=self.args) as p: + rows = ( + p + | ReadBigQueryChangeHistory( + table=table_str, + poll_interval_sec=15, + start_time=start_time, + stop_time=stop_time, + change_function='APPENDS', + buffer_sec=10, + project=self.project, + temp_dataset=self.temp_dataset, + location=self.location, + max_split_rounds=10)) + + def check_rows(actual): + assert len(actual) == 3, f'Expected 3 rows, got {len(actual)}' + got = sorted([{ + k: v + for k, v in row.items() if k != 'change_timestamp' + } for row in actual], + key=lambda r: r['id']) + assert got == _EXPECTED_ROWS, ( + f'Row mismatch:\n got: {got}\n expected: ' + f'{_EXPECTED_ROWS}') assert_that(rows, check_rows)