36 lines
1.0 KiB
Python
36 lines
1.0 KiB
Python
from __future__ import annotations
|
|
|
|
from collections.abc import AsyncIterator
|
|
from contextlib import asynccontextmanager
|
|
|
|
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine
|
|
|
|
from app.config import Settings
|
|
from app.db.base import Base
|
|
|
|
|
|
def create_engine(settings: Settings) -> AsyncEngine:
|
|
return create_async_engine(settings.database_url, future=True, echo=False)
|
|
|
|
|
|
def create_session_factory(engine: AsyncEngine) -> async_sessionmaker[AsyncSession]:
|
|
return async_sessionmaker(engine, expire_on_commit=False, class_=AsyncSession)
|
|
|
|
|
|
async def init_db(engine: AsyncEngine) -> None:
|
|
async with engine.begin() as conn:
|
|
await conn.run_sync(Base.metadata.create_all)
|
|
|
|
|
|
@asynccontextmanager
|
|
async def session_scope(factory: async_sessionmaker[AsyncSession]) -> AsyncIterator[AsyncSession]:
|
|
session = factory()
|
|
try:
|
|
yield session
|
|
await session.commit()
|
|
except Exception:
|
|
await session.rollback()
|
|
raise
|
|
finally:
|
|
await session.close()
|