22from typing import Tuple , TypeVar , Generic , Any , Callable
33from . import utils , io_unsafe , locking
44
5+ from contextlib import contextmanager
56
67
78T = TypeVar ("T" )
@@ -44,6 +45,28 @@ def write(self):
4445
4546
4647
48+ @contextmanager
49+ def safe_context (super , self , * , db_names_to_lock = None ):
50+ """
51+ If an exception happens in the context, the __exit__ method of the passed super
52+ class will be called.
53+ """
54+ super .__enter__ ()
55+ try :
56+ if isinstance (db_names_to_lock , str ):
57+ self .write_lock = locking .WriteLock (self .db_name )
58+ self .write_lock ._lock ()
59+ elif isinstance (db_names_to_lock , list ):
60+ self .write_lock = [locking .WriteLock (x ) for x in self .db_name ]
61+ for lock in self .write_lock :
62+ lock ._lock ()
63+ yield
64+ except BaseException as e :
65+ super .__exit__ (type (e ), e , e .__traceback__ )
66+ raise e
67+
68+
69+
4770class SessionFileFull (SessionBase , Generic [T ]):
4871 """
4972 Context manager for read-write access to a full file.
@@ -52,19 +75,10 @@ class SessionFileFull(SessionBase, Generic[T]):
5275 Reads and writes the entire file.
5376 """
5477
55- def __init__ (self , db_name : str , as_type : T = None ):
56- super ().__init__ (db_name , as_type )
57-
5878 def __enter__ (self ) -> Tuple [SessionFileFull , JSONSerializable | T ]:
59- super ().__enter__ ()
60- try :
61- self .write_lock = locking .WriteLock (self .db_name )
62- self .write_lock ._lock ()
79+ with safe_context (super (), self , db_names_to_lock = self .db_name ):
6380 self .data_handle = io_unsafe .read (self .db_name )
6481 return self , type_cast (self .data_handle , self .as_type )
65- except BaseException as e :
66- super ().__exit__ (type (e ), e , e .__traceback__ )
67- raise e
6882
6983 def write (self ):
7084 super ().write ()
@@ -82,21 +96,15 @@ class SessionFileKey(SessionBase, Generic[T]):
8296 the key-value are written.
8397 """
8498
85- def __init__ (self , db_name : str , key : str , as_type : T = None ):
99+ def __init__ (self , db_name : str , key : str , as_type : T ):
86100 super ().__init__ (db_name , as_type )
87101 self .key = key
88102
89103 def __enter__ (self ) -> Tuple [SessionFileKey , JSONSerializable | T ]:
90- super ().__enter__ ()
91- try :
92- self .write_lock = locking .WriteLock (self .db_name )
93- self .write_lock ._lock ()
104+ with safe_context (super (), self , db_names_to_lock = self .db_name ):
94105 self .partial_handle = io_unsafe .get_partial_file_handle (self .db_name , self .key )
95106 self .data_handle = self .partial_handle .partial_dict .value
96107 return self , type_cast (self .data_handle , self .as_type )
97- except BaseException as e :
98- super ().__exit__ (type (e ), e , e .__traceback__ )
99- raise e
100108
101109 def write (self ):
102110 super ().write ()
@@ -113,23 +121,17 @@ class SessionFileWhere(SessionBase, Generic[T]):
113121 Reads and writes the entire file, so it is not more efficient than
114122 SessionFileFull.
115123 """
116- def __init__ (self , db_name : str , where : Callable [[Any , Any ], bool ], as_type : T = None ):
124+ def __init__ (self , db_name : str , where : Callable [[Any , Any ], bool ], as_type : T ):
117125 super ().__init__ (db_name , as_type )
118126 self .where = where
119127
120128 def __enter__ (self ) -> Tuple [SessionFileWhere , JSONSerializable | T ]:
121- super ().__enter__ ()
122- try :
123- self .write_lock = locking .WriteLock (self .db_name )
124- self .write_lock ._lock ()
129+ with safe_context (super (), self , db_names_to_lock = self .db_name ):
125130 self .original_data = io_unsafe .read (self .db_name )
126131 for k , v in self .original_data .items ():
127132 if self .where (k , v ):
128133 self .data_handle [k ] = v
129134 return self , type_cast (self .data_handle , self .as_type )
130- except BaseException as e :
131- super ().__exit__ (type (e ), e , e .__traceback__ )
132- raise e
133135
134136 def write (self ):
135137 super ().write ()
@@ -147,20 +149,13 @@ class SessionDirFull(SessionBase, Generic[T]):
147149 Efficiency:
148150 Fully reads and writes all files.
149151 """
150- def __init__ (self , db_name : str , as_type : T = None ):
152+ def __init__ (self , db_name : str , as_type : T ):
151153 super ().__init__ (utils .find_all (db_name ), as_type )
152154
153155 def __enter__ (self ) -> Tuple [SessionDirFull , JSONSerializable | T ]:
154- super ().__enter__ ()
155- try :
156- self .write_lock = [locking .WriteLock (x ) for x in self .db_name ]
157- for lock in self .write_lock :
158- lock ._lock ()
156+ with safe_context (super (), self , db_names_to_lock = self .db_name ):
159157 self .data_handle = {n .split ("/" )[- 1 ]: io_unsafe .read (n ) for n in self .db_name }
160158 return self , type_cast (self .data_handle , self .as_type )
161- except BaseException as e :
162- super ().__exit__ (type (e ), e , e .__traceback__ )
163- raise e
164159
165160 def write (self ):
166161 super ().write ()
@@ -177,13 +172,12 @@ class SessionDirWhere(SessionBase, Generic[T]):
177172 Efficiency:
178173 Fully reads all files, but only writes the selected files.
179174 """
180- def __init__ (self , db_name : str , where : Callable [[Any , Any ], bool ], as_type : T = None ):
175+ def __init__ (self , db_name : str , where : Callable [[Any , Any ], bool ], as_type : T ):
181176 super ().__init__ (utils .find_all (db_name ), as_type )
182177 self .where = where
183178
184179 def __enter__ (self ) -> Tuple [SessionDirWhere , JSONSerializable | T ]:
185- super ().__enter__ ()
186- try :
180+ with safe_context (super (), self ):
187181 selected_db_names , write_lock = [], []
188182 for db_name in self .db_name :
189183 lock = locking .WriteLock (db_name )
@@ -198,9 +192,6 @@ def __enter__(self) -> Tuple[SessionDirWhere, JSONSerializable | T]:
198192 self .write_lock = write_lock
199193 self .db_name = selected_db_names
200194 return self , type_cast (self .data_handle , self .as_type )
201- except BaseException as e :
202- super ().__exit__ (type (e ), e , e .__traceback__ )
203- raise e
204195
205196 def write (self ):
206197 super ().write ()
0 commit comments