home · contact · privacy
Include plomlib for its db.py, adapt DB code to it.
authorChristian Heller <c.heller@plomlompom.de>
Wed, 15 Jan 2025 14:27:59 +0000 (15:27 +0100)
committerChristian Heller <c.heller@plomlompom.de>
Wed, 15 Jan 2025 14:27:59 +0000 (15:27 +0100)
.gitmodules [new file with mode: 0644]
src/plomlib [new submodule]
src/run.py
src/ytplom/db.py
src/ytplom/http.py
src/ytplom/migrations.py
src/ytplom/misc.py

diff --git a/.gitmodules b/.gitmodules
new file mode 100644 (file)
index 0000000..95be54b
--- /dev/null
@@ -0,0 +1,3 @@
+[submodule "src/plomlib"]
+       path = src/plomlib
+       url = https://plomlompom.com/repos/clone/plomlib
diff --git a/src/plomlib b/src/plomlib
new file mode 160000 (submodule)
index 0000000..743dbe0
--- /dev/null
@@ -0,0 +1 @@
+Subproject commit 743dbe0d493ddeb47eca981fa5be6d78e4d754c9
index 4b1ddc31050096f6c847a2908e01271d0f6a54c2..49214ce2fb66cb4e4ef73dca2333ef3fdf1f5c51 100755 (executable)
@@ -6,7 +6,7 @@ from sys import argv, exit as sys_exit
 # ourselves
 from ytplom.db import DbFile
 from ytplom.primitives import HandledException
-from ytplom.migrations import migrate
+from ytplom.migrations import MIGRATIONS
 from ytplom.http import serve
 from ytplom.sync import sync
 
@@ -19,7 +19,7 @@ if __name__ == '__main__':
             case 'create_db':
                 DbFile.create()
             case 'migrate_db':
-                migrate()
+                DbFile(skip_validations=True).migrate(MIGRATIONS)
             case 'serve':
                 serve()
             case 'sync':
index 816e4012be1fa3a0456d26303c59853c257f65f6..3bf719863ee31a609d656e3ccaee894d34551527 100644 (file)
@@ -2,23 +2,20 @@
 
 # included libs
 from base64 import urlsafe_b64decode, urlsafe_b64encode
-from difflib import Differ
 from hashlib import file_digest
 from pathlib import Path
-from sqlite3 import (
-        connect as sql_connect, Cursor as SqlCursor, Row as SqlRow)
-from typing import Callable, Literal, NewType, Optional, Self
+from sqlite3 import Row as SqlRow
+from typing import Self
 # ourselves
+from plomlib.db import (
+        PlomDbConn, PlomDbFile, PlomDbMigration, TypePlomDbMigration)
 from ytplom.primitives import (
         HandledException, NotFoundException, PATH_APP_DATA)
 
 
-EXPECTED_DB_VERSION = 6
 PATH_DB = PATH_APP_DATA.joinpath('db.sql')
 
-SqlText = NewType('SqlText', str)
-
-_SQL_DB_VERSION = SqlText('PRAGMA user_version')
+_EXPECTED_DB_VERSION = 6
 _PATH_MIGRATIONS = PATH_APP_DATA.joinpath('migrations')
 _HASH_ALGO = 'sha512'
 _PATH_DB_SCHEMA = _PATH_MIGRATIONS.joinpath('new_init.sql')
@@ -58,124 +55,20 @@ class Hash:
         return urlsafe_b64encode(self.bytes).decode('utf8')
 
 
