Skip to content
This repository was archived by the owner on Jun 30, 2022. It is now read-only.

Commit eb4f582

Browse files
charlesccychenaaltay
authored andcommitted
Don't use KV coder for ungrouped shuffle reads/writes
This fixes an issue where ungrouped reshuffles were incorrectly requiring KV coders. ----Release Notes---- [] ------------- Created by MOE: https://github.com/google/moe MOE_MIGRATED_REVID=117872394
1 parent a44f1e2 commit eb4f582

4 files changed

Lines changed: 40 additions & 26 deletions

File tree

google/cloud/dataflow/worker/executor.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import random
2323

2424

25+
from google.cloud.dataflow.coders import BytesCoder
2526
from google.cloud.dataflow.internal import pickler
2627
from google.cloud.dataflow.pvalue import EmptySideInput
2728
from google.cloud.dataflow.runners import common
@@ -278,8 +279,9 @@ def __init__(self, spec, shuffle_source=None):
278279
def start(self):
279280
super(GroupedShuffleReadOperation, self).start()
280281
if self.shuffle_source is None:
282+
coders = (self.spec.coder.key_coder(), self.spec.coder.value_coder())
281283
self.shuffle_source = shuffle.GroupedShuffleSource(
282-
self.spec.shuffle_reader_config, coder=self.spec.coders,
284+
self.spec.shuffle_reader_config, coder=coders,
283285
start_position=self.spec.start_shuffle_position,
284286
end_position=self.spec.end_shuffle_position)
285287
with self.shuffle_source.reader() as reader:
@@ -308,8 +310,9 @@ def __init__(self, spec, shuffle_source=None):
308310
def start(self):
309311
super(UngroupedShuffleReadOperation, self).start()
310312
if self.shuffle_source is None:
313+
coders = (BytesCoder(), self.spec.coder)
311314
self.shuffle_source = shuffle.UngroupedShuffleSource(
312-
self.spec.shuffle_reader_config, coder=self.spec.coders,
315+
self.spec.shuffle_reader_config, coder=coders,
313316
start_position=self.spec.start_shuffle_position,
314317
end_position=self.spec.end_shuffle_position)
315318
with self.shuffle_source.reader() as reader:
@@ -337,13 +340,17 @@ def __init__(self, spec, shuffle_sink=None):
337340

338341
def start(self):
339342
super(ShuffleWriteOperation, self).start()
340-
# TODO(silviuc): Shuffle 'kind' is ignored!
343+
self.is_ungrouped = self.spec.shuffle_kind == 'ungrouped'
344+
coder = self.spec.coder
345+
if self.is_ungrouped:
346+
coders = (BytesCoder(), coder)
347+
else:
348+
coders = (coder.key_coder(), coder.value_coder())
341349
if self.shuffle_sink is None:
342350
self.shuffle_sink = shuffle.ShuffleSink(
343-
self.spec.shuffle_writer_config, coder=self.spec.coders)
351+
self.spec.shuffle_writer_config, coder=coders)
344352
self.writer = self.shuffle_sink.writer()
345353
self.writer.__enter__()
346-
self.is_ungrouped = self.spec.shuffle_kind == 'ungrouped'
347354

348355
def finish(self):
349356
logging.debug('Finishing %s', self)

google/cloud/dataflow/worker/executor_test.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def reader(self):
9898

9999
class ExecutorTest(unittest.TestCase):
100100

101-
SHUFFLE_CODERS = (coders.PickleCoder(), coders.PickleCoder())
101+
SHUFFLE_CODER = coders.PickleCoder()
102102

