Skip to content
This repository was archived by the owner on Jun 30, 2022. It is now read-only.

Commit cb1a6bb

Browse files
gildeasilviulica
authored andcommitted
Allow operations to override the coder passed to update_counters
Some operations want to override at read or write time the coder that was declared when the work item was initialized, so they can provide a coder specific to this element. ----Release Notes---- [] ------------- Created by MOE: https://github.com/google/moe MOE_MIGRATED_REVID=120007941
1 parent 8484d34 commit cb1a6bb

3 files changed

Lines changed: 22 additions & 12 deletions

File tree

google/cloud/dataflow/worker/executor.pxd

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ cdef class Operation(object):
2828
cpdef finish(self)
2929

3030
@cython.locals(receiver=Operation)
31-
cpdef output(self, windowed_value, int output_index=*)
31+
cpdef output(self, windowed_value, object coder=*, int output_index=*)
3232

3333
cdef class ReadOperation(Operation):
3434
cdef object _current_progress
@@ -45,6 +45,7 @@ cdef class CombineOperation(Operation):
4545
cdef class ShuffleWriteOperation(Operation):
4646
cdef object shuffle_sink
4747
cdef object writer
48+
cdef object _write_coder
4849
cdef bint is_ungrouped
4950

5051
cdef class GroupedShuffleReadOperation(Operation):

google/cloud/dataflow/worker/executor.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323

2424

2525
from google.cloud.dataflow.coders import BytesCoder
26+
from google.cloud.dataflow.coders import TupleCoder
27+
from google.cloud.dataflow.coders import WindowedValueCoder
2628
from google.cloud.dataflow.internal import pickler
2729
from google.cloud.dataflow.pvalue import EmptySideInput
2830
from google.cloud.dataflow.runners import common
@@ -65,15 +67,15 @@ def start(self, step_name):
6567
self.opcounter = opcounters.OperationCounters(
6668
self.counter_factory, step_name, self.coder, self.output_index)
6769

68-
def output(self, windowed_value):
69-
self.update_counters_start(windowed_value)
70+
def output(self, windowed_value, coder=None):
71+
self.update_counters_start(windowed_value, coder)
7072
for receiver in self.receivers:
7173
receiver.process(windowed_value)
7274
self.update_counters_finish()
7375

74-
def update_counters_start(self, windowed_value):
76+
def update_counters_start(self, windowed_value, coder=None):
7577
if self.opcounter:
76-
self.opcounter.update_from(windowed_value)
78+
self.opcounter.update_from(windowed_value, coder)
7779

7880
def update_counters_finish(self):
7981
if self.opcounter:
@@ -130,8 +132,8 @@ def process(self, o):
130132
"""Process element in operation."""
131133
pass
132134

133-
def output(self, windowed_value, output_index=0):
134-
self.receivers[output_index].output(windowed_value)
135+
def output(self, windowed_value, coder=None, output_index=0):
136+
self.receivers[output_index].output(windowed_value, coder)
135137

136138
def add_receiver(self, operation, output_index=0):
137139
"""Adds a receiver operation for the specified output."""
@@ -282,8 +284,10 @@ def __init__(self, spec, counter_factory, shuffle_source=None):
282284

283285
def start(self):
284286
super(GroupedShuffleReadOperation, self).start()
287+
write_coder = None
285288
if self.shuffle_source is None:
286289
coders = (self.spec.coder.key_coder(), self.spec.coder.value_coder())
290+
write_coder = WindowedValueCoder(TupleCoder(coders))
287291
self.shuffle_source = shuffle.GroupedShuffleSource(
288292
self.spec.shuffle_reader_config, coder=coders,
289293
start_position=self.spec.start_shuffle_position,
@@ -292,7 +296,7 @@ def start(self):
292296
for key, key_values in reader:
293297
self._reader = reader
294298
windowed_value = GlobalWindows.WindowedValue((key, key_values))
295-
self.output(windowed_value)
299+
self.output(windowed_value, coder=write_coder)
296300

297301
def get_progress(self):
298302
if self._reader is not None:
@@ -313,8 +317,10 @@ def __init__(self, spec, counter_factory, shuffle_source=None):
313317

314318
def start(self):
315319
super(UngroupedShuffleReadOperation, self).start()
320+
write_coder = None
316321
if self.shuffle_source is None:
317322
coders = (BytesCoder(), self.spec.coder)
323+
write_coder = WindowedValueCoder(TupleCoder(coders))
318324
self.shuffle_source = shuffle.UngroupedShuffleSource(
319325
self.spec.shuffle_reader_config, coder=coders,
320326
start_position=self.spec.start_shuffle_position,
@@ -323,7 +329,7 @@ def start(self):
323329
for value in reader:
324330
self._reader = reader
325331
windowed_value = GlobalWindows.WindowedValue(value)
326-
self.output(windowed_value)
332+
self.output(windowed_value, coder=write_coder)
327333

328334
def get_progress(self):
329335
# 'UngroupedShuffleReader' does not support progress reporting.
@@ -350,6 +356,7 @@ def start(self):
350356
coders = (BytesCoder(), coder)
351357
else:
352358
coders = (coder.key_coder(), coder.value_coder())
359+
self._write_coder = WindowedValueCoder(TupleCoder(coders))
353360
if self.shuffle_sink is None:
354361
self.shuffle_sink = shuffle.ShuffleSink(
355362
self.spec.shuffle_writer_config, coder=coders)
@@ -364,7 +371,7 @@ def process(self, o):
364371
if self.debug_logging_enabled:
365372
logging.debug('Processing [%s] in %s', o, self)
366373
assert isinstance(o, WindowedValue)
367-
self.receivers[0].update_counters_start(o)
374+
self.receivers[0].update_counters_start(o, coder=self._write_coder)
368375
# We typically write into shuffle key/value pairs. This is the reason why
369376
# the else branch below expects the value attribute of the WindowedValue
370377
# argument to be a KV pair. However the service may write to shuffle in

google/cloud/dataflow/worker/opcounters.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,15 +31,17 @@ def __init__(self, counter_factory, step_name, coder, output_index):
3131
self.coder = coder
3232
self._active_accumulators = []
3333

34-
def update_from(self, windowed_value):
34+
def update_from(self, windowed_value, coder=None):
3535
"""Add one value to this counter."""
3636
self.element_counter.update(1)
3737
byte_size_accumulator = Accumulator(self.mean_byte_counter.name)
3838
self._active_accumulators.append(byte_size_accumulator)
3939
# TODO(gildea):
4040
# Actually compute the encoded size of this value.
4141
# In spirit, something like this:
42-
# self.coder.store_estimated_size(windowed_value, byte_size_accumulator)
42+
# if coder is None:
43+
# coder = self.coder
44+
# coder.store_estimated_size(windowed_value, byte_size_accumulator)
4345
# but will need to do sampling.
4446

4547
def update_collect(self):

0 commit comments

Comments
 (0)