"""Database engine, session helpers, bootstrap and lightweight migration.""" from __future__ import annotations from contextlib import contextmanager from typing import Iterator from sqlalchemy import inspect, text from sqlmodel import Session, SQLModel, create_engine, select from . import config from .auth import hash_password from .models import Settings, User engine = create_engine( config.DATABASE_URL, echo=False, connect_args={"check_same_thread": False}, ) def _migrate() -> None: """Add any model columns missing from existing tables (SQLite ALTER ADD). Keeps simple deployments upgradeable without a migration framework. New columns always have defaults, so a plain ADD COLUMN is sufficient. """ inspector = inspect(engine) existing_tables = set(inspector.get_table_names()) type_map = {"INTEGER": "INTEGER", "BOOLEAN": "BOOLEAN", "VARCHAR": "VARCHAR", "DATETIME": "DATETIME"} with engine.begin() as conn: for table in SQLModel.metadata.sorted_tables: if table.name not in existing_tables: continue have = {c["name"] for c in inspector.get_columns(table.name)} for column in table.columns: if column.name in have: continue col_type = type_map.get( column.type.__class__.__name__.upper(), "VARCHAR" ) default = column.default.arg if column.default is not None else None if isinstance(default, bool): default_sql = "1" if default else "0" elif isinstance(default, (int, float)): default_sql = str(default) elif isinstance(default, str): default_sql = f"'{default}'" else: default_sql = "NULL" conn.execute( text( f'ALTER TABLE "{table.name}" ' f'ADD COLUMN "{column.name}" {col_type} DEFAULT {default_sql}' ) ) def init_db() -> None: """Create tables, run migration, ensure settings + admin user exist.""" SQLModel.metadata.create_all(engine) _migrate() with Session(engine) as session: if session.get(Settings, 1) is None: session.add( Settings( id=1, default_ntfy_server=config.DEFAULT_NTFY_SERVER, check_interval=config.DEFAULT_CHECK_INTERVAL, auth_enabled=False, ) ) session.commit() # Bootstrap the first admin account if no users exist. if not session.exec(select(User)).first(): session.add( User( username=config.ADMIN_USERNAME, password_hash=hash_password(config.ADMIN_PASSWORD), role="admin", ) ) session.commit() def get_settings(session: Session) -> Settings: settings = session.get(Settings, 1) if settings is None: # safety net settings = Settings(id=1) session.add(settings) session.commit() session.refresh(settings) return settings @contextmanager def session_scope() -> Iterator[Session]: session = Session(engine) try: yield session session.commit() except Exception: session.rollback() raise finally: session.close() def get_session() -> Iterator[Session]: """FastAPI dependency.""" with Session(engine) as session: yield session __all__ = [ "engine", "init_db", "get_settings", "get_session", "session_scope", "select", ]