Skip to content

Commit f0608bc

Browse files
committed
refactor: use safe context
1 parent 0b90200 commit f0608bc

1 file changed

Lines changed: 32 additions & 41 deletions

File tree

dictdatabase/sessions.py

Lines changed: 32 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing import Tuple, TypeVar, Generic, Any, Callable
33
from . import utils, io_unsafe, locking
44

5+
from contextlib import contextmanager
56

67

78
T = TypeVar("T")
@@ -44,6 +45,28 @@ def write(self):
4445

4546

4647

48+
@contextmanager
49+
def safe_context(super, self, *, db_names_to_lock=None):
50+
"""
51+
If an exception happens in the context, the __exit__ method of the passed super
52+
class will be called.
53+
"""
54+
super.__enter__()
55+
try:
56+
if isinstance(db_names_to_lock, str):
57+
self.write_lock = locking.WriteLock(self.db_name)
58+
self.write_lock._lock()
59+
elif isinstance(db_names_to_lock, list):
60+
self.write_lock = [locking.WriteLock(x) for x in self.db_name]
61+
for lock in self.write_lock:
62+
lock._lock()
63+
yield
64+
except BaseException as e:
65+
super.__exit__(type(e), e, e.__traceback__)
66+
raise e
67+
68+
69+
4770
class SessionFileFull(SessionBase, Generic[T]):
4871
"""
4972
Context manager for read-write access to a full file.
@@ -52,19 +75,10 @@ class SessionFileFull(SessionBase, Generic[T]):
5275
Reads and writes the entire file.
5376
"""
5477

55-
def __init__(self, db_name: str, as_type: T = None):
56-
super().__init__(db_name, as_type)
57-
5878
def __enter__(self) -> Tuple[SessionFileFull, JSONSerializable | T]:
59-
super().__enter__()
60-
try:
61-
self.write_lock = locking.WriteLock(self.db_name)
62-
self.write_lock._lock()
79+
with safe_context(super(), self, db_names_to_lock=self.db_name):
6380
self.data_handle = io_unsafe.read(self.db_name)
6481
return self, type_cast(self.data_handle, self.as_type)
65-
except BaseException as e:
66-
super().__exit__(type(e), e, e.__traceback__)
67-
raise e
6882

6983
def write(self):
7084
super().write()
@@ -82,21 +96,15 @@ class SessionFileKey(SessionBase, Generic[T]):
8296
the key-value are written.
8397
"""
8498

85-
def __init__(self, db_name: str, key: str, as_type: T = None):
99+
def __init__(self, db_name: str, key: str, as_type: T):
86100
super().__init__(db_name, as_type)
87101
self.key = key
88102

89103
def __enter__(self) -> Tuple[SessionFileKey, JSONSerializable | T]:
90-
super().__enter__()
91-
try:
92-
self.write_lock = locking.WriteLock(self.db_name)
93-
self.write_lock._lock()
104+
with safe_context(super(), self, db_names_to_lock=self.db_name):
94105
self.partial_handle = io_unsafe.get_partial_file_handle(self.db_name, self.key)
95106
self.data_handle = self.partial_handle.partial_dict.value
96107
return self, type_cast(self.data_handle, self.as_type)
97-
except BaseException as e:
98-
super().__exit__(type(e), e, e.__traceback__)
99-
raise e
100108

101109
def write(self):
102110
super().write()
@@ -113,23 +121,17 @@ class SessionFileWhere(SessionBase, Generic[T]):
113121
Reads and writes the entire file, so it is not more efficient than
114122
SessionFileFull.
115123
"""
116-
def __init__(self, db_name: str, where: Callable[[Any, Any], bool], as_type: T = None):
124+
def __init__(self, db_name: str, where: Callable[[Any, Any], bool], as_type: T):
117125
super().__init__(db_name, as_type)
118126
self.where = where
119127

120128
def __enter__(self) -> Tuple[SessionFileWhere, JSONSerializable | T]:
121-
super().__enter__()
122-
try:
123-
self.write_lock = locking.WriteLock(self.db_name)
124-
self.write_lock._lock()
129+
with safe_context(super(), self, db_names_to_lock=self.db_name):
125130
self.original_data = io_unsafe.read(self.db_name)
126131
for k, v in self.original_data.items():
127132
if self.where(k, v):
128133
self.data_handle[k] = v
129134
return self, type_cast(self.data_handle, self.as_type)
130-
except BaseException as e:
131-
super().__exit__(type(e), e, e.__traceback__)
132-
raise e
133135

134136
def write(self):
135137
super().write()
@@ -147,20 +149,13 @@ class SessionDirFull(SessionBase, Generic[T]):
147149
Efficiency:
148150
Fully reads and writes all files.
149151
"""
150-
def __init__(self, db_name: str, as_type: T = None):
152+
def __init__(self, db_name: str, as_type: T):
151153
super().__init__(utils.find_all(db_name), as_type)
152154

153155
def __enter__(self) -> Tuple[SessionDirFull, JSONSerializable | T]:
154-
super().__enter__()
155-
try:
156-
self.write_lock = [locking.WriteLock(x) for x in self.db_name]
157-
for lock in self.write_lock:
158-
lock._lock()
156+
with safe_context(super(), self, db_names_to_lock=self.db_name):
159157
self.data_handle = {n.split("/")[-1]: io_unsafe.read(n) for n in self.db_name}
160158
return self, type_cast(self.data_handle, self.as_type)
161-
except BaseException as e:
162-
super().__exit__(type(e), e, e.__traceback__)
163-
raise e
164159

165160
def write(self):
166161
super().write()
@@ -177,13 +172,12 @@ class SessionDirWhere(SessionBase, Generic[T]):
177172
Efficiency:
178173
Fully reads all files, but only writes the selected files.
179174
"""
180-
def __init__(self, db_name: str, where: Callable[[Any, Any], bool], as_type: T = None):
175+
def __init__(self, db_name: str, where: Callable[[Any, Any], bool], as_type: T):
181176
super().__init__(utils.find_all(db_name), as_type)
182177
self.where = where
183178

184179
def __enter__(self) -> Tuple[SessionDirWhere, JSONSerializable | T]:
185-
super().__enter__()
186-
try:
180+
with safe_context(super(), self):
187181
selected_db_names, write_lock = [], []
188182
for db_name in self.db_name:
189183
lock = locking.WriteLock(db_name)
@@ -198,9 +192,6 @@ def __enter__(self) -> Tuple[SessionDirWhere, JSONSerializable | T]:
198192
self.write_lock = write_lock
199193
self.db_name = selected_db_names
200194
return self, type_cast(self.data_handle, self.as_type)
201-
except BaseException as e:
202-
super().__exit__(type(e), e, e.__traceback__)
203-
raise e
204195

205196
def write(self):
206197
super().write()

0 commit comments

Comments
 (0)