+"""Database access and management code."""
+from pathlib import Path
+from sqlite3 import (
+ connect as sql_connect, Connection as SqlConnection, Cursor, Row)
+from typing import Any, Literal, NewType, Self
+from ytplom.primitives import (
+ HandledException, NotFoundException, PATH_APP_DATA)
+
+SqlText = NewType('SqlText', str)
+
+EXPECTED_DB_VERSION = 3
+PATH_DB = PATH_APP_DATA.joinpath('db.sql')
+SQL_DB_VERSION = SqlText('PRAGMA user_version')
+PATH_MIGRATIONS = PATH_APP_DATA.joinpath('migrations')
+_PATH_DB_SCHEMA = PATH_MIGRATIONS.joinpath('new_init.sql')
+_NAME_INSTALLER = Path('install.sh')
+
+
+def get_db_version(db_path: Path) -> int:
+ """Return user_version value of DB at db_path."""
+ with sql_connect(db_path) as conn:
+ return list(conn.execute(SQL_DB_VERSION))[0][0]
+
+
+class BaseDbConn:
+ """Wrapper for pre-established sqlite3.Connection."""
+
+ def __init__(self, sql_conn: SqlConnection) -> None:
+ self._conn = sql_conn
+
+ def exec(self, sql: SqlText, inputs: tuple[Any, ...] = tuple()) -> Cursor:
+ """Wrapper around sqlite3.Connection.execute."""
+ return self._conn.execute(sql, inputs)
+
+ def commit(self) -> None:
+ """Commit changes (i.e. DbData.save() calls) to database."""
+ self._conn.commit()
+
+
+class DbConn(BaseDbConn):
+ """Like parent, but opening (and as context mgr: closing) connection."""
+
+ def __init__(self, path: Path = PATH_DB) -> None:
+ if not path.is_file():
+ if path.exists():
+ raise HandledException(f'no DB at {path}; would create, '
+ 'but something\'s already there?')
+ if not path.parent.is_dir():
+ raise HandledException(
+ f'cannot find {path.parent} as directory to put '
+ f'DB into, did you run {_NAME_INSTALLER}?')
+ with sql_connect(path) as conn:
+ conn.executescript(_PATH_DB_SCHEMA.read_text(encoding='utf8'))
+ conn.execute(f'{SQL_DB_VERSION} = {EXPECTED_DB_VERSION}')
+ cur_version = get_db_version(path)
+ if cur_version != EXPECTED_DB_VERSION:
+ raise HandledException(
+ f'wrong database version {cur_version}, expected: '
+ f'{EXPECTED_DB_VERSION} – run "migrate"?')
+ super().__init__(sql_connect(path, autocommit=False))
+
+ def __enter__(self) -> Self:
+ return self
+
+ def __exit__(self, exc_type, exc_value, traceback) -> Literal[False]:
+ self._conn.close()
+ return False
+
+
+class DbData:
+ """Abstraction of common DB operation."""
+ id_name: str = 'id'
+ _table_name: str
+ _cols: tuple[str, ...]
+
+ def __eq__(self, other: Any) -> bool:
+ if not isinstance(other, self.__class__):
+ return False
+ for attr_name in self._cols:
+ if getattr(self, attr_name) != getattr(other, attr_name):
+ return False
+ return True
+
+ @classmethod
+ def _from_table_row(cls, row: Row) -> Self:
+ kwargs = {}
+ for i, col_name in enumerate(cls._cols):
+ kwargs[col_name] = row[i]
+ for attr_name, type_ in cls.__annotations__.items():
+ if attr_name in kwargs:
+ kwargs[attr_name] = type_(kwargs[attr_name])
+ return cls(**kwargs)
+
+ @classmethod
+ def get_one(cls, conn: BaseDbConn, id_: str) -> Self:
+ """Return single entry of id_ from DB."""
+ sql = SqlText(f'SELECT * FROM {cls._table_name} '
+ f'WHERE {cls.id_name} = ?')
+ row = conn.exec(sql, (id_,)).fetchone()
+ if not row:
+ msg = f'no entry found for ID "{id_}" in table {cls._table_name}'
+ raise NotFoundException(msg)
+ return cls._from_table_row(row)
+
+ @classmethod
+ def get_all(cls, conn: BaseDbConn) -> list[Self]:
+ """Return all entries from DB."""
+ sql = SqlText(f'SELECT * FROM {cls._table_name}')
+ rows = conn.exec(sql).fetchall()
+ return [cls._from_table_row(row) for row in rows]
+
+ def save(self, conn: BaseDbConn) -> Cursor:
+ """Save entry to DB."""
+ vals = [getattr(self, col_name) for col_name in self._cols]
+ q_marks = '(' + ','.join(['?'] * len(vals)) + ')'
+ sql = SqlText(f'REPLACE INTO {self._table_name} VALUES {q_marks}')
+ return conn.exec(sql, tuple(str(v) if isinstance(v, Path) else v
+ for v in vals))