-class DbFile:
-    """Wrapper around the file of a sqlite3 database."""
-
-    def __init__(self,
-                 path: Path = PATH_DB,
-                 version_to_validate: int = EXPECTED_DB_VERSION
-                 ) -> None:
-        self.path = path
-        if not self.path.is_file():
-            raise HandledException(f'no DB file at {self.path}')
-        if version_to_validate < 0:
-            return
-        if (user_version := self._get_user_version()) != version_to_validate:
-            raise HandledException(
-                f'wrong DB version {user_version} (!= {version_to_validate})')
-        with DbConn(self) as conn:
-            self._validate_schema(conn)
-
-    @staticmethod
-    def _validate_schema(conn: 'DbConn'):
-        schema_rows_normed = []
-        indent = '  '
-        for row in [
-                r[0] for r in conn.exec(SqlText(
-                    'SELECT sql FROM sqlite_master ORDER BY sql'))
-                if r[0]]:
-            row_normed = []
-            for subrow in [sr.rstrip() for sr in row.split('\n')]:
-                in_parentheses = 0
-                split_at = []
-                for i, c in enumerate(subrow):
-                    if '(' == c:
-                        in_parentheses += 1
-                    elif ')' == c:
-                        in_parentheses -= 1
-                    elif ',' == c and 0 == in_parentheses:
-                        split_at += [i + 1]
-                prev_split = 0
-                for i in split_at:
-                    if segment := subrow[prev_split:i].strip():
-                        row_normed += [f'{indent}{segment}']
-                    prev_split = i
-                if segment := subrow[prev_split:].strip():
-                    row_normed += [f'{indent}{segment}']
-            row_normed[0] = row_normed[0].lstrip()  # no indent for opening …
-            row_normed[-1] = row_normed[-1].lstrip()  # … and closing line
-            if row_normed[-1] != ')' and row_normed[-3][-1] != ',':
-                row_normed[-3] = row_normed[-3] + ','
-                row_normed[-2:] = [indent + row_normed[-1][:-1]] + [')']
-            row_normed[-1] = row_normed[-1] + ';'
-            schema_rows_normed += row_normed
-        if ((expected_rows :=
-             _PATH_DB_SCHEMA.read_text(encoding='utf8').rstrip().splitlines()
-             ) != schema_rows_normed):
-            raise HandledException(
-                'Unexpected tables schema. Diff to {_PATH_DB_SCHEMA}:\n' +
-                '\n'.join(Differ().compare(schema_rows_normed, expected_rows)))
-
-    def _get_user_version(self) -> int:
-        with sql_connect(self.path) as conn:
-            return list(conn.execute(_SQL_DB_VERSION))[0][0]
-
-    @staticmethod
-    def create(path: Path = PATH_DB) -> None:
-        """Create DB file at path according to _PATH_DB_SCHEMA."""
-        if path.exists():
-            raise HandledException(
-                    f'There already exists a node at {path}.')
-        if not path.parent.is_dir():
-            raise HandledException(
-                    f'No directory {path.parent} found to write into.')
-        with sql_connect(path) as conn:
-            conn.executescript(_PATH_DB_SCHEMA.read_text(encoding='utf8'))
-            conn.execute(f'{_SQL_DB_VERSION} = {EXPECTED_DB_VERSION}')
-
-    def migrate(self, migrations: set['DbMigration']) -> None:
-        """Migrate self towards EXPECTED_DB_VERSION"""
-        start_version = self._get_user_version()
-        if start_version == EXPECTED_DB_VERSION:
-            raise HandledException(
-                f'Already at {EXPECTED_DB_VERSION}, nothing to migrate.')
-        if start_version > EXPECTED_DB_VERSION:
-            raise HandledException(
-                f'Cannot migrate backwards from {start_version}'
-                f'to {EXPECTED_DB_VERSION}.')
-        with DbConn(self) as conn:
-            for migration in DbMigration.from_to_in_set(
-                    start_version, EXPECTED_DB_VERSION, migrations):
-                migration.perform(conn)
-            self._validate_schema(conn)
-            conn.commit()
-
-
-class DbMigration:
-    """Representation of DbFile migration data."""
-
-    def __init__(self,
-                 version: int,
-                 sql_path: Optional[Path] = None,
-                 after_sql_steps: Optional[Callable[['DbConn'], None]] = None
-                 ) -> None:
-        if sql_path:
-            start_tok = str(sql_path).split('_', maxsplit=1)[0]
-            if (not start_tok.isdigit()) or int(start_tok) != version:
-                raise HandledException(
-                        f'migration {version} mapped to bad path {sql_path}')
-        self._version = version
-        self._sql_path = sql_path
-        self._after_sql_steps = after_sql_steps
+class DbMigration(PlomDbMigration):
+    """Collects and enacts DbFile migration commands."""
+    migs_dir_path = _PATH_MIGRATIONS
 
     @classmethod
