home · contact · privacy
b4dc3e982c496833e7962ab02dc643c027235c1e
[plomtask] / plomtask / db.py
1 """Database management."""
2 from __future__ import annotations
3 from os import listdir
4 from os.path import isfile
5 from difflib import Differ
6 from sqlite3 import connect as sql_connect, Cursor, Row
7 from typing import Any, Self, TypeVar, Generic
8 from plomtask.exceptions import HandledException, NotFoundException
9
10 EXPECTED_DB_VERSION = 3
11 MIGRATIONS_DIR = 'migrations'
12 FILENAME_DB_SCHEMA = f'init_{EXPECTED_DB_VERSION}.sql'
13 PATH_DB_SCHEMA = f'{MIGRATIONS_DIR}/{FILENAME_DB_SCHEMA}'
14
15
16 class UnmigratedDbException(HandledException):
17     """To identify case of unmigrated DB file."""
18
19
20 class DatabaseFile:  # pylint: disable=too-few-public-methods
21     """Represents the sqlite3 database's file."""
22
23     def __init__(self, path: str) -> None:
24         self.path = path
25         self._check()
26
27     @classmethod
28     def create_at(cls, path: str) -> DatabaseFile:
29         """Make new DB file at path."""
30         with sql_connect(path) as conn:
31             with open(PATH_DB_SCHEMA, 'r', encoding='utf-8') as f:
32                 conn.executescript(f.read())
33             conn.execute(f'PRAGMA user_version = {EXPECTED_DB_VERSION}')
34         return cls(path)
35
36     @classmethod
37     def migrate(cls, path: str) -> DatabaseFile:
38         """Apply migrations from_version to EXPECTED_DB_VERSION."""
39         migrations = cls._available_migrations()
40         from_version = cls.get_version_of_db(path)
41         migrations_todo = migrations[from_version+1:]
42         for j, filename in enumerate(migrations_todo):
43             with sql_connect(path) as conn:
44                 with open(f'{MIGRATIONS_DIR}/{filename}', 'r',
45                           encoding='utf-8') as f:
46                     conn.executescript(f.read())
47             user_version = from_version + j + 1
48             with sql_connect(path) as conn:
49                 conn.execute(f'PRAGMA user_version = {user_version}')
50         return cls(path)
51
52     def _check(self) -> None:
53         """Check file exists, and is of proper DB version and schema."""
54         if not isfile(self.path):
55             raise NotFoundException
56         if self.user_version != EXPECTED_DB_VERSION:
57             raise UnmigratedDbException()
58         self._validate_schema()
59
60     @staticmethod
61     def _available_migrations() -> list[str]:
62         """Validate migrations directory and return sorted entries."""
63         msg_too_big = 'Migration directory points beyond expected DB version.'
64         msg_bad_entry = 'Migration directory contains unexpected entry: '
65         msg_missing = 'Migration directory misses migration of number: '
66         migrations = {}
67         for entry in listdir(MIGRATIONS_DIR):
68             if entry == FILENAME_DB_SCHEMA:
69                 continue
70             toks = entry.split('_', 1)
71             if len(toks) < 2:
72                 raise HandledException(msg_bad_entry + entry)
73             try:
74                 i = int(toks[0])
75             except ValueError as e:
76                 raise HandledException(msg_bad_entry + entry) from e
77             if i > EXPECTED_DB_VERSION:
78                 raise HandledException(msg_too_big)
79             migrations[i] = toks[1]
80         migrations_list = []
81         for i in range(EXPECTED_DB_VERSION + 1):
82             if i not in migrations:
83                 raise HandledException(msg_missing + str(i))
84             migrations_list += [f'{i}_{migrations[i]}']
85         return migrations_list
86
87     @staticmethod
88     def get_version_of_db(path: str) -> int:
89         """Get DB user_version, fail if outside expected range."""
90         sql_for_db_version = 'PRAGMA user_version'
91         with sql_connect(path) as conn:
92             db_version = list(conn.execute(sql_for_db_version))[0][0]
93         if db_version > EXPECTED_DB_VERSION:
94             msg = f'Wrong DB version, expected '\
95                     f'{EXPECTED_DB_VERSION}, got unknown {db_version}.'
96             raise HandledException(msg)
97         assert isinstance(db_version, int)
98         return db_version
99
100     @property
101     def user_version(self) -> int:
102         """Get DB user_version."""
103         return self.__class__.get_version_of_db(self.path)
104
105     def _validate_schema(self) -> None:
106         """Compare found schema with what's stored at PATH_DB_SCHEMA."""
107
108         def reformat_rows(rows: list[str]) -> list[str]:
109             new_rows = []
110             for row in rows:
111                 new_row = []
112                 for subrow in row.split('\n'):
113                     subrow = subrow.rstrip()
114                     in_parentheses = 0
115                     split_at = []
116                     for i, c in enumerate(subrow):
117                         if '(' == c:
118                             in_parentheses += 1
119                         elif ')' == c:
120                             in_parentheses -= 1
121                         elif ',' == c and 0 == in_parentheses:
122                             split_at += [i + 1]
123                     prev_split = 0
124                     for i in split_at:
125                         segment = subrow[prev_split:i].strip()
126                         if len(segment) > 0:
127                             new_row += [f'    {segment}']
128                         prev_split = i
129                     segment = subrow[prev_split:].strip()
130                     if len(segment) > 0:
131                         new_row += [f'    {segment}']
132                 new_row[0] = new_row[0].lstrip()
133                 new_row[-1] = new_row[-1].lstrip()
134                 if new_row[-1] != ')' and new_row[-3][-1] != ',':
135                     new_row[-3] = new_row[-3] + ','
136                     new_row[-2:] = ['    ' + new_row[-1][:-1]] + [')']
137                 new_rows += ['\n'.join(new_row)]
138             return new_rows
139
140         sql_for_schema = 'SELECT sql FROM sqlite_master ORDER BY sql'
141         msg_err = 'Database has wrong tables schema. Diff:\n'
142         with sql_connect(self.path) as conn:
143             schema_rows = [r[0] for r in conn.execute(sql_for_schema) if r[0]]
144         schema_rows = reformat_rows(schema_rows)
145         retrieved_schema = ';\n'.join(schema_rows) + ';'
146         with open(PATH_DB_SCHEMA, 'r', encoding='utf-8') as f:
147             stored_schema = f.read().rstrip()
148         if stored_schema != retrieved_schema:
149             diff_msg = Differ().compare(retrieved_schema.splitlines(),
150                                         stored_schema.splitlines())
151             raise HandledException(msg_err + '\n'.join(diff_msg))
152
153
154 class DatabaseConnection:
155     """A single connection to the database."""
156
157     def __init__(self, db_file: DatabaseFile) -> None:
158         self.file = db_file
159         self.conn = sql_connect(self.file.path)
160
161     def commit(self) -> None:
162         """Commit SQL transaction."""
163         self.conn.commit()
164
165     def exec(self, code: str, inputs: tuple[Any, ...] = tuple()) -> Cursor:
166         """Add commands to SQL transaction."""
167         return self.conn.execute(code, inputs)
168
169     def close(self) -> None:
170         """Close DB connection."""
171         self.conn.close()
172
173     def rewrite_relations(self, table_name: str, key: str, target: int | str,
174                           rows: list[list[Any]]) -> None:
175         """Rewrite relations in table_name to target, with rows values."""
176         self.delete_where(table_name, key, target)
177         for row in rows:
178             values = tuple([target] + row)
179             q_marks = self.__class__.q_marks_from_values(values)
180             self.exec(f'INSERT INTO {table_name} VALUES {q_marks}', values)
181
182     def row_where(self, table_name: str, key: str,
183                   target: int | str) -> list[Row]:
184         """Return list of Rows at table where key == target."""
185         return list(self.exec(f'SELECT * FROM {table_name} WHERE {key} = ?',
186                               (target,)))
187
188     def column_where(self, table_name: str, column: str, key: str,
189                      target: int | str) -> list[Any]:
190         """Return column of table where key == target."""
191         return [row[0] for row in
192                 self.exec(f'SELECT {column} FROM {table_name} '
193                           f'WHERE {key} = ?', (target,))]
194
195     def column_all(self, table_name: str, column: str) -> list[Any]:
196         """Return complete column of table."""
197         return [row[0] for row in
198                 self.exec(f'SELECT {column} FROM {table_name}')]
199
200     def delete_where(self, table_name: str, key: str,
201                      target: int | str) -> None:
202         """Delete from table where key == target."""
203         self.exec(f'DELETE FROM {table_name} WHERE {key} = ?', (target,))
204
205     @staticmethod
206     def q_marks_from_values(values: tuple[Any]) -> str:
207         """Return placeholder to insert values into SQL code."""
208         return '(' + ','.join(['?'] * len(values)) + ')'
209
210
211 BaseModelId = TypeVar('BaseModelId', int, str)
212 BaseModelInstance = TypeVar('BaseModelInstance', bound='BaseModel[Any]')
213
214
215 class BaseModel(Generic[BaseModelId]):
216     """Template for most of the models we use/derive from the DB."""
217     table_name = ''
218     to_save: list[str] = []
219     to_save_versioned: list[str] = []
220     to_save_relations: list[tuple[str, str, str]] = []
221     id_: None | BaseModelId
222     cache_: dict[BaseModelId, Self]
223
224     def __init__(self, id_: BaseModelId | None) -> None:
225         if isinstance(id_, int) and id_ < 1:
226             msg = f'illegal {self.__class__.__name__} ID, must be >=1: {id_}'
227             raise HandledException(msg)
228         self.id_ = id_
229
230     def __eq__(self, other: object) -> bool:
231         if not isinstance(other, self.__class__):
232             return False
233         to_hash_me = tuple([self.id_] +
234                            [getattr(self, name) for name in self.to_save])
235         to_hash_other = tuple([other.id_] +
236                               [getattr(other, name) for name in other.to_save])
237         return hash(to_hash_me) == hash(to_hash_other)
238
239     def __lt__(self, other: Any) -> bool:
240         if not isinstance(other, self.__class__):
241             msg = 'cannot compare to object of different class'
242             raise HandledException(msg)
243         assert isinstance(self.id_, int)
244         assert isinstance(other.id_, int)
245         return self.id_ < other.id_
246
247     @classmethod
248     def get_cached(cls: type[BaseModelInstance],
249                    id_: BaseModelId) -> BaseModelInstance | None:
250         """Get object of id_ from class's cache, or None if not found."""
251         # pylint: disable=consider-iterating-dictionary
252         cache = cls.get_cache()
253         if id_ in cache.keys():
254             obj = cache[id_]
255             assert isinstance(obj, cls)
256             return obj
257         return None
258
259     @classmethod
260     def empty_cache(cls) -> None:
261         """Empty class's cache."""
262         cls.cache_ = {}
263
264     @classmethod
265     def get_cache(cls: type[BaseModelInstance]) -> dict[Any, BaseModel[Any]]:
266         """Get cache dictionary, create it if not yet existing."""
267         if not hasattr(cls, 'cache_'):
268             d: dict[Any, BaseModel[Any]] = {}
269             cls.cache_ = d
270         return cls.cache_
271
272     def cache(self) -> None:
273         """Update object in class's cache."""
274         if self.id_ is None:
275             raise HandledException('Cannot cache object without ID.')
276         cache = self.__class__.get_cache()
277         cache[self.id_] = self
278
279     def uncache(self) -> None:
280         """Remove self from cache."""
281         if self.id_ is None:
282             raise HandledException('Cannot un-cache object without ID.')
283         cache = self.__class__.get_cache()
284         del cache[self.id_]
285
286     @classmethod
287     def from_table_row(cls: type[BaseModelInstance],
288                        # pylint: disable=unused-argument
289                        db_conn: DatabaseConnection,
290                        row: Row | list[Any]) -> BaseModelInstance:
291         """Make from DB row, write to DB cache."""
292         obj = cls(*row)
293         obj.cache()
294         return obj
295
296     @classmethod
297     def by_id(cls, db_conn: DatabaseConnection,
298               id_: BaseModelId | None,
299               # pylint: disable=unused-argument
300               create: bool = False) -> Self:
301         """Retrieve by id_, on failure throw NotFoundException.
302
303         First try to get from cls.cache_, only then check DB; if found,
304         put into cache.
305
306         If create=True, make anew (but do not cache yet).
307         """
308         obj = None
309         if id_ is not None:
310             obj = cls.get_cached(id_)
311             if not obj:
312                 for row in db_conn.row_where(cls.table_name, 'id', id_):
313                     obj = cls.from_table_row(db_conn, row)
314                     obj.cache()
315                     break
316         if obj:
317             return obj
318         if create:
319             obj = cls(id_)
320             return obj
321         raise NotFoundException(f'found no object of ID {id_}')
322
323     @classmethod
324     def all(cls: type[BaseModelInstance],
325             db_conn: DatabaseConnection) -> list[BaseModelInstance]:
326         """Collect all objects of class into list.
327
328         Note that this primarily returns the contents of the cache, and only
329         _expands_ that by additional findings in the DB. This assumes the
330         cache is always instantly cleaned of any items that would be removed
331         from the DB.
332         """
333         items: dict[BaseModelId, BaseModelInstance] = {}
334         for k, v in cls.get_cache().items():
335             assert isinstance(v, cls)
336             items[k] = v
337         already_recorded = items.keys()
338         for id_ in db_conn.column_all(cls.table_name, 'id'):
339             if id_ not in already_recorded:
340                 item = cls.by_id(db_conn, id_)
341                 assert item.id_ is not None
342                 items[item.id_] = item
343         return list(items.values())
344
345     def save(self, db_conn: DatabaseConnection) -> None:
346         """Write self to DB and cache and ensure .id_.
347
348         Write both to DB, and to cache. To DB, write .id_ and attributes
349         listed in cls.to_save[_versioned|_relations].
350
351         Ensure self.id_ by setting it to what the DB command returns as the
352         last saved row's ID (cursor.lastrowid), EXCEPT if self.id_ already
353         exists as a 'str', which implies we do our own ID creation (so far
354         only the case with the Day class, where it's to be a date string.
355         """
356         values = tuple([self.id_] + [getattr(self, key)
357                                      for key in self.to_save])
358         q_marks = DatabaseConnection.q_marks_from_values(values)
359         table_name = self.table_name
360         cursor = db_conn.exec(f'REPLACE INTO {table_name} VALUES {q_marks}',
361                               values)
362         if not isinstance(self.id_, str):
363             self.id_ = cursor.lastrowid  # type: ignore[assignment]
364         self.cache()
365         for attr_name in self.to_save_versioned:
366             getattr(self, attr_name).save(db_conn)
367         for table, column, attr_name in self.to_save_relations:
368             assert isinstance(self.id_, (int, str))
369             db_conn.rewrite_relations(table, column, self.id_,
370                                       [[i.id_] for i
371                                        in getattr(self, attr_name)])
372
373     def remove(self, db_conn: DatabaseConnection) -> None:
374         """Remove from DB and cache, including dependencies."""
375         if self.id_ is None or self.__class__.get_cached(self.id_) is None:
376             raise HandledException('cannot remove unsaved item')
377         for attr_name in self.to_save_versioned:
378             getattr(self, attr_name).remove(db_conn)
379         for table, column, attr_name in self.to_save_relations:
380             db_conn.delete_where(table, column, self.id_)
381         self.uncache()
382         db_conn.delete_where(self.table_name, 'id', self.id_)