2323
2424
2525from google .cloud .dataflow .coders import BytesCoder
26+ from google .cloud .dataflow .coders import TupleCoder
27+ from google .cloud .dataflow .coders import WindowedValueCoder
2628from google .cloud .dataflow .internal import pickler
2729from google .cloud .dataflow .pvalue import EmptySideInput
2830from google .cloud .dataflow .runners import common
@@ -65,15 +67,15 @@ def start(self, step_name):
6567 self .opcounter = opcounters .OperationCounters (
6668 self .counter_factory , step_name , self .coder , self .output_index )
6769
68- def output (self , windowed_value ):
69- self .update_counters_start (windowed_value )
70+ def output (self , windowed_value , coder = None ):
71+ self .update_counters_start (windowed_value , coder )
7072 for receiver in self .receivers :
7173 receiver .process (windowed_value )
7274 self .update_counters_finish ()
7375
74- def update_counters_start (self , windowed_value ):
76+ def update_counters_start (self , windowed_value , coder = None ):
7577 if self .opcounter :
76- self .opcounter .update_from (windowed_value )
78+ self .opcounter .update_from (windowed_value , coder )
7779
7880 def update_counters_finish (self ):
7981 if self .opcounter :
@@ -130,8 +132,8 @@ def process(self, o):
130132 """Process element in operation."""
131133 pass
132134
133- def output (self , windowed_value , output_index = 0 ):
134- self .receivers [output_index ].output (windowed_value )
135+ def output (self , windowed_value , coder = None , output_index = 0 ):
136+ self .receivers [output_index ].output (windowed_value , coder )
135137
136138 def add_receiver (self , operation , output_index = 0 ):
137139 """Adds a receiver operation for the specified output."""
@@ -282,8 +284,10 @@ def __init__(self, spec, counter_factory, shuffle_source=None):
282284
283285 def start (self ):
284286 super (GroupedShuffleReadOperation , self ).start ()
287+ write_coder = None
285288 if self .shuffle_source is None :
286289 coders = (self .spec .coder .key_coder (), self .spec .coder .value_coder ())
290+ write_coder = WindowedValueCoder (TupleCoder (coders ))
287291 self .shuffle_source = shuffle .GroupedShuffleSource (
288292 self .spec .shuffle_reader_config , coder = coders ,
289293 start_position = self .spec .start_shuffle_position ,
@@ -292,7 +296,7 @@ def start(self):
292296 for key , key_values in reader :
293297 self ._reader = reader
294298 windowed_value = GlobalWindows .WindowedValue ((key , key_values ))
295- self .output (windowed_value )
299+ self .output (windowed_value , coder = write_coder )
296300
297301 def get_progress (self ):
298302 if self ._reader is not None :
@@ -313,8 +317,10 @@ def __init__(self, spec, counter_factory, shuffle_source=None):
313317
314318 def start (self ):
315319 super (UngroupedShuffleReadOperation , self ).start ()
320+ write_coder = None
316321 if self .shuffle_source is None :
317322 coders = (BytesCoder (), self .spec .coder )
323+ write_coder = WindowedValueCoder (TupleCoder (coders ))
318324 self .shuffle_source = shuffle .UngroupedShuffleSource (
319325 self .spec .shuffle_reader_config , coder = coders ,
320326 start_position = self .spec .start_shuffle_position ,
@@ -323,7 +329,7 @@ def start(self):
323329 for value in reader :
324330 self ._reader = reader
325331 windowed_value = GlobalWindows .WindowedValue (value )
326- self .output (windowed_value )
332+ self .output (windowed_value , coder = write_coder )
327333
328334 def get_progress (self ):
329335 # 'UngroupedShuffleReader' does not support progress reporting.
@@ -350,6 +356,7 @@ def start(self):
350356 coders = (BytesCoder (), coder )
351357 else :
352358 coders = (coder .key_coder (), coder .value_coder ())
359+ self ._write_coder = WindowedValueCoder (TupleCoder (coders ))
353360 if self .shuffle_sink is None :
354361 self .shuffle_sink = shuffle .ShuffleSink (
355362 self .spec .shuffle_writer_config , coder = coders )
@@ -364,7 +371,7 @@ def process(self, o):
364371 if self .debug_logging_enabled :
365372 logging .debug ('Processing [%s] in %s' , o , self )
366373 assert isinstance (o , WindowedValue )
367- self .receivers [0 ].update_counters_start (o )
374+ self .receivers [0 ].update_counters_start (o , coder = self . _write_coder )
368375 # We typically write into shuffle key/value pairs. This is the reason why
369376 # the else branch below expects the value attribute of the WindowedValue
370377 # argument to be a KV pair. However the service may write to shuffle in
0 commit comments