77from sqlalchemy .engine .interfaces import _CoreAnyExecuteParams
88from sqlalchemy .engine .url import URL
99from sqlalchemy .ext .asyncio import (
10+ AsyncSession ,
1011 async_scoped_session ,
1112 async_sessionmaker ,
1213 create_async_engine ,
@@ -33,8 +34,7 @@ class DBConnection:
3334 asyncio.run(conn.init_db()) # Initialize the database
3435
3536 If your base model is not ``ActiveRecordBaseModel`` you must
36- pass your base model class to the ``init_db`` method in the
37- ``base_model`` argument::
37+ pass your base model class to the ``init_db`` method::
3838
3939 from sqlactive import DBConnection, ActiveRecordBaseModel
4040
@@ -45,9 +45,12 @@ class BaseModel(ActiveRecordBaseModel):
4545 conn = DBConnection(DATABASE_URL, echo=True)
4646 asyncio.run(conn.init_db(BaseModel)) # Pass your base model
4747
48+ You can also initialize multiple base models at once::
49+
50+ asyncio.run(conn.init_db(BaseModel, AnotherBaseModel))
51+
4852 The ``close`` method can be called to close the database
49- connection. It also sets the ``session`` attribute of
50- the base model to ``None``::
53+ connection::
5154
5255 from sqlactive import DBConnection
5356
@@ -57,25 +60,6 @@ class BaseModel(ActiveRecordBaseModel):
5760 # Perform operations...
5861
5962 asyncio.run(conn.close()) # Close the connection
60-
61- If your base model is not ``ActiveRecordBaseModel``
62- you should pass your base model class to the ``close`` method
63- in the ``base_model`` argument::
64-
65- from sqlactive import DBConnection, ActiveRecordBaseModel
66-
67- # Note that it does not matter if your base model
68- # inherits from ``ActiveRecordBaseModel``, you still
69- # need to pass it to this method
70- class BaseModel(ActiveRecordBaseModel):
71- __abstract__ = True
72-
73- DATABASE_URL = 'sqlite+aiosqlite://'
74- conn = DBConnection(DATABASE_URL, echo=True)
75-
76- # Perform operations...
77-
78- asyncio.run(conn.close(BaseModel)) # Pass your base model
7963 """
8064
8165 def __init__ (self , url : str | URL , ** kw : Any ) -> None :
@@ -114,69 +98,24 @@ def __init__(self, url: str | URL, **kw: Any) -> None:
11498
11599 async def init_db (
116100 self ,
117- base_model : type [ActiveRecordBaseModel ] | None = None ,
118- ) -> None :
119- """Initialize the database tables.
120-
121- If your base model is not ``ActiveRecordBaseModel`` you
122- must pass your base model class to this method in the
123- ``base_model`` argument::
124-
125- from sqlactive import DBConnection, ActiveRecordBaseModel
126-
127- # Note that it does not matter if your base model
128- # inherits from ``ActiveRecordBaseModel``, you still
129- # need to pass it to this method
130- class BaseModel(ActiveRecordBaseModel):
131- __abstract__ = True
132-
133- DATABASE_URL = 'sqlite+aiosqlite://'
134- conn = DBConnection(DATABASE_URL, echo=True)
135- asyncio.run(conn.init_db(BaseModel)) # Pass your base model
136- """
137- if not base_model :
138- base_model = ActiveRecordBaseModel
139-
140- base_model .set_session (self .async_scoped_session )
141-
142- async with self .async_engine .begin () as conn :
143- await conn .run_sync (base_model .metadata .create_all )
144-
145- async def close (
146- self ,
147- base_model : type [ActiveRecordBaseModel ] | None = None ,
101+ * base_models : type [ActiveRecordBaseModel ],
148102 ) -> None :
149- """Close both the database connection and the session.
150-
151- If your base model is not ``ActiveRecordBaseModel``
152- you should pass your base model class to this method
153- in the ``base_model`` argument::
154-
155- from sqlactive import DBConnection, ActiveRecordBaseModel
156-
157- # Note that it does not matter if your base model
158- # inherits from ``ActiveRecordBaseModel``, you still
159- # need to pass it to this method
160- class BaseModel(ActiveRecordBaseModel):
161- __abstract__ = True
162-
163- DATABASE_URL = 'sqlite+aiosqlite://'
164- conn = DBConnection(DATABASE_URL, echo=True)
165- asyncio.run(conn.init_db(BaseModel))
166-
167- # Perform operations...
168-
169- asyncio.run(conn.close(BaseModel)) # Pass your base model
170- """
103+ """Initialize the database tables for the given base models."""
104+ for base_model in base_models or [ActiveRecordBaseModel ]:
105+ base_model .set_session (self .async_scoped_session )
106+ async with self .async_engine .begin () as conn :
107+ await conn .run_sync (base_model .metadata .create_all )
108+
109+ async def close (self ) -> None :
110+ """Close the database connection and remove the session."""
111+ await self .async_scoped_session .remove ()
112+ self .async_sessionmaker .configure (bind = None )
171113 await self .async_engine .dispose ()
172- if base_model :
173- base_model .close_session ()
174- ActiveRecordBaseModel .close_session ()
175114
176115
177116async def execute (
117+ async_scoped_session : async_scoped_session [AsyncSession ],
178118 statement : TypedReturnsRows [RowType ],
179- base_model : type [ActiveRecordBaseModel ] | None = None ,
180119 params : _CoreAnyExecuteParams | None = None ,
181120 ** kwargs ,
182121) -> Result [RowType ]:
@@ -187,42 +126,18 @@ async def execute(
187126 of the ``execute`` method of the
188127 ``sqlalchemy.ext.asyncio.AsyncSession`` class.
189128
190- If your base model is not ``ActiveRecordBaseModel``
191- you must pass your base model class to this method
192- in the ``base_model`` argument::
193-
194- # Note that it does not matter if your base model
195- # inherits from ``ActiveRecordBaseModel``, you still
196- # need to pass it to this method
197- class BaseModel(ActiveRecordBaseModel):
198- __abstract__ = True
199-
200- class User(BaseModel):
201- __tablename__ = 'users'
202- # ...
203-
204- query = select(User.age, func.count(User.id)).group_by(User.age)
205- result = await execute(query, BaseModel) # or execute(query, User)
206-
207- .. warning::
208- Your base model must have a session in order to use this method.
209- Otherwise, it will raise an ``NoSessionError`` exception.
210- If you are not using the ``DBConnection`` class to initialize
211- your base model, you can call its ``set_session`` method
212- to set the session.
213-
214129 Examples
215130 --------
131+ >>> from sqlactive import DBConnection
132+ >>> conn = DBConnection(DATABASE_URL, echo=True)
216133 >>> query = select(User.age, func.count(User.id)).group_by(User.age)
217- >>> result = await execute(query)
134+ >>> result = await execute(conn.async_scoped_session, query)
218135 >>> result
219136 <sqlalchemy.engine.result.Result object at 0x...>
220137 >>> users = result.all()
221138 >>> users
222139 [(20, 1), (22, 4), (25, 12)]
223140
224141 """
225- if not base_model :
226- base_model = ActiveRecordBaseModel
227- async with base_model .AsyncSession () as session :
142+ async with async_scoped_session () as session :
228143 return await session .execute (statement , params , ** kwargs )
0 commit comments