home · contact · privacy
054060e13b7c1fa4aa27274efab13ff1c2828815
[plomtask] / plomtask / db.py
1 """Database management."""
2 from __future__ import annotations
3 from os import listdir
4 from os.path import isfile
5 from difflib import Differ
6 from sqlite3 import connect as sql_connect, Cursor, Row
7 from typing import Any, Self, TypeVar, Generic
8 from plomtask.exceptions import HandledException, NotFoundException
9 from plomtask.dating import valid_date
10
11 EXPECTED_DB_VERSION = 5
12 MIGRATIONS_DIR = 'migrations'
13 FILENAME_DB_SCHEMA = f'init_{EXPECTED_DB_VERSION}.sql'
14 PATH_DB_SCHEMA = f'{MIGRATIONS_DIR}/{FILENAME_DB_SCHEMA}'
15
16
17 class UnmigratedDbException(HandledException):
18     """To identify case of unmigrated DB file."""
19
20
21 class DatabaseFile:
22     """Represents the sqlite3 database's file."""
23     # pylint: disable=too-few-public-methods
24
25     def __init__(self, path: str) -> None:
26         self.path = path
27         self._check()
28
29     @classmethod
30     def create_at(cls, path: str) -> DatabaseFile:
31         """Make new DB file at path."""
32         with sql_connect(path) as conn:
33             with open(PATH_DB_SCHEMA, 'r', encoding='utf-8') as f:
34                 conn.executescript(f.read())
35             conn.execute(f'PRAGMA user_version = {EXPECTED_DB_VERSION}')
36         return cls(path)
37
38     @classmethod
39     def migrate(cls, path: str) -> DatabaseFile:
40         """Apply migrations from_version to EXPECTED_DB_VERSION."""
41         migrations = cls._available_migrations()
42         from_version = cls._get_version_of_db(path)
43         migrations_todo = migrations[from_version+1:]
44         for j, filename in enumerate(migrations_todo):
45             with sql_connect(path) as conn:
46                 with open(f'{MIGRATIONS_DIR}/{filename}', 'r',
47                           encoding='utf-8') as f:
48                     conn.executescript(f.read())
49             user_version = from_version + j + 1
50             with sql_connect(path) as conn:
51                 conn.execute(f'PRAGMA user_version = {user_version}')
52         return cls(path)
53
54     def _check(self) -> None:
55         """Check file exists, and is of proper DB version and schema."""
56         if not isfile(self.path):
57             raise NotFoundException
58         if self._user_version != EXPECTED_DB_VERSION:
59             raise UnmigratedDbException()
60         self._validate_schema()
61
62     @staticmethod
63     def _available_migrations() -> list[str]:
64         """Validate migrations directory and return sorted entries."""
65         msg_too_big = 'Migration directory points beyond expected DB version.'
66         msg_bad_entry = 'Migration directory contains unexpected entry: '
67         msg_missing = 'Migration directory misses migration of number: '
68         migrations = {}
69         for entry in listdir(MIGRATIONS_DIR):
70             if entry == FILENAME_DB_SCHEMA:
71                 continue
72             toks = entry.split('_', 1)
73             if len(toks) < 2:
74                 raise HandledException(msg_bad_entry + entry)
75             try:
76                 i = int(toks[0])
77             except ValueError as e:
78                 raise HandledException(msg_bad_entry + entry) from e
79             if i > EXPECTED_DB_VERSION:
80                 raise HandledException(msg_too_big)
81             migrations[i] = toks[1]
82         migrations_list = []
83         for i in range(EXPECTED_DB_VERSION + 1):
84             if i not in migrations:
85                 raise HandledException(msg_missing + str(i))
86             migrations_list += [f'{i}_{migrations[i]}']
87         return migrations_list
88
89     @staticmethod
90     def _get_version_of_db(path: str) -> int:
91         """Get DB user_version, fail if outside expected range."""
92         sql_for_db_version = 'PRAGMA user_version'
93         with sql_connect(path) as conn:
94             db_version = list(conn.execute(sql_for_db_version))[0][0]
95         if db_version > EXPECTED_DB_VERSION:
96             msg = f'Wrong DB version, expected '\
97                     f'{EXPECTED_DB_VERSION}, got unknown {db_version}.'
98             raise HandledException(msg)
99         assert isinstance(db_version, int)
100         return db_version
101
102     @property
103     def _user_version(self) -> int:
104         """Get DB user_version."""
105         return self._get_version_of_db(self.path)
106
107     def _validate_schema(self) -> None:
108         """Compare found schema with what's stored at PATH_DB_SCHEMA."""
109
110         def reformat_rows(rows: list[str]) -> list[str]:
111             new_rows = []
112             for row in rows:
113                 new_row = []
114                 for subrow in row.split('\n'):
115                     subrow = subrow.rstrip()
116                     in_parentheses = 0
117                     split_at = []
118                     for i, c in enumerate(subrow):
119                         if '(' == c:
120                             in_parentheses += 1
121                         elif ')' == c:
122                             in_parentheses -= 1
123                         elif ',' == c and 0 == in_parentheses:
124                             split_at += [i + 1]
125                     prev_split = 0
126                     for i in split_at:
127                         segment = subrow[prev_split:i].strip()
128                         if len(segment) > 0:
129                             new_row += [f'    {segment}']
130                         prev_split = i
131                     segment = subrow[prev_split:].strip()
132                     if len(segment) > 0:
133                         new_row += [f'    {segment}']
134                 new_row[0] = new_row[0].lstrip()
135                 new_row[-1] = new_row[-1].lstrip()
136                 if new_row[-1] != ')' and new_row[-3][-1] != ',':
137                     new_row[-3] = new_row[-3] + ','
138                     new_row[-2:] = ['    ' + new_row[-1][:-1]] + [')']
139                 new_rows += ['\n'.join(new_row)]
140             return new_rows
141
142         sql_for_schema = 'SELECT sql FROM sqlite_master ORDER BY sql'
143         msg_err = 'Database has wrong tables schema. Diff:\n'
144         with sql_connect(self.path) as conn:
145             schema_rows = [r[0] for r in conn.execute(sql_for_schema) if r[0]]
146         schema_rows = reformat_rows(schema_rows)
147         retrieved_schema = ';\n'.join(schema_rows) + ';'
148         with open(PATH_DB_SCHEMA, 'r', encoding='utf-8') as f:
149             stored_schema = f.read().rstrip()
150         if stored_schema != retrieved_schema:
151             diff_msg = Differ().compare(retrieved_schema.splitlines(),
152                                         stored_schema.splitlines())
153             raise HandledException(msg_err + '\n'.join(diff_msg))
154
155
156 class DatabaseConnection:
157     """A single connection to the database."""
158
159     def __init__(self, db_file: DatabaseFile) -> None:
160         self.conn = sql_connect(db_file.path)
161
162     def commit(self) -> None:
163         """Commit SQL transaction."""
164         self.conn.commit()
165
166     def exec(self, code: str, inputs: tuple[Any, ...] = tuple()) -> Cursor:
167         """Add commands to SQL transaction."""
168         return self.conn.execute(code, inputs)
169
170     def exec_on_vals(self, code: str, inputs: tuple[Any, ...]) -> Cursor:
171         """Wrapper around .exec appending adequate " (?, …)" to code."""
172         q_marks_from_values = '(' + ','.join(['?'] * len(inputs)) + ')'
173         return self.exec(f'{code} {q_marks_from_values}', inputs)
174
175     def close(self) -> None:
176         """Close DB connection."""
177         self.conn.close()
178
179     def rewrite_relations(self, table_name: str, key: str, target: int | str,
180                           rows: list[list[Any]], key_index: int = 0) -> None:
181         # pylint: disable=too-many-arguments
182         """Rewrite relations in table_name to target, with rows values.
183
184         Note that single rows are expected without the column and value
185         identified by key and target, which are inserted inside the function
186         at key_index.
187         """
188         self.delete_where(table_name, key, target)
189         for row in rows:
190             values = tuple(row[:key_index] + [target] + row[key_index:])
191             self.exec_on_vals(f'INSERT INTO {table_name} VALUES', values)
192
193     def row_where(self, table_name: str, key: str,
194                   target: int | str) -> list[Row]:
195         """Return list of Rows at table where key == target."""
196         return list(self.exec(f'SELECT * FROM {table_name} WHERE {key} = ?',
197                               (target,)))
198
199     # def column_where_pattern(self,
200     #                          table_name: str,
201     #                          column: str,
202     #                          pattern: str,
203     #                          keys: list[str]) -> list[Any]:
204     #     """Return column of rows where one of keys matches pattern."""
205     #     targets = tuple([f'%{pattern}%'] * len(keys))
206     #     haystack = ' OR '.join([f'{k} LIKE ?' for k in keys])
207     #     sql = f'SELECT {column} FROM {table_name} WHERE {haystack}'
208     #     return [row[0] for row in self.exec(sql, targets)]
209
210     def column_where(self, table_name: str, column: str, key: str,
211                      target: int | str) -> list[Any]:
212         """Return column of table where key == target."""
213         return [row[0] for row in
214                 self.exec(f'SELECT {column} FROM {table_name} '
215                           f'WHERE {key} = ?', (target,))]
216
217     def column_all(self, table_name: str, column: str) -> list[Any]:
218         """Return complete column of table."""
219         return [row[0] for row in
220                 self.exec(f'SELECT {column} FROM {table_name}')]
221
222     def delete_where(self, table_name: str, key: str,
223                      target: int | str) -> None:
224         """Delete from table where key == target."""
225         self.exec(f'DELETE FROM {table_name} WHERE {key} = ?', (target,))
226
227
228 BaseModelId = TypeVar('BaseModelId', int, str)
229 BaseModelInstance = TypeVar('BaseModelInstance', bound='BaseModel[Any]')
230
231
232 class BaseModel(Generic[BaseModelId]):
233     """Template for most of the models we use/derive from the DB."""
234     table_name = ''
235     to_save: list[str] = []
236     to_save_versioned: list[str] = []
237     to_save_relations: list[tuple[str, str, str, int]] = []
238     id_: None | BaseModelId
239     cache_: dict[BaseModelId, Self]
240     to_search: list[str] = []
241     _exists = True
242
243     def __init__(self, id_: BaseModelId | None) -> None:
244         if isinstance(id_, int) and id_ < 1:
245             msg = f'illegal {self.__class__.__name__} ID, must be >=1: {id_}'
246             raise HandledException(msg)
247         if isinstance(id_, str) and "" == id_:
248             msg = f'illegal {self.__class__.__name__} ID, must be non-empty'
249             raise HandledException(msg)
250         self.id_ = id_
251
252     def __hash__(self) -> int:
253         hashable = [self.id_] + [getattr(self, name) for name in self.to_save]
254         for definition in self.to_save_relations:
255             attr = getattr(self, definition[2])
256             hashable += [tuple(rel.id_ for rel in attr)]
257         for name in self.to_save_versioned:
258             hashable += [hash(getattr(self, name))]
259         return hash(tuple(hashable))
260
261     def __eq__(self, other: object) -> bool:
262         if not isinstance(other, self.__class__):
263             return False
264         return hash(self) == hash(other)
265
266     def __lt__(self, other: Any) -> bool:
267         if not isinstance(other, self.__class__):
268             msg = 'cannot compare to object of different class'
269             raise HandledException(msg)
270         assert isinstance(self.id_, int)
271         assert isinstance(other.id_, int)
272         return self.id_ < other.id_
273
274     @property
275     def as_dict(self) -> dict[str, object]:
276         """Return self as (json.dumps-coompatible) dict."""
277         d: dict[str, object] = {'id': self.id_}
278         if len(self.to_save_versioned) > 0:
279             d['_versioned'] = {}
280         for k in self.to_save:
281             attr = getattr(self, k)
282             if hasattr(attr, 'as_dict'):
283                 d[k] = attr.as_dict
284             d[k] = attr
285         for k in self.to_save_versioned:
286             attr = getattr(self, k)
287             assert isinstance(d['_versioned'], dict)
288             d['_versioned'][k] = attr.history
289         for r in self.to_save_relations:
290             attr_name = r[2]
291             d[attr_name] = [x.as_dict for x in getattr(self, attr_name)]
292         return d
293
294     # cache management
295     # (we primarily use the cache to ensure we work on the same object in
296     # memory no matter where and how we retrieve it, e.g. we don't want
297     # .by_id() calls to create a new object each time, but rather a pointer
298     # to the one already instantiated)
299
300     def __getattribute__(self, name: str) -> Any:
301         """Ensure fail if ._disappear() was called, except to check ._exists"""
302         if name != '_exists' and not super().__getattribute__('_exists'):
303             raise HandledException('Object does not exist.')
304         return super().__getattribute__(name)
305
306     def _disappear(self) -> None:
307         """Invalidate object, make future use raise exceptions."""
308         assert self.id_ is not None
309         if self._get_cached(self.id_):
310             self._uncache()
311         to_kill = list(self.__dict__.keys())
312         for attr in to_kill:
313             delattr(self, attr)
314         self._exists = False
315
316     @classmethod
317     def empty_cache(cls) -> None:
318         """Empty class's cache."""
319         cls.cache_ = {}
320
321     @classmethod
322     def get_cache(cls: type[BaseModelInstance]) -> dict[Any, BaseModel[Any]]:
323         """Get cache dictionary, create it if not yet existing."""
324         if not hasattr(cls, 'cache_'):
325             d: dict[Any, BaseModel[Any]] = {}
326             cls.cache_ = d
327         return cls.cache_
328
329     @classmethod
330     def _get_cached(cls: type[BaseModelInstance],
331                     id_: BaseModelId) -> BaseModelInstance | None:
332         """Get object of id_ from class's cache, or None if not found."""
333         # pylint: disable=consider-iterating-dictionary
334         cache = cls.get_cache()
335         if id_ in cache.keys():
336             obj = cache[id_]
337             assert isinstance(obj, cls)
338             return obj
339         return None
340
341     def _cache(self) -> None:
342         """Update object in class's cache.
343
344         Also calls ._disappear if cache holds older reference to object of same
345         ID, but different memory address, to avoid doing anything with
346         dangling leftovers.
347         """
348         if self.id_ is None:
349             raise HandledException('Cannot cache object without ID.')
350         cache = self.get_cache()
351         old_cached = self._get_cached(self.id_)
352         if old_cached and id(old_cached) != id(self):
353             # pylint: disable=protected-access
354             # (cause we remain within the class)
355             old_cached._disappear()
356         cache[self.id_] = self
357
358     def _uncache(self) -> None:
359         """Remove self from cache."""
360         if self.id_ is None:
361             raise HandledException('Cannot un-cache object without ID.')
362         cache = self.get_cache()
363         del cache[self.id_]
364
365     # object retrieval and generation
366
367     @classmethod
368     def from_table_row(cls: type[BaseModelInstance],
369                        # pylint: disable=unused-argument
370                        db_conn: DatabaseConnection,
371                        row: Row | list[Any]) -> BaseModelInstance:
372         """Make from DB row, update DB cache with it."""
373         obj = cls(*row)
374         obj._cache()
375         return obj
376
377     @classmethod
378     def by_id(cls, db_conn: DatabaseConnection,
379               id_: BaseModelId | None,
380               # pylint: disable=unused-argument
381               create: bool = False) -> Self:
382         """Retrieve by id_, on failure throw NotFoundException.
383
384         First try to get from cls.cache_, only then check DB; if found,
385         put into cache.
386
387         If create=True, make anew (but do not cache yet).
388         """
389         obj = None
390         if id_ is not None:
391             obj = cls._get_cached(id_)
392             if not obj:
393                 for row in db_conn.row_where(cls.table_name, 'id', id_):
394                     obj = cls.from_table_row(db_conn, row)
395                     break
396         if obj:
397             return obj
398         if create:
399             obj = cls(id_)
400             return obj
401         raise NotFoundException(f'found no object of ID {id_}')
402
403     @classmethod
404     def all(cls: type[BaseModelInstance],
405             db_conn: DatabaseConnection) -> list[BaseModelInstance]:
406         """Collect all objects of class into list.
407
408         Note that this primarily returns the contents of the cache, and only
409         _expands_ that by additional findings in the DB. This assumes the
410         cache is always instantly cleaned of any items that would be removed
411         from the DB.
412         """
413         items: dict[BaseModelId, BaseModelInstance] = {}
414         for k, v in cls.get_cache().items():
415             assert isinstance(v, cls)
416             items[k] = v
417         already_recorded = items.keys()
418         for id_ in db_conn.column_all(cls.table_name, 'id'):
419             if id_ not in already_recorded:
420                 item = cls.by_id(db_conn, id_)
421                 assert item.id_ is not None
422                 items[item.id_] = item
423         return list(items.values())
424
425     @classmethod
426     def by_date_range_with_limits(cls: type[BaseModelInstance],
427                                   db_conn: DatabaseConnection,
428                                   date_range: tuple[str, str],
429                                   date_col: str = 'day'
430                                   ) -> tuple[list[BaseModelInstance], str,
431                                              str]:
432         """Return list of items in database within (open) date_range interval.
433
434         If no range values provided, defaults them to 'yesterday' and
435         'tomorrow'. Knows to properly interpret these and 'today' as value.
436         """
437         start_str = date_range[0] if date_range[0] else 'yesterday'
438         end_str = date_range[1] if date_range[1] else 'tomorrow'
439         start_date = valid_date(start_str)
440         end_date = valid_date(end_str)
441         items = []
442         sql = f'SELECT id FROM {cls.table_name} '
443         sql += f'WHERE {date_col} >= ? AND {date_col} <= ?'
444         for row in db_conn.exec(sql, (start_date, end_date)):
445             items += [cls.by_id(db_conn, row[0])]
446         return items, start_date, end_date
447
448     @classmethod
449     def matching(cls: type[BaseModelInstance], db_conn: DatabaseConnection,
450                  pattern: str) -> list[BaseModelInstance]:
451         """Return all objects whose .to_search match pattern."""
452         items = cls.all(db_conn)
453         if pattern:
454             filtered = []
455             for item in items:
456                 for attr_name in cls.to_search:
457                     toks = attr_name.split('.')
458                     parent = item
459                     for tok in toks:
460                         attr = getattr(parent, tok)
461                         parent = attr
462                     if pattern in attr:
463                         filtered += [item]
464                         break
465             return filtered
466         return items
467
468     # database writing
469
470     def save(self, db_conn: DatabaseConnection) -> None:
471         """Write self to DB and cache and ensure .id_.
472
473         Write both to DB, and to cache. To DB, write .id_ and attributes
474         listed in cls.to_save[_versioned|_relations].
475
476         Ensure self.id_ by setting it to what the DB command returns as the
477         last saved row's ID (cursor.lastrowid), EXCEPT if self.id_ already
478         exists as a 'str', which implies we do our own ID creation (so far
479         only the case with the Day class, where it's to be a date string.
480         """
481         values = tuple([self.id_] + [getattr(self, key)
482                                      for key in self.to_save])
483         table_name = self.table_name
484         cursor = db_conn.exec_on_vals(f'REPLACE INTO {table_name} VALUES',
485                                       values)
486         if not isinstance(self.id_, str):
487             self.id_ = cursor.lastrowid  # type: ignore[assignment]
488         self._cache()
489         for attr_name in self.to_save_versioned:
490             getattr(self, attr_name).save(db_conn)
491         for table, column, attr_name, key_index in self.to_save_relations:
492             assert isinstance(self.id_, (int, str))
493             db_conn.rewrite_relations(table, column, self.id_,
494                                       [[i.id_] for i
495                                        in getattr(self, attr_name)], key_index)
496
497     def remove(self, db_conn: DatabaseConnection) -> None:
498         """Remove from DB and cache, including dependencies."""
499         if self.id_ is None or self._get_cached(self.id_) is None:
500             raise HandledException('cannot remove unsaved item')
501         for attr_name in self.to_save_versioned:
502             getattr(self, attr_name).remove(db_conn)
503         for table, column, attr_name, _ in self.to_save_relations:
504             db_conn.delete_where(table, column, self.id_)
505         self._uncache()
506         db_conn.delete_where(self.table_name, 'id', self.id_)
507         self._disappear()