-    def from_to_in_set(
-            cls, from_version: int, to_version: int, migs_set: set[Self]
-            ) -> list[Self]:
-        """From migs_set make sorted unbroken list from_version to_version."""
+    def gather(cls,
+               from_version: int,
+               base_set: set[TypePlomDbMigration]
+               ) -> list[TypePlomDbMigration]:
         selected_migs = []
-        for version in [n+1 for n in range(from_version, to_version)]:
-            matching_migs = [m for m in migs_set if version == m._version]
+        for version in [n+1 for n in range(from_version,
+                                           _EXPECTED_DB_VERSION)]:
+            matching_migs = [m for m in base_set  # cls.collection
+                             if version == m.target_version]
             if not matching_migs:
                 raise HandledException(f'Missing migration of v{version}')
             if len(matching_migs) > 1:
@@ -183,43 +76,19 @@ class DbMigration:
             selected_migs += [matching_migs[0]]
         return selected_migs
 
-    def perform(self, conn: 'DbConn') -> None:
-        """Do 1) script at sql_path, 2) after_sql_steps, 3) versino setting."""
-        if self._sql_path:
-            conn.exec_script(
-                SqlText(_PATH_MIGRATIONS.joinpath(self._sql_path)
-                        .read_text(encoding='utf8')))
-        if self._after_sql_steps:
-            self._after_sql_steps(conn)
-        conn.exec(SqlText(f'{_SQL_DB_VERSION} = {self._version}'))
-
-
-class DbConn:
-    """Wrapper for sqlite3 connections."""
-
-    def __init__(self, db_file: Optional[DbFile] = None) -> None:
-        self._conn = sql_connect((db_file or DbFile()).path, autocommit=False)
-        self.commit = self._conn.commit
-
-    def __enter__(self) -> Self:
-        return self
 
-    def __exit__(self, exc_type, exc_value, traceback) -> Literal[False]:
-        self._conn.close()
-        return False
+class DbFile(PlomDbFile):
+    """File readable as DB of expected schema, user version."""
+    indent_n = 2
+    target_version = _EXPECTED_DB_VERSION
+    path_schema = _PATH_DB_SCHEMA
+    default_path = PATH_DB
+    mig_class = DbMigration
 
-    def exec(self, sql: SqlText, inputs: tuple = tuple()
-             ) -> SqlCursor:
-        """Wrapper around sqlite3.Connection.execute, building '?' if inputs"""
-        if len(inputs) > 0:
-            q_marks = ('?' if len(inputs) == 1
-                       else '(' + ','.join(['?'] * len(inputs)) + ')')
-            return self._conn.execute(SqlText(f'{sql} {q_marks}'), inputs)
-        return self._conn.execute(sql)
 
-    def exec_script(self, sql: SqlText) -> None:
-        """Wrapper around sqlite3.Connection.executescript."""
-        self._conn.executescript(sql)
+class DbConn(PlomDbConn):
+    """SQL connection to DbFile."""
+    default_path = PATH_DB
 
 
 class DbData:
@@ -253,7 +122,7 @@ class DbData:
     @classmethod
     def get_one(cls, conn: DbConn, id_: str | Hash) -> Self:
         """Return single entry of id_ from DB."""
-        sql = SqlText(f'SELECT * FROM {cls._table_name} WHERE {cls.id_name} =')
+        sql = f'SELECT * FROM {cls._table_name} WHERE {cls.id_name} ='
         id__ = id_.bytes if isinstance(id_, Hash) else id_
         row = conn.exec(sql, (id__,)).fetchone()
         if not row:
@@ -264,7 +133,7 @@ class DbData:
     @classmethod
     def get_all(cls, conn: DbConn) -> list[Self]:
         """Return all entries from DB."""
-        sql = SqlText(f'SELECT * FROM {cls._table_name}')
+        sql = f'SELECT * FROM {cls._table_name}'
         rows = conn.exec(sql).fetchall()
         return [cls._from_table_row(row) for row in rows]
 
