Skip to content
This repository was archived by the owner on Jun 30, 2022. It is now read-only.

Commit c586bac

Browse files
committed
Use shelve as a disk backed dictionary optionally in PValueCache
A new DirectRunner based DiskCachedPipelineRunner is introduced. Shelve will automatically spill dictionary entries into disk, reducing the memory requirement. For small pipelines the performance impact is minimal as it has an in memory cache for recent objects. For large pipelines that requires multiple pcollection object to be in the cache at the same time it prevents OOMs. However it will have a performance impact for large pipelines because of disk IO. Memory requirement of this new runner is capped by the single ptransform in the pipeline that consumes the largest total input (input + side inputs in aggregate). ----Release Notes---- [] ------------- Created by MOE: https://github.com/google/moe MOE_MIGRATED_REVID=123441597
1 parent f84b9d9 commit c586bac

4 files changed

Lines changed: 99 additions & 18 deletions

File tree

google/cloud/dataflow/pipeline_test.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
from google.cloud.dataflow.pipeline import PipelineVisitor
2525
from google.cloud.dataflow.pvalue import AsIter
2626
from google.cloud.dataflow.pvalue import SideOutputValue
27-
from google.cloud.dataflow.runners import DirectPipelineRunner
2827
from google.cloud.dataflow.transforms import CombinePerKey
2928
from google.cloud.dataflow.transforms import Create
3029
from google.cloud.dataflow.transforms import FlatMap
@@ -62,6 +61,9 @@ def reader(self):
6261

6362
class PipelineTest(unittest.TestCase):
6463

64+
def setUp(self):
65+
self.runner_name = 'DirectPipelineRunner'
66+
6567
@staticmethod
6668
def custom_callable(pcoll):
6769
return pcoll | FlatMap('+1', lambda x: [x + 1])
@@ -92,7 +94,7 @@ def leave_composite_transform(self, transform_node):
9294
self.leave_composite.append(transform_node)
9395

9496
def test_create(self):
95-
pipeline = Pipeline('DirectPipelineRunner')
97+
pipeline = Pipeline(self.runner_name)
9698
pcoll = pipeline | Create('label1', [1, 2, 3])
9799
assert_that(pcoll, equal_to([1, 2, 3]))
98100

@@ -103,20 +105,19 @@ def test_create(self):
103105
pipeline.run()
104106

105107
def test_create_singleton_pcollection(self):
106-
pipeline = Pipeline(DirectPipelineRunner())
108+
pipeline = Pipeline(self.runner_name)
107109
pcoll = pipeline | Create('label', [[1, 2, 3]])
108110
assert_that(pcoll, equal_to([[1, 2, 3]]))
109111
pipeline.run()
110112

111113
def test_read(self):
112-
pipeline = Pipeline('DirectPipelineRunner')
114+
pipeline = Pipeline(self.runner_name)
113115
pcoll = pipeline | Read('read', FakeSource([1, 2, 3]))
114116
assert_that(pcoll, equal_to([1, 2, 3]))
115117
pipeline.run()
116118

117119
def test_visit_entire_graph(self):
118-
119-
pipeline = Pipeline(DirectPipelineRunner())
120+
pipeline = Pipeline(self.runner_name)
120121
pcoll1 = pipeline | Create('pcoll', [1, 2, 3])
121122
pcoll2 = pcoll1 | FlatMap('do1', lambda x: [x + 1])
122123
pcoll3 = pcoll2 | FlatMap('do2', lambda x: [x + 1])
@@ -135,14 +136,14 @@ def test_visit_entire_graph(self):
135136
self.assertEqual(visitor.leave_composite[0].transform, transform)
136137

137138
def test_apply_custom_transform(self):
138-
pipeline = Pipeline(DirectPipelineRunner())
139+
pipeline = Pipeline(self.runner_name)
139140
pcoll = pipeline | Create('pcoll', [1, 2, 3])
140141
result = pcoll | PipelineTest.CustomTransform()
141142
assert_that(result, equal_to([2, 3, 4]))
142143
pipeline.run()
143144

144145
def test_reuse_custom_transform_instance(self):
145-
pipeline = Pipeline(DirectPipelineRunner())
146+
pipeline = Pipeline(self.runner_name)
146147
pcoll1 = pipeline | Create('pcoll1', [1, 2, 3])
147148
pcoll2 = pipeline | Create('pcoll2', [4, 5, 6])
148149
transform = PipelineTest.CustomTransform()
@@ -157,7 +158,7 @@ def test_reuse_custom_transform_instance(self):
157158
'transform.clone("NEW LABEL").')
158159

