Skip to content
This repository was archived by the owner on Jun 30, 2022. It is now read-only.

Commit 30aac61

Browse files
committed
Add reference counting for consumers of AppliedPTransform outputs
This is used by DirectPipelineRunner to delete cached values aggressively after all their respective consumers have used them. Without such a feature the runner can get into out of memory situations. ----Release Notes---- Improve memory footprint for DirrectPipelineRunner. [] ------------- Created by MOE: https://github.com/google/moe MOE_MIGRATED_REVID=118410157
1 parent d7a2b0a commit 30aac61

3 files changed

Lines changed: 80 additions & 5 deletions

File tree

google/cloud/dataflow/pipeline.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939

4040
from __future__ import absolute_import
4141

42+
import collections
4243
import logging
4344
import os
4445
import shutil
@@ -292,6 +293,7 @@ def apply(self, transform, pvalueish=None):
292293
'output type-hint was found for the '
293294
'PTransform %s' % ptransform_name)
294295

296+
child.update_input_refcounts()
295297
self.transforms_stack.pop()
296298
return pvalueish_result
297299

@@ -357,6 +359,26 @@ def __init__(self, parent, transform, full_label, inputs):
357359
self.outputs = []
358360
self.parts = []
359361

362+
# Per tag refcount dictionary for PValues for which this node is a
363+
# root producer.
364+
self.refcounts = collections.defaultdict(int)
365+
366+
def update_input_refcounts(self):
367+
"""Increment refcounts for all transforms providing inputs."""
368+
369+
def real_producer(pv):
370+
real = pv.producer
371+
while real.parts:
372+
real = real.parts[-1]
373+
return real
374+
375+
if not self.is_composite():
376+
for main_input in self.inputs:
377+
if not isinstance(main_input, pvalue.PBegin):
378+
real_producer(main_input).refcounts[main_input.tag] += 1
379+
for side_input in self.side_inputs:
380+
real_producer(side_input.pvalue).refcounts[side_input.pvalue.tag] += 1
381+
360382
def add_output(self, output):
361383
assert (isinstance(output, pvalue.PValue) or
362384
isinstance(output, pvalue.DoOutputsTuple))

google/cloud/dataflow/pipeline_test.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,21 @@
1414

1515
"""Unit tests for the Pipeline class."""
1616

17+
import gc
18+
import logging
1719
import unittest
1820

1921
from google.cloud.dataflow.io.iobase import Source
2022
from google.cloud.dataflow.pipeline import Pipeline
2123
from google.cloud.dataflow.pipeline import PipelineOptions
2224
from google.cloud.dataflow.pipeline import PipelineVisitor
25+
from google.cloud.dataflow.pvalue import AsIter
26+
from google.cloud.dataflow.pvalue import SideOutputValue
2327
from google.cloud.dataflow.runners import DirectPipelineRunner
28+
from google.cloud.dataflow.transforms import CombinePerKey
2429
from google.cloud.dataflow.transforms import Create
2530
from google.cloud.dataflow.transforms import FlatMap
31+
from google.cloud.dataflow.transforms import Flatten
2632
from google.cloud.dataflow.transforms import Map
2733
from google.cloud.dataflow.transforms import PTransform
2834
from google.cloud.dataflow.transforms import Read
@@ -194,6 +200,47 @@ def apply(self, pcoll):
194200
['a-x', 'b-x', 'c-x'],
195201
sorted(['a', 'b', 'c'] | AddSuffix('-x')))
196202

203+
def test_cached_pvalues_are_refcounted(self):
204+
"""Test that cached PValues are refcounted and deleted.
205+
206+
The intermediary PValues computed by the workflow below contain
207+
one million elements so if the refcounting does not work the number of
208+
objects tracked by the garbage collector will increase by a few millions
209+
by the time we execute the final Map checking the objects tracked.
210+
Anything that is much larger than what we started with will fail the test.
211+
"""
212+
def check_memory(value, count_threshold):
213+
gc.collect()
214+
objects_count = len(gc.get_objects())
215+
if objects_count > count_threshold:
216+
raise RuntimeError(
217+
'PValues are not refcounted: %s, %s' % (
218+
objects_count, count_threshold))
219+
return value
220+
221+
def create_dupes(o, _):
222+
yield o
223+
yield SideOutputValue('side', o)
224+
225+
pipeline = Pipeline('DirectPipelineRunner')
226+
227+
gc.collect()
228+
count_threshold = len(gc.get_objects()) + 10000
229+
biglist = pipeline | Create('oom:create', ['x'] * 1000000)
230+
dupes = (
231+
biglist
232+
| Map('oom:addone', lambda x: (x, 1))
233+
| FlatMap('oom:dupes', create_dupes,
234+
AsIter(biglist)).with_outputs('side', main='main'))
235+
result = (
236+
(dupes.side, dupes.main, dupes.side)
237+
| Flatten('oom:flatten')
238+
| CombinePerKey('oom:combine', sum)
239+
| Map('oom:check', check_memory, count_threshold))
240+
241+
assert_that(result, equal_to([('x', 3000000)]))
242+
pipeline.run()
243+
197244

198245
class Bacon(PipelineOptions):
199246

@@ -264,4 +311,5 @@ def test_dir(self):
264311

265312

266313
if __name__ == '__main__':
314+
logging.getLogger().setLevel(logging.INFO)
267315
unittest.main()

google/cloud/dataflow/runners/runner.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -184,35 +184,40 @@ def _ensure_pvalue_has_real_producer(self, pvalue):
184184
composite transform we need to find the output of its rightmost transform
185185
part.
186186
"""
187-
if not hasattr(pvalue, 'read_producer'):
187+
if not hasattr(pvalue, 'real_producer'):
188188
real_producer = pvalue.producer
189189
while real_producer.parts:
190190
real_producer = real_producer.parts[-1]
191191
pvalue.real_producer = real_producer
192192

193193
def is_cached(self, pobj):
194-
# Import here to avoid circular dependencies.
195194
from google.cloud.dataflow.pipeline import AppliedPTransform
196195
if isinstance(pobj, AppliedPTransform):
197196
transform = pobj
197+
tag = None
198198
else:
199199
self._ensure_pvalue_has_real_producer(pobj)
200200
transform = pobj.real_producer
201-
return (id(transform), None) in self._cache
201+
tag = pobj.tag
202+
return (id(transform), tag) in self._cache
202203

203204
def cache_output(self, transform, tag_or_value, value=None):
204205
if value is None:
205206
value = tag_or_value
206207
tag = None
207208
else:
208209
tag = tag_or_value
209-
self._cache[id(transform), tag] = value
210+
self._cache[id(transform), tag] = [value, transform.refcounts[tag]]
210211

211212
def get_pvalue(self, pvalue):
212213
"""Gets the value associated with a PValue from the cache."""
213214
self._ensure_pvalue_has_real_producer(pvalue)
214215
try:
215-
return self._cache[self.key(pvalue)]
216+
value_with_refcount = self._cache[self.key(pvalue)]
217+
value_with_refcount[1] -= 1
218+
if value_with_refcount[1] <= 0:
219+
self.clear_pvalue(pvalue)
220+
return value_with_refcount[0]
216221
except KeyError:
217222
if (pvalue.tag is not None
218223
and (id(pvalue.real_producer), None) in self._cache):

0 commit comments

Comments
 (0)