@@ -277,5 +146,5 @@ class DbData:
             elif isinstance(val, Hash):
                 val = val.bytes
             vals += [val]
-        conn.exec(SqlText(f'REPLACE INTO {self._table_name} VALUES'),
+        conn.exec(f'REPLACE INTO {self._table_name} VALUES',
                   tuple(vals))
index 5fef79d112e7d166702d1213e3272cb95167e50e..c57dfaea10a034f9f898dd000ec13ea43102625a 100644 (file)
@@ -367,7 +367,6 @@ class _TaskHandler(BaseHTTPRequestHandler):
             self._send_http(f.read(), [(_HEADER_CONTENT_TYPE, 'image/jpg')])
 
     def _send_yt_result(self, video_id: YoutubeId) -> None:
-        conn = DbConn()
         with DbConn() as conn:
             linked_queries = YoutubeQuery.get_all_for_video(conn, video_id)
             try:
index f075cdf19617b6758f28fb264bfce0c88a938f06..aadec7806fffc7b8f1370398e953015f92fff522 100644 (file)
@@ -4,7 +4,7 @@
 from pathlib import Path
 from typing import Callable
 # ourselves
-from ytplom.db import DbConn, DbFile, DbMigration, SqlText
+from ytplom.db import DbConn, DbMigration
 from ytplom.primitives import HandledException
 
 
@@ -14,10 +14,10 @@ _LEGIT_YES = 'YES!!'
 def _rewrite_files_last_field_processing_first_field(conn: DbConn,
                                                      cb: Callable
                                                      ) -> None:
-    rows = conn.exec(SqlText('SELECT * FROM files')).fetchall()
+    rows = conn.exec('SELECT * FROM files').fetchall()
     for row in [list(r) for r in rows]:
         row[-1] = cb(row[0])
-        conn.exec(SqlText('REPLACE INTO files VALUES'), tuple(row))
+        conn.exec('REPLACE INTO files VALUES', tuple(row))
 
 
 def _mig_2_calc_digests(conn: DbConn) -> None:
@@ -27,7 +27,7 @@ def _mig_2_calc_digests(conn: DbConn) -> None:
     from ytplom.misc import PATH_DOWNLOADS
     rel_paths = [
             p[0] for p
-            in conn.exec(SqlText('SELECT rel_path FROM files')).fetchall()]
+            in conn.exec('SELECT rel_path FROM files').fetchall()]
     missing = [p for p in rel_paths
                if not Path(PATH_DOWNLOADS).joinpath(p).exists()]
     if missing:
@@ -40,7 +40,7 @@ def _mig_2_calc_digests(conn: DbConn) -> None:
         if _LEGIT_YES != reply:
             raise HandledException('Migration aborted!')
         for path in missing:
-            conn.exec(SqlText('DELETE FROM files WHERE rel_path ='), (path,))
+            conn.exec('DELETE FROM files WHERE rel_path =', (path,))
 
     def hexdigest_file(path):
         print(f'Calculating digest for: {path}')
@@ -55,17 +55,13 @@ def _mig_4_convert_digests(conn: DbConn) -> None:
     _rewrite_files_last_field_processing_first_field(conn, bytes.fromhex)
 
 