159160
def test_reuse_cloned_custom_transform_instance(self):
160-
pipeline = Pipeline(DirectPipelineRunner())
161+
pipeline = Pipeline(self.runner_name)
161162
pcoll1 = pipeline | Create('pcoll1', [1, 2, 3])
162163
pcoll2 = pipeline | Create('pcoll2', [4, 5, 6])
163164
transform = PipelineTest.CustomTransform()
@@ -168,7 +169,7 @@ def test_reuse_cloned_custom_transform_instance(self):
168169
pipeline.run()
169170

170171
def test_apply_custom_callable(self):
171-
pipeline = Pipeline('DirectPipelineRunner')
172+
pipeline = Pipeline(self.runner_name)
172173
pcoll = pipeline | Create('pcoll', [1, 2, 3])
173174
result = pipeline.apply(PipelineTest.custom_callable, pcoll)
174175
assert_that(result, equal_to([2, 3, 4]))
@@ -249,6 +250,20 @@ def test_eager_pipeline(self):
249250
self.assertEqual([1, 4, 9], p | Create([1, 2, 3]) | Map(lambda x: x*x))
250251

251252

253+
class DiskCachedRunnerPipelineTest(PipelineTest):
254+
255+
def setUp(self):
256+
self.runner_name = 'DiskCachedPipelineRunner'
257+
258+
def test_cached_pvalues_are_refcounted(self):
259+
# Takes long with disk spilling.
260+
pass
261+
262+
def test_eager_pipeline(self):
263+
# Tests eager runner only
264+
pass
265+
266+
252267
class Bacon(PipelineOptions):
253268

254269
@classmethod

google/cloud/dataflow/runners/direct_runner.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,10 @@ def __init__(self, cache=None):
7070
self.debug_counters = {}
7171
self.debug_counters['element_counts'] = collections.Counter()
7272

73+
@property
74+
def cache(self):
75+
return self._cache
76+
7377
def get_pvalue(self, pvalue):
7478
"""Gets the PValue's computed value from the runner's cache."""
7579
try:
@@ -285,3 +289,38 @@ def run_transform(self, transform):
285289
if transform not in self._seen_transforms:
286290
self._seen_transforms.add(transform)
287291
super(EagerPipelineRunner, self).run_transform(transform)
292+
293+
294+
class DiskCachedPipelineRunner(DirectPipelineRunner):
295+
"""A DirectPipelineRunner that uses a disk backed cache.
296+
297+
DiskCachedPipelineRunner uses a temporary disk backed cache for running
298+
pipelines. This allows for running pipelines that will require more memory
299+
than it is available, however this comes with a performance cost due to disk
300+
IO.
301+
302+
Memory requirement for DiskCachedPipelineRunner is approximately capped by the
303+
single transform in the pipeline that consumes and outputs largest total
304+
collection (i.e. inputs, side-inputs and outputs in aggregate). In the extreme
305+
case a where a transform will use all previous intermediate values as input,
306+
memory requirements for DiskCachedPipelineRunner will be the same as
307+
DirectPipelineRunner.
308+
"""
309+
310+
def __init__(self):
311+
self._null_cache = ()
312+
super(DiskCachedPipelineRunner, self).__init__(self._null_cache)
313+
314+
def run(self, pipeline):
315+
try:
316+
self._cache = PValueCache(use_disk_backed_cache=True)
317+
return super(DirectPipelineRunner, self).run(pipeline)
318+
finally:
319+
del self._cache
320+
self._cache = self._null_cache
321+
322+
@property
323+
def cache(self):
324+
raise NotImplementedError(
325+
'DiskCachedPipelineRunner does not keep cache outside the scope of its '
326+
'run method.')

google/cloud/dataflow/runners/runner.py

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@
1717
from __future__ import absolute_import
1818

1919
import logging
20+
import os
21+
import shelve
22+
import shutil
23+
import tempfile
2024

2125

