Skip to content

Commit 755d7b6

Browse files
authored
Updates from orm branch (#36)
* improve sessions * improve tests
1 parent ae732a9 commit 755d7b6

5 files changed

Lines changed: 233 additions & 136 deletions

File tree

dictdatabase/models.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22
from typing import TypeVar, Any, Callable
33
from . import utils, io_safe, config
4-
from . session import DDBSession
4+
from . sessions import SessionFileFull, SessionFileKey, SessionFileWhere, SessionDirFull, SessionDirWhere
55

66
T = TypeVar("T")
77

@@ -200,7 +200,7 @@ def type_cast(value):
200200
return type_cast(data)
201201

202202

203-
def session(self, as_type: T = None) -> DDBSession[T]:
203+
def session(self, as_type: T = None) -> SessionFileFull[T] | SessionFileKey[T] | SessionFileWhere[T] | SessionDirFull[T] | SessionDirWhere[T]:
204204
"""
205205
Opens a session to the selected file(s) or folder, depending on previous
206206
`.at(...)` selection. Inside the with block, you have exclusive access
@@ -215,4 +215,13 @@ def session(self, as_type: T = None) -> DDBSession[T]:
215215
- `FileNotFoundError`: If the file does not exist.
216216
- `KeyError`: If a key is specified and it does not exist.
217217
"""
218-
return DDBSession(self.path, self.op_type, self.key, self.where, as_type)
218+
if self.op_type.file_normal:
219+
return SessionFileFull(self.path, as_type)
220+
if self.op_type.file_key:
221+
return SessionFileKey(self.path, self.key, as_type)
222+
if self.op_type.file_where:
223+
return SessionFileWhere(self.path, self.where, as_type)
224+
if self.op_type.dir_normal:
225+
return SessionDirFull(self.path, as_type)
226+
if self.op_type.dir_where:
227+
return SessionDirWhere(self.path, self.where, as_type)

dictdatabase/session.py

Lines changed: 0 additions & 130 deletions
This file was deleted.

dictdatabase/sessions.py

Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
from __future__ import annotations
2+
from typing import Tuple, TypeVar, Generic, Any, Callable
3+
from . import utils, io_unsafe, locking
4+
5+
from contextlib import contextmanager
6+
7+
8+
T = TypeVar("T")
9+
JSONSerializable = TypeVar("JSONSerializable", str, int, float, bool, None, list, dict)
10+
11+
12+
13+
def type_cast(obj, as_type):
14+
return obj if as_type is None else as_type(obj)
15+
16+
17+
18+
class SessionBase:
19+
in_session: bool
20+
db_name: str
21+
as_type: T
22+
23+
def __init__(self, db_name: str, as_type):
24+
self.in_session = False
25+
self.db_name = db_name
26+
self.as_type = as_type
27+
28+
def __enter__(self):
29+
self.in_session = True
30+
self.data_handle = {}
31+
32+
def __exit__(self, type, value, tb):
33+
write_lock = getattr(self, "write_lock", None)
34+
if write_lock is not None:
35+
if isinstance(write_lock, list):
36+
for lock in write_lock:
37+
lock._unlock()
38+
else:
39+
write_lock._unlock()
40+
self.write_lock, self.in_session = None, False
41+
42+
def write(self):
43+
if not self.in_session:
44+
raise PermissionError("Only call write() inside a with statement.")
45+
46+
47+
48+
@contextmanager
49+
def safe_context(super, self, *, db_names_to_lock=None):
50+
"""
51+
If an exception happens in the context, the __exit__ method of the passed super
52+
class will be called.
53+
"""
54+
super.__enter__()
55+
try:
56+
if isinstance(db_names_to_lock, str):
57+
self.write_lock = locking.WriteLock(self.db_name)
58+
self.write_lock._lock()
59+
elif isinstance(db_names_to_lock, list):
60+
self.write_lock = [locking.WriteLock(x) for x in self.db_name]
61+
for lock in self.write_lock:
62+
lock._lock()
63+
yield
64+
except BaseException as e:
65+
super.__exit__(type(e), e, e.__traceback__)
66+
raise e
67+
68+
69+
70+
########################################################################################
71+
#### File sessions
72+
########################################################################################
73+
74+
75+
76+
class SessionFileFull(SessionBase, Generic[T]):
77+
"""
78+
Context manager for read-write access to a full file.
79+
80+
Efficiency:
81+
Reads and writes the entire file.
82+
"""
83+
84+
def __enter__(self) -> Tuple[SessionFileFull, JSONSerializable | T]:
85+
with safe_context(super(), self, db_names_to_lock=self.db_name):
86+
self.data_handle = io_unsafe.read(self.db_name)
87+
return self, type_cast(self.data_handle, self.as_type)
88+
89+
def write(self):
90+
super().write()
91+
io_unsafe.write(self.db_name, self.data_handle)
92+
93+
94+
95+
class SessionFileKey(SessionBase, Generic[T]):
96+
"""
97+
Context manager for read-write access to a single key-value item in a file.
98+
99+
Efficiency:
100+
Uses partial reading, which allows only reading the bytes of the key-value item.
101+
When writing, only the bytes of the key-value and the bytes of the file after
102+
the key-value are written.
103+
"""
104+
105+
def __init__(self, db_name: str, key: str, as_type: T):
106+
super().__init__(db_name, as_type)
107+
self.key = key
108+
109+
def __enter__(self) -> Tuple[SessionFileKey, JSONSerializable | T]:
110+
with safe_context(super(), self, db_names_to_lock=self.db_name):
111+
self.partial_handle = io_unsafe.get_partial_file_handle(self.db_name, self.key)
112+
self.data_handle = self.partial_handle.partial_dict.value
113+
return self, type_cast(self.data_handle, self.as_type)
114+
115+
def write(self):
116+
super().write()
117+
io_unsafe.partial_write(self.partial_handle)
118+
119+
120+
121+
class SessionFileWhere(SessionBase, Generic[T]):
122+
"""
123+
Context manager for read-write access to selection of key-value items in a file.
124+
The where callable is called with the key and value of each item in the file.
125+
126+
Efficiency:
127+
Reads and writes the entire file, so it is not more efficient than
128+
SessionFileFull.
129+
"""
130+
def __init__(self, db_name: str, where: Callable[[Any, Any], bool], as_type: T):
131+
super().__init__(db_name, as_type)
132+
self.where = where
133+
134+
def __enter__(self) -> Tuple[SessionFileWhere, JSONSerializable | T]:
135+
with safe_context(super(), self, db_names_to_lock=self.db_name):
136+
self.original_data = io_unsafe.read(self.db_name)
137+
for k, v in self.original_data.items():
138+
if self.where(k, v):
139+
self.data_handle[k] = v
140+
return self, type_cast(self.data_handle, self.as_type)
141+
142+
def write(self):
143+
super().write()
144+
self.original_data.update(self.data_handle)
145+
io_unsafe.write(self.db_name, self.original_data)
146+
147+
148+
149+
########################################################################################
150+
#### File sessions
151+
########################################################################################
152+
153+
154+
155+
class SessionDirFull(SessionBase, Generic[T]):
156+
"""
157+
Context manager for read-write access to all files in a directory.
158+
They are provided as a dict of {str(file_name): dict(file_content)}, where the
159+
file name does not contain the directory name nor the file extension.
160+
161+
Efficiency:
162+
Fully reads and writes all files.
163+
"""
164+
def __init__(self, db_name: str, as_type: T):
165+
super().__init__(utils.find_all(db_name), as_type)
166+
167+
def __enter__(self) -> Tuple[SessionDirFull, JSONSerializable | T]:
168+
with safe_context(super(), self, db_names_to_lock=self.db_name):
169+
self.data_handle = {n.split("/")[-1]: io_unsafe.read(n) for n in self.db_name}
170+
return self, type_cast(self.data_handle, self.as_type)
171+
172+
def write(self):
173+
super().write()
174+
for name in self.db_name:
175+
io_unsafe.write(name, self.data_handle[name.split("/")[-1]])
176+
177+
178+
179+
class SessionDirWhere(SessionBase, Generic[T]):
180+
"""
181+
Context manager for read-write access to selection of files in a directory.
182+
The where callable is called with the file name and parsed content of each file.
183+
184+
Efficiency:
185+
Fully reads all files, but only writes the selected files.
186+
"""
187+
def __init__(self, db_name: str, where: Callable[[Any, Any], bool], as_type: T):
188+
super().__init__(utils.find_all(db_name), as_type)
189+
self.where = where
190+
191+
def __enter__(self) -> Tuple[SessionDirWhere, JSONSerializable | T]:
192+
with safe_context(super(), self):
193+
selected_db_names, write_lock = [], []
194+
for db_name in self.db_name:
195+
lock = locking.WriteLock(db_name)
196+
lock._lock()
197+
k, v = db_name.split("/")[-1], io_unsafe.read(db_name)
198+
if self.where(k, v):
199+
self.data_handle[k] = v
200+
write_lock.append(lock)
201+
selected_db_names.append(db_name)
202+
else:
203+
lock._unlock()
204+
self.write_lock = write_lock
205+
self.db_name = selected_db_names
206+
return self, type_cast(self.data_handle, self.as_type)
207+
208+
def write(self):
209+
super().write()
210+
for name in self.db_name:
211+
io_unsafe.write(name, self.data_handle[name.split("/")[-1]])

tests/conftest.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,23 @@
22
from tests import TEST_DIR
33
import pytest
44
import shutil
5+
import os
56

67

78
@pytest.fixture(scope="session")
89
def use_test_dir(request):
910
DDB.config.storage_directory = TEST_DIR
11+
os.makedirs(TEST_DIR, exist_ok=True)
1012
request.addfinalizer(lambda: shutil.rmtree(TEST_DIR))
1113

1214

1315

16+
@pytest.fixture(scope="function")
17+
def name_of_test(request):
18+
return request.function.__name__
19+
20+
21+
1422
@pytest.fixture(params=[True, False])
1523
def use_compression(request):
1624
DDB.config.use_compression = request.param

0 commit comments

Comments
 (0)