Skip to content

Commit e4c74fc

Browse files
authored
Merge pull request #7 from mynhardtburger/typing
Add type hints
2 parents 17b9e5b + a1531b7 commit e4c74fc

1 file changed

Lines changed: 99 additions & 74 deletions

File tree

aconfig/aconfig.py

Lines changed: 99 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,27 @@
1-
#*****************************************************************#
1+
# ****************************************************************#
22
# (C) Copyright IBM Corporation 2020. #
33
# #
44
# The source code for this program is not published or otherwise #
55
# divested of its trade secrets, irrespective of what has been #
66
# deposited with the U.S. Copyright Office. #
7-
#*****************************************************************#
8-
'''Handle all config-based operations.
9-
'''
7+
# ****************************************************************#
8+
"""Handle all config-based operations.
9+
"""
1010

11+
import copy
1112
import os
1213
import re
13-
import copy
14-
import typing
14+
from typing import Any, Dict, List, NoReturn, Optional, Type, Union
1515

16-
from yaml.representer import SafeRepresenter
1716
import yaml
17+
from yaml.representer import SafeRepresenter
18+
1819

20+
class AttributeAccessDict(Dict[str, Any]):
21+
"""Wrapper around Python dict to make it accessible like an object."""
1922

20-
class AttributeAccessDict(dict):
21-
'''Wrapper around Python dict to make it accessible like an object.
22-
'''
23-
def __init__(self, input_map):
24-
'''Recursively assign attribute access and call parent __init__ as well.
23+
def __init__(self, input_map: Dict[str, Any]):
24+
"""Recursively assign attribute access and call parent __init__ as well.
2525
2626
Args:
2727
input_map: dict
@@ -33,10 +33,13 @@ def __init__(self, input_map):
3333
Instance of class that allows for attribute-like access or Python dict-like, and
3434
overrides dict's methods to enable this. Can be modified later on and keep the same
3535
behavior.
36-
'''
36+
"""
3737
if not isinstance(input_map, dict):
38-
raise TypeError('`input_map` argument should be of type dict, but found type: <{0}>'.format(
39-
type(input_map)))
38+
raise TypeError(
39+
"`input_map` argument should be of type dict, but found type: <{0}>".format(
40+
type(input_map)
41+
)
42+
)
4043

4144
# copy so as not to modify passed in dictionary
4245
copied_map = copy.deepcopy(input_map)
@@ -49,7 +52,9 @@ def __init__(self, input_map):
4952
super().__init__(**copied_map)
5053

5154
@classmethod
52-
def _make_attribute_access_dict(cls, value):
55+
def _make_attribute_access_dict(
56+
cls, value: Any
57+
) -> Union["AttributeAccessDict", List[Any], Any]:
5358
"""Recursively walk down any `dict`s or `list`s and build attribute access dicts
5459
🌶️: This is a classmethod so that inheritance is respected.
5560
🌶️🌶️🌶️: We don't call the `cls` initializer directly for the recursion, because we
@@ -67,36 +72,36 @@ def _make_attribute_access_dict(cls, value):
6772
return value
6873

6974
@classmethod
70-
def _recursive_dict_class(cls) -> typing.Type['AttributeAccessDict']:
75+
def _recursive_dict_class(cls) -> Type["AttributeAccessDict"]:
7176
"""Returns the class to be used to recursively build the config object"""
7277
return AttributeAccessDict
7378

7479
# BELOW MAKES INSTANCE ACCESSIBLE VIA NATIVE PYTHON DICT METHODS ###############################
7580

76-
def __getattr__(self, key, default=None):
81+
def __getattr__(self, key: str, default: Any = None) -> Any:
7782
return super().get(key, default)
7883

79-
def __setattr__(self, key, value):
84+
def __setattr__(self, key: str, value: Any) -> None:
8085
if isinstance(value, AttributeAccessDict):
8186
value = value
8287
elif isinstance(value, dict):
8388
value = AttributeAccessDict(value)
8489
super().__setitem__(key, value)
8590

