Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
230 changes: 197 additions & 33 deletions sdks/python/apache_beam/io/gcp/bigquery_change_history.py
Original file line number Diff line number Diff line change
Expand Up @@ -711,9 +711,13 @@ class _ReadStorageStreamsSDF(beam.DoFn,
def __init__(
self,
batch_arrow_read: bool = True,
change_timestamp_column: str = 'change_timestamp') -> None:
change_timestamp_column: str = 'change_timestamp',
max_split_rounds: int = 1,
emit_raw_batches: bool = False) -> None:
self._batch_arrow_read = batch_arrow_read
self._change_timestamp_column = change_timestamp_column
self._max_split_rounds = max_split_rounds
self._emit_raw_batches = emit_raw_batches
self._storage_client = None

def _ensure_client(self) -> None:
Expand All @@ -730,16 +734,80 @@ def _ensure_client(self) -> None:
def setup(self) -> None:
self._ensure_client()

def _split_all_streams(
self, stream_names: Tuple[str, ...],
max_split_rounds: int) -> Tuple[str, ...]:
"""Split each stream at fraction=0.5 for up to max_split_rounds rounds.

Each round attempts to split every stream in the current list. A
successful split replaces the original stream with primary + remainder.
A refused split (both fields empty) keeps the original stream intact.
Stops when max_split_rounds is reached or a full round produces zero
new splits.

BQ's server-side granularity controls how many splits are possible.
Small tables may not split at all; large tables may allow multiple
rounds of doubling.
"""
result = list(stream_names)
no_split = set()
for round_num in range(1, max_split_rounds + 1):
new_result = []
made_progress = False
for name in result:
if name in no_split:
new_result.append(name)
continue
response = self._storage_client.split_read_stream(
request=bq_storage.types.SplitReadStreamRequest(
name=name, fraction=0.5))
Comment thread
claudevdm marked this conversation as resolved.
primary = response.primary_stream.name
remainder = response.remainder_stream.name
if primary and remainder:
new_result.extend([primary, remainder])
made_progress = True
else:
new_result.append(name)
Comment thread
claudevdm marked this conversation as resolved.
no_split.add(name)
result = new_result
_LOGGER.info(
'[Read] _split_all_streams round %d/%d: %d streams '
'(progress=%s)',
round_num,
max_split_rounds,
len(result),
made_progress)
if not made_progress:
break
return tuple(result)

def initial_restriction(self, element: _QueryResult) -> _StreamRestriction:
"""Create ReadSession and return _StreamRestriction with stream names."""
"""Create ReadSession and return _StreamRestriction with stream names.

When max_split_rounds > 0, uses SplitReadStream to subdivide each
stream at fraction=0.5 for up to max_split_rounds rounds, maximizing
parallelism beyond what CreateReadSession provides.
"""
self._ensure_client()
table_key = bigquery_tools.get_hashable_destination(element.temp_table_ref)
session = self._create_read_session(element.temp_table_ref)
stream_names = tuple(s.name for s in session.streams)
original_count = len(stream_names)
_LOGGER.info(
'[Read] initial_restriction for %s: %d streams',
'[Read] initial_restriction for %s: %d streams from CreateReadSession',
table_key,
len(stream_names))
original_count)

if self._max_split_rounds > 0:
stream_names = self._split_all_streams(
stream_names, self._max_split_rounds)
_LOGGER.info(
'[Read] initial_restriction for %s: %d -> %d streams '
'after SplitReadStream',
table_key,
original_count,
len(stream_names))

return _StreamRestriction(stream_names, 0, len(stream_names))

def create_tracker(
Expand Down Expand Up @@ -767,8 +835,7 @@ def process(
element: _QueryResult,
restriction_tracker=beam.DoFn.RestrictionParam(),
watermark_estimator=beam.DoFn.WatermarkEstimatorParam(
_CDCWatermarkEstimatorProvider())
) -> Iterable[Dict[str, Any]]:
_CDCWatermarkEstimatorProvider())):
self._ensure_client()
table_key = bigquery_tools.get_hashable_destination(element.temp_table_ref)