103103
def create_temp_file(self, content_text):
104104
"""Creates a temporary file with content and returns the path to it."""
@@ -178,7 +178,7 @@ def test_read_do_shuffle_write(self):
178178
maptask.WorkerShuffleWrite(shuffle_kind='group_keys',
179179
shuffle_writer_config='none',
180180
input=(1, 0),
181-
coders=self.SHUFFLE_CODERS)
181+
coder=self.SHUFFLE_CODER)
182182
]
183183
shuffle_sink_mock = mock.MagicMock()
184184
executor.MapTaskExecutor().execute(
@@ -195,7 +195,7 @@ def test_shuffle_read_do_write(self):
195195
maptask.WorkerGroupingShuffleRead(shuffle_reader_config='none',
196196
start_shuffle_position='aaa',
197197
end_shuffle_position='zzz',
198-
coders=self.SHUFFLE_CODERS),
198+
coder=self.SHUFFLE_CODER),
199199
maptask.WorkerDoFn(serialized_fn=pickle_with_side_inputs(
200200
ptransform.CallableWrapperDoFn(
201201
lambda (k, vs): [str((k, v)) for v in vs])),
@@ -223,7 +223,7 @@ def test_ungrouped_shuffle_read_and_write(self):
223223
maptask.WorkerUngroupedShuffleRead(shuffle_reader_config='none',
224224
start_shuffle_position='aaa',
225225
end_shuffle_position='zzz',
226-
coders=self.SHUFFLE_CODERS),
226+
coder=self.SHUFFLE_CODER),
227227
maptask.WorkerWrite(
228228
fileio.TextFileSink(file_path_prefix=output_path,
229229
append_trailing_newlines=True,

google/cloud/dataflow/worker/maptask.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def worker_object_to_string(worker_object):
9797
WorkerGroupingShuffleRead = build_worker_instruction(
9898
'WorkerGroupingShuffleRead',
9999
['start_shuffle_position', 'end_shuffle_position',
100-
'shuffle_reader_config', 'coders'])
100+
'shuffle_reader_config', 'coder'])
101101
"""Worker details needed to read from a grouping shuffle source.
102102
103103
Attributes:
@@ -108,14 +108,14 @@ def worker_object_to_string(worker_object):
108108
shuffle_reader_config: An opaque string used to initialize the shuffle
109109
reader. Contains things like connection endpoints for the shuffle
110110
server appliance and various options.
111-
coders: A 2-tuple of coders (key, value) to decode shuffle entries.
111+
coder: The KV coder used to decode shuffle entries.
112112
"""
113113

114114

115115
WorkerUngroupedShuffleRead = build_worker_instruction(
116116
'WorkerUngroupedShuffleRead',
117117
['start_shuffle_position', 'end_shuffle_position',
118-
'shuffle_reader_config', 'coders'])
118+
'shuffle_reader_config', 'coder'])
119119
"""Worker details needed to read from an ungrouped shuffle source.
120120
121121
Attributes:
@@ -126,7 +126,7 @@ def worker_object_to_string(worker_object):
126126
shuffle_reader_config: An opaque string used to initialize the shuffle
127127
reader. Contains things like connection endpoints for the shuffle
128128
server appliance and various options.
129-
coders: A 2-tuple of coders (key, value) to decode shuffle entries.
129+
coder: The value coder used to decode shuffle entries.
130130
"""
131131

132132

@@ -160,7 +160,7 @@ def worker_object_to_string(worker_object):
160160

161161
WorkerShuffleWrite = build_worker_instruction(
162162
'WorkerShuffleWrite',
163-
['shuffle_kind', 'shuffle_writer_config', 'input', 'coders'])
163+
['shuffle_kind', 'shuffle_writer_config', 'input', 'coder'])
164164
"""Worker details needed to write to a shuffle sink.
165165
166166
Attributes:
@@ -173,7 +173,8 @@ def worker_object_to_string(worker_object):
173173
input: A (producer index, output index) tuple representing the
174174
ParallelInstruction operation whose output feeds into this operation.
175175
The output index is 0 except for multi-output operations (like ParDo).
176-
coders: A 2-tuple of coders (key, value) to encode shuffle entries.
176+
coder: The coder for input elements. If the shuffle_kind is grouping, this is
177+
expected to be a KV coder.
177178
"""
178179

179180

@@ -370,20 +371,23 @@ def get_read_work_item(work, env, context):
370371
if source:
371372
return WorkerRead(source, tag=None)
372373