86-
def __setitem__(self, key, value):
91+
def __setitem__(self, key: str, value: Any) -> None:
8792
if isinstance(value, dict):
8893
value = AttributeAccessDict(value)
8994
super().__setitem__(key, value)
9095

91-
def __delattr__(self, key):
96+
def __delattr__(self, key: str) -> None:
9297
super().__delitem__(key)
9398

9499
# ABOVE MAKES INSTANCE ACCESSIBLE VIA NATIVE PYTHON DICT METHODS ###############################
95100

96-
def __deepcopy__(self, memo):
97-
'''This enables deepcopy to successfully copy a Config object, despite
101+
def __deepcopy__(self, memo) -> "AttributeAccessDict":
102+
"""This enables deepcopy to successfully copy a Config object, despite
98103
the default value semantics
99-
'''
104+
"""
100105
return self.__class__(copy.deepcopy(dict(self)))
101106

102107

@@ -108,37 +113,42 @@ class ImmutableAttributeAccessDict(AttributeAccessDict):
108113
AttributeAccessDict, while maintaining nested immutability.
109114
"""
110115

111-
def __init__(self, input_map, *_):
116+
def __init__(self, input_map: Dict[str, Any], *_) -> None:
112117
"""See :func:`~aconfig.aconfig.AttributeAccessDict.__init__`"""
113118
if not isinstance(input_map, dict):
114-
raise TypeError('`input_map` argument should be of type dict, but found type: <{0}>'.format(
115-
type(input_map)))
119+
raise TypeError(
120+
"`input_map` argument should be of type dict, but found type: <{0}>".format(
121+
type(input_map)
122+
)
123+
)
116124
# 🌶️🌶️🌶️: we explicitly cast back down to `dict` for the immutable case
117125
# If we were to build an immutable dict from the top-down, that would
118126
# obviously fail.
119127
input_map = dict(input_map)
120128
# Invoke the AttributeAccessDict initializer
121129
super().__init__(input_map)
122130

123-
def __setitem__(self, key, value):
131+
def __setitem__(self, key: str, value: Any) -> NoReturn:
124132
raise TypeError("ImmutableAttributeAccessDict does not support item assignment")
125133

126-
def __setattr__(self, key, value):
127-
raise AttributeError("ImmutableAttributeAccessDict does not support attribute assignment")
134+
def __setattr__(self, key: str, value: Any) -> NoReturn:
135+
raise AttributeError(
136+
"ImmutableAttributeAccessDict does not support attribute assignment"
137+
)
128138

129139
@classmethod
130-
def _recursive_dict_class(cls) -> typing.Type['AttributeAccessDict']:
140+
def _recursive_dict_class(cls) -> Type["AttributeAccessDict"]:
131141
"""Make this class available to recursively build a full config"""
132142
return ImmutableAttributeAccessDict
133143

134144

135145
class Config(AttributeAccessDict):
136-
'''Config which holds the configurations at the given config location.
137-
'''
138-
_search_pattern = re.compile('[.-]')
146+
"""Config which holds the configurations at the given config location."""
139147

