Skip to content
This repository was archived by the owner on Jun 30, 2022. It is now read-only.

Commit 493afc5

Browse files
robertwbgildea
authored andcommitted
Implement continuous combining in pre-shuffle combining table.
This is particularly useful for global combines. Fixed counters to be updated exactly once per output, rather than once for every consumer. Also optimized the callable combine fn and fixed tracebacks for the DoFn runner. ----Release Notes---- [] ------------- Created by MOE: https://github.com/google/moe MOE_MIGRATED_REVID=117146623
1 parent 5505162 commit 493afc5

4 files changed

Lines changed: 116 additions & 69 deletions

File tree

google/cloud/dataflow/runners/common.py

Lines changed: 12 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -93,25 +93,18 @@ def process(self, element):
9393
self.context.set_element(element)
9494
self._process_outputs(element, self.dofn.process(self.context))
9595
except BaseException as exn:
96-
raise self.augment_exception(exn)
97-
98-
def augment_exception(self, exn):
99-
try:
100-
if getattr(exn, '_tagged_with_step', False) or not self.step_name:
101-
return exn
102-
args = exn.args
103-
if args and isinstance(args[0], str):
104-
args = (args[0] + " [while running '%s']" % self.step_name,) + args[1:]
105-
# Poor man's exception chaining.
106-
try:
107-
raise type(exn), args, sys.exc_info()[2]
108-
except BaseException as new_exn:
109-
new_exn._tagged_with_step = True
110-
return new_exn
111-
else:
112-
return exn
113-
except:
114-
return exn
96+
self.reraise_augmented(exn)
97+
98+
def reraise_augmented(self, exn):
99+
if getattr(exn, '_tagged_with_step', False) or not self.step_name:
100+
raise
101+
args = exn.args
102+
if args and isinstance(args[0], str):
103+
args = (args[0] + " [while running '%s']" % self.step_name,) + args[1:]
104+
# Poor man's exception chaining.
105+
raise type(exn), args, sys.exc_info()[2]
106+
else:
107+
raise
115108