373-
# TODO(mairbek) create Shuffler Source/Reader
374-
kv_coders = get_coder_from_spec(codec_specs, kv_pair=True)
374+
coder = get_coder_from_spec(codec_specs)
375+
# TODO(ccy): Reconcile WindowedValueCoder wrappings for sources with custom
376+
# coders so this special case won't be necessary.
377+
if isinstance(coder, coders.WindowedValueCoder):
378+
coder = coder.wrapped_value_coder
375379
if specs['@type'] == 'GroupingShuffleSource':
376380
return WorkerGroupingShuffleRead(
377381
start_shuffle_position=specs['start_shuffle_position']['value'],
378382
end_shuffle_position=specs['end_shuffle_position']['value'],
379383
shuffle_reader_config=specs['shuffle_reader_config']['value'],
380-
coders=kv_coders)
384+
coder=coder)
381385
elif specs['@type'] == 'UngroupedShuffleSource':
382386
return WorkerUngroupedShuffleRead(
383387
start_shuffle_position=specs['start_shuffle_position']['value'],
384388
end_shuffle_position=specs['end_shuffle_position']['value'],
385389
shuffle_reader_config=specs['shuffle_reader_config']['value'],
386-
coders=kv_coders)
390+
coder=coder)
387391
else:
388392
raise NotImplementedError('Unknown source type: %r' % specs)
389393

@@ -452,14 +456,17 @@ def get_write_work_item(work, env, context):
452456
sink = env.parse_sink(specs, codec_specs, context)
453457
if sink:
454458
return WorkerWrite(sink, input=get_input_spec(work.write.input))
455-
# TODO(mairbek) create Shuffler Sink/Writer
456459
if specs['@type'] == 'ShuffleSink':
457-
kv_coders = get_coder_from_spec(codec_specs, kv_pair=True)
460+
coder = get_coder_from_spec(codec_specs)
461+
# TODO(ccy): Reconcile WindowedValueCoder wrappings for sources with custom
462+
# coders so this special case won't be necessary.
463+
if isinstance(coder, coders.WindowedValueCoder):
464+
coder = coder.wrapped_value_coder
458465
return WorkerShuffleWrite(
459466
shuffle_kind=specs['shuffle_kind']['value'],
460467
shuffle_writer_config=specs['shuffle_writer_config']['value'],
461468
input=get_input_spec(work.write.input),
462-
coders=kv_coders)
469+
coder=coder)
463470
else:
464471
raise NotImplementedError('Unknown sink type: %r' % specs)
465472

google/cloud/dataflow/worker/workitem_test.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -413,7 +413,7 @@ def test_concat_source_to_shuffle_sink(self):
413413
shuffle_kind='group_keys',
414414
shuffle_writer_config='opaque',
415415
input=(1, 0),
416-
coders=(CODER.key_coder(), CODER.value_coder()))]))
416+
coder=CODER)]))
417417

418418
def test_text_source_to_shuffle_sink(self):
419419
work = workitem.get_work_items(get_text_source_to_shuffle_sink_message())
@@ -433,7 +433,7 @@ def test_text_source_to_shuffle_sink(self):
433433
shuffle_kind='group_keys',
434434
shuffle_writer_config='opaque',
435435
input=(1, 0),
436-
coders=(CODER.key_coder(), CODER.value_coder()))]))
436+
coder=CODER)]))
437437

438438
def test_shuffle_source_to_text_sink(self):
439439
work = workitem.get_work_items(
@@ -445,7 +445,7 @@ def test_shuffle_source_to_text_sink(self):
445445
start_shuffle_position='opaque',
446446
end_shuffle_position='opaque',
447447
shuffle_reader_config='opaque',
448-
coders=(CODER.key_coder(), CODER.value_coder())),
448+
coder=CODER),
449449
maptask.WorkerWrite(io.TextFileSink(
450450
file_path_prefix='gs://somefile',
451451
append_trailing_newlines=True,
@@ -461,7 +461,7 @@ def test_ungrouped_shuffle_source_to_text_sink(self):
461461
start_shuffle_position='opaque',
462462
end_shuffle_position='opaque',
463463
shuffle_reader_config='opaque',
464-
coders=(CODER.key_coder(), CODER.value_coder())),
464+
coder=CODER),
465465
maptask.WorkerWrite(io.TextFileSink(
466466
file_path_prefix='gs://somefile',
467467
append_trailing_newlines=True,

0 commit comments

Comments
 (0)