140-
def __init__(self, config, override_env_vars=True):
141-
'''
148+
_search_pattern = re.compile("[.-]")
149+
150+
def __init__(self, config: Dict[str, Any], override_env_vars=True) -> None:
151+
"""
142152
143153
NOTE:
144154
It is recommended NOT to use lists/arrays in .yaml files because lists cannot be
@@ -159,7 +169,7 @@ def __init__(self, config, override_env_vars=True):
159169
Note:
160170
Loaded Config will be available on `self` -- this class has no provided
161171
attributes in itself outside of the loaded config.
162-
'''
172+
"""
163173
# override with retrieved environment variable values if they exist
164174
updated_config = {key: value for key, value in config.items()}
165175
if override_env_vars:
@@ -169,8 +179,8 @@ def __init__(self, config, override_env_vars=True):
169179
super().__init__(updated_config)
170180

171181
@classmethod
172-
def from_yaml(cls, config_location=None, **kwargs):
173-
'''Load a config definition at specified location, parse it, and get environment var's.
182+
def from_yaml(cls, config_location: str, **kwargs: Any) -> "Config":
183+
"""Load a config definition at specified location, parse it, and get environment var's.
174184
175185
Args:
176186
config_location: str
@@ -184,46 +194,55 @@ def from_yaml(cls, config_location=None, **kwargs):
184194
Wrapped via internal methods so that it can be accessed using normal Python
185195
dictionary methods or nested attribute like syntax, for example:
186196
`config.timeout.downstream_10`
187-
'''
197+
"""
188198
# validate before moving forward: will raise exceptions if invalid
189199
config_location = cls._verify_config_location(config_location)
190200

191201
loaded_config = cls._load_yaml_file(config_location)
192202
return cls(loaded_config, **kwargs)
193203

194204
@staticmethod
195-
def _verify_config_location(config_location):
196-
'''Check to see if config location exists and is a .yaml file.
205+
def _verify_config_location(config_location: str) -> str:
206+
"""Check to see if config location exists and is a .yaml file.
197207
NOTE: enforces .yaml extension.
198208
199209
Args:
200210
config_location: str
201211
Location of .yaml to parse where desired configurations exist.
202212
203213
Returns:
204-
config_location: bool
214+
config_location: str
205215
Correct config_location relative to this file if the file exists and is .yaml file,
206216
otherwise raises AssertionError if config_location is not a Python str or if it is
207217
not a .yml/.yaml file or cannot be found.
208-
'''
209-
assert isinstance(config_location, str), \
210-
'config_location must be str, but you sent in type: <{0}>'.format(type(config_location))
218+
"""
219+
assert isinstance(
220+
config_location, str
221+
), "config_location must be str, but you sent in type: <{0}>".format(
222+
type(config_location)
223+
)
211224

212225
# cross-platform location relative to this file
213226
config_location = os.path.normpath(config_location)
214227

215-
assert (config_location.endswith('.yml') or config_location.endswith('.yaml')), \
216-
'Must send in a .yaml or .yaml file, you sent in: <{0}>'.format(config_location)
228+
assert config_location.endswith(".yml") or config_location.endswith(
229+
".yaml"
230+
), "Must send in a .yaml or .yaml file, you sent in: <{0}>".format(
231+
config_location
232+
)
217233

218-
assert os.path.exists(config_location), \
219-
'config_location <{0}> does not exist or cannot be found!'.format(config_location)
234+
assert os.path.exists(
235+
config_location
236+
), "config_location <{0}> does not exist or cannot be found!".format(
237+
config_location
238+
)
220239

221240
# finally found it's valid
222241
return config_location
223242

