home · contact · privacy
Extend Todo tests, overhaul Ctx library building.
[plomtask] / plomtask / db.py
1 """Database management."""
2 from __future__ import annotations
3 from os import listdir
4 from os.path import isfile
5 from difflib import Differ
6 from sqlite3 import connect as sql_connect, Cursor, Row
7 from typing import Any, Self, TypeVar, Generic, Callable
8 from plomtask.exceptions import HandledException, NotFoundException
9 from plomtask.dating import valid_date
10
11 EXPECTED_DB_VERSION = 5
12 MIGRATIONS_DIR = 'migrations'
13 FILENAME_DB_SCHEMA = f'init_{EXPECTED_DB_VERSION}.sql'
14 PATH_DB_SCHEMA = f'{MIGRATIONS_DIR}/{FILENAME_DB_SCHEMA}'
15
16
17 class UnmigratedDbException(HandledException):
18     """To identify case of unmigrated DB file."""
19
20
21 class CtxReferences:
22     """Collects references for future library building."""
23     # pylint: disable=too-few-public-methods
24
25     def __init__(self, d: dict[str, list[int | str]]) -> None:
26         # NB: For tighter mypy testing, we might prefer the library argument
27         # to be of type dict[str, list[int] | list[str] instead. But my
28         # current coding knowledge only manages to make that work by turning
29         # the code much more complex, so let's leave it at
30         # that for now …
31         self.d = d
32
33     def update(self, other: CtxReferences) -> bool:
34         """Updates other with entries in self."""
35         changed = False
36         for cls_name, id_list in self.d.items():
37             if cls_name not in other.d:
38                 other.d[cls_name] = []
39             for id_ in id_list:
40                 if id_ not in other.d[cls_name]:
41                     other.d[cls_name] += [id_]
42         return changed
43
44
45 class DatabaseFile:
46     """Represents the sqlite3 database's file."""
47     # pylint: disable=too-few-public-methods
48
49     def __init__(self, path: str) -> None:
50         self.path = path
51         self._check()
52
53     @classmethod
54     def create_at(cls, path: str) -> DatabaseFile:
55         """Make new DB file at path."""
56         with sql_connect(path) as conn:
57             with open(PATH_DB_SCHEMA, 'r', encoding='utf-8') as f:
58                 conn.executescript(f.read())
59             conn.execute(f'PRAGMA user_version = {EXPECTED_DB_VERSION}')
60         return cls(path)
61
62     @classmethod
63     def migrate(cls, path: str) -> DatabaseFile:
64         """Apply migrations from_version to EXPECTED_DB_VERSION."""
65         migrations = cls._available_migrations()
66         from_version = cls._get_version_of_db(path)
67         migrations_todo = migrations[from_version+1:]
68         for j, filename in enumerate(migrations_todo):
69             with sql_connect(path) as conn:
70                 with open(f'{MIGRATIONS_DIR}/{filename}', 'r',
71                           encoding='utf-8') as f:
72                     conn.executescript(f.read())
73             user_version = from_version + j + 1
74             with sql_connect(path) as conn:
75                 conn.execute(f'PRAGMA user_version = {user_version}')
76         return cls(path)
77
78     def _check(self) -> None:
79         """Check file exists, and is of proper DB version and schema."""
80         if not isfile(self.path):
81             raise NotFoundException
82         if self._user_version != EXPECTED_DB_VERSION:
83             raise UnmigratedDbException()
84         self._validate_schema()
85
86     @staticmethod
87     def _available_migrations() -> list[str]:
88         """Validate migrations directory and return sorted entries."""
89         msg_too_big = 'Migration directory points beyond expected DB version.'
90         msg_bad_entry = 'Migration directory contains unexpected entry: '
91         msg_missing = 'Migration directory misses migration of number: '
92         migrations = {}
93         for entry in listdir(MIGRATIONS_DIR):
94             if entry == FILENAME_DB_SCHEMA:
95                 continue
96             toks = entry.split('_', 1)
97             if len(toks) < 2:
98                 raise HandledException(msg_bad_entry + entry)
99             try:
100                 i = int(toks[0])
101             except ValueError as e:
102                 raise HandledException(msg_bad_entry + entry) from e
103             if i > EXPECTED_DB_VERSION:
104                 raise HandledException(msg_too_big)
105             migrations[i] = toks[1]
106         migrations_list = []
107         for i in range(EXPECTED_DB_VERSION + 1):
108             if i not in migrations:
109                 raise HandledException(msg_missing + str(i))
110             migrations_list += [f'{i}_{migrations[i]}']
111         return migrations_list
112
113     @staticmethod
114     def _get_version_of_db(path: str) -> int:
115         """Get DB user_version, fail if outside expected range."""
116         sql_for_db_version = 'PRAGMA user_version'
117         with sql_connect(path) as conn:
118             db_version = list(conn.execute(sql_for_db_version))[0][0]
119         if db_version > EXPECTED_DB_VERSION:
120             msg = f'Wrong DB version, expected '\
121                     f'{EXPECTED_DB_VERSION}, got unknown {db_version}.'
122             raise HandledException(msg)
123         assert isinstance(db_version, int)
124         return db_version
125
126     @property
127     def _user_version(self) -> int:
128         """Get DB user_version."""
129         return self._get_version_of_db(self.path)
130
131     def _validate_schema(self) -> None:
132         """Compare found schema with what's stored at PATH_DB_SCHEMA."""
133
134         def reformat_rows(rows: list[str]) -> list[str]:
135             new_rows = []
136             for row in rows:
137                 new_row = []
138                 for subrow in row.split('\n'):
139                     subrow = subrow.rstrip()
140                     in_parentheses = 0
141                     split_at = []
142                     for i, c in enumerate(subrow):
143                         if '(' == c:
144                             in_parentheses += 1
145                         elif ')' == c:
146                             in_parentheses -= 1
147                         elif ',' == c and 0 == in_parentheses:
148                             split_at += [i + 1]
149                     prev_split = 0
150                     for i in split_at:
151                         segment = subrow[prev_split:i].strip()
152                         if len(segment) > 0:
153                             new_row += [f'    {segment}']
154                         prev_split = i
155                     segment = subrow[prev_split:].strip()
156                     if len(segment) > 0:
157                         new_row += [f'    {segment}']
158                 new_row[0] = new_row[0].lstrip()
159                 new_row[-1] = new_row[-1].lstrip()
160                 if new_row[-1] != ')' and new_row[-3][-1] != ',':
161                     new_row[-3] = new_row[-3] + ','
162                     new_row[-2:] = ['    ' + new_row[-1][:-1]] + [')']
163                 new_rows += ['\n'.join(new_row)]
164             return new_rows
165
166         sql_for_schema = 'SELECT sql FROM sqlite_master ORDER BY sql'
167         msg_err = 'Database has wrong tables schema. Diff:\n'
168         with sql_connect(self.path) as conn:
169             schema_rows = [r[0] for r in conn.execute(sql_for_schema) if r[0]]
170         schema_rows = reformat_rows(schema_rows)
171         retrieved_schema = ';\n'.join(schema_rows) + ';'
172         with open(PATH_DB_SCHEMA, 'r', encoding='utf-8') as f:
173             stored_schema = f.read().rstrip()
174         if stored_schema != retrieved_schema:
175             diff_msg = Differ().compare(retrieved_schema.splitlines(),
176                                         stored_schema.splitlines())
177             raise HandledException(msg_err + '\n'.join(diff_msg))
178
179
180 class DatabaseConnection:
181     """A single connection to the database."""
182
183     def __init__(self, db_file: DatabaseFile) -> None:
184         self.conn = sql_connect(db_file.path)
185
186     def commit(self) -> None:
187         """Commit SQL transaction."""
188         self.conn.commit()
189
190     def exec(self, code: str, inputs: tuple[Any, ...] = tuple()) -> Cursor:
191         """Add commands to SQL transaction."""
192         return self.conn.execute(code, inputs)
193
194     def exec_on_vals(self, code: str, inputs: tuple[Any, ...]) -> Cursor:
195         """Wrapper around .exec appending adequate " (?, …)" to code."""
196         q_marks_from_values = '(' + ','.join(['?'] * len(inputs)) + ')'
197         return self.exec(f'{code} {q_marks_from_values}', inputs)
198
199     def close(self) -> None:
200         """Close DB connection."""
201         self.conn.close()
202
203     def rewrite_relations(self, table_name: str, key: str, target: int | str,
204                           rows: list[list[Any]], key_index: int = 0) -> None:
205         # pylint: disable=too-many-arguments
206         """Rewrite relations in table_name to target, with rows values.
207
208         Note that single rows are expected without the column and value
209         identified by key and target, which are inserted inside the function
210         at key_index.
211         """
212         self.delete_where(table_name, key, target)
213         for row in rows:
214             values = tuple(row[:key_index] + [target] + row[key_index:])
215             self.exec_on_vals(f'INSERT INTO {table_name} VALUES', values)
216
217     def row_where(self, table_name: str, key: str,
218                   target: int | str) -> list[Row]:
219         """Return list of Rows at table where key == target."""
220         return list(self.exec(f'SELECT * FROM {table_name} WHERE {key} = ?',
221                               (target,)))
222
223     # def column_where_pattern(self,
224     #                          table_name: str,
225     #                          column: str,
226     #                          pattern: str,
227     #                          keys: list[str]) -> list[Any]:
228     #     """Return column of rows where one of keys matches pattern."""
229     #     targets = tuple([f'%{pattern}%'] * len(keys))
230     #     haystack = ' OR '.join([f'{k} LIKE ?' for k in keys])
231     #     sql = f'SELECT {column} FROM {table_name} WHERE {haystack}'
232     #     return [row[0] for row in self.exec(sql, targets)]
233
234     def column_where(self, table_name: str, column: str, key: str,
235                      target: int | str) -> list[Any]:
236         """Return column of table where key == target."""
237         return [row[0] for row in
238                 self.exec(f'SELECT {column} FROM {table_name} '
239                           f'WHERE {key} = ?', (target,))]
240
241     def column_all(self, table_name: str, column: str) -> list[Any]:
242         """Return complete column of table."""
243         return [row[0] for row in
244                 self.exec(f'SELECT {column} FROM {table_name}')]
245
246     def delete_where(self, table_name: str, key: str,
247                      target: int | str) -> None:
248         """Delete from table where key == target."""
249         self.exec(f'DELETE FROM {table_name} WHERE {key} = ?', (target,))
250
251
252 BaseModelId = TypeVar('BaseModelId', int, str)
253 BaseModelInstance = TypeVar('BaseModelInstance', bound='BaseModel[Any]')
254
255
256 class BaseModel(Generic[BaseModelId]):
257     """Template for most of the models we use/derive from the DB."""
258     table_name = ''
259     to_save_simples: list[str] = []
260     to_save_relations: list[tuple[str, str, str, int]] = []
261     versioned_defaults: dict[str, str | float] = {}
262     add_to_dict: list[str] = []
263     id_: None | BaseModelId
264     cache_: dict[BaseModelId, Self]
265     to_search: list[str] = []
266     can_create_by_id = False
267     _exists = True
268     sorters: dict[str, Callable[..., Any]] = {}
269
270     def __init__(self, id_: BaseModelId | None) -> None:
271         if isinstance(id_, int) and id_ < 1:
272             msg = f'illegal {self.__class__.__name__} ID, must be >=1: {id_}'
273             raise HandledException(msg)
274         if isinstance(id_, str) and "" == id_:
275             msg = f'illegal {self.__class__.__name__} ID, must be non-empty'
276             raise HandledException(msg)
277         self.id_ = id_
278
279     def __hash__(self) -> int:
280         hashable = [self.id_] + [getattr(self, name)
281                                  for name in self.to_save_simples]
282         for definition in self.to_save_relations:
283             attr = getattr(self, definition[2])
284             hashable += [tuple(rel.id_ for rel in attr)]
285         for name in self.to_save_versioned():
286             hashable += [hash(getattr(self, name))]
287         return hash(tuple(hashable))
288
289     def __eq__(self, other: object) -> bool:
290         if not isinstance(other, self.__class__):
291             return False
292         return hash(self) == hash(other)
293
294     def __lt__(self, other: Any) -> bool:
295         if not isinstance(other, self.__class__):
296             msg = 'cannot compare to object of different class'
297             raise HandledException(msg)
298         assert isinstance(self.id_, int)
299         assert isinstance(other.id_, int)
300         return self.id_ < other.id_
301
302     @classmethod
303     def to_save_versioned(cls) -> list[str]:
304         """Return keys of cls.versioned_defaults assuming we wanna save 'em."""
305         return list(cls.versioned_defaults.keys())
306
307     @property
308     def as_dict(self) -> dict[str, object]:
309         """Return self as (json.dumps-compatible) dict."""
310         references = CtxReferences({})
311         d: dict[str, object] = {'id': self.id_, '_references': references}
312         for to_save in self.to_save_simples:
313             d[to_save] = getattr(self, to_save)
314         if len(self.to_save_versioned()) > 0:
315             d['_versioned'] = {}
316         for k in self.to_save_versioned():
317             attr = getattr(self, k)
318             assert isinstance(d['_versioned'], dict)
319             d['_versioned'][k] = attr.history
320         for r in self.to_save_relations:
321             attr_name = r[2]
322             l: list[int | str] = []
323             for rel in getattr(self, attr_name):
324                 cls_name = rel.__class__.__name__
325                 if cls_name not in references.d:
326                     references.d[cls_name] = []
327                 l += [rel.id_]
328                 references.d[cls_name] += [rel.id_]
329             d[attr_name] = l
330         for k in self.add_to_dict:
331             d[k] = [x.into_reference(references)
332                     for x in getattr(self, k)]
333         return d
334
335     def into_reference(self, references: CtxReferences) -> int | str:
336         """Return self.id_ and write into references for class.."""
337         cls_name = self.__class__.__name__
338         if cls_name not in references.d:
339             references.d[cls_name] = []
340         assert self.id_ is not None
341         references.d[cls_name] += [self.id_]
342         own_refs = self.as_dict['_references']
343         assert isinstance(own_refs, CtxReferences)
344         own_refs.update(references)
345         return self.id_
346
347     @classmethod
348     def name_lowercase(cls) -> str:
349         """Convenience method to return cls' name in lowercase."""
350         return cls.__name__.lower()
351
352     @classmethod
353     def sort_by(cls, seq: list[Any], sort_key: str, default: str = 'title'
354                 ) -> str:
355         """Sort cls list by cls.sorters[sort_key] (reverse if '-'-prefixed).
356
357         Before cls.sorters[sort_key] is applied, seq is sorted by .id_, to
358         ensure predictability where parts of seq are of same sort value.
359         """
360         reverse = False
361         if len(sort_key) > 1 and '-' == sort_key[0]:
362             sort_key = sort_key[1:]
363             reverse = True
364         if sort_key not in cls.sorters:
365             sort_key = default
366         seq.sort(key=lambda x: x.id_, reverse=reverse)
367         sorter: Callable[..., Any] = cls.sorters[sort_key]
368         seq.sort(key=sorter, reverse=reverse)
369         if reverse:
370             sort_key = f'-{sort_key}'
371         return sort_key
372
373     # cache management
374     # (we primarily use the cache to ensure we work on the same object in
375     # memory no matter where and how we retrieve it, e.g. we don't want
376     # .by_id() calls to create a new object each time, but rather a pointer
377     # to the one already instantiated)
378
379     def __getattribute__(self, name: str) -> Any:
380         """Ensure fail if ._disappear() was called, except to check ._exists"""
381         if name != '_exists' and not super().__getattribute__('_exists'):
382             raise HandledException('Object does not exist.')
383         return super().__getattribute__(name)
384
385     def _disappear(self) -> None:
386         """Invalidate object, make future use raise exceptions."""
387         assert self.id_ is not None
388         if self._get_cached(self.id_):
389             self._uncache()
390         to_kill = list(self.__dict__.keys())
391         for attr in to_kill:
392             delattr(self, attr)
393         self._exists = False
394
395     @classmethod
396     def empty_cache(cls) -> None:
397         """Empty class's cache, and disappear all former inhabitants."""
398         # pylint: disable=protected-access
399         # (cause we remain within the class)
400         if hasattr(cls, 'cache_'):
401             to_disappear = list(cls.cache_.values())
402             for item in to_disappear:
403                 item._disappear()
404         cls.cache_ = {}
405
406     @classmethod
407     def get_cache(cls: type[BaseModelInstance]) -> dict[Any, BaseModel[Any]]:
408         """Get cache dictionary, create it if not yet existing."""
409         if not hasattr(cls, 'cache_'):
410             d: dict[Any, BaseModel[Any]] = {}
411             cls.cache_ = d
412         return cls.cache_
413
414     @classmethod
415     def _get_cached(cls: type[BaseModelInstance],
416                     id_: BaseModelId) -> BaseModelInstance | None:
417         """Get object of id_ from class's cache, or None if not found."""
418         cache = cls.get_cache()
419         if id_ in cache:
420             obj = cache[id_]
421             assert isinstance(obj, cls)
422             return obj
423         return None
424
425     def cache(self) -> None:
426         """Update object in class's cache.
427
428         Also calls ._disappear if cache holds older reference to object of same
429         ID, but different memory address, to avoid doing anything with
430         dangling leftovers.
431         """
432         if self.id_ is None:
433             raise HandledException('Cannot cache object without ID.')
434         cache = self.get_cache()
435         old_cached = self._get_cached(self.id_)
436         if old_cached and id(old_cached) != id(self):
437             # pylint: disable=protected-access
438             # (cause we remain within the class)
439             old_cached._disappear()
440         cache[self.id_] = self
441
442     def _uncache(self) -> None:
443         """Remove self from cache."""
444         if self.id_ is None:
445             raise HandledException('Cannot un-cache object without ID.')
446         cache = self.get_cache()
447         del cache[self.id_]
448
449     # object retrieval and generation
450
451     @classmethod
452     def from_table_row(cls: type[BaseModelInstance],
453                        # pylint: disable=unused-argument
454                        db_conn: DatabaseConnection,
455                        row: Row | list[Any]) -> BaseModelInstance:
456         """Make from DB row (sans relations), update DB cache with it."""
457         obj = cls(*row)
458         assert obj.id_ is not None
459         for attr_name in cls.to_save_versioned():
460             attr = getattr(obj, attr_name)
461             table_name = attr.table_name
462             for row_ in db_conn.row_where(table_name, 'parent', obj.id_):
463                 attr.history_from_row(row_)
464         obj.cache()
465         return obj
466
467     @classmethod
468     def by_id(cls, db_conn: DatabaseConnection, id_: BaseModelId) -> Self:
469         """Retrieve by id_, on failure throw NotFoundException.
470
471         First try to get from cls.cache_, only then check DB; if found,
472         put into cache.
473         """
474         obj = None
475         if id_ is not None:
476             obj = cls._get_cached(id_)
477             if not obj:
478                 for row in db_conn.row_where(cls.table_name, 'id', id_):
479                     obj = cls.from_table_row(db_conn, row)
480                     break
481         if obj:
482             return obj
483         raise NotFoundException(f'found no object of ID {id_}')
484
485     @classmethod
486     def by_id_or_create(cls, db_conn: DatabaseConnection,
487                         id_: BaseModelId | None
488                         ) -> Self:
489         """Wrapper around .by_id, creating (not caching/saving) if not find."""
490         if not cls.can_create_by_id:
491             raise HandledException('Class cannot .by_id_or_create.')
492         if id_ is None:
493             return cls(None)
494         try:
495             return cls.by_id(db_conn, id_)
496         except NotFoundException:
497             return cls(id_)
498
499     @classmethod
500     def all(cls: type[BaseModelInstance],
501             db_conn: DatabaseConnection) -> list[BaseModelInstance]:
502         """Collect all objects of class into list.
503
504         Note that this primarily returns the contents of the cache, and only
505         _expands_ that by additional findings in the DB. This assumes the
506         cache is always instantly cleaned of any items that would be removed
507         from the DB.
508         """
509         items: dict[BaseModelId, BaseModelInstance] = {}
510         for k, v in cls.get_cache().items():
511             assert isinstance(v, cls)
512             items[k] = v
513         already_recorded = items.keys()
514         for id_ in db_conn.column_all(cls.table_name, 'id'):
515             if id_ not in already_recorded:
516                 item = cls.by_id(db_conn, id_)
517                 assert item.id_ is not None
518                 items[item.id_] = item
519         return list(items.values())
520
521     @classmethod
522     def by_date_range_with_limits(cls: type[BaseModelInstance],
523                                   db_conn: DatabaseConnection,
524                                   date_range: tuple[str, str],
525                                   date_col: str = 'day'
526                                   ) -> tuple[list[BaseModelInstance], str,
527                                              str]:
528         """Return list of items in DB within (closed) date_range interval.
529
530         If no range values provided, defaults them to 'yesterday' and
531         'tomorrow'. Knows to properly interpret these and 'today' as value.
532         """
533         start_str = date_range[0] if date_range[0] else 'yesterday'
534         end_str = date_range[1] if date_range[1] else 'tomorrow'
535         start_date = valid_date(start_str)
536         end_date = valid_date(end_str)
537         items = []
538         sql = f'SELECT id FROM {cls.table_name} '
539         sql += f'WHERE {date_col} >= ? AND {date_col} <= ?'
540         for row in db_conn.exec(sql, (start_date, end_date)):
541             items += [cls.by_id(db_conn, row[0])]
542         return items, start_date, end_date
543
544     @classmethod
545     def matching(cls: type[BaseModelInstance], db_conn: DatabaseConnection,
546                  pattern: str) -> list[BaseModelInstance]:
547         """Return all objects whose .to_search match pattern."""
548         items = cls.all(db_conn)
549         if pattern:
550             filtered = []
551             for item in items:
552                 for attr_name in cls.to_search:
553                     toks = attr_name.split('.')
554                     parent = item
555                     for tok in toks:
556                         attr = getattr(parent, tok)
557                         parent = attr
558                     if pattern in attr:
559                         filtered += [item]
560                         break
561             return filtered
562         return items
563
564     # database writing
565
566     def save(self, db_conn: DatabaseConnection) -> None:
567         """Write self to DB and cache and ensure .id_.
568
569         Write both to DB, and to cache. To DB, write .id_ and attributes
570         listed in cls.to_save_[simples|versioned|_relations].
571
572         Ensure self.id_ by setting it to what the DB command returns as the
573         last saved row's ID (cursor.lastrowid), EXCEPT if self.id_ already
574         exists as a 'str', which implies we do our own ID creation (so far
575         only the case with the Day class, where it's to be a date string.
576         """
577         values = tuple([self.id_] + [getattr(self, key)
578                                      for key in self.to_save_simples])
579         table_name = self.table_name
580         cursor = db_conn.exec_on_vals(f'REPLACE INTO {table_name} VALUES',
581                                       values)
582         if not isinstance(self.id_, str):
583             self.id_ = cursor.lastrowid  # type: ignore[assignment]
584         self.cache()
585         for attr_name in self.to_save_versioned():
586             getattr(self, attr_name).save(db_conn)
587         for table, column, attr_name, key_index in self.to_save_relations:
588             assert isinstance(self.id_, (int, str))
589             db_conn.rewrite_relations(table, column, self.id_,
590                                       [[i.id_] for i
591                                        in getattr(self, attr_name)], key_index)
592
593     def remove(self, db_conn: DatabaseConnection) -> None:
594         """Remove from DB and cache, including dependencies."""
595         if self.id_ is None or self._get_cached(self.id_) is None:
596             raise HandledException('cannot remove unsaved item')
597         for attr_name in self.to_save_versioned():
598             getattr(self, attr_name).remove(db_conn)
599         for table, column, attr_name, _ in self.to_save_relations:
600             db_conn.delete_where(table, column, self.id_)
601         self._uncache()
602         db_conn.delete_where(self.table_name, 'id', self.id_)
603         self._disappear()