|
20 | 20 | from google.cloud.dataflow.utils.options import PipelineOptions |
21 | 21 |
|
22 | 22 |
|
23 | | -class SetupTest(unittest.TestCase): |
24 | | - |
25 | | - def test_get_unknown_args(self): |
26 | | - |
27 | | - # Used for testing newly added flags. |
28 | | - class MockOptions(PipelineOptions): |
29 | | - |
30 | | - @classmethod |
31 | | - def _add_argparse_args(cls, parser): |
32 | | - parser.add_argument('--mock_flag', |
33 | | - action='store_true', |
34 | | - help='Enable work item profiling') |
35 | | - |
36 | | - test_cases = [ |
37 | | - {'flags': ['--num_workers', '5'], |
38 | | - 'expected': {'num_workers': 5, 'mock_flag': False}}, |
39 | | - { |
40 | | - 'flags': [ |
41 | | - '--profile', '--profile_location', 'gs://bucket/', 'ignored'], |
42 | | - 'expected': { |
43 | | - 'profile': True, 'profile_location': 'gs://bucket/', |
44 | | - 'mock_flag': False} |
45 | | - }, |
46 | | - {'flags': ['--num_workers', '5', '--mock_flag'], |
47 | | - 'expected': {'num_workers': 5, 'mock_flag': True}}, |
48 | | - ] |
49 | | - |
50 | | - for case in test_cases: |
| 23 | +class PipelineOptionsTest(unittest.TestCase): |
| 24 | + |
| 25 | + TEST_CASES = [ |
| 26 | + {'flags': ['--num_workers', '5'], |
| 27 | + 'expected': {'num_workers': 5, 'mock_flag': False, 'mock_option': None}}, |
| 28 | + { |
| 29 | + 'flags': [ |
| 30 | + '--profile', '--profile_location', 'gs://bucket/', 'ignored'], |
| 31 | + 'expected': { |
| 32 | + 'profile': True, 'profile_location': 'gs://bucket/', |
| 33 | + 'mock_flag': False, 'mock_option': None} |
| 34 | + }, |
| 35 | + {'flags': ['--num_workers', '5', '--mock_flag'], |
| 36 | + 'expected': {'num_workers': 5, 'mock_flag': True, 'mock_option': None}}, |
| 37 | + {'flags': ['--mock_option', 'abc'], |
| 38 | + 'expected': {'mock_flag': False, 'mock_option': 'abc'}}, |
| 39 | + {'flags': ['--mock_option', ' abc def '], |
| 40 | + 'expected': {'mock_flag': False, 'mock_option': ' abc def '}}, |
| 41 | + {'flags': ['--mock_option= abc xyz '], |
| 42 | + 'expected': {'mock_flag': False, 'mock_option': ' abc xyz '}}, |
| 43 | + {'flags': ['--mock_option=gs://my bucket/my folder/my file'], |
| 44 | + 'expected': {'mock_flag': False, |
| 45 | + 'mock_option': 'gs://my bucket/my folder/my file'}}, |
| 46 | + ] |
| 47 | + |
| 48 | + # Used for testing newly added flags. |
| 49 | + class MockOptions(PipelineOptions): |
| 50 | + |
| 51 | + @classmethod |
| 52 | + def _add_argparse_args(cls, parser): |
| 53 | + parser.add_argument('--mock_flag', action='store_true', help='mock flag') |
| 54 | + parser.add_argument('--mock_option', help='mock option') |
| 55 | + parser.add_argument('--option with space', help='mock option with space') |
| 56 | + |
| 57 | + def test_get_all_options(self): |
| 58 | + for case in PipelineOptionsTest.TEST_CASES: |
51 | 59 | options = PipelineOptions(flags=case['flags']) |
52 | 60 | self.assertDictContainsSubset(case['expected'], options.get_all_options()) |
53 | | - self.assertEqual(options.view_as(MockOptions).mock_flag, |
| 61 | + self.assertEqual(options.view_as( |
| 62 | + PipelineOptionsTest.MockOptions).mock_flag, |
54 | 63 | case['expected']['mock_flag']) |
| 64 | + self.assertEqual(options.view_as( |
| 65 | + PipelineOptionsTest.MockOptions).mock_option, |
| 66 | + case['expected']['mock_option']) |
| 67 | + |
| 68 | + def test_from_dictionary(self): |
| 69 | + for case in PipelineOptionsTest.TEST_CASES: |
| 70 | + options = PipelineOptions(flags=case['flags']) |
| 71 | + all_options_dict = options.get_all_options() |
| 72 | + options_from_dict = PipelineOptions.from_dictionary(all_options_dict) |
| 73 | + self.assertEqual(options_from_dict.view_as( |
| 74 | + PipelineOptionsTest.MockOptions).mock_flag, |
| 75 | + case['expected']['mock_flag']) |
| 76 | + self.assertEqual(options.view_as( |
| 77 | + PipelineOptionsTest.MockOptions).mock_option, |
| 78 | + case['expected']['mock_option']) |
| 79 | + |
| 80 | + def test_option_with_spcae(self): |
| 81 | + options = PipelineOptions(flags=['--option with space= value with space']) |
| 82 | + self.assertEqual( |
| 83 | + getattr(options.view_as(PipelineOptionsTest.MockOptions), |
| 84 | + 'option with space'), ' value with space') |
| 85 | + options_from_dict = PipelineOptions.from_dictionary( |
| 86 | + options.get_all_options()) |
| 87 | + self.assertEqual( |
| 88 | + getattr(options_from_dict.view_as(PipelineOptionsTest.MockOptions), |
| 89 | + 'option with space'), ' value with space') |
| 90 | + |
| 91 | + def test_override_options(self): |
| 92 | + base_flags = ['--num_workers', '5'] |
| 93 | + options = PipelineOptions(base_flags) |
| 94 | + self.assertEqual(options.get_all_options()['num_workers'], 5) |
| 95 | + self.assertEqual(options.get_all_options()['mock_flag'], False) |
| 96 | + |
| 97 | + options.view_as(PipelineOptionsTest.MockOptions).mock_flag = True |
| 98 | + self.assertEqual(options.get_all_options()['num_workers'], 5) |
| 99 | + self.assertEqual(options.get_all_options()['mock_flag'], True) |
| 100 | + |
55 | 101 |
|
56 | 102 | if __name__ == '__main__': |
57 | 103 | logging.getLogger().setLevel(logging.INFO) |
|
0 commit comments