224243
@staticmethod
225-
def _load_yaml_file(config_location):
226-
'''Helper to load .yaml file at location. Assumes file location has been validated.
244+
def _load_yaml_file(config_location: str) -> Dict[Any, Any]:
245+
"""Helper to load .yaml file at location. Assumes file location has been validated.
227246
228247
Args:
229248
config_location: str
@@ -232,16 +251,16 @@ def _load_yaml_file(config_location):
232251
Returns:
233252
loaded_config: dict
234253
Config definition in a Python dictionary.
235-
'''
236-
with open(config_location, 'r', encoding='utf8') as config_handle:
254+
"""
255+
with open(config_location, "r", encoding="utf8") as config_handle:
237256
loaded_config = yaml.safe_load(config_handle)
238257

239258
# verify it is *definitely* a dict -- likely overkill
240259
return dict(loaded_config)
241260

242261
@staticmethod
243-
def _eval_value(candidate_value):
244-
'''Logic to convert str version of given value into Python data type. Used for env. var's.
262+
def _eval_value(candidate_value: str) -> Union[int, float, str, bool]:
263+
"""Logic to convert str version of given value into Python data type. Used for env. var's.
245264
246265
Args:
247266
candidate_value: str
@@ -251,18 +270,20 @@ def _eval_value(candidate_value):
251270
converted_value: bool, str, int, or float
252271
Value converted to its correct type. Leading/trailing whitespace stripped. If you
253272
did not pass in a Python str, throws TypeError.
254-
'''
273+
"""
255274
if not isinstance(candidate_value, str):
256275
raise TypeError(
257-
'Must pass in a str as candidate_value. You passed in type: <{0}>'.format(
258-
type(candidate_value)))
276+
"Must pass in a str as candidate_value. You passed in type: <{0}>".format(
277+
type(candidate_value)
278+
)
279+
)
259280

260281
# try to get bool
261282
candidate_value = candidate_value.strip()
262283

263-
if candidate_value.lower() == 'true':
284+
if candidate_value.lower() == "true":
264285
return True
265-
if candidate_value.lower() == 'false':
286+
if candidate_value.lower() == "false":
266287
return False
267288

268289
# try to get an integer or a float
@@ -281,8 +302,10 @@ def _eval_value(candidate_value):
281302
# last chance -- return as string
282303
return str(candidate_value)
283304

284-
def _update_with_env_vars(self, default_dict, prefix=None):
285-
'''Recursively update defaults with env. var's. Used for nested updating of dictionaries.
305+
def _update_with_env_vars(
306+
self, default_dict: Dict[str, Any], prefix: Optional[str] = None
307+
) -> Dict[str, Any]:
308+
"""Recursively update defaults with env. var's. Used for nested updating of dictionaries.
286309
287310
Args:
288311
default_dict: dict
@@ -292,18 +315,19 @@ def _update_with_env_vars(self, default_dict, prefix=None):
292315
updated_default_dict: dict
293316
Same as default_dict where the default values are updated with env. var's
294317
if they were found and they are the same type.
295-
'''
318+
"""
296319
for default_key, default_val in default_dict.items():
297-
298320
# step 1: Create the "full key" using the provided prefix
299321
if prefix:
300-
full_key = '.'.join([prefix, default_key])
322+
full_key = ".".join([prefix, default_key])
301323
else:
302324
full_key = default_key
303325

304326
# step 2: call recursively if necessary; skip empty dict's
305327
if isinstance(default_val, dict) and default_val:
306-
default_dict[default_key] = self._update_with_env_vars(default_val, full_key)
328+
default_dict[default_key] = self._update_with_env_vars(
329+
default_val, full_key
330+
)
307331

308332
# step 3: skip env. var. process for lists
309333
elif isinstance(default_val, list):
@@ -320,29 +344,30 @@ def _update_with_env_vars(self, default_dict, prefix=None):
320344
# step 5: update default_dict with value
321345
default_dict[default_key] = env_var_val
322346

323-
# values have now been overriden where possible!
347+
# values have now been overridden where possible!
324348
return default_dict
325349

326-
def _env_var_from_key(self, config_key):
327-
'''Convert a config key to the corresponding env var to check for.
350+
def _env_var_from_key(self, config_key: str) -> str:
351+
"""Convert a config key to the corresponding env var to check for.
328352
329353
Args:
330354
config_key: str
331355
Key in config file that should be converted to the environment variable to search
332-
for to override cofig_key's value with.
356+
for to override config_key's value with.
333357
334358
Returns:
335359
env_var_key: str
336360
Environment variable key to attempt to retrieve; converted from config_key.
337361
Replaced "." and "-" with "_" & upper-cased the key.
338-
'''
362+
"""
339363

340-
return re.sub(self._search_pattern, '_', config_key.upper())
364+
return re.sub(self._search_pattern, "_", config_key.upper())
341365

342366

343367
class ImmutableConfig(ImmutableAttributeAccessDict, Config):
344368
"""This class is the Immutable version of Config"""
345-
def __init__(self, config, override_env_vars=True):
369+
370+
def __init__(self, config: Dict[str, Any], override_env_vars=True) -> None:
346371
"""See :func:`~aconfig.aconfig.Config.__init__`"""
347372
if not isinstance(config, dict):
348373
raise TypeError("config must be a dict")

0 commit comments

Comments
 (0)