Skip to content

Commit d7455f4

Browse files
committed
refactor: Make init_db method support multiple base models initialization and enforce a stronger shutdown flow
1 parent 1316aef commit d7455f4

1 file changed

Lines changed: 23 additions & 108 deletions

File tree

sqlactive/conn.py

Lines changed: 23 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from sqlalchemy.engine.interfaces import _CoreAnyExecuteParams
88
from sqlalchemy.engine.url import URL
99
from sqlalchemy.ext.asyncio import (
10+
AsyncSession,
1011
async_scoped_session,
1112
async_sessionmaker,
1213
create_async_engine,
@@ -33,8 +34,7 @@ class DBConnection:
3334
asyncio.run(conn.init_db()) # Initialize the database
3435
3536
If your base model is not ``ActiveRecordBaseModel`` you must
36-
pass your base model class to the ``init_db`` method in the
37-
``base_model`` argument::
37+
pass your base model class to the ``init_db`` method::
3838
3939
from sqlactive import DBConnection, ActiveRecordBaseModel
4040
@@ -45,9 +45,12 @@ class BaseModel(ActiveRecordBaseModel):
4545
conn = DBConnection(DATABASE_URL, echo=True)
4646
asyncio.run(conn.init_db(BaseModel)) # Pass your base model
4747
48+
You can also initialize multiple base models at once::
49+
50+
asyncio.run(conn.init_db(BaseModel, AnotherBaseModel))
51+
4852
The ``close`` method can be called to close the database
49-
connection. It also sets the ``session`` attribute of
50-
the base model to ``None``::
53+
connection::
5154
5255
from sqlactive import DBConnection
5356
@@ -57,25 +60,6 @@ class BaseModel(ActiveRecordBaseModel):
5760
# Perform operations...
5861
5962
asyncio.run(conn.close()) # Close the connection
60-
61-
If your base model is not ``ActiveRecordBaseModel``
62-
you should pass your base model class to the ``close`` method
63-
in the ``base_model`` argument::
64-
65-
from sqlactive import DBConnection, ActiveRecordBaseModel
66-
67-
# Note that it does not matter if your base model
68-
# inherits from ``ActiveRecordBaseModel``, you still
69-
# need to pass it to this method
70-
class BaseModel(ActiveRecordBaseModel):
71-
__abstract__ = True
72-
73-
DATABASE_URL = 'sqlite+aiosqlite://'
74-
conn = DBConnection(DATABASE_URL, echo=True)
75-
76-
# Perform operations...
77-
78-
asyncio.run(conn.close(BaseModel)) # Pass your base model
7963
"""
8064

8165
def __init__(self, url: str | URL, **kw: Any) -> None:
@@ -114,69 +98,24 @@ def __init__(self, url: str | URL, **kw: Any) -> None:
11498

11599
async def init_db(
116100
self,
117-
base_model: type[ActiveRecordBaseModel] | None = None,
118-
) -> None:
119-
"""Initialize the database tables.
120-
121-
If your base model is not ``ActiveRecordBaseModel`` you
122-
must pass your base model class to this method in the
123-
``base_model`` argument::
124-
125-
from sqlactive import DBConnection, ActiveRecordBaseModel
126-
127-
# Note that it does not matter if your base model
128-
# inherits from ``ActiveRecordBaseModel``, you still
129-
# need to pass it to this method
130-
class BaseModel(ActiveRecordBaseModel):
131-
__abstract__ = True
132-
133-
DATABASE_URL = 'sqlite+aiosqlite://'
134-
conn = DBConnection(DATABASE_URL, echo=True)
135-
asyncio.run(conn.init_db(BaseModel)) # Pass your base model
136-
"""
137-
if not base_model:
138-
base_model = ActiveRecordBaseModel
139-
140-
base_model.set_session(self.async_scoped_session)
141-
142-
async with self.async_engine.begin() as conn:
143-
await conn.run_sync(base_model.metadata.create_all)
144-
145-
async def close(
146-
self,
147-
base_model: type[ActiveRecordBaseModel] | None = None,
101+
*base_models: type[ActiveRecordBaseModel],
148102
) -> None:
149-
"""Close both the database connection and the session.
150-
151-
If your base model is not ``ActiveRecordBaseModel``
152-
you should pass your base model class to this method
153-
in the ``base_model`` argument::
154-
155-
from sqlactive import DBConnection, ActiveRecordBaseModel
156-
157-
# Note that it does not matter if your base model
158-
# inherits from ``ActiveRecordBaseModel``, you still
159-
# need to pass it to this method
160-
class BaseModel(ActiveRecordBaseModel):
161-
__abstract__ = True
162-
163-
DATABASE_URL = 'sqlite+aiosqlite://'
164-
conn = DBConnection(DATABASE_URL, echo=True)
165-
asyncio.run(conn.init_db(BaseModel))
166-
167-
# Perform operations...
168-
169-
asyncio.run(conn.close(BaseModel)) # Pass your base model
170-
"""
103+
"""Initialize the database tables for the given base models."""
104+
for base_model in base_models or [ActiveRecordBaseModel]:
105+
base_model.set_session(self.async_scoped_session)
106+
async with self.async_engine.begin() as conn:
107+
await conn.run_sync(base_model.metadata.create_all)
108+
109+
async def close(self) -> None:
110+
"""Close the database connection and remove the session."""
111+
await self.async_scoped_session.remove()
112+
self.async_sessionmaker.configure(bind=None)
171113
await self.async_engine.dispose()
172-
if base_model:
173-
base_model.close_session()
174-
ActiveRecordBaseModel.close_session()
175114

176115

177116
async def execute(
117+
async_scoped_session: async_scoped_session[AsyncSession],
178118
statement: TypedReturnsRows[RowType],
179-
base_model: type[ActiveRecordBaseModel] | None = None,
180119
params: _CoreAnyExecuteParams | None = None,
181120
**kwargs,
182121
) -> Result[RowType]:
@@ -187,42 +126,18 @@ async def execute(
187126
of the ``execute`` method of the
188127
``sqlalchemy.ext.asyncio.AsyncSession`` class.
189128
190-
If your base model is not ``ActiveRecordBaseModel``
191-
you must pass your base model class to this method
192-
in the ``base_model`` argument::
193-
194-
# Note that it does not matter if your base model
195-
# inherits from ``ActiveRecordBaseModel``, you still
196-
# need to pass it to this method
197-
class BaseModel(ActiveRecordBaseModel):
198-
__abstract__ = True
199-
200-
class User(BaseModel):
201-
__tablename__ = 'users'
202-
# ...
203-
204-
query = select(User.age, func.count(User.id)).group_by(User.age)
205-
result = await execute(query, BaseModel) # or execute(query, User)
206-
207-
.. warning::
208-
Your base model must have a session in order to use this method.
209-
Otherwise, it will raise an ``NoSessionError`` exception.
210-
If you are not using the ``DBConnection`` class to initialize
211-
your base model, you can call its ``set_session`` method
212-
to set the session.
213-
214129
Examples
215130
--------
131+
>>> from sqlactive import DBConnection
132+
>>> conn = DBConnection(DATABASE_URL, echo=True)
216133
>>> query = select(User.age, func.count(User.id)).group_by(User.age)
217-
>>> result = await execute(query)
134+
>>> result = await execute(conn.async_scoped_session, query)
218135
>>> result
219136
<sqlalchemy.engine.result.Result object at 0x...>
220137
>>> users = result.all()
221138
>>> users
222139
[(20, 1), (22, 4), (25, 12)]
223140
224141
"""
225-
if not base_model:
226-
base_model = ActiveRecordBaseModel
227-
async with base_model.AsyncSession() as session:
142+
async with async_scoped_session() as session:
228143
return await session.execute(statement, params, **kwargs)

0 commit comments

Comments
 (0)