"""
Database engines and session factories (sync and async) for the app.
This module centralizes SQLAlchemy setup. It exposes reusable, process-wide
engine singletons and companion session factories for both synchronous and
asynchronous usage. When targeting SQLite, it applies pragmatic defaults to
improve concurrency (WAL) and resilience (busy timeout) via connection event
hooks. All ORM models inherit from the exported ``Base``.
See Also
--------
app.models : Declarative ORM models (bound to ``app.database.Base``).
app.schemas : Pydantic schemas mirroring ORM shapes.
app.main : FastAPI app wiring request-scoped DB sessions.
app.config : Source of ``settings.DATABASE_URL``.
Notes
-----
- Primary role: provide ready-to-use SQLAlchemy engines and session factories
and a declarative ``Base`` for model definitions.
- Key dependencies: ``app.config.settings.DATABASE_URL`` must be defined; SQLAlchemy
(sync/async) drivers must be available for the chosen URL.
- Invariants: engines are module-level singletons; SQLite engines use
``check_same_thread=False`` and have PRAGMA ``journal_mode=WAL`` and
``busy_timeout`` set at connect time.
Examples
--------
>>> # Ensure schema exists (tables are created if missing) # doctest: +SKIP
>>> from app.database import ensure_database_schema
>>> ensure_database_schema()
>>> # Sync usage # doctest: +SKIP
>>> from app.database import SessionLocal
>>> with SessionLocal() as session:
... pass
>>> # Async usage # doctest: +SKIP
>>> from app.database import AsyncSessionLocal
>>> async def get_count() -> int:
... async with AsyncSessionLocal() as session: # doctest: +SKIP
... return 0
"""
import logging
import os
from typing import Any, Optional
from sqlalchemy import create_engine, event
from sqlalchemy.engine import Engine
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.ext.asyncio import (
AsyncEngine,
AsyncSession,
async_sessionmaker,
create_async_engine,
)
from sqlalchemy.orm import DeclarativeBase, sessionmaker
from .config import settings
logger = logging.getLogger(__name__)
DATABASE_URL: str = settings.DATABASE_URL
# URL prefixes and connection arguments
SQLITE_URL_PREFIX: str = "sqlite"
SQLITE_SYNC_URL_PREFIX: str = "sqlite:///"
SQLITE_ASYNC_URL_PREFIX: str = "sqlite+aiosqlite:///"
POSTGRESQL_SYNC_URL_PREFIX: str = "postgresql://"
POSTGRESQL_ASYNC_URL_PREFIX: str = "postgresql+asyncpg://"
SQLITE_CONNECT_ARGS: dict[str, Any] = {"check_same_thread": False}
# PRAGMA constants for SQLite
PRAGMA_JOURNAL_MODE: str = "WAL"
PRAGMA_BUSY_TIMEOUT_MS: int = 60000
def _get_engine(database_url: str) -> Engine:
"""Create a SQLAlchemy engine with optional SQLite tuning.
Parameters
----------
database_url : str
Database URL (e.g., ``sqlite:///...`` or ``postgresql://...``).
Returns
-------
sqlalchemy.engine.Engine
Configured engine instance with SQLite PRAGMAs applied when relevant.
Raises
------
SQLAlchemyError
If engine creation fails (e.g., invalid URL or driver error).
Notes
-----
- For SQLite URLs, ``check_same_thread=False`` is used for thread safety in
multi-threaded contexts like FastAPI, and WAL/timeout PRAGMAs are set.
Examples
--------
>>> from app.database import _get_engine
>>> engine = _get_engine("sqlite:///:memory:")
>>> engine.dialect.name == "sqlite"
True
"""
assert (
database_url
), f"Database URL must be a non-empty string, got {database_url!r}"
try:
logger.debug("Creating SQLAlchemy engine for URL %s", database_url)
if database_url.startswith(SQLITE_URL_PREFIX):
connect_args = SQLITE_CONNECT_ARGS
else:
connect_args = {}
engine = create_engine(database_url, connect_args=connect_args)
if database_url.startswith(SQLITE_URL_PREFIX):
_configure_sqlite_engine(engine)
return engine
except SQLAlchemyError:
logger.error(
f"Failed to create SQLAlchemy engine for URL {database_url}", exc_info=True
)
raise
def _configure_sqlite_engine(engine: Engine) -> None:
"""Apply SQLite-specific PRAGMA statements via event listeners.
Parameters
----------
engine : sqlalchemy.engine.Engine
Engine connected to a SQLite database.
Notes
-----
- Sets ``journal_mode=WAL`` and a busy timeout to improve concurrency.
Examples
--------
>>> from sqlalchemy import create_engine
>>> from app.database import _configure_sqlite_engine
>>> engine = create_engine("sqlite:///:memory:")
>>> _configure_sqlite_engine(engine) # doctest: +SKIP
"""
assert isinstance(
engine, Engine
), f"Engine must be SQLAlchemy Engine, got {type(engine).__name__}"
def set_sqlite_pragmas(dbapi_connection: Any, connection_record: Any) -> None:
with dbapi_connection.cursor() as cursor:
cursor.execute(f"PRAGMA journal_mode={PRAGMA_JOURNAL_MODE}")
cursor.execute(f"PRAGMA busy_timeout={PRAGMA_BUSY_TIMEOUT_MS}")
event.listen(engine, "connect", set_sqlite_pragmas)
logger.debug(
"Registered SQLite PRAGMA listeners (journal_mode=%s, busy_timeout=%dms)",
PRAGMA_JOURNAL_MODE,
PRAGMA_BUSY_TIMEOUT_MS,
)
if os.environ.get("SKIP_ENGINE_INIT"):
engine: Optional[Engine] = None
SessionLocal: Optional[sessionmaker[Any]] = None
else:
engine: Engine = _get_engine(DATABASE_URL)
SessionLocal: sessionmaker[Any] = sessionmaker(
autocommit=False,
autoflush=False,
bind=engine,
)
# --- ASYNC SQLALCHEMY SUPPORT ---
def _get_async_database_url(database_url: str) -> str:
"""Convert a sync URL to an async-compatible SQLAlchemy URL.
Parameters
----------
database_url : str
Original database URL (sync or already async).
Returns
-------
str
Async driver URL (``sqlite+aiosqlite:///`` or ``postgresql+asyncpg://``).
Raises
------
ValueError
If the input URL uses an unsupported scheme for async operation.
Examples
--------
>>> from app.database import _get_async_database_url
>>> _get_async_database_url("sqlite:///test.db")
'sqlite+aiosqlite:///test.db'
>>> _get_async_database_url("postgresql://u:p@h/db")
'postgresql+asyncpg://u:p@h/db'
"""
assert (
database_url
), f"Database URL must be a non-empty string, got {database_url!r}"
if database_url.startswith(SQLITE_SYNC_URL_PREFIX):
return database_url.replace(SQLITE_SYNC_URL_PREFIX, SQLITE_ASYNC_URL_PREFIX)
if database_url.startswith(POSTGRESQL_SYNC_URL_PREFIX):
return database_url.replace(
POSTGRESQL_SYNC_URL_PREFIX, POSTGRESQL_ASYNC_URL_PREFIX
)
if database_url.startswith(SQLITE_ASYNC_URL_PREFIX) or database_url.startswith(
POSTGRESQL_ASYNC_URL_PREFIX
):
return database_url
raise ValueError(f"Unsupported database URL for async: {database_url}")
if os.environ.get("SKIP_ENGINE_INIT"):
ASYNC_DATABASE_URL: Optional[str] = None
async_engine: Optional[AsyncEngine] = None
AsyncSessionLocal: Optional[async_sessionmaker[AsyncSession]] = None
else:
ASYNC_DATABASE_URL: str = _get_async_database_url(DATABASE_URL)
async_engine: AsyncEngine = create_async_engine(
ASYNC_DATABASE_URL,
echo=False,
future=True,
)
AsyncSessionLocal: async_sessionmaker[AsyncSession] = async_sessionmaker(
bind=async_engine,
expire_on_commit=False,
autoflush=False,
autocommit=False,
)
[docs]
class Base(DeclarativeBase):
"""Declarative base for all ORM models.
All SQLAlchemy ORM models in this project must inherit from this base.
The metadata bound to this base is used by :func:`ensure_database_schema`
to create database tables.
Notes
-----
- Exposed so that ``app.models`` can subclass it for all entities.
- Keeps a single metadata registry for consistent schema management.
"""
pass
[docs]
def ensure_database_schema() -> None:
"""Create missing tables based on ORM metadata.
Delegates to SQLAlchemy's ``Base.metadata.create_all`` using the configured
synchronous engine. This is idempotent and safe to call during startup or
migrations bootstrap.
Raises
------
SQLAlchemyError
If the schema creation step fails (for example due to missing
privileges or an unavailable database).
Examples
--------
>>> from app.database import ensure_database_schema # doctest: +SKIP
>>> ensure_database_schema() # doctest: +SKIP
See Also
--------
app.models : Entities whose tables are created from the metadata.
app.main : Application startup that may call this function.
"""
assert isinstance(
engine, Engine
), f"Engine must be SQLAlchemy Engine, got {type(engine).__name__}"
logger.info("Ensuring database schema exists; creating tables if missing.")
try:
Base.metadata.create_all(engine)
except SQLAlchemyError:
logger.error("Failed to create database schema", exc_info=True)
raise
__all__ = [
"engine",
"SessionLocal",
"Base",
"ensure_database_schema",
"async_engine",
"AsyncSessionLocal",
]