Expand All @@ -785,7 +852,6 @@ def process(
total_streams = len(stream_names)

streams_read = 0
total_rows = 0

_LOGGER.info(
'[Read] Reading streams [%d, %d) of %d total for %s',
Expand All @@ -808,19 +874,27 @@ def process(
'[Read] try_claim(%d) succeeded: reading stream %s', i, stream_name)

stream_rows = 0
for row in self._read_stream(stream_name):
ts = row.get(self._change_timestamp_column)
if ts is None:
raise ValueError(
'Row missing %r column. Row keys: %s' %
(self._change_timestamp_column, list(row.keys())))
if isinstance(ts, datetime.datetime):
ts = Timestamp.from_utc_datetime(ts)

yield TimestampedValue(row, ts)
stream_rows += 1
total_rows += 1
Metrics.counter('BigQueryChangeHistory', 'rows_emitted').inc(total_rows)
if self._emit_raw_batches:
stream_batches = 0
for raw_batch in self._read_stream_raw(stream_name):
yield TimestampedValue(raw_batch, element.range_start)
stream_batches += 1
Metrics.counter('BigQueryChangeHistory',
'batches_emitted').inc(stream_batches)
else:
for row in self._read_stream(stream_name):
ts = row.get(self._change_timestamp_column)
if ts is None:
raise ValueError(
'Row missing %r column. Row keys: %s' %
(self._change_timestamp_column, list(row.keys())))
if isinstance(ts, datetime.datetime):
ts = Timestamp.from_utc_datetime(ts)

yield TimestampedValue(row, ts)
stream_rows += 1
Metrics.counter('BigQueryChangeHistory',
'rows_emitted').inc(stream_rows)
Comment thread
claudevdm marked this conversation as resolved.

streams_read += 1
_LOGGER.info(
Expand All @@ -838,16 +912,19 @@ def process(
_utc(element.range_end),
table_key)

# Release the storage client so the gRPC channel doesn't go stale
# between process() calls. _ensure_client() will create a fresh one.
self._storage_client = None
Comment thread
claudevdm marked this conversation as resolved.

# Emit cleanup signal. Every split that reads at least one stream
# reports how many it read.
if streams_read > 0:
_LOGGER.info(
'[Read] Emitting cleanup signal for %s: '
'streams_read=%d, total_streams=%d, total_rows=%d',
'streams_read=%d, total_streams=%d',
table_key,
streams_read,
total_streams,
total_rows)
total_streams)
yield beam.pvalue.TaggedOutput(
_CLEANUP_TAG, (table_key, (streams_read, total_streams)))

Expand All @@ -863,7 +940,7 @@ def _create_read_session(self, table_ref: 'bigquery.TableReference') -> Any:
requested_session.data_format = bq_storage.types.DataFormat.ARROW
read_options = requested_session.read_options
read_options.arrow_serialization_options.buffer_compression = (
bq_storage.types.ArrowSerializationOptions.CompressionCodec.LZ4_FRAME)
bq_storage.types.ArrowSerializationOptions.CompressionCodec.ZSTD)

session = self._storage_client.create_read_session(
parent=f'projects/{table_ref.projectId}',
Expand All @@ -879,7 +956,7 @@ def _read_stream(self, stream_name: str) -> Iterable[Dict[str, Any]]:
"""Read all rows from a single Storage API stream as dicts.

When batch_arrow_read is enabled, converts entire Arrow RecordBatches
at once using to_pydict() instead of calling .as_py() on each cell
at once using to_pylist() instead of calling .as_py() on each cell
individually. This is ~1.5x faster for large tables at the cost of ~2x
peak memory per batch.
"""
Expand Down Expand Up @@ -925,6 +1002,56 @@ def _read_stream_batch(self, stream_name: str) -> Iterable[Dict[str, Any]]:
elapsed,
row_count / elapsed if elapsed > 0 else 0)

def _read_stream_raw(self, stream_name: str) -> Iterable[Tuple[bytes, bytes]]:
"""Yield raw (schema_bytes, batch_bytes) without decompression.

Used when emit_raw_batches is enabled to defer decompression and
Arrow-to-Python conversion to a downstream DoFn after reshuffling.
Schema bytes are included in each tuple so each batch is
self-contained and can be decoded independently.
"""
schema_bytes = b''
batch_count = 0
t0 = time.time()
for response in self._storage_client.read_rows(stream_name):
if not schema_bytes and response.arrow_schema.serialized_schema:
schema_bytes = bytes(response.arrow_schema.serialized_schema)
batch_bytes = response.arrow_record_batch.serialized_record_batch
if batch_bytes and schema_bytes:
yield (schema_bytes, bytes(batch_bytes))
batch_count += 1
elapsed = time.time() - t0
_LOGGER.info('[Read] raw_read: %d batches in %.2fs', batch_count, elapsed)


class _DecompressArrowBatchesFn(beam.DoFn):
"""Decompress and convert raw Arrow batches to timestamped row dicts.

Receives individual (schema_bytes, batch_bytes) tuples after Reshuffle
and converts each batch to individual row dicts with event timestamps
extracted from the change_timestamp column.
"""
def __init__(self, change_timestamp_column: str = 'change_timestamp') -> None:
self._change_timestamp_column = change_timestamp_column

def process(self, element: Tuple[bytes, bytes]) -> Iterable[Dict[str, Any]]:
schema_bytes, batch_bytes = element
schema = pyarrow.ipc.read_schema(pyarrow.py_buffer(schema_bytes))
batch = pyarrow.ipc.read_record_batch(
pyarrow.py_buffer(batch_bytes), schema)

rows = batch.to_pylist()
for row in rows:
ts = row.get(self._change_timestamp_column)
if ts is None:
raise ValueError(
'Row missing %r column. Row keys: %s' %
(self._change_timestamp_column, list(row.keys())))
if isinstance(ts, datetime.datetime):
ts = Timestamp.from_utc_datetime(ts)
yield TimestampedValue(row, ts)
Metrics.counter('BigQueryChangeHistory', 'rows_emitted').inc(len(rows))


# =============================================================================
# Cleanup: _CleanupTempTablesFn
Expand Down Expand Up @@ -1038,9 +1165,21 @@ class ReadBigQueryChangeHistory(beam.PTransform):
on the CHANGES/APPENDS query. Do not include the WHERE keyword.
Example: ``'status = "active" AND region = "US"'``.
batch_arrow_read: If True (default), convert Arrow RecordBatches in
bulk using to_pydict() instead of per-cell .as_py() calls.
bulk using to_pylist() instead of per-cell .as_py() calls.
This is 1.5x faster for large tables at the cost of ~2x peak
memory per RecordBatch. Set to False for minimal memory usage.
max_split_rounds: Maximum number of recursive SplitReadStream
rounds. Each round splits every stream at fraction=0.5,
potentially doubling the stream count (if BQ allows). Default
1 (one round of splitting). Set 0 to disable splitting
entirely. Set higher for very large tables where more
parallelism is needed.
reshuffle_decompress: If True (default), the Read SDF emits raw
compressed Arrow batches instead of decoded rows. The batches
are reshuffled for fan-out and then decoded in a separate DoFn.
This spreads decompression and Arrow-to-Python conversion CPU
across more workers. Set to False to decode rows inline within
the Read SDF.
"""
def __init__(
self,
Expand All @@ -1057,7 +1196,9 @@ def __init__(
change_timestamp_column: str = 'change_timestamp',
columns: Optional[List[str]] = None,
row_filter: Optional[str] = None,
batch_arrow_read: bool = True) -> None:
batch_arrow_read: bool = True,
max_split_rounds: int = 1,
reshuffle_decompress: bool = True) -> None:
super().__init__()
if bq_storage is None:
raise ImportError(
Expand Down Expand Up @@ -1091,6 +1232,8 @@ def __init__(
self._columns = columns
self._row_filter = row_filter
self._batch_arrow_read = batch_arrow_read
self._max_split_rounds = max_split_rounds
self._reshuffle_decompress = reshuffle_decompress

def expand(self, pbegin: beam.pvalue.PBegin) -> beam.PCollection:
project = self._project
Expand Down Expand Up @@ -1170,16 +1313,37 @@ def expand(self, pbegin: beam.pvalue.PBegin) -> beam.PCollection:
row_filter=self._row_filter))
| 'CommitQueryResults' >> beam.Reshuffle())

emit_raw = self._reshuffle_decompress

read_sdf = beam.ParDo(
_ReadStorageStreamsSDF(
batch_arrow_read=self._batch_arrow_read,
change_timestamp_column=self._change_timestamp_column,
max_split_rounds=self._max_split_rounds,
emit_raw_batches=emit_raw))
if emit_raw:
read_sdf = read_sdf.with_output_types(Tuple[bytes, bytes])
else:
read_sdf = read_sdf.with_output_types(Dict[str, Any])

read_outputs = (
query_results
| 'ReadStorageStreams' >> beam.ParDo(
_ReadStorageStreamsSDF(
batch_arrow_read=self._batch_arrow_read,
change_timestamp_column=self._change_timestamp_column)).
with_outputs(_CLEANUP_TAG, main='rows'))
| 'ReadStorageStreams' >> read_sdf.with_outputs(
_CLEANUP_TAG, main='rows'))

_ = (
read_outputs[_CLEANUP_TAG]
| 'CleanupTempTables' >> beam.ParDo(_CleanupTempTablesFn()))

return read_outputs['rows']
if emit_raw:
# Reshuffle raw Arrow batches for fan-out, then decompress and
# convert to timestamped row dicts in a separate DoFn.
rows = (
read_outputs['rows']
| 'ReshuffleForFanout' >> beam.Reshuffle()
| 'DecompressBatches' >> beam.ParDo(
_DecompressArrowBatchesFn(
change_timestamp_column=(self._change_timestamp_column))))
return rows
else:
return read_outputs['rows']
Loading
Loading