-_MIGRATIONS: set[DbMigration] = {
+MIGRATIONS: set[DbMigration] = {
     DbMigration(0, Path('0_init.sql'), None),
     DbMigration(1, Path('1_add_files_last_updated.sql'), None),
     DbMigration(2, Path('2_add_files_sha512.sql'), _mig_2_calc_digests),
     DbMigration(3, Path('3_files_redo.sql'), None),
-    DbMigration(4, Path('4_add_files_sha512_blob.sql'), _mig_4_convert_digests),
+    DbMigration(4, Path('4_add_files_sha512_blob.sql'),
+                _mig_4_convert_digests),
     DbMigration(5, Path('5_files_redo.sql'), None),
     DbMigration(6, Path('6_add_files_tags.sql'), None)
 }
-
-
-def migrate():
-    """Migrate DB file at expected default path to most recent version."""
-    DbFile(version_to_validate=-1).migrate(_MIGRATIONS)
index 5dc6cf5b1764d3f784feb60f570433725b893b6d..44ab385ca80afcf3c8d481ca26df950653183c7a 100644 (file)
@@ -18,7 +18,7 @@ import googleapiclient.discovery  # type: ignore
 from mpv import MPV  # type: ignore
 from yt_dlp import YoutubeDL  # type: ignore
 # ourselves
-from ytplom.db import DbConn, DbData, Hash, SqlText
+from ytplom.db import DbConn, DbData, Hash
 from ytplom.primitives import HandledException, NotFoundException
 
 
@@ -254,8 +254,7 @@ class YoutubeQuery(DbData):
     def get_all_for_video(cls, conn: DbConn, video_id: YoutubeId
                           ) -> list[Self]:
         """Return YoutubeQueries containing YoutubeVideo's ID in results."""
-        sql = SqlText('SELECT query_id FROM '
-                      'yt_query_results WHERE video_id =')
+        sql = 'SELECT query_id FROM yt_query_results WHERE video_id ='
         query_ids = conn.exec(sql, (video_id,)).fetchall()
         return [cls.get_one(conn, query_id_tup[0])
                 for query_id_tup in query_ids]
@@ -305,16 +304,14 @@ class YoutubeVideo(DbData):
     @classmethod
     def get_all_for_query(cls, conn: DbConn, query_id: QueryId) -> list[Self]:
         """Return all videos for query of query_id."""
-        sql = SqlText('SELECT video_id '
-                      'FROM yt_query_results WHERE query_id =')
+        sql = 'SELECT video_id FROM yt_query_results WHERE query_id ='
         video_ids = conn.exec(sql, (query_id,)).fetchall()
         return [cls.get_one(conn, video_id_tup[0])
                 for video_id_tup in video_ids]
 
     def save_to_query(self, conn: DbConn, query_id: QueryId) -> None:
         """Save inclusion of self in results to query of query_id."""
-        conn.exec(SqlText('REPLACE INTO yt_query_results VALUES'),
-                  (query_id, self.id_))
+        conn.exec('REPLACE INTO yt_query_results VALUES', (query_id, self.id_))
 
 
 class VideoFile(DbData):
@@ -373,8 +370,8 @@ class VideoFile(DbData):
     @classmethod
     def get_by_yt_id(cls, conn: DbConn, yt_id: YoutubeId) -> Self:
         """Return VideoFile of .yt_id."""
-        sql = SqlText(f'SELECT * FROM {cls._table_name} WHERE yt_id =')
-        row = conn.exec(sql, (yt_id,)).fetchone()
+        row = conn.exec(f'SELECT * FROM {cls._table_name} WHERE yt_id =',
+                        (yt_id,)).fetchone()
         if not row:
             raise NotFoundException(f'no entry for file to Youtube ID {yt_id}')
         return cls._from_table_row(row)
@@ -489,9 +486,8 @@ class VideoFile(DbData):
             if file.present:
                 file.unlink_locally()
             print(f'SYNC: purging off DB: {file.digest.b64} ({file.rel_path})')
-            conn.exec(
-                    SqlText(f'DELETE FROM {cls._table_name} WHERE digest ='),
-                    (file.digest.bytes,))
+            conn.exec(f'DELETE FROM {cls._table_name} WHERE digest =',
+                      (file.digest.bytes,))
 
 
 class QuotaLog(DbData):
@@ -528,8 +524,8 @@ class QuotaLog(DbData):
     @classmethod
     def _remove_old(cls, conn: DbConn) -> None:
         cutoff = datetime.now() - timedelta(days=1)
-        sql = SqlText(f'DELETE FROM {cls._table_name} WHERE timestamp <')
-        conn.exec(SqlText(sql), (cutoff.strftime(TIMESTAMP_FMT),))
+        conn.exec(f'DELETE FROM {cls._table_name} WHERE timestamp <',
+                  (cutoff.strftime(TIMESTAMP_FMT),))
 
 
 class Player: