Skip to content
This repository was archived by the owner on Jun 30, 2022. It is now read-only.

Commit 2843cf9

Browse files
robertwbaaltay
authored andcommitted
Generalize base PTransform._extract_input_pvalues
The basecase now understands tuples and dicts of pvalues, which eases writing multi-input composite transforms. ----Release Notes---- [] ------------- Created by MOE: https://github.com/google/moe MOE_MIGRATED_REVID=121960544
1 parent f0467cb commit 2843cf9

2 files changed

Lines changed: 24 additions & 1 deletion

File tree

google/cloud/dataflow/transforms/ptransform.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -421,6 +421,8 @@ def _extract_input_pvalues(self, pvalueish):
421421
Returns pvalueish as well as the flat inputs list as the input may have to
422422
be copied as inspection may be destructive.
423423
424+
By default, recursively extracts tuple components and dict values.
425+
424426
Generally only needs to be overriden for multi-input PTransforms.
425427
"""
426428
# pylint: disable=g-import-not-at-top
@@ -429,7 +431,18 @@ def _extract_input_pvalues(self, pvalueish):
429431
if isinstance(pvalueish, pipeline.Pipeline):
430432
pvalueish = pvalue.PBegin(pvalueish)
431433

432-
return pvalueish, (pvalueish,)
434+
def _dict_tuple_leaves(pvalueish):
435+
if isinstance(pvalueish, tuple):
436+
for a in pvalueish:
437+
for p in _dict_tuple_leaves(a):
438+
yield p
439+
elif isinstance(pvalueish, dict):
440+
for a in pvalueish.values():
441+
for p in _dict_tuple_leaves(a):
442+
yield p
443+
else:
444+
yield pvalueish
445+
return pvalueish, tuple(_dict_tuple_leaves(pvalueish))
433446

434447

435448
class ChainedPTransform(PTransform):

google/cloud/dataflow/transforms/ptransform_test.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -516,6 +516,16 @@ def test_apply_to_list(self):
516516
self.assertEqual([('k', (['a'], ['b', 'c']))],
517517
join_input | df.CoGroupByKey('join'))
518518

519+
def test_multi_input_ptransform(self):
520+
class DisjointUnion(PTransform):
521+
def apply(self, pcollections):
522+
return (pcollections
523+
| df.Flatten()
524+
| df.Map(lambda x: (x, None))
525+
| df.GroupByKey()
526+
| df.Map(lambda (x, _): x))
527+
self.assertEqual([1, 2, 3], sorted(([1, 2], [2, 3]) | DisjointUnion()))
528+
519529
def test_apply_to_crazy_pvaluish(self):
520530
class NestedFlatten(PTransform):
521531
"""A PTransform taking and returning nested PValueish.

0 commit comments

Comments
 (0)