Skip to content
This repository was archived by the owner on Jun 30, 2022. It is now read-only.

Commit 9c6e898

Browse files
charlesccychenaaltay
authored andcommitted
Support large iterable side inputs
This change provides support in the DataflowPipelineRunner for large iterable side inputs that do not fit in memory. ----Release Notes---- [] ------------- Created by MOE: https://github.com/google/moe MOE_MIGRATED_REVID=121318997
1 parent c22a4c3 commit 9c6e898

7 files changed

Lines changed: 164 additions & 38 deletions

File tree

google/cloud/dataflow/pvalue.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,11 @@ class ListPCollectionView(PCollectionView):
263263
pass
264264

265265

266+
class DictPCollectionView(PCollectionView):
267+
"""A PCollectionView that can be treated as a dict."""
268+
pass
269+
270+
266271
def _get_cached_view(pipeline, key):
267272
return pipeline._view_cache.get(key, None) # pylint: disable=protected-access
268273

@@ -411,24 +416,20 @@ def AsList(pcoll, label=None): # pylint: disable=invalid-name
411416

412417
@can_take_label_as_first_argument
413418
def AsDict(pcoll, label=None): # pylint: disable=invalid-name
414-
"""Convenience function packaging an entire PCollection as a side input dict.
419+
"""Create a DictPCollectionView from the elements of input PCollection.
415420
416-
Intended for use in side-argument specification---the same places where
417-
AsSingleton and AsIter are used. Unlike those wrapper classes, AsDict (as
418-
implemented) is a function that schedules a Combiner to condense pcoll into a
419-
single dict, then wraps the resulting one-element PCollection in AsSingleton.
421+
The contents of the given PCollection whose elements are 2-tuples of key and
422+
value will be available as a dict-like object in PTransforms that use the
423+
returned PCollectionView as a side input.
420424
421425
Args:
422-
pcoll: Input pcollection. All elements should be key-value pairs (i.e.
423-
2-tuples) with unique keys.
426+
pcoll: Input pcollection containing 2-tuples of key and value.
424427
label: Label to be specified if several AsDict's for the same PCollection.
425428
426429
Returns:
427-
A singleton PCollectionView containing the dict as above.
430+
A dict PCollectionView containing the dict as above.
428431
"""
429432
label = label or _format_view_label(pcoll)
430-
singleton_label = 'ToDict(%s)' % label
431-
combine_label = 'CombineToDict(%s)' % label
432433

433434
# Don't recreate the view if it was already created.
434435
cache_key = (pcoll, AsDict)
@@ -439,8 +440,8 @@ def AsDict(pcoll, label=None): # pylint: disable=invalid-name
439440
# Local import is required due to dependency loop; even though the
440441
# implementation of this function requires concepts defined in modules that
441442
# depend on pvalue, it lives in this module to reduce user workload.
442-
from google.cloud.dataflow.transforms import combiners # pylint: disable=g-import-not-at-top
443-
view = AsSingleton(singleton_label, pcoll | combiners.ToDict(combine_label))
443+
from google.cloud.dataflow.transforms import sideinputs # pylint: disable=g-import-not-at-top
444+
view = (pcoll | sideinputs.ViewAsDict(label=label))
444445
_cache_view(pcoll.pipeline, cache_key, view)
445446
return view
446447

google/cloud/dataflow/runners/direct_runner.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from google.cloud.dataflow import coders
2929
from google.cloud.dataflow import error
3030
from google.cloud.dataflow.io import fileio
31+
from google.cloud.dataflow.pvalue import DictPCollectionView
3132
from google.cloud.dataflow.pvalue import EmptySideInput
3233
from google.cloud.dataflow.pvalue import IterablePCollectionView
3334
from google.cloud.dataflow.pvalue import ListPCollectionView
@@ -119,6 +120,8 @@ def run_CreatePCollectionView(self, transform_node):
119120
result = [v.value for v in values]
120121
elif isinstance(view, ListPCollectionView):
121122
result = [v.value for v in values]
123+
elif isinstance(view, DictPCollectionView):
124+
result = dict(v.value for v in values)
122125
else:
123126
raise NotImplementedError
124127

google/cloud/dataflow/transforms/sideinputs.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,3 +115,31 @@ def apply(self, pcoll):
115115
| CreatePCollectionView(pvalue.ListPCollectionView(pcoll.pipeline))
116116
.with_input_types(input_type)
117117
.with_output_types(output_type))
118+
119+
K = typehints.TypeVariable('K')
120+
V = typehints.TypeVariable('V')
121+
@typehints.with_input_types(typehints.Tuple[K, V])
122+
@typehints.with_output_types(typehints.Dict[K, V])
123+
class ViewAsDict(PTransform): # pylint: disable=g-wrong-blank-lines
124+
"""Transform to view PCollection as a dict PCollectionView.
125+
126+
Important: this transform is an implementation detail and should not be used
127+
directly by pipeline writers. Use pvalue.AsDict(...) instead.
128+
"""
129+
130+
def __init__(self, label=None):
131+
if label:
132+
label = 'ViewAsDict(%s)' % label
133+
super(ViewAsDict, self).__init__(label=label)
134+
135+
def apply(self, pcoll):
136+
self._check_pcollection(pcoll)
137+
input_type = pcoll.element_type
138+
key_type, value_type = (
139+
typehints.trivial_inference.key_value_types(input_type))
140+
output_type = typehints.Dict[key_type, value_type]
141+
return (pcoll
142+
| CreatePCollectionView(
143+
pvalue.DictPCollectionView(pcoll.pipeline))
144+
.with_input_types(input_type)
145+
.with_output_types(output_type))

google/cloud/dataflow/worker/executor.py

Lines changed: 27 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
from google.cloud.dataflow.worker import maptask
4444
from google.cloud.dataflow.worker import opcounters
4545
from google.cloud.dataflow.worker import shuffle
46+
from google.cloud.dataflow.worker import sideinputs
4647

4748

4849
class ReceiverSet(object):
@@ -201,14 +202,6 @@ def start(self):
201202
windowed_value = GlobalWindows.WindowedValue(value)
202203
self.output(windowed_value)
203204

204-
def side_read_all(self, singleton=False):
205-
# TODO(mairbek): Should we return WindowedValue here?
206-
with self.spec.source.reader() as reader:
207-
for value in reader:
208-
yield value
209-
if singleton:
210-
return
211-
212205
def request_dynamic_split(self, dynamic_split_request):
213206
if self._reader is not None:
214207
return self._reader.request_dynamic_split(dynamic_split_request)
@@ -423,36 +416,46 @@ def _read_side_inputs(self, tags_and_types):
423416
# specification. This can happen for instance if the source has been
424417
# sharded into several files.
425418
for side_tag, view_class, view_options in tags_and_types:
426-
# Note that currently, the implementation of Iterable and List views
427-
# are identical. This may change in the future once we allow very large
428-
# side input collections.
429-
is_singleton = view_class == pvalue.SingletonPCollectionView
419+
sources = []
430420
# Using the side_tag in the lambda below will trigger a pylint warning.
431421
# However in this case it is fine because the lambda is used right away
432422
# while the variable has the value assigned by the current iteration of
433423
# the for loop.
434424
# pylint: disable=cell-var-from-loop
435-
results = []
436425
for si in itertools.ifilter(
437426
lambda o: o.tag == side_tag, self.spec.side_inputs):
438-
if isinstance(si, maptask.WorkerSideInputSource):
439-
op = ReadOperation(si, self.counter_factory)
440-
else:
427+
if not isinstance(si, maptask.WorkerSideInputSource):
441428
raise NotImplementedError('Unknown side input type: %r' % si)
442-
for v in op.side_read_all(singleton=is_singleton):
443-
results.append(v)
444-
if is_singleton:
445-
break
446-
if is_singleton:
429+
sources.append(si.source)
430+
iterator_fn = sideinputs.get_iterator_fn_for_sources(sources)
431+
432+
if view_class == pvalue.SingletonPCollectionView:
447433
has_default, default = view_options
448-
if results:
449-
yield results[0]
434+
has_result = False
435+
result = None
436+
for v in iterator_fn():
437+
has_result = True
438+
result = v
439+
break
440+
if has_result:
441+
yield result
450442
elif has_default:
451443
yield default
452444
else:
453445
yield EmptySideInput()
446+
elif view_class == pvalue.IterablePCollectionView:
447+
yield sideinputs.EmulatedIterable(iterator_fn)
448+
elif view_class == pvalue.ListPCollectionView:
449+
# TODO(ccy): this is not yet suitable for lists that do not fit in
450+
# memory on a single machine.
451+
yield list(iterator_fn())
452+
elif view_class == pvalue.DictPCollectionView:
453+
# TODO(ccy): this is not yet suitable for dictionaries that do not fit
454+
# in memory on a single machine.
455+
yield dict(iterator_fn())
454456
else:
455-
yield results
457+
raise NotImplementedError('Unknown PCollectionView type: %s' %
458+
view_class)
456459

457460
def start(self):
458461
super(DoOperation, self).start()

google/cloud/dataflow/worker/executor_test.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -400,7 +400,9 @@ def test_create_do_with_singleton_side_bigquery_write(self):
400400
# Setup the reader so it will yield the values in 'side_elements'.
401401
reader_mock = mock_class.return_value
402402
reader_mock.__enter__.return_value = reader_mock
403-
reader_mock.__iter__.return_value = (x for x in side_elements)
403+
# Use a lambda so that multiple readers can be created, each reading the
404+
# entirety of the side elements.
405+
reader_mock.__iter__.side_effect = lambda: (x for x in side_elements)
404406

405407
pickled_elements = [pickler.dumps(e) for e in elements]
406408
executor.MapTaskExecutor().execute(make_map_task([
@@ -442,7 +444,9 @@ def test_create_do_with_collection_side_bigquery_write(self):
442444
# Setup the reader so it will yield the values in 'side_elements'.
443445
reader_mock = mock_class.return_value
444446
reader_mock.__enter__.return_value = reader_mock
445-
reader_mock.__iter__.return_value = (x for x in side_elements)
447+
# Use a lambda so that multiple readers can be created, each reading the
448+
# entirety of the side elements.
449+
reader_mock.__iter__.side_effect = lambda: (x for x in side_elements)
446450

447451
executor.MapTaskExecutor().execute(make_map_task([
448452
maptask.WorkerRead(
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# Copyright 2016 Google Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the 'License');
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an 'AS IS' BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Utilities for handling side inputs."""
16+
17+
import collections
18+
19+
20+
21+
def get_iterator_fn_for_sources(sources):
22+
"""Returns callable that returns iterator over elements for given sources."""
23+
def _inner():
24+
for source in sources:
25+
with source.reader() as reader:
26+
for value in reader:
27+
yield value
28+
return _inner
29+
30+
31+
class EmulatedIterable(collections.Iterable):
32+
"""Emulates an iterable for a side input."""
33+
34+
def __init__(self, iterator_fn):
35+
self.iterator_fn = iterator_fn
36+
37+
def __iter__(self):
38+
return self.iterator_fn()
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# Copyright 2016 Google Inc. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Tests for side input utilities."""
16+
17+
import logging
18+
import unittest
19+
20+
21+
from google.cloud.dataflow.worker import sideinputs
22+
23+
24+
class EmulatedCollectionsTest(unittest.TestCase):
25+
26+
def test_emulated_iterable(self):
27+
def _iterable_fn():
28+
for i in range(10):
29+
yield i
30+
iterable = sideinputs.EmulatedIterable(_iterable_fn)
31+
# Check that multiple iterations are supported.
32+
for _ in range(0, 5):
33+
for i, j in enumerate(iterable):
34+
self.assertEqual(i, j)
35+
36+
def test_large_iterable_values(self):
37+
def _iterable_fn():
38+
for i in range(10):
39+
yield ('%d' % i) * (200 * 1024 * 1024)
40+
iterable = sideinputs.EmulatedIterable(_iterable_fn)
41+
# Check that multiple iterations are supported.
42+
for _ in range(0, 3):
43+
for i, j in enumerate(iterable):
44+
self.assertEqual(('%d' % i) * (200 * 1024 * 1024), j)
45+
46+
47+
if __name__ == '__main__':
48+
logging.getLogger().setLevel(logging.INFO)
49+
unittest.main()

0 commit comments

Comments
 (0)