home · contact · privacy
Improve template layouts.
[plomtask] / plomtask / db.py
1 """Database management."""
2 from __future__ import annotations
3 from os import listdir
4 from os.path import isfile
5 from difflib import Differ
6 from sqlite3 import connect as sql_connect, Cursor, Row
7 from typing import Any, Self, TypeVar, Generic
8 from plomtask.exceptions import HandledException, NotFoundException
9
10 EXPECTED_DB_VERSION = 1
11 MIGRATIONS_DIR = 'migrations'
12 FILENAME_DB_SCHEMA = f'init_{EXPECTED_DB_VERSION}.sql'
13 PATH_DB_SCHEMA = f'{MIGRATIONS_DIR}/{FILENAME_DB_SCHEMA}'
14
15
16 class UnmigratedDbException(HandledException):
17     """To identify case of unmigrated DB file."""
18
19
20 class DatabaseFile:  # pylint: disable=too-few-public-methods
21     """Represents the sqlite3 database's file."""
22
23     def __init__(self, path: str) -> None:
24         self.path = path
25         self._check()
26
27     @classmethod
28     def create_at(cls, path: str) -> DatabaseFile:
29         """Make new DB file at path."""
30         with sql_connect(path) as conn:
31             with open(PATH_DB_SCHEMA, 'r', encoding='utf-8') as f:
32                 conn.executescript(f.read())
33             conn.execute(f'PRAGMA user_version = {EXPECTED_DB_VERSION}')
34         return cls(path)
35
36     @classmethod
37     def migrate(cls, path: str) -> DatabaseFile:
38         """Apply migrations from_version to EXPECTED_DB_VERSION."""
39         migrations = cls._available_migrations()
40         from_version = cls.get_version_of_db(path)
41         migrations_todo = migrations[from_version+1:]
42         for j, filename in enumerate(migrations_todo):
43             with sql_connect(path) as conn:
44                 with open(f'{MIGRATIONS_DIR}/{filename}', 'r',
45                           encoding='utf-8') as f:
46                     conn.executescript(f.read())
47             user_version = from_version + j + 1
48             with sql_connect(path) as conn:
49                 conn.execute(f'PRAGMA user_version = {user_version}')
50         return cls(path)
51
52     def _check(self) -> None:
53         """Check file exists, and is of proper DB version and schema."""
54         if not isfile(self.path):
55             raise NotFoundException
56         if self.user_version != EXPECTED_DB_VERSION:
57             raise UnmigratedDbException()
58         self._validate_schema()
59
60     @staticmethod
61     def _available_migrations() -> list[str]:
62         """Validate migrations directory and return sorted entries."""
63         msg_too_big = 'Migration directory points beyond expected DB version.'
64         msg_bad_entry = 'Migration directory contains unexpected entry: '
65         msg_missing = 'Migration directory misses migration of number: '
66         migrations = {}
67         for entry in listdir(MIGRATIONS_DIR):
68             if entry == FILENAME_DB_SCHEMA:
69                 continue
70             toks = entry.split('_', 1)
71             if len(toks) < 2:
72                 raise HandledException(msg_bad_entry + entry)
73             try:
74                 i = int(toks[0])
75             except ValueError as e:
76                 raise HandledException(msg_bad_entry + entry) from e
77             if i > EXPECTED_DB_VERSION:
78                 raise HandledException(msg_too_big)
79             migrations[i] = toks[1]
80         migrations_list = []
81         for i in range(EXPECTED_DB_VERSION + 1):
82             if i not in migrations:
83                 raise HandledException(msg_missing + str(i))
84             migrations_list += [f'{i}_{migrations[i]}']
85         return migrations_list
86
87     @staticmethod
88     def get_version_of_db(path: str) -> int:
89         """Get DB user_version, fail if outside expected range."""
90         sql_for_db_version = 'PRAGMA user_version'
91         with sql_connect(path) as conn:
92             db_version = list(conn.execute(sql_for_db_version))[0][0]
93         if db_version > EXPECTED_DB_VERSION:
94             msg = f'Wrong DB version, expected '\
95                     f'{EXPECTED_DB_VERSION}, got unknown {db_version}.'
96             raise HandledException(msg)
97         assert isinstance(db_version, int)
98         return db_version
99
100     @property
101     def user_version(self) -> int:
102         """Get DB user_version."""
103         return self.__class__.get_version_of_db(self.path)
104
105     def _validate_schema(self) -> None:
106         """Compare found schema with what's stored at PATH_DB_SCHEMA."""
107
108         def reformat_rows(rows: list[str]) -> list[str]:
109             new_rows = []
110             for row in rows:
111                 new_row = []
112                 for subrow in row.split('\n'):
113                     subrow = subrow.rstrip()
114                     in_parentheses = 0
115                     split_at = []
116                     for i, c in enumerate(subrow):
117                         if '(' == c:
118                             in_parentheses += 1
119                         elif ')' == c:
120                             in_parentheses -= 1
121                         elif ',' == c and 0 == in_parentheses:
122                             split_at += [i + 1]
123                     prev_split = 0
124                     for i in split_at:
125                         segment = subrow[prev_split:i].strip()
126                         if len(segment) > 0:
127                             new_row += [f'    {segment}']
128                         prev_split = i
129                     segment = subrow[prev_split:].strip()
130                     if len(segment) > 0:
131                         new_row += [f'    {segment}']
132                 new_row[0] = new_row[0].lstrip()
133                 new_row[-1] = new_row[-1].lstrip()
134                 new_rows += ['\n'.join(new_row)]
135             return new_rows
136
137         sql_for_schema = 'SELECT sql FROM sqlite_master ORDER BY sql'
138         msg_err = 'Database has wrong tables schema. Diff:\n'
139         with sql_connect(self.path) as conn:
140             schema_rows = [r[0] for r in conn.execute(sql_for_schema) if r[0]]
141         schema_rows = reformat_rows(schema_rows)
142         retrieved_schema = ';\n'.join(schema_rows) + ';'
143         with open(PATH_DB_SCHEMA, 'r', encoding='utf-8') as f:
144             stored_schema = f.read().rstrip()
145         if stored_schema != retrieved_schema:
146             diff_msg = Differ().compare(retrieved_schema.splitlines(),
147                                         stored_schema.splitlines())
148             raise HandledException(msg_err + '\n'.join(diff_msg))
149
150
151 class DatabaseConnection:
152     """A single connection to the database."""
153
154     def __init__(self, db_file: DatabaseFile) -> None:
155         self.file = db_file
156         self.conn = sql_connect(self.file.path)
157
158     def commit(self) -> None:
159         """Commit SQL transaction."""
160         self.conn.commit()
161
162     def exec(self, code: str, inputs: tuple[Any, ...] = tuple()) -> Cursor:
163         """Add commands to SQL transaction."""
164         return self.conn.execute(code, inputs)
165
166     def close(self) -> None:
167         """Close DB connection."""
168         self.conn.close()
169
170     def rewrite_relations(self, table_name: str, key: str, target: int | str,
171                           rows: list[list[Any]]) -> None:
172         """Rewrite relations in table_name to target, with rows values."""
173         self.delete_where(table_name, key, target)
174         for row in rows:
175             values = tuple([target] + row)
176             q_marks = self.__class__.q_marks_from_values(values)
177             self.exec(f'INSERT INTO {table_name} VALUES {q_marks}', values)
178
179     def row_where(self, table_name: str, key: str,
180                   target: int | str) -> list[Row]:
181         """Return list of Rows at table where key == target."""
182         return list(self.exec(f'SELECT * FROM {table_name} WHERE {key} = ?',
183                               (target,)))
184
185     def column_where(self, table_name: str, column: str, key: str,
186                      target: int | str) -> list[Any]:
187         """Return column of table where key == target."""
188         return [row[0] for row in
189                 self.exec(f'SELECT {column} FROM {table_name} '
190                           f'WHERE {key} = ?', (target,))]
191
192     def column_all(self, table_name: str, column: str) -> list[Any]:
193         """Return complete column of table."""
194         return [row[0] for row in
195                 self.exec(f'SELECT {column} FROM {table_name}')]
196
197     def delete_where(self, table_name: str, key: str,
198                      target: int | str) -> None:
199         """Delete from table where key == target."""
200         self.exec(f'DELETE FROM {table_name} WHERE {key} = ?', (target,))
201
202     @staticmethod
203     def q_marks_from_values(values: tuple[Any]) -> str:
204         """Return placeholder to insert values into SQL code."""
205         return '(' + ','.join(['?'] * len(values)) + ')'
206
207
208 BaseModelId = TypeVar('BaseModelId', int, str)
209 BaseModelInstance = TypeVar('BaseModelInstance', bound='BaseModel[Any]')
210
211
212 class BaseModel(Generic[BaseModelId]):
213     """Template for most of the models we use/derive from the DB."""
214     table_name = ''
215     to_save: list[str] = []
216     to_save_versioned: list[str] = []
217     to_save_relations: list[tuple[str, str, str]] = []
218     id_: None | BaseModelId
219     cache_: dict[BaseModelId, Self]
220
221     def __init__(self, id_: BaseModelId | None) -> None:
222         if isinstance(id_, int) and id_ < 1:
223             msg = f'illegal {self.__class__.__name__} ID, must be >=1: {id_}'
224             raise HandledException(msg)
225         self.id_ = id_
226
227     def __eq__(self, other: object) -> bool:
228         if not isinstance(other, self.__class__):
229             return False
230         to_hash_me = tuple([self.id_] +
231                            [getattr(self, name) for name in self.to_save])
232         to_hash_other = tuple([other.id_] +
233                               [getattr(other, name) for name in other.to_save])
234         return hash(to_hash_me) == hash(to_hash_other)
235
236     def __lt__(self, other: Any) -> bool:
237         if not isinstance(other, self.__class__):
238             msg = 'cannot compare to object of different class'
239             raise HandledException(msg)
240         assert isinstance(self.id_, int)
241         assert isinstance(other.id_, int)
242         return self.id_ < other.id_
243
244     @classmethod
245     def get_cached(cls: type[BaseModelInstance],
246                    id_: BaseModelId) -> BaseModelInstance | None:
247         """Get object of id_ from class's cache, or None if not found."""
248         # pylint: disable=consider-iterating-dictionary
249         cache = cls.get_cache()
250         if id_ in cache.keys():
251             obj = cache[id_]
252             assert isinstance(obj, cls)
253             return obj
254         return None
255
256     @classmethod
257     def empty_cache(cls) -> None:
258         """Empty class's cache."""
259         cls.cache_ = {}
260
261     @classmethod
262     def get_cache(cls: type[BaseModelInstance]) -> dict[Any, BaseModel[Any]]:
263         """Get cache dictionary, create it if not yet existing."""
264         if not hasattr(cls, 'cache_'):
265             d: dict[Any, BaseModel[Any]] = {}
266             cls.cache_ = d
267         return cls.cache_
268
269     def cache(self) -> None:
270         """Update object in class's cache."""
271         if self.id_ is None:
272             raise HandledException('Cannot cache object without ID.')
273         cache = self.__class__.get_cache()
274         cache[self.id_] = self
275
276     def uncache(self) -> None:
277         """Remove self from cache."""
278         if self.id_ is None:
279             raise HandledException('Cannot un-cache object without ID.')
280         cache = self.__class__.get_cache()
281         del cache[self.id_]
282
283     @classmethod
284     def from_table_row(cls: type[BaseModelInstance],
285                        # pylint: disable=unused-argument
286                        db_conn: DatabaseConnection,
287                        row: Row | list[Any]) -> BaseModelInstance:
288         """Make from DB row, write to DB cache."""
289         obj = cls(*row)
290         obj.cache()
291         return obj
292
293     @classmethod
294     def by_id(cls, db_conn: DatabaseConnection,
295               id_: BaseModelId | None,
296               # pylint: disable=unused-argument
297               create: bool = False) -> Self:
298         """Retrieve by id_, on failure throw NotFoundException.
299
300         First try to get from cls.cache_, only then check DB; if found,
301         put into cache.
302
303         If create=True, make anew (but do not cache yet).
304         """
305         obj = None
306         if id_ is not None:
307             obj = cls.get_cached(id_)
308             if not obj:
309                 for row in db_conn.row_where(cls.table_name, 'id', id_):
310                     obj = cls.from_table_row(db_conn, row)
311                     obj.cache()
312                     break
313         if obj:
314             return obj
315         if create:
316             obj = cls(id_)
317             return obj
318         raise NotFoundException(f'found no object of ID {id_}')
319
320     @classmethod
321     def all(cls: type[BaseModelInstance],
322             db_conn: DatabaseConnection) -> list[BaseModelInstance]:
323         """Collect all objects of class into list.
324
325         Note that this primarily returns the contents of the cache, and only
326         _expands_ that by additional findings in the DB. This assumes the
327         cache is always instantly cleaned of any items that would be removed
328         from the DB.
329         """
330         items: dict[BaseModelId, BaseModelInstance] = {}
331         for k, v in cls.get_cache().items():
332             assert isinstance(v, cls)
333             items[k] = v
334         already_recorded = items.keys()
335         for id_ in db_conn.column_all(cls.table_name, 'id'):
336             if id_ not in already_recorded:
337                 item = cls.by_id(db_conn, id_)
338                 assert item.id_ is not None
339                 items[item.id_] = item
340         return list(items.values())
341
342     def save(self, db_conn: DatabaseConnection) -> None:
343         """Write self to DB and cache and ensure .id_.
344
345         Write both to DB, and to cache. To DB, write .id_ and attributes
346         listed in cls.to_save[_versioned|_relations].
347
348         Ensure self.id_ by setting it to what the DB command returns as the
349         last saved row's ID (cursor.lastrowid), EXCEPT if self.id_ already
350         exists as a 'str', which implies we do our own ID creation (so far
351         only the case with the Day class, where it's to be a date string.
352         """
353         values = tuple([self.id_] + [getattr(self, key)
354                                      for key in self.to_save])
355         q_marks = DatabaseConnection.q_marks_from_values(values)
356         table_name = self.table_name
357         cursor = db_conn.exec(f'REPLACE INTO {table_name} VALUES {q_marks}',
358                               values)
359         if not isinstance(self.id_, str):
360             self.id_ = cursor.lastrowid  # type: ignore[assignment]
361         self.cache()
362         for attr_name in self.to_save_versioned:
363             getattr(self, attr_name).save(db_conn)
364         for table, column, attr_name in self.to_save_relations:
365             assert isinstance(self.id_, (int, str))
366             db_conn.rewrite_relations(table, column, self.id_,
367                                       [[i.id_] for i
368                                        in getattr(self, attr_name)])
369
370     def remove(self, db_conn: DatabaseConnection) -> None:
371         """Remove from DB and cache, including dependencies."""
372         if self.id_ is None or self.__class__.get_cached(self.id_) is None:
373             raise HandledException('cannot remove unsaved item')
374         for attr_name in self.to_save_versioned:
375             getattr(self, attr_name).remove(db_conn)
376         for table, column, attr_name in self.to_save_relations:
377             db_conn.delete_where(table, column, self.id_)
378         self.uncache()
379         db_conn.delete_where(self.table_name, 'id', self.id_)