116109
def _process_outputs(self, element, results):
117110
"""Dispatch the result of computation to the appropriate receivers.

google/cloud/dataflow/transforms/combiners.py

Lines changed: 26 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -452,37 +452,42 @@ def extract_output(self, accumulator):
452452
return accumulator
453453

454454

455-
class PhasedCombineFnExecutor(object):
456-
"""Executor for phases of combine operations."""
455+
def curry_combine_fn(fn, args, kwargs):
456+
if not args and not kwargs:
457+
return fn
457458

458-
def __init__(self, phase, fn, args, kwargs):
459+
else:
459460

460-
if not args and not kwargs:
461-
self.combine_fn = fn
462-
else:
461+
class CurriedFn(core.CombineFn):
462+
"""CombineFn that applies extra arguments."""
463+
464+
def create_accumulator(self):
465+
return fn.create_accumulator(*args, **kwargs)
463466

464-
class CurriedFn(core.CombineFn):
465-
"""CombineFn that applies extra arguments."""
467+
def add_input(self, accumulator, element):
468+
return fn.add_input(accumulator, element, *args, **kwargs)
466469

467-
def create_accumulator(self):
468-
return fn.create_accumulator(*args, **kwargs)
470+
def add_inputs(self, accumulator, elements):
471+
return fn.add_inputs(accumulator, elements, *args, **kwargs)
469472

470-
def add_input(self, accumulator, element):
471-
return fn.add_input(accumulator, element, *args, **kwargs)
473+
def merge_accumulators(self, accumulators):
474+
return fn.merge_accumulators(accumulators, *args, **kwargs)
472475

473-
def add_inputs(self, accumulator, elements):
474-
return fn.add_inputs(accumulator, elements, *args, **kwargs)
476+
def extract_output(self, accumulator):
477+
return fn.extract_output(accumulator, *args, **kwargs)
475478

476-
def merge_accumulators(self, accumulators):
477-
return fn.merge_accumulators(accumulators, *args, **kwargs)
479+
def apply(self, elements):
480+
return fn.apply(elements, *args, **kwargs)
478481

479-
def extract_output(self, accumulator):
480-
return fn.extract_output(accumulator, *args, **kwargs)
482+
return CurriedFn()
481483

482-
def apply(self, elements):
483-
return fn.apply(elements, *args, **kwargs)
484484

485-
self.combine_fn = CurriedFn()
485+
class PhasedCombineFnExecutor(object):
486+
"""Executor for phases of combine operations."""
487+
488+
def __init__(self, phase, fn, args, kwargs):
489+
490+
self.combine_fn = curry_combine_fn(fn, args, kwargs)
486491

487492
if phase == 'all':
488493
self.apply = self.full_combine

google/cloud/dataflow/transforms/core.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -384,17 +384,22 @@ def create_accumulator(self, *args, **kwargs):
384384
return self._EMPTY
385385

386386
def add_input(self, accumulator, element, *args, **kwargs):
387-
return self.add_inputs([element], *args, **kwargs)
387+
if accumulator is self._EMPTY:
388+
return element
389+
else:
390+
return self._fn([accumulator, element], *args, **kwargs)
388391

389392
def add_inputs(self, accumulator, elements, *args, **kwargs):
390-
if accumulator is not self._EMPTY:
393+
if accumulator is self._EMPTY:
394+
return self._fn(elements, *args, **kwargs)
395+
elif isinstance(elements, (list, tuple)):
396+
return self._fn([accumulator] + elements, *args, **kwargs)
397+
else:
391398
def union():
392399
yield accumulator
393400
for e in elements:
394401
yield e
395-
else:
396-
union = lambda: elements
397-
return self._fn(union(), *args, **kwargs)
402+
return self._fn(union(), *args, **kwargs)
398403

399404
def merge_accumulators(self, accumulators, *args, **kwargs):
400405
# It's (weakly) assumed that self._fn is associative.

google/cloud/dataflow/worker/executor.py

Lines changed: 68 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
# cython: profile=True
16+
1517
"""Worker operations executor."""
1618

1719
import collections
@@ -24,8 +26,10 @@
2426
from google.cloud.dataflow.pvalue import EmptySideInput
2527
from google.cloud.dataflow.runners import common
2628
import google.cloud.dataflow.transforms as ptransform
29+
from google.cloud.dataflow.transforms import combiners
2730
from google.cloud.dataflow.transforms import trigger
2831
from google.cloud.dataflow.transforms import window
32+
from google.cloud.dataflow.transforms.combiners import curry_combine_fn
2933
from google.cloud.dataflow.transforms.combiners import PhasedCombineFnExecutor
3034
from google.cloud.dataflow.transforms.trigger import InMemoryUnmergedState
3135
from google.cloud.dataflow.transforms.window import GlobalWindows
@@ -147,8 +151,8 @@ def start(self):
147151
windowed_value = value
148152
else:
149153
windowed_value = GlobalWindows.WindowedValue(value)
154+
self.counters[0].update(windowed_value)
150155
for receiver in self.receivers[0]:
151-
self.counters[0].update(windowed_value)
152156
receiver.process(windowed_value)
153157

154158
def side_read_all(self, singleton=False):
@@ -239,9 +243,9 @@ def start(self):
239243
with self.shuffle_source.reader() as reader:
240244
for key, key_values in reader:
241245
self._reader = reader
246+
windowed_value = GlobalWindows.WindowedValue((key, key_values))
247+
self.counters[0].update(windowed_value)
242248
for receiver in self.receivers[0]:
243-
windowed_value = GlobalWindows.WindowedValue((key, key_values))
244-
self.counters[0].update(windowed_value)
245249
receiver.process(windowed_value)
246250

247251
def get_progress(self):
@@ -271,9 +275,9 @@ def start(self):
271275
with self.shuffle_source.reader() as reader:
272276
for value in reader:
273277
self._reader = reader
278+
windowed_value = GlobalWindows.WindowedValue(value)
279+
self.counters[0].update(windowed_value)
274280
for receiver in self.receivers[0]:
275-
windowed_value = GlobalWindows.WindowedValue(value)
276-
self.counters[0].update(windowed_value)
277281
receiver.process(windowed_value)
278282

279283
def get_progress(self):
@@ -463,11 +467,18 @@ def process(self, o):
463467
key, values = o.value
464468
windowed_result = WindowedValue(
465469
(key, self.phased_combine_fn.apply(values)), o.timestamp, o.windows)
470+
self.counters[0].update(windowed_result)
466471
for receiver in self.receivers[0]:
467-
self.counters[0].update(windowed_result)
468472
receiver.process(windowed_result)
469473

470474

475+
def create_pgbk_op(spec):
476+
if spec.combine_fn:
477+
return PGBKCVOperation(spec)
478+
else:
479+
return PGBKOperation(spec)
480+
481+
471482
class PGBKOperation(Operation):
472483
"""Partial group-by-key operation.
473484
@@ -478,16 +489,7 @@ class PGBKOperation(Operation):
478489

