home · contact · privacy
Minor refactoring.
[plomtask] / plomtask / db.py
1 """Database management."""
2 from __future__ import annotations
3 from os import listdir
4 from os.path import isfile
5 from difflib import Differ
6 from sqlite3 import connect as sql_connect, Cursor, Row
7 from typing import Any, Self, TypeVar, Generic, Callable
8 from plomtask.exceptions import HandledException, NotFoundException
9 from plomtask.dating import valid_date
10
11 EXPECTED_DB_VERSION = 5
12 MIGRATIONS_DIR = 'migrations'
13 FILENAME_DB_SCHEMA = f'init_{EXPECTED_DB_VERSION}.sql'
14 PATH_DB_SCHEMA = f'{MIGRATIONS_DIR}/{FILENAME_DB_SCHEMA}'
15
16
17 class UnmigratedDbException(HandledException):
18     """To identify case of unmigrated DB file."""
19
20
21 class DatabaseFile:
22     """Represents the sqlite3 database's file."""
23     # pylint: disable=too-few-public-methods
24
25     def __init__(self, path: str) -> None:
26         self.path = path
27         self._check()
28
29     @classmethod
30     def create_at(cls, path: str) -> DatabaseFile:
31         """Make new DB file at path."""
32         with sql_connect(path) as conn:
33             with open(PATH_DB_SCHEMA, 'r', encoding='utf-8') as f:
34                 conn.executescript(f.read())
35             conn.execute(f'PRAGMA user_version = {EXPECTED_DB_VERSION}')
36         return cls(path)
37
38     @classmethod
39     def migrate(cls, path: str) -> DatabaseFile:
40         """Apply migrations from_version to EXPECTED_DB_VERSION."""
41         migrations = cls._available_migrations()
42         from_version = cls._get_version_of_db(path)
43         migrations_todo = migrations[from_version+1:]
44         for j, filename in enumerate(migrations_todo):
45             with sql_connect(path) as conn:
46                 with open(f'{MIGRATIONS_DIR}/{filename}', 'r',
47                           encoding='utf-8') as f:
48                     conn.executescript(f.read())
49             user_version = from_version + j + 1
50             with sql_connect(path) as conn:
51                 conn.execute(f'PRAGMA user_version = {user_version}')
52         return cls(path)
53
54     def _check(self) -> None:
55         """Check file exists, and is of proper DB version and schema."""
56         if not isfile(self.path):
57             raise NotFoundException
58         if self._user_version != EXPECTED_DB_VERSION:
59             raise UnmigratedDbException()
60         self._validate_schema()
61
62     @staticmethod
63     def _available_migrations() -> list[str]:
64         """Validate migrations directory and return sorted entries."""
65         msg_too_big = 'Migration directory points beyond expected DB version.'
66         msg_bad_entry = 'Migration directory contains unexpected entry: '
67         msg_missing = 'Migration directory misses migration of number: '
68         migrations = {}
69         for entry in listdir(MIGRATIONS_DIR):
70             if entry == FILENAME_DB_SCHEMA:
71                 continue
72             toks = entry.split('_', 1)
73             if len(toks) < 2:
74                 raise HandledException(msg_bad_entry + entry)
75             try:
76                 i = int(toks[0])
77             except ValueError as e:
78                 raise HandledException(msg_bad_entry + entry) from e
79             if i > EXPECTED_DB_VERSION:
80                 raise HandledException(msg_too_big)
81             migrations[i] = toks[1]
82         migrations_list = []
83         for i in range(EXPECTED_DB_VERSION + 1):
84             if i not in migrations:
85                 raise HandledException(msg_missing + str(i))
86             migrations_list += [f'{i}_{migrations[i]}']
87         return migrations_list
88
89     @staticmethod
90     def _get_version_of_db(path: str) -> int:
91         """Get DB user_version, fail if outside expected range."""
92         sql_for_db_version = 'PRAGMA user_version'
93         with sql_connect(path) as conn:
94             db_version = list(conn.execute(sql_for_db_version))[0][0]
95         if db_version > EXPECTED_DB_VERSION:
96             msg = f'Wrong DB version, expected '\
97                     f'{EXPECTED_DB_VERSION}, got unknown {db_version}.'
98             raise HandledException(msg)
99         assert isinstance(db_version, int)
100         return db_version
101
102     @property
103     def _user_version(self) -> int:
104         """Get DB user_version."""
105         return self._get_version_of_db(self.path)
106
107     def _validate_schema(self) -> None:
108         """Compare found schema with what's stored at PATH_DB_SCHEMA."""
109
110         def reformat_rows(rows: list[str]) -> list[str]:
111             new_rows = []
112             for row in rows:
113                 new_row = []
114                 for subrow in row.split('\n'):
115                     subrow = subrow.rstrip()
116                     in_parentheses = 0
117                     split_at = []
118                     for i, c in enumerate(subrow):
119                         if '(' == c:
120                             in_parentheses += 1
121                         elif ')' == c:
122                             in_parentheses -= 1
123                         elif ',' == c and 0 == in_parentheses:
124                             split_at += [i + 1]
125                     prev_split = 0
126                     for i in split_at:
127                         segment = subrow[prev_split:i].strip()
128                         if len(segment) > 0:
129                             new_row += [f'    {segment}']
130                         prev_split = i
131                     segment = subrow[prev_split:].strip()
132                     if len(segment) > 0:
133                         new_row += [f'    {segment}']
134                 new_row[0] = new_row[0].lstrip()
135                 new_row[-1] = new_row[-1].lstrip()
136                 if new_row[-1] != ')' and new_row[-3][-1] != ',':
137                     new_row[-3] = new_row[-3] + ','
138                     new_row[-2:] = ['    ' + new_row[-1][:-1]] + [')']
139                 new_rows += ['\n'.join(new_row)]
140             return new_rows
141
142         sql_for_schema = 'SELECT sql FROM sqlite_master ORDER BY sql'
143         msg_err = 'Database has wrong tables schema. Diff:\n'
144         with sql_connect(self.path) as conn:
145             schema_rows = [r[0] for r in conn.execute(sql_for_schema) if r[0]]
146         schema_rows = reformat_rows(schema_rows)
147         retrieved_schema = ';\n'.join(schema_rows) + ';'
148         with open(PATH_DB_SCHEMA, 'r', encoding='utf-8') as f:
149             stored_schema = f.read().rstrip()
150         if stored_schema != retrieved_schema:
151             diff_msg = Differ().compare(retrieved_schema.splitlines(),
152                                         stored_schema.splitlines())
153             raise HandledException(msg_err + '\n'.join(diff_msg))
154
155
156 class DatabaseConnection:
157     """A single connection to the database."""
158
159     def __init__(self, db_file: DatabaseFile) -> None:
160         self.conn = sql_connect(db_file.path)
161
162     def commit(self) -> None:
163         """Commit SQL transaction."""
164         self.conn.commit()
165
166     def exec(self, code: str, inputs: tuple[Any, ...] = tuple()) -> Cursor:
167         """Add commands to SQL transaction."""
168         return self.conn.execute(code, inputs)
169
170     def exec_on_vals(self, code: str, inputs: tuple[Any, ...]) -> Cursor:
171         """Wrapper around .exec appending adequate " (?, …)" to code."""
172         q_marks_from_values = '(' + ','.join(['?'] * len(inputs)) + ')'
173         return self.exec(f'{code} {q_marks_from_values}', inputs)
174
175     def close(self) -> None:
176         """Close DB connection."""
177         self.conn.close()
178
179     def rewrite_relations(self, table_name: str, key: str, target: int | str,
180                           rows: list[list[Any]], key_index: int = 0) -> None:
181         # pylint: disable=too-many-arguments
182         """Rewrite relations in table_name to target, with rows values.
183
184         Note that single rows are expected without the column and value
185         identified by key and target, which are inserted inside the function
186         at key_index.
187         """
188         self.delete_where(table_name, key, target)
189         for row in rows:
190             values = tuple(row[:key_index] + [target] + row[key_index:])
191             self.exec_on_vals(f'INSERT INTO {table_name} VALUES', values)
192
193     def row_where(self, table_name: str, key: str,
194                   target: int | str) -> list[Row]:
195         """Return list of Rows at table where key == target."""
196         return list(self.exec(f'SELECT * FROM {table_name} WHERE {key} = ?',
197                               (target,)))
198
199     # def column_where_pattern(self,
200     #                          table_name: str,
201     #                          column: str,
202     #                          pattern: str,
203     #                          keys: list[str]) -> list[Any]:
204     #     """Return column of rows where one of keys matches pattern."""
205     #     targets = tuple([f'%{pattern}%'] * len(keys))
206     #     haystack = ' OR '.join([f'{k} LIKE ?' for k in keys])
207     #     sql = f'SELECT {column} FROM {table_name} WHERE {haystack}'
208     #     return [row[0] for row in self.exec(sql, targets)]
209
210     def column_where(self, table_name: str, column: str, key: str,
211                      target: int | str) -> list[Any]:
212         """Return column of table where key == target."""
213         return [row[0] for row in
214                 self.exec(f'SELECT {column} FROM {table_name} '
215                           f'WHERE {key} = ?', (target,))]
216
217     def column_all(self, table_name: str, column: str) -> list[Any]:
218         """Return complete column of table."""
219         return [row[0] for row in
220                 self.exec(f'SELECT {column} FROM {table_name}')]
221
222     def delete_where(self, table_name: str, key: str,
223                      target: int | str) -> None:
224         """Delete from table where key == target."""
225         self.exec(f'DELETE FROM {table_name} WHERE {key} = ?', (target,))
226
227
228 BaseModelId = TypeVar('BaseModelId', int, str)
229 BaseModelInstance = TypeVar('BaseModelInstance', bound='BaseModel[Any]')
230
231
232 class BaseModel(Generic[BaseModelId]):
233     """Template for most of the models we use/derive from the DB."""
234     table_name = ''
235     to_save_simples: list[str] = []
236     to_save_relations: list[tuple[str, str, str, int]] = []
237     versioned_defaults: dict[str, str | float] = {}
238     add_to_dict: list[str] = []
239     id_: None | BaseModelId
240     cache_: dict[BaseModelId, Self]
241     to_search: list[str] = []
242     can_create_by_id = False
243     _exists = True
244     sorters: dict[str, Callable[..., Any]] = {}
245
246     def __init__(self, id_: BaseModelId | None) -> None:
247         if isinstance(id_, int) and id_ < 1:
248             msg = f'illegal {self.__class__.__name__} ID, must be >=1: {id_}'
249             raise HandledException(msg)
250         if isinstance(id_, str) and "" == id_:
251             msg = f'illegal {self.__class__.__name__} ID, must be non-empty'
252             raise HandledException(msg)
253         self.id_ = id_
254
255     def __hash__(self) -> int:
256         hashable = [self.id_] + [getattr(self, name)
257                                  for name in self.to_save_simples]
258         for definition in self.to_save_relations:
259             attr = getattr(self, definition[2])
260             hashable += [tuple(rel.id_ for rel in attr)]
261         for name in self.to_save_versioned():
262             hashable += [hash(getattr(self, name))]
263         return hash(tuple(hashable))
264
265     def __eq__(self, other: object) -> bool:
266         if not isinstance(other, self.__class__):
267             return False
268         return hash(self) == hash(other)
269
270     def __lt__(self, other: Any) -> bool:
271         if not isinstance(other, self.__class__):
272             msg = 'cannot compare to object of different class'
273             raise HandledException(msg)
274         assert isinstance(self.id_, int)
275         assert isinstance(other.id_, int)
276         return self.id_ < other.id_
277
278     @classmethod
279     def to_save_versioned(cls) -> list[str]:
280         """Return keys of cls.versioned_defaults assuming we wanna save 'em."""
281         return list(cls.versioned_defaults.keys())
282
283     @property
284     def as_dict(self) -> dict[str, object]:
285         """Return self as (json.dumps-compatible) dict."""
286         library: dict[str, dict[str | int, object]] = {}
287         d: dict[str, object] = {'id': self.id_, '_library': library}
288         for to_save in self.to_save_simples:
289             attr = getattr(self, to_save)
290             if hasattr(attr, 'as_dict_into_reference'):
291                 d[to_save] = attr.as_dict_into_reference(library)
292             else:
293                 d[to_save] = attr
294         if len(self.to_save_versioned()) > 0:
295             d['_versioned'] = {}
296         for k in self.to_save_versioned():
297             attr = getattr(self, k)
298             assert isinstance(d['_versioned'], dict)
299             d['_versioned'][k] = attr.history
300         for r in self.to_save_relations:
301             attr_name = r[2]
302             l: list[int | str] = []
303             for rel in getattr(self, attr_name):
304                 l += [rel.as_dict_into_reference(library)]
305             d[attr_name] = l
306         for k in self.add_to_dict:
307             d[k] = [x.as_dict_into_reference(library)
308                     for x in getattr(self, k)]
309         return d
310
311     def as_dict_into_reference(self,
312                                library: dict[str, dict[str | int, object]]
313                                ) -> int | str:
314         """Return self.id_ while writing .as_dict into library."""
315         def into_library(library: dict[str, dict[str | int, object]],
316                          cls_name: str,
317                          id_: str | int,
318                          d: dict[str, object]
319                          ) -> None:
320             if cls_name not in library:
321                 library[cls_name] = {}
322             if id_ in library[cls_name]:
323                 if library[cls_name][id_] != d:
324                     msg = 'Unexpected inequality of entries for ' +\
325                             f'_library at: {cls_name}/{id_}'
326                     raise HandledException(msg)
327             else:
328                 library[cls_name][id_] = d
329         as_dict = self.as_dict
330         assert isinstance(as_dict['_library'], dict)
331         for cls_name, dict_of_objs in as_dict['_library'].items():
332             for id_, obj in dict_of_objs.items():
333                 into_library(library, cls_name, id_, obj)
334         del as_dict['_library']
335         assert self.id_ is not None
336         into_library(library, self.__class__.__name__, self.id_, as_dict)
337         assert isinstance(as_dict['id'], (int, str))
338         return as_dict['id']
339
340     @classmethod
341     def name_lowercase(cls) -> str:
342         """Convenience method to return cls' name in lowercase."""
343         return cls.__name__.lower()
344
345     @classmethod
346     def sort_by(cls, seq: list[Any], sort_key: str, default: str = 'title'
347                 ) -> str:
348         """Sort cls list by cls.sorters[sort_key] (reverse if '-'-prefixed).
349
350         Before cls.sorters[sort_key] is applied, seq is sorted by .id_, to
351         ensure predictability where parts of seq are of same sort value.
352         """
353         reverse = False
354         if len(sort_key) > 1 and '-' == sort_key[0]:
355             sort_key = sort_key[1:]
356             reverse = True
357         if sort_key not in cls.sorters:
358             sort_key = default
359         seq.sort(key=lambda x: x.id_, reverse=reverse)
360         sorter: Callable[..., Any] = cls.sorters[sort_key]
361         seq.sort(key=sorter, reverse=reverse)
362         if reverse:
363             sort_key = f'-{sort_key}'
364         return sort_key
365
366     # cache management
367     # (we primarily use the cache to ensure we work on the same object in
368     # memory no matter where and how we retrieve it, e.g. we don't want
369     # .by_id() calls to create a new object each time, but rather a pointer
370     # to the one already instantiated)
371
372     def __getattribute__(self, name: str) -> Any:
373         """Ensure fail if ._disappear() was called, except to check ._exists"""
374         if name != '_exists' and not super().__getattribute__('_exists'):
375             raise HandledException('Object does not exist.')
376         return super().__getattribute__(name)
377
378     def _disappear(self) -> None:
379         """Invalidate object, make future use raise exceptions."""
380         assert self.id_ is not None
381         if self._get_cached(self.id_):
382             self._uncache()
383         to_kill = list(self.__dict__.keys())
384         for attr in to_kill:
385             delattr(self, attr)
386         self._exists = False
387
388     @classmethod
389     def empty_cache(cls) -> None:
390         """Empty class's cache, and disappear all former inhabitants."""
391         # pylint: disable=protected-access
392         # (cause we remain within the class)
393         if hasattr(cls, 'cache_'):
394             to_disappear = list(cls.cache_.values())
395             for item in to_disappear:
396                 item._disappear()
397         cls.cache_ = {}
398
399     @classmethod
400     def get_cache(cls: type[BaseModelInstance]) -> dict[Any, BaseModel[Any]]:
401         """Get cache dictionary, create it if not yet existing."""
402         if not hasattr(cls, 'cache_'):
403             d: dict[Any, BaseModel[Any]] = {}
404             cls.cache_ = d
405         return cls.cache_
406
407     @classmethod
408     def _get_cached(cls: type[BaseModelInstance],
409                     id_: BaseModelId) -> BaseModelInstance | None:
410         """Get object of id_ from class's cache, or None if not found."""
411         cache = cls.get_cache()
412         if id_ in cache:
413             obj = cache[id_]
414             assert isinstance(obj, cls)
415             return obj
416         return None
417
418     def cache(self) -> None:
419         """Update object in class's cache.
420
421         Also calls ._disappear if cache holds older reference to object of same
422         ID, but different memory address, to avoid doing anything with
423         dangling leftovers.
424         """
425         if self.id_ is None:
426             raise HandledException('Cannot cache object without ID.')
427         cache = self.get_cache()
428         old_cached = self._get_cached(self.id_)
429         if old_cached and id(old_cached) != id(self):
430             # pylint: disable=protected-access
431             # (cause we remain within the class)
432             old_cached._disappear()
433         cache[self.id_] = self
434
435     def _uncache(self) -> None:
436         """Remove self from cache."""
437         if self.id_ is None:
438             raise HandledException('Cannot un-cache object without ID.')
439         cache = self.get_cache()
440         del cache[self.id_]
441
442     # object retrieval and generation
443
444     @classmethod
445     def from_table_row(cls: type[BaseModelInstance],
446                        # pylint: disable=unused-argument
447                        db_conn: DatabaseConnection,
448                        row: Row | list[Any]) -> BaseModelInstance:
449         """Make from DB row (sans relations), update DB cache with it."""
450         obj = cls(*row)
451         assert obj.id_ is not None
452         for attr_name in cls.to_save_versioned():
453             attr = getattr(obj, attr_name)
454             table_name = attr.table_name
455             for row_ in db_conn.row_where(table_name, 'parent', obj.id_):
456                 attr.history_from_row(row_)
457         obj.cache()
458         return obj
459
460     @classmethod
461     def by_id(cls, db_conn: DatabaseConnection, id_: BaseModelId) -> Self:
462         """Retrieve by id_, on failure throw NotFoundException.
463
464         First try to get from cls.cache_, only then check DB; if found,
465         put into cache.
466         """
467         obj = None
468         if id_ is not None:
469             obj = cls._get_cached(id_)
470             if not obj:
471                 for row in db_conn.row_where(cls.table_name, 'id', id_):
472                     obj = cls.from_table_row(db_conn, row)
473                     break
474         if obj:
475             return obj
476         raise NotFoundException(f'found no object of ID {id_}')
477
478     @classmethod
479     def by_id_or_create(cls, db_conn: DatabaseConnection,
480                         id_: BaseModelId | None
481                         ) -> Self:
482         """Wrapper around .by_id, creating (not caching/saving) if not find."""
483         if not cls.can_create_by_id:
484             raise HandledException('Class cannot .by_id_or_create.')
485         if id_ is None:
486             return cls(None)
487         try:
488             return cls.by_id(db_conn, id_)
489         except NotFoundException:
490             return cls(id_)
491
492     @classmethod
493     def all(cls: type[BaseModelInstance],
494             db_conn: DatabaseConnection) -> list[BaseModelInstance]:
495         """Collect all objects of class into list.
496
497         Note that this primarily returns the contents of the cache, and only
498         _expands_ that by additional findings in the DB. This assumes the
499         cache is always instantly cleaned of any items that would be removed
500         from the DB.
501         """
502         items: dict[BaseModelId, BaseModelInstance] = {}
503         for k, v in cls.get_cache().items():
504             assert isinstance(v, cls)
505             items[k] = v
506         already_recorded = items.keys()
507         for id_ in db_conn.column_all(cls.table_name, 'id'):
508             if id_ not in already_recorded:
509                 item = cls.by_id(db_conn, id_)
510                 assert item.id_ is not None
511                 items[item.id_] = item
512         return list(items.values())
513
514     @classmethod
515     def by_date_range_with_limits(cls: type[BaseModelInstance],
516                                   db_conn: DatabaseConnection,
517                                   date_range: tuple[str, str],
518                                   date_col: str = 'day'
519                                   ) -> tuple[list[BaseModelInstance], str,
520                                              str]:
521         """Return list of items in DB within (closed) date_range interval.
522
523         If no range values provided, defaults them to 'yesterday' and
524         'tomorrow'. Knows to properly interpret these and 'today' as value.
525         """
526         start_str = date_range[0] if date_range[0] else 'yesterday'
527         end_str = date_range[1] if date_range[1] else 'tomorrow'
528         start_date = valid_date(start_str)
529         end_date = valid_date(end_str)
530         items = []
531         sql = f'SELECT id FROM {cls.table_name} '
532         sql += f'WHERE {date_col} >= ? AND {date_col} <= ?'
533         for row in db_conn.exec(sql, (start_date, end_date)):
534             items += [cls.by_id(db_conn, row[0])]
535         return items, start_date, end_date
536
537     @classmethod
538     def matching(cls: type[BaseModelInstance], db_conn: DatabaseConnection,
539                  pattern: str) -> list[BaseModelInstance]:
540         """Return all objects whose .to_search match pattern."""
541         items = cls.all(db_conn)
542         if pattern:
543             filtered = []
544             for item in items:
545                 for attr_name in cls.to_search:
546                     toks = attr_name.split('.')
547                     parent = item
548                     for tok in toks:
549                         attr = getattr(parent, tok)
550                         parent = attr
551                     if pattern in attr:
552                         filtered += [item]
553                         break
554             return filtered
555         return items
556
557     # database writing
558
559     def save(self, db_conn: DatabaseConnection) -> None:
560         """Write self to DB and cache and ensure .id_.
561
562         Write both to DB, and to cache. To DB, write .id_ and attributes
563         listed in cls.to_save_[simples|versioned|_relations].
564
565         Ensure self.id_ by setting it to what the DB command returns as the
566         last saved row's ID (cursor.lastrowid), EXCEPT if self.id_ already
567         exists as a 'str', which implies we do our own ID creation (so far
568         only the case with the Day class, where it's to be a date string.
569         """
570         values = tuple([self.id_] + [getattr(self, key)
571                                      for key in self.to_save_simples])
572         table_name = self.table_name
573         cursor = db_conn.exec_on_vals(f'REPLACE INTO {table_name} VALUES',
574                                       values)
575         if not isinstance(self.id_, str):
576             self.id_ = cursor.lastrowid  # type: ignore[assignment]
577         self.cache()
578         for attr_name in self.to_save_versioned():
579             getattr(self, attr_name).save(db_conn)
580         for table, column, attr_name, key_index in self.to_save_relations:
581             assert isinstance(self.id_, (int, str))
582             db_conn.rewrite_relations(table, column, self.id_,
583                                       [[i.id_] for i
584                                        in getattr(self, attr_name)], key_index)
585
586     def remove(self, db_conn: DatabaseConnection) -> None:
587         """Remove from DB and cache, including dependencies."""
588         if self.id_ is None or self._get_cached(self.id_) is None:
589             raise HandledException('cannot remove unsaved item')
590         for attr_name in self.to_save_versioned():
591             getattr(self, attr_name).remove(db_conn)
592         for table, column, attr_name, _ in self.to_save_relations:
593             db_conn.delete_where(table, column, self.id_)
594         self._uncache()
595         db_conn.delete_where(self.table_name, 'id', self.id_)
596         self._disappear()