2226
def create_runner(runner_name):
@@ -37,6 +41,10 @@ def create_runner(runner_name):
3741
if runner_name == 'DirectPipelineRunner':
3842
import google.cloud.dataflow.runners.direct_runner
3943
return google.cloud.dataflow.runners.direct_runner.DirectPipelineRunner()
44+
if runner_name == 'DiskCachedPipelineRunner':
45+
import google.cloud.dataflow.runners.direct_runner
46+
return google.cloud.dataflow.runners.direct_runner.DiskCachedPipelineRunner(
47+
)
4048
if runner_name == 'EagerPipelineRunner':
4149
import google.cloud.dataflow.runners.direct_runner
4250
return google.cloud.dataflow.runners.direct_runner.EagerPipelineRunner()
@@ -164,17 +172,32 @@ def run_transform(self, transform_node):
164172
class PValueCache(object):
165173
"""Local cache for arbitrary information computed for PValue objects."""
166174

167-
def __init__(self):
175+
def __init__(self, use_disk_backed_cache=False):
168176
# Cache of values computed while a runner executes a pipeline. This is a
169177
# dictionary of PValues and their computed values. Note that in principle
170178
# the runner could contain PValues from several pipelines without clashes
171179
# since a PValue is associated with one and only one pipeline. The keys of
172-
# the dictionary are PValue instance addresses obtained using id().
173-
self._cache = {}
180+
# the dictionary are tuple of PValue instance addresses obtained using id()
181+
# and tag names converted to strings.
182+
183+
self._use_disk_backed_cache = use_disk_backed_cache
184+
if use_disk_backed_cache:
185+
self._tempdir = tempfile.mkdtemp()
186+
self._cache = shelve.open(os.path.join(self._tempdir, 'shelve'))
187+
else:
188+
self._cache = {}
189+
190+
def __del__(self):
191+
if self._use_disk_backed_cache:
192+
self._cache.close()
193+
shutil.rmtree(self._tempdir)
174194

175195
def __len__(self):
176196
return len(self._cache)
177197

198+
def to_cache_key(self, transform, tag):
199+
return str((id(transform), tag))
200+
178201
def _ensure_pvalue_has_real_producer(self, pvalue):
179202
"""Ensure the passed-in PValue has the real_producer attribute.
180203
@@ -201,15 +224,16 @@ def is_cached(self, pobj):
201224
self._ensure_pvalue_has_real_producer(pobj)
202225
transform = pobj.real_producer
203226
tag = pobj.tag
204-
return (id(transform), tag) in self._cache
227+
return self.to_cache_key(transform, tag) in self._cache
205228

206229
def cache_output(self, transform, tag_or_value, value=None):
207230
if value is None:
208231
value = tag_or_value
209232
tag = None
210233
else:
211234
tag = tag_or_value
212-
self._cache[id(transform), tag] = [value, transform.refcounts[tag]]
235+
self._cache[
236+
self.to_cache_key(transform, tag)] = [value, transform.refcounts[tag]]
213237

214238
def get_pvalue(self, pvalue):
215239
"""Gets the value associated with a PValue from the cache."""
@@ -225,7 +249,7 @@ def get_pvalue(self, pvalue):
225249
return value_with_refcount[0]
226250
except KeyError:
227251
if (pvalue.tag is not None
228-
and (id(pvalue.real_producer), None) in self._cache):
252+
and self.to_cache_key(pvalue.real_producer, None) in self._cache):
229253
# This is an undeclared, empty side output of a DoFn executed
230254
# in the local runner before this side output referenced.
231255
return []
@@ -242,7 +266,7 @@ def clear_pvalue(self, pvalue):
242266

243267
def key(self, pobj):
244268
self._ensure_pvalue_has_real_producer(pobj)
245-
return id(pobj.real_producer), pobj.tag
269+
return self.to_cache_key(pobj.real_producer, pobj.tag)
246270

247271

248272
class PipelineState(object):

google/cloud/dataflow/transforms/ptransform.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -412,8 +412,11 @@ def __ror__(self, left):
412412
if deferred:
413413
return result
414414
else:
415+
# Get a reference to the runners internal cache, otherwise runner may
416+
# clean it after run.
417+
cache = p.runner.cache
415418
p.run()
416-
return _MaterializePValues(p.runner._cache).visit(result)
419+
return _MaterializePValues(cache).visit(result)
417420

418421
def _extract_input_pvalues(self, pvalueish):
419422
"""Extract all the pvalues contained in the input pvalueish.

0 commit comments

Comments
 (0)