479490
def __init__(self, spec):
480491
super(PGBKOperation, self).__init__(spec)
481-
self.phased_combine_fn = None
482-
if self.spec.combine_fn:
483-
# Combiners do not accept deferred side-inputs (the ignored fourth
484-
# argument) and therefore the code to handle the extra args/kwargs is
485-
# simpler than for the DoFn's of ParDo.
486-
#
487-
# TODO(ccy): Combine as we go for each key instead of storing up state
488-
# for combination when flushing.
489-
fn, args, kwargs = pickler.loads(self.spec.combine_fn)[:3]
490-
self.phased_combine_fn = PhasedCombineFnExecutor('add', fn, args, kwargs)
492+
assert not self.spec.combine_fn
491493
self.table = collections.defaultdict(list)
492494
self.size = 0
493495
# TODO(robertwb) Make this configurable.
@@ -512,16 +514,58 @@ def flush(self, target):
512514
del self.table[kw]
513515
key, windows = kw
514516
output_value = [v.value[1] for v in vs]
515-
if self.phased_combine_fn:
516-
output_value = self.phased_combine_fn.apply(output_value)
517517
windowed_value = WindowedValue(
518518
(key, output_value),
519519
vs[0].timestamp, windows)
520+
self.counters[0].update(windowed_value)
520521
for receiver in self.receivers[0]:
521-
self.counters[0].update(windowed_value)
522522
receiver.process(windowed_value)
523523

524524

525+
class PGBKCVOperation(Operation):
526+
527+
def __init__(self, spec):
528+
super(PGBKCVOperation, self).__init__(spec)
529+
# Combiners do not accept deferred side-inputs (the ignored fourth
530+
# argument) and therefore the code to handle the extra args/kwargs is
531+
# simpler than for the DoFn's of ParDo.
532+
fn, args, kwargs = pickler.loads(self.spec.combine_fn)[:3]
533+
self.combine_fn = curry_combine_fn(fn, args, kwargs)
534+
# Optimization for the (known tiny accumulator, often wide keyspace)
535+
# count function.
536+
# TODO(robertwb): Bound by in-memory size rather than key count.
537+
self.max_keys = (
538+
1000000 if isinstance(fn, combiners.CountCombineFn) else 10000)
539+
self.key_count = 0
540+
self.table = {}
541+
542+
def process(self, wkv):
543+
key, value = wkv.value
544+
wkey = tuple(wkv.windows), key
545+
entry = self.table.get(wkey, None)
546+
if entry is None:
547+
if self.key_count >= self.max_keys:
548+
old_wkey = self.table.iterkeys().next() # Any key, could use LRU
549+
self.output(old_wkey, self.table.pop(old_wkey)[0])
550+
else:
551+
self.key_count += 1
552+
entry = self.table[wkey] = [self.combine_fn.create_accumulator()]
553+
entry[0] = self.combine_fn.add_inputs(entry[0], [value])
554+
555+
def finish(self):
556+
for wkey, value in self.table.iteritems():
557+
self.output(wkey, value[0])
558+
self.entries = {}
559+
self.key_count = 0
560+
561+
def output(self, wkey, value):
562+
windows, key = wkey
563+
windowed_value = WindowedValue((key, value), windows[0].end, windows)
564+
self.counters[0].update(windowed_value)
565+
for receiver in self.receivers[0]:
566+
receiver.process(windowed_value)
567+
568+
525569
class FlattenOperation(Operation):
526570
"""Flatten operation.
527571
@@ -533,8 +577,8 @@ def process(self, o):
533577
logging.debug('Processing [%s] in %s', o, self)
534578
assert isinstance(o, WindowedValue)
535579
windowed_result = WindowedValue(o.value, o.timestamp, o.windows)
580+
self.counters[0].update(windowed_result)
536581
for receiver in self.receivers[0]:
537-
self.counters[0].update(windowed_result)
538582
receiver.process(windowed_result)
539583

540584

@@ -559,8 +603,8 @@ def process(self, o):
559603
o.timestamp, o.windows))
560604

561605
def output(self, windowed_result):
606+
self.counters[0].update(windowed_result)
562607
for receiver in self.receivers[0]:
563-
self.counters[0].update(windowed_result)
564608
receiver.process(windowed_result)
565609

566610

@@ -608,8 +652,8 @@ def process(self, o):
608652
window.WindowedValue((k, values), timestamp, [out_window]))
609653

610654
def output(self, windowed_result):
655+
self.counters[0].update(windowed_result)
611656
for receiver in self.receivers[0]:
612-
self.counters[0].update(windowed_result)
613657
receiver.process(windowed_result)
614658

615659

@@ -646,8 +690,8 @@ def process(self, o):
646690
[out_window]))
647691

648692
def output(self, windowed_result):
693+
self.counters[0].update(windowed_result)
649694
for receiver in self.receivers[0]:
650-
self.counters[0].update(windowed_result)
651695
receiver.process(windowed_result)
652696

653697

@@ -706,7 +750,7 @@ def execute(self, map_task, test_shuffle_source=None, test_shuffle_sink=None):
706750
elif isinstance(spec, maptask.WorkerCombineFn):
707751
op = CombineOperation(spec)
708752
elif isinstance(spec, maptask.WorkerPartialGroupByKey):
709-
op = PGBKOperation(spec)
753+
op = create_pgbk_op(spec)
710754
elif isinstance(spec, maptask.WorkerDoFn):
711755
op = DoOperation(spec)
712756
elif isinstance(spec, maptask.WorkerGroupingShuffleRead):

0 commit comments

Comments
 (0)