Skip to content
This repository was archived by the owner on Jun 30, 2022. It is now read-only.

Commit 0a52fae

Browse files
charlesccychensilviulica
authored andcommitted
Fix incorrectly cached values in pvalue.AsList
----Release Notes---- Fixes an issue where a pipeline that used multiple list side inputs was constructed incorrectly. [] ------------- Created by MOE: https://github.com/google/moe MOE_MIGRATED_REVID=120977589
1 parent 69889f6 commit 0a52fae

2 files changed

Lines changed: 20 additions & 14 deletions

File tree

google/cloud/dataflow/pvalue.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -263,12 +263,12 @@ class ListPCollectionView(PCollectionView):
263263
pass
264264

265265

266-
def _get_cached_view(pcoll, key):
267-
return pcoll.pipeline._view_cache.get(key, None) # pylint: disable=protected-access
266+
def _get_cached_view(pipeline, key):
267+
return pipeline._view_cache.get(key, None) # pylint: disable=protected-access
268268

269269

270-
def _cache_view(pcoll, key, view):
271-
pcoll.pipeline._view_cache[key] = view # pylint: disable=protected-access
270+
def _cache_view(pipeline, key, view):
271+
pipeline._view_cache[key] = view # pylint: disable=protected-access
272272

273273

274274
def can_take_label_as_first_argument(callee):
@@ -333,7 +333,7 @@ def AsSingleton(pcoll, default_value=_SINGLETON_NO_DEFAULT, label=None): # pyli
333333
# Massage default value to treat as hash key.
334334
hashable_default_value = ('id', id(default_value))
335335
cache_key = (pcoll, AsSingleton, has_default, hashable_default_value)
336-
cached_view = _get_cached_view(pcoll, cache_key)
336+
cached_view = _get_cached_view(pcoll.pipeline, cache_key)
337337
if cached_view:
338338
return cached_view
339339

@@ -343,7 +343,7 @@ def AsSingleton(pcoll, default_value=_SINGLETON_NO_DEFAULT, label=None): # pyli
343343
from google.cloud.dataflow.transforms import sideinputs # pylint: disable=g-import-not-at-top
344344
view = (pcoll | sideinputs.ViewAsSingleton(has_default, default_value,
345345
label=label))
346-
_cache_view(pcoll, cache_key, view)
346+
_cache_view(pcoll.pipeline, cache_key, view)
347347
return view
348348

349349

@@ -365,7 +365,7 @@ def AsIter(pcoll, label=None): # pylint: disable=invalid-name
365365

366366
# Don't recreate the view if it was already created.
367367
cache_key = (pcoll, AsIter)
368-
cached_view = _get_cached_view(pcoll, cache_key)
368+
cached_view = _get_cached_view(pcoll.pipeline, cache_key)
369369
if cached_view:
370370
return cached_view
371371

@@ -374,7 +374,7 @@ def AsIter(pcoll, label=None): # pylint: disable=invalid-name
374374
# depend on pvalue, it lives in this module to reduce user workload.
375375
from google.cloud.dataflow.transforms import sideinputs # pylint: disable=g-import-not-at-top
376376
view = (pcoll | sideinputs.ViewAsIterable(label=label))
377-
_cache_view(pcoll, cache_key, view)
377+
_cache_view(pcoll.pipeline, cache_key, view)
378378
return view
379379

380380

@@ -395,8 +395,8 @@ def AsList(pcoll, label=None): # pylint: disable=invalid-name
395395
label = label or _format_view_label(pcoll)
396396

397397
# Don't recreate the view if it was already created.
398-
cache_key = AsList
399-
cached_view = _get_cached_view(pcoll, cache_key)
398+
cache_key = (pcoll, AsList)
399+
cached_view = _get_cached_view(pcoll.pipeline, cache_key)
400400
if cached_view:
401401
return cached_view
402402

@@ -405,7 +405,7 @@ def AsList(pcoll, label=None): # pylint: disable=invalid-name
405405
# depend on pvalue, it lives in this module to reduce user workload.
406406
from google.cloud.dataflow.transforms import sideinputs # pylint: disable=g-import-not-at-top
407407
view = (pcoll | sideinputs.ViewAsList(label=label))
408-
_cache_view(pcoll, cache_key, view)
408+
_cache_view(pcoll.pipeline, cache_key, view)
409409
return view
410410

411411

@@ -427,20 +427,21 @@ def AsDict(pcoll, label=None): # pylint: disable=invalid-name
427427
A singleton PCollectionView containing the dict as above.
428428
"""
429429
label = label or _format_view_label(pcoll)
430+
singleton_label = 'ToDict(%s)' % label
430431
combine_label = 'CombineToDict(%s)' % label
431432

432433
# Don't recreate the view if it was already created.
433434
cache_key = (pcoll, AsDict)
434-
cached_view = _get_cached_view(pcoll, cache_key)
435+
cached_view = _get_cached_view(pcoll.pipeline, cache_key)
435436
if cached_view:
436437
return cached_view
437438

438439
# Local import is required due to dependency loop; even though the
439440
# implementation of this function requires concepts defined in modules that
440441
# depend on pvalue, it lives in this module to reduce user workload.
441442
from google.cloud.dataflow.transforms import combiners # pylint: disable=g-import-not-at-top
442-
view = AsSingleton(label, pcoll | combiners.ToDict(combine_label))
443-
_cache_view(pcoll, cache_key, view)
443+
view = AsSingleton(singleton_label, pcoll | combiners.ToDict(combine_label))
444+
_cache_view(pcoll.pipeline, cache_key, view)
444445
return view
445446

446447

google/cloud/dataflow/pvalue_test.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,11 @@ def test_pcollectionview_not_recreated(self):
5353
self.assertEqual(AsList(value), AsList(value))
5454
self.assertEqual(AsDict(value2), AsDict(value2))
5555

56+
self.assertNotEqual(AsSingleton(value), AsSingleton(value2))
57+
self.assertNotEqual(AsIter(value), AsIter(value2))
58+
self.assertNotEqual(AsList(value), AsList(value2))
59+
self.assertNotEqual(AsDict(value), AsDict(value2))
60+
5661

5762
if __name__ == '__main__':
5863
unittest.main()

0 commit comments

Comments
 (0)