home · contact · privacy
Slightly improve and re-organize Condition tests.
[plomtask] / plomtask / db.py
1 """Database management."""
2 from __future__ import annotations
3 from os import listdir
4 from os.path import isfile
5 from difflib import Differ
6 from sqlite3 import connect as sql_connect, Cursor, Row
7 from typing import Any, Self, TypeVar, Generic, Callable
8 from plomtask.exceptions import HandledException, NotFoundException
9 from plomtask.dating import valid_date
10
11 EXPECTED_DB_VERSION = 5
12 MIGRATIONS_DIR = 'migrations'
13 FILENAME_DB_SCHEMA = f'init_{EXPECTED_DB_VERSION}.sql'
14 PATH_DB_SCHEMA = f'{MIGRATIONS_DIR}/{FILENAME_DB_SCHEMA}'
15
16
17 class UnmigratedDbException(HandledException):
18     """To identify case of unmigrated DB file."""
19
20
21 class DatabaseFile:
22     """Represents the sqlite3 database's file."""
23     # pylint: disable=too-few-public-methods
24
25     def __init__(self, path: str) -> None:
26         self.path = path
27         self._check()
28
29     @classmethod
30     def create_at(cls, path: str) -> DatabaseFile:
31         """Make new DB file at path."""
32         with sql_connect(path) as conn:
33             with open(PATH_DB_SCHEMA, 'r', encoding='utf-8') as f:
34                 conn.executescript(f.read())
35             conn.execute(f'PRAGMA user_version = {EXPECTED_DB_VERSION}')
36         return cls(path)
37
38     @classmethod
39     def migrate(cls, path: str) -> DatabaseFile:
40         """Apply migrations from_version to EXPECTED_DB_VERSION."""
41         migrations = cls._available_migrations()
42         from_version = cls._get_version_of_db(path)
43         migrations_todo = migrations[from_version+1:]
44         for j, filename in enumerate(migrations_todo):
45             with sql_connect(path) as conn:
46                 with open(f'{MIGRATIONS_DIR}/{filename}', 'r',
47                           encoding='utf-8') as f:
48                     conn.executescript(f.read())
49             user_version = from_version + j + 1
50             with sql_connect(path) as conn:
51                 conn.execute(f'PRAGMA user_version = {user_version}')
52         return cls(path)
53
54     def _check(self) -> None:
55         """Check file exists, and is of proper DB version and schema."""
56         if not isfile(self.path):
57             raise NotFoundException
58         if self._user_version != EXPECTED_DB_VERSION:
59             raise UnmigratedDbException()
60         self._validate_schema()
61
62     @staticmethod
63     def _available_migrations() -> list[str]:
64         """Validate migrations directory and return sorted entries."""
65         msg_too_big = 'Migration directory points beyond expected DB version.'
66         msg_bad_entry = 'Migration directory contains unexpected entry: '
67         msg_missing = 'Migration directory misses migration of number: '
68         migrations = {}
69         for entry in listdir(MIGRATIONS_DIR):
70             if entry == FILENAME_DB_SCHEMA:
71                 continue
72             toks = entry.split('_', 1)
73             if len(toks) < 2:
74                 raise HandledException(msg_bad_entry + entry)
75             try:
76                 i = int(toks[0])
77             except ValueError as e:
78                 raise HandledException(msg_bad_entry + entry) from e
79             if i > EXPECTED_DB_VERSION:
80                 raise HandledException(msg_too_big)
81             migrations[i] = toks[1]
82         migrations_list = []
83         for i in range(EXPECTED_DB_VERSION + 1):
84             if i not in migrations:
85                 raise HandledException(msg_missing + str(i))
86             migrations_list += [f'{i}_{migrations[i]}']
87         return migrations_list
88
89     @staticmethod
90     def _get_version_of_db(path: str) -> int:
91         """Get DB user_version, fail if outside expected range."""
92         sql_for_db_version = 'PRAGMA user_version'
93         with sql_connect(path) as conn:
94             db_version = list(conn.execute(sql_for_db_version))[0][0]
95         if db_version > EXPECTED_DB_VERSION:
96             msg = f'Wrong DB version, expected '\
97                     f'{EXPECTED_DB_VERSION}, got unknown {db_version}.'
98             raise HandledException(msg)
99         assert isinstance(db_version, int)
100         return db_version
101
102     @property
103     def _user_version(self) -> int:
104         """Get DB user_version."""
105         return self._get_version_of_db(self.path)
106
107     def _validate_schema(self) -> None:
108         """Compare found schema with what's stored at PATH_DB_SCHEMA."""
109
110         def reformat_rows(rows: list[str]) -> list[str]:
111             new_rows = []
112             for row in rows:
113                 new_row = []
114                 for subrow in row.split('\n'):
115                     subrow = subrow.rstrip()
116                     in_parentheses = 0
117                     split_at = []
118                     for i, c in enumerate(subrow):
119                         if '(' == c:
120                             in_parentheses += 1
121                         elif ')' == c:
122                             in_parentheses -= 1
123                         elif ',' == c and 0 == in_parentheses:
124                             split_at += [i + 1]
125                     prev_split = 0
126                     for i in split_at:
127                         segment = subrow[prev_split:i].strip()
128                         if len(segment) > 0:
129                             new_row += [f'    {segment}']
130                         prev_split = i
131                     segment = subrow[prev_split:].strip()
132                     if len(segment) > 0:
133                         new_row += [f'    {segment}']
134                 new_row[0] = new_row[0].lstrip()
135                 new_row[-1] = new_row[-1].lstrip()
136                 if new_row[-1] != ')' and new_row[-3][-1] != ',':
137                     new_row[-3] = new_row[-3] + ','
138                     new_row[-2:] = ['    ' + new_row[-1][:-1]] + [')']
139                 new_rows += ['\n'.join(new_row)]
140             return new_rows
141
142         sql_for_schema = 'SELECT sql FROM sqlite_master ORDER BY sql'
143         msg_err = 'Database has wrong tables schema. Diff:\n'
144         with sql_connect(self.path) as conn:
145             schema_rows = [r[0] for r in conn.execute(sql_for_schema) if r[0]]
146         schema_rows = reformat_rows(schema_rows)
147         retrieved_schema = ';\n'.join(schema_rows) + ';'
148         with open(PATH_DB_SCHEMA, 'r', encoding='utf-8') as f:
149             stored_schema = f.read().rstrip()
150         if stored_schema != retrieved_schema:
151             diff_msg = Differ().compare(retrieved_schema.splitlines(),
152                                         stored_schema.splitlines())
153             raise HandledException(msg_err + '\n'.join(diff_msg))
154
155
156 class DatabaseConnection:
157     """A single connection to the database."""
158
159     def __init__(self, db_file: DatabaseFile) -> None:
160         self.conn = sql_connect(db_file.path)
161
162     def commit(self) -> None:
163         """Commit SQL transaction."""
164         self.conn.commit()
165
166     def exec(self, code: str, inputs: tuple[Any, ...] = tuple()) -> Cursor:
167         """Add commands to SQL transaction."""
168         return self.conn.execute(code, inputs)
169
170     def exec_on_vals(self, code: str, inputs: tuple[Any, ...]) -> Cursor:
171         """Wrapper around .exec appending adequate " (?, …)" to code."""
172         q_marks_from_values = '(' + ','.join(['?'] * len(inputs)) + ')'
173         return self.exec(f'{code} {q_marks_from_values}', inputs)
174
175     def close(self) -> None:
176         """Close DB connection."""
177         self.conn.close()
178
179     def rewrite_relations(self, table_name: str, key: str, target: int | str,
180                           rows: list[list[Any]], key_index: int = 0) -> None:
181         # pylint: disable=too-many-arguments
182         """Rewrite relations in table_name to target, with rows values.
183
184         Note that single rows are expected without the column and value
185         identified by key and target, which are inserted inside the function
186         at key_index.
187         """
188         self.delete_where(table_name, key, target)
189         for row in rows:
190             values = tuple(row[:key_index] + [target] + row[key_index:])
191             self.exec_on_vals(f'INSERT INTO {table_name} VALUES', values)
192
193     def row_where(self, table_name: str, key: str,
194                   target: int | str) -> list[Row]:
195         """Return list of Rows at table where key == target."""
196         return list(self.exec(f'SELECT * FROM {table_name} WHERE {key} = ?',
197                               (target,)))
198
199     # def column_where_pattern(self,
200     #                          table_name: str,
201     #                          column: str,
202     #                          pattern: str,
203     #                          keys: list[str]) -> list[Any]:
204     #     """Return column of rows where one of keys matches pattern."""
205     #     targets = tuple([f'%{pattern}%'] * len(keys))
206     #     haystack = ' OR '.join([f'{k} LIKE ?' for k in keys])
207     #     sql = f'SELECT {column} FROM {table_name} WHERE {haystack}'
208     #     return [row[0] for row in self.exec(sql, targets)]
209
210     def column_where(self, table_name: str, column: str, key: str,
211                      target: int | str) -> list[Any]:
212         """Return column of table where key == target."""
213         return [row[0] for row in
214                 self.exec(f'SELECT {column} FROM {table_name} '
215                           f'WHERE {key} = ?', (target,))]
216
217     def column_all(self, table_name: str, column: str) -> list[Any]:
218         """Return complete column of table."""
219         return [row[0] for row in
220                 self.exec(f'SELECT {column} FROM {table_name}')]
221
222     def delete_where(self, table_name: str, key: str,
223                      target: int | str) -> None:
224         """Delete from table where key == target."""
225         self.exec(f'DELETE FROM {table_name} WHERE {key} = ?', (target,))
226
227
228 BaseModelId = TypeVar('BaseModelId', int, str)
229 BaseModelInstance = TypeVar('BaseModelInstance', bound='BaseModel[Any]')
230
231
232 class BaseModel(Generic[BaseModelId]):
233     """Template for most of the models we use/derive from the DB."""
234     table_name = ''
235     to_save: list[str] = []
236     to_save_versioned: list[str] = []
237     to_save_relations: list[tuple[str, str, str, int]] = []
238     add_to_dict: list[str] = []
239     id_: None | BaseModelId
240     cache_: dict[BaseModelId, Self]
241     to_search: list[str] = []
242     can_create_by_id = False
243     _exists = True
244     sorters: dict[str, Callable[..., Any]] = {}
245
246     def __init__(self, id_: BaseModelId | None) -> None:
247         if isinstance(id_, int) and id_ < 1:
248             msg = f'illegal {self.__class__.__name__} ID, must be >=1: {id_}'
249             raise HandledException(msg)
250         if isinstance(id_, str) and "" == id_:
251             msg = f'illegal {self.__class__.__name__} ID, must be non-empty'
252             raise HandledException(msg)
253         self.id_ = id_
254
255     def __hash__(self) -> int:
256         hashable = [self.id_] + [getattr(self, name) for name in self.to_save]
257         for definition in self.to_save_relations:
258             attr = getattr(self, definition[2])
259             hashable += [tuple(rel.id_ for rel in attr)]
260         for name in self.to_save_versioned:
261             hashable += [hash(getattr(self, name))]
262         return hash(tuple(hashable))
263
264     def __eq__(self, other: object) -> bool:
265         if not isinstance(other, self.__class__):
266             return False
267         return hash(self) == hash(other)
268
269     def __lt__(self, other: Any) -> bool:
270         if not isinstance(other, self.__class__):
271             msg = 'cannot compare to object of different class'
272             raise HandledException(msg)
273         assert isinstance(self.id_, int)
274         assert isinstance(other.id_, int)
275         return self.id_ < other.id_
276
277     @property
278     def as_dict(self) -> dict[str, object]:
279         """Return self as (json.dumps-compatible) dict."""
280         library: dict[str, dict[str | int, object]] = {}
281         d: dict[str, object] = {'id': self.id_, '_library': library}
282         for to_save in self.to_save:
283             attr = getattr(self, to_save)
284             if hasattr(attr, 'as_dict_into_reference'):
285                 d[to_save] = attr.as_dict_into_reference(library)
286             else:
287                 d[to_save] = attr
288         if len(self.to_save_versioned) > 0:
289             d['_versioned'] = {}
290         for k in self.to_save_versioned:
291             attr = getattr(self, k)
292             assert isinstance(d['_versioned'], dict)
293             d['_versioned'][k] = attr.history
294         for r in self.to_save_relations:
295             attr_name = r[2]
296             l: list[int | str] = []
297             for rel in getattr(self, attr_name):
298                 l += [rel.as_dict_into_reference(library)]
299             d[attr_name] = l
300         for k in self.add_to_dict:
301             d[k] = [x.as_dict_into_reference(library)
302                     for x in getattr(self, k)]
303         return d
304
305     def as_dict_into_reference(self,
306                                library: dict[str, dict[str | int, object]]
307                                ) -> int | str:
308         """Return self.id_ while writing .as_dict into library."""
309         def into_library(library: dict[str, dict[str | int, object]],
310                          cls_name: str,
311                          id_: str | int,
312                          d: dict[str, object]
313                          ) -> None:
314             if cls_name not in library:
315                 library[cls_name] = {}
316             if id_ in library[cls_name]:
317                 if library[cls_name][id_] != d:
318                     msg = 'Unexpected inequality of entries for ' +\
319                             f'_library at: {cls_name}/{id_}'
320                     raise HandledException(msg)
321             else:
322                 library[cls_name][id_] = d
323         as_dict = self.as_dict
324         assert isinstance(as_dict['_library'], dict)
325         for cls_name, dict_of_objs in as_dict['_library'].items():
326             for id_, obj in dict_of_objs.items():
327                 into_library(library, cls_name, id_, obj)
328         del as_dict['_library']
329         assert self.id_ is not None
330         into_library(library, self.__class__.__name__, self.id_, as_dict)
331         assert isinstance(as_dict['id'], (int, str))
332         return as_dict['id']
333
334     @classmethod
335     def name_lowercase(cls) -> str:
336         """Convenience method to return cls' name in lowercase."""
337         return cls.__name__.lower()
338
339     @classmethod
340     def sort_by(cls, seq: list[Any], sort_key: str, default: str = 'title'
341                 ) -> str:
342         """Sort cls list by cls.sorters[sort_key] (reverse if '-'-prefixed)."""
343         reverse = False
344         if len(sort_key) > 1 and '-' == sort_key[0]:
345             sort_key = sort_key[1:]
346             reverse = True
347         if sort_key not in cls.sorters:
348             sort_key = default
349         sorter: Callable[..., Any] = cls.sorters[sort_key]
350         seq.sort(key=sorter, reverse=reverse)
351         if reverse:
352             sort_key = f'-{sort_key}'
353         return sort_key
354
355     # cache management
356     # (we primarily use the cache to ensure we work on the same object in
357     # memory no matter where and how we retrieve it, e.g. we don't want
358     # .by_id() calls to create a new object each time, but rather a pointer
359     # to the one already instantiated)
360
361     def __getattribute__(self, name: str) -> Any:
362         """Ensure fail if ._disappear() was called, except to check ._exists"""
363         if name != '_exists' and not super().__getattribute__('_exists'):
364             raise HandledException('Object does not exist.')
365         return super().__getattribute__(name)
366
367     def _disappear(self) -> None:
368         """Invalidate object, make future use raise exceptions."""
369         assert self.id_ is not None
370         if self._get_cached(self.id_):
371             self._uncache()
372         to_kill = list(self.__dict__.keys())
373         for attr in to_kill:
374             delattr(self, attr)
375         self._exists = False
376
377     @classmethod
378     def empty_cache(cls) -> None:
379         """Empty class's cache, and disappear all former inhabitants."""
380         # pylint: disable=protected-access
381         # (cause we remain within the class)
382         if hasattr(cls, 'cache_'):
383             to_disappear = list(cls.cache_.values())
384             for item in to_disappear:
385                 item._disappear()
386         cls.cache_ = {}
387
388     @classmethod
389     def get_cache(cls: type[BaseModelInstance]) -> dict[Any, BaseModel[Any]]:
390         """Get cache dictionary, create it if not yet existing."""
391         if not hasattr(cls, 'cache_'):
392             d: dict[Any, BaseModel[Any]] = {}
393             cls.cache_ = d
394         return cls.cache_
395
396     @classmethod
397     def _get_cached(cls: type[BaseModelInstance],
398                     id_: BaseModelId) -> BaseModelInstance | None:
399         """Get object of id_ from class's cache, or None if not found."""
400         cache = cls.get_cache()
401         if id_ in cache:
402             obj = cache[id_]
403             assert isinstance(obj, cls)
404             return obj
405         return None
406
407     def cache(self) -> None:
408         """Update object in class's cache.
409
410         Also calls ._disappear if cache holds older reference to object of same
411         ID, but different memory address, to avoid doing anything with
412         dangling leftovers.
413         """
414         if self.id_ is None:
415             raise HandledException('Cannot cache object without ID.')
416         cache = self.get_cache()
417         old_cached = self._get_cached(self.id_)
418         if old_cached and id(old_cached) != id(self):
419             # pylint: disable=protected-access
420             # (cause we remain within the class)
421             old_cached._disappear()
422         cache[self.id_] = self
423
424     def _uncache(self) -> None:
425         """Remove self from cache."""
426         if self.id_ is None:
427             raise HandledException('Cannot un-cache object without ID.')
428         cache = self.get_cache()
429         del cache[self.id_]
430
431     # object retrieval and generation
432
433     @classmethod
434     def from_table_row(cls: type[BaseModelInstance],
435                        # pylint: disable=unused-argument
436                        db_conn: DatabaseConnection,
437                        row: Row | list[Any]) -> BaseModelInstance:
438         """Make from DB row (sans relations), update DB cache with it."""
439         obj = cls(*row)
440         assert obj.id_ is not None
441         for attr_name in cls.to_save_versioned:
442             attr = getattr(obj, attr_name)
443             table_name = attr.table_name
444             for row_ in db_conn.row_where(table_name, 'parent', obj.id_):
445                 attr.history_from_row(row_)
446         obj.cache()
447         return obj
448
449     @classmethod
450     def by_id(cls, db_conn: DatabaseConnection, id_: BaseModelId) -> Self:
451         """Retrieve by id_, on failure throw NotFoundException.
452
453         First try to get from cls.cache_, only then check DB; if found,
454         put into cache.
455         """
456         obj = None
457         if id_ is not None:
458             obj = cls._get_cached(id_)
459             if not obj:
460                 for row in db_conn.row_where(cls.table_name, 'id', id_):
461                     obj = cls.from_table_row(db_conn, row)
462                     break
463         if obj:
464             return obj
465         raise NotFoundException(f'found no object of ID {id_}')
466
467     @classmethod
468     def by_id_or_create(cls, db_conn: DatabaseConnection,
469                         id_: BaseModelId | None
470                         ) -> Self:
471         """Wrapper around .by_id, creating (not caching/saving) if not find."""
472         if not cls.can_create_by_id:
473             raise HandledException('Class cannot .by_id_or_create.')
474         if id_ is None:
475             return cls(None)
476         try:
477             return cls.by_id(db_conn, id_)
478         except NotFoundException:
479             return cls(id_)
480
481     @classmethod
482     def all(cls: type[BaseModelInstance],
483             db_conn: DatabaseConnection) -> list[BaseModelInstance]:
484         """Collect all objects of class into list.
485
486         Note that this primarily returns the contents of the cache, and only
487         _expands_ that by additional findings in the DB. This assumes the
488         cache is always instantly cleaned of any items that would be removed
489         from the DB.
490         """
491         items: dict[BaseModelId, BaseModelInstance] = {}
492         for k, v in cls.get_cache().items():
493             assert isinstance(v, cls)
494             items[k] = v
495         already_recorded = items.keys()
496         for id_ in db_conn.column_all(cls.table_name, 'id'):
497             if id_ not in already_recorded:
498                 item = cls.by_id(db_conn, id_)
499                 assert item.id_ is not None
500                 items[item.id_] = item
501         return list(items.values())
502
503     @classmethod
504     def by_date_range_with_limits(cls: type[BaseModelInstance],
505                                   db_conn: DatabaseConnection,
506                                   date_range: tuple[str, str],
507                                   date_col: str = 'day'
508                                   ) -> tuple[list[BaseModelInstance], str,
509                                              str]:
510         """Return list of items in database within (open) date_range interval.
511
512         If no range values provided, defaults them to 'yesterday' and
513         'tomorrow'. Knows to properly interpret these and 'today' as value.
514         """
515         start_str = date_range[0] if date_range[0] else 'yesterday'
516         end_str = date_range[1] if date_range[1] else 'tomorrow'
517         start_date = valid_date(start_str)
518         end_date = valid_date(end_str)
519         items = []
520         sql = f'SELECT id FROM {cls.table_name} '
521         sql += f'WHERE {date_col} >= ? AND {date_col} <= ?'
522         for row in db_conn.exec(sql, (start_date, end_date)):
523             items += [cls.by_id(db_conn, row[0])]
524         return items, start_date, end_date
525
526     @classmethod
527     def matching(cls: type[BaseModelInstance], db_conn: DatabaseConnection,
528                  pattern: str) -> list[BaseModelInstance]:
529         """Return all objects whose .to_search match pattern."""
530         items = cls.all(db_conn)
531         if pattern:
532             filtered = []
533             for item in items:
534                 for attr_name in cls.to_search:
535                     toks = attr_name.split('.')
536                     parent = item
537                     for tok in toks:
538                         attr = getattr(parent, tok)
539                         parent = attr
540                     if pattern in attr:
541                         filtered += [item]
542                         break
543             return filtered
544         return items
545
546     # database writing
547
548     def save(self, db_conn: DatabaseConnection) -> None:
549         """Write self to DB and cache and ensure .id_.
550
551         Write both to DB, and to cache. To DB, write .id_ and attributes
552         listed in cls.to_save[_versioned|_relations].
553
554         Ensure self.id_ by setting it to what the DB command returns as the
555         last saved row's ID (cursor.lastrowid), EXCEPT if self.id_ already
556         exists as a 'str', which implies we do our own ID creation (so far
557         only the case with the Day class, where it's to be a date string.
558         """
559         values = tuple([self.id_] + [getattr(self, key)
560                                      for key in self.to_save])
561         table_name = self.table_name
562         cursor = db_conn.exec_on_vals(f'REPLACE INTO {table_name} VALUES',
563                                       values)
564         if not isinstance(self.id_, str):
565             self.id_ = cursor.lastrowid  # type: ignore[assignment]
566         self.cache()
567         for attr_name in self.to_save_versioned:
568             getattr(self, attr_name).save(db_conn)
569         for table, column, attr_name, key_index in self.to_save_relations:
570             assert isinstance(self.id_, (int, str))
571             db_conn.rewrite_relations(table, column, self.id_,
572                                       [[i.id_] for i
573                                        in getattr(self, attr_name)], key_index)
574
575     def remove(self, db_conn: DatabaseConnection) -> None:
576         """Remove from DB and cache, including dependencies."""
577         if self.id_ is None or self._get_cached(self.id_) is None:
578             raise HandledException('cannot remove unsaved item')
579         for attr_name in self.to_save_versioned:
580             getattr(self, attr_name).remove(db_conn)
581         for table, column, attr_name, _ in self.to_save_relations:
582             db_conn.delete_where(table, column, self.id_)
583         self._uncache()
584         db_conn.delete_where(self.table_name, 'id', self.id_)
585         self._disappear()