home · contact · privacy
Turn TodoNode into full class with .as_dict, with result expand Day tests.
[plomtask] / plomtask / db.py
1 """Database management."""
2 from __future__ import annotations
3 from os import listdir
4 from os.path import isfile
5 from difflib import Differ
6 from sqlite3 import connect as sql_connect, Cursor, Row
7 from typing import Any, Self, TypeVar, Generic
8 from plomtask.exceptions import HandledException, NotFoundException
9 from plomtask.dating import valid_date
10
11 EXPECTED_DB_VERSION = 5
12 MIGRATIONS_DIR = 'migrations'
13 FILENAME_DB_SCHEMA = f'init_{EXPECTED_DB_VERSION}.sql'
14 PATH_DB_SCHEMA = f'{MIGRATIONS_DIR}/{FILENAME_DB_SCHEMA}'
15
16
17 class UnmigratedDbException(HandledException):
18     """To identify case of unmigrated DB file."""
19
20
21 class DatabaseFile:
22     """Represents the sqlite3 database's file."""
23     # pylint: disable=too-few-public-methods
24
25     def __init__(self, path: str) -> None:
26         self.path = path
27         self._check()
28
29     @classmethod
30     def create_at(cls, path: str) -> DatabaseFile:
31         """Make new DB file at path."""
32         with sql_connect(path) as conn:
33             with open(PATH_DB_SCHEMA, 'r', encoding='utf-8') as f:
34                 conn.executescript(f.read())
35             conn.execute(f'PRAGMA user_version = {EXPECTED_DB_VERSION}')
36         return cls(path)
37
38     @classmethod
39     def migrate(cls, path: str) -> DatabaseFile:
40         """Apply migrations from_version to EXPECTED_DB_VERSION."""
41         migrations = cls._available_migrations()
42         from_version = cls._get_version_of_db(path)
43         migrations_todo = migrations[from_version+1:]
44         for j, filename in enumerate(migrations_todo):
45             with sql_connect(path) as conn:
46                 with open(f'{MIGRATIONS_DIR}/{filename}', 'r',
47                           encoding='utf-8') as f:
48                     conn.executescript(f.read())
49             user_version = from_version + j + 1
50             with sql_connect(path) as conn:
51                 conn.execute(f'PRAGMA user_version = {user_version}')
52         return cls(path)
53
54     def _check(self) -> None:
55         """Check file exists, and is of proper DB version and schema."""
56         if not isfile(self.path):
57             raise NotFoundException
58         if self._user_version != EXPECTED_DB_VERSION:
59             raise UnmigratedDbException()
60         self._validate_schema()
61
62     @staticmethod
63     def _available_migrations() -> list[str]:
64         """Validate migrations directory and return sorted entries."""
65         msg_too_big = 'Migration directory points beyond expected DB version.'
66         msg_bad_entry = 'Migration directory contains unexpected entry: '
67         msg_missing = 'Migration directory misses migration of number: '
68         migrations = {}
69         for entry in listdir(MIGRATIONS_DIR):
70             if entry == FILENAME_DB_SCHEMA:
71                 continue
72             toks = entry.split('_', 1)
73             if len(toks) < 2:
74                 raise HandledException(msg_bad_entry + entry)
75             try:
76                 i = int(toks[0])
77             except ValueError as e:
78                 raise HandledException(msg_bad_entry + entry) from e
79             if i > EXPECTED_DB_VERSION:
80                 raise HandledException(msg_too_big)
81             migrations[i] = toks[1]
82         migrations_list = []
83         for i in range(EXPECTED_DB_VERSION + 1):
84             if i not in migrations:
85                 raise HandledException(msg_missing + str(i))
86             migrations_list += [f'{i}_{migrations[i]}']
87         return migrations_list
88
89     @staticmethod
90     def _get_version_of_db(path: str) -> int:
91         """Get DB user_version, fail if outside expected range."""
92         sql_for_db_version = 'PRAGMA user_version'
93         with sql_connect(path) as conn:
94             db_version = list(conn.execute(sql_for_db_version))[0][0]
95         if db_version > EXPECTED_DB_VERSION:
96             msg = f'Wrong DB version, expected '\
97                     f'{EXPECTED_DB_VERSION}, got unknown {db_version}.'
98             raise HandledException(msg)
99         assert isinstance(db_version, int)
100         return db_version
101
102     @property
103     def _user_version(self) -> int:
104         """Get DB user_version."""
105         return self._get_version_of_db(self.path)
106
107     def _validate_schema(self) -> None:
108         """Compare found schema with what's stored at PATH_DB_SCHEMA."""
109
110         def reformat_rows(rows: list[str]) -> list[str]:
111             new_rows = []
112             for row in rows:
113                 new_row = []
114                 for subrow in row.split('\n'):
115                     subrow = subrow.rstrip()
116                     in_parentheses = 0
117                     split_at = []
118                     for i, c in enumerate(subrow):
119                         if '(' == c:
120                             in_parentheses += 1
121                         elif ')' == c:
122                             in_parentheses -= 1
123                         elif ',' == c and 0 == in_parentheses:
124                             split_at += [i + 1]
125                     prev_split = 0
126                     for i in split_at:
127                         segment = subrow[prev_split:i].strip()
128                         if len(segment) > 0:
129                             new_row += [f'    {segment}']
130                         prev_split = i
131                     segment = subrow[prev_split:].strip()
132                     if len(segment) > 0:
133                         new_row += [f'    {segment}']
134                 new_row[0] = new_row[0].lstrip()
135                 new_row[-1] = new_row[-1].lstrip()
136                 if new_row[-1] != ')' and new_row[-3][-1] != ',':
137                     new_row[-3] = new_row[-3] + ','
138                     new_row[-2:] = ['    ' + new_row[-1][:-1]] + [')']
139                 new_rows += ['\n'.join(new_row)]
140             return new_rows
141
142         sql_for_schema = 'SELECT sql FROM sqlite_master ORDER BY sql'
143         msg_err = 'Database has wrong tables schema. Diff:\n'
144         with sql_connect(self.path) as conn:
145             schema_rows = [r[0] for r in conn.execute(sql_for_schema) if r[0]]
146         schema_rows = reformat_rows(schema_rows)
147         retrieved_schema = ';\n'.join(schema_rows) + ';'
148         with open(PATH_DB_SCHEMA, 'r', encoding='utf-8') as f:
149             stored_schema = f.read().rstrip()
150         if stored_schema != retrieved_schema:
151             diff_msg = Differ().compare(retrieved_schema.splitlines(),
152                                         stored_schema.splitlines())
153             raise HandledException(msg_err + '\n'.join(diff_msg))
154
155
156 class DatabaseConnection:
157     """A single connection to the database."""
158
159     def __init__(self, db_file: DatabaseFile) -> None:
160         self.conn = sql_connect(db_file.path)
161
162     def commit(self) -> None:
163         """Commit SQL transaction."""
164         self.conn.commit()
165
166     def exec(self, code: str, inputs: tuple[Any, ...] = tuple()) -> Cursor:
167         """Add commands to SQL transaction."""
168         return self.conn.execute(code, inputs)
169
170     def exec_on_vals(self, code: str, inputs: tuple[Any, ...]) -> Cursor:
171         """Wrapper around .exec appending adequate " (?, …)" to code."""
172         q_marks_from_values = '(' + ','.join(['?'] * len(inputs)) + ')'
173         return self.exec(f'{code} {q_marks_from_values}', inputs)
174
175     def close(self) -> None:
176         """Close DB connection."""
177         self.conn.close()
178
179     def rewrite_relations(self, table_name: str, key: str, target: int | str,
180                           rows: list[list[Any]], key_index: int = 0) -> None:
181         # pylint: disable=too-many-arguments
182         """Rewrite relations in table_name to target, with rows values.
183
184         Note that single rows are expected without the column and value
185         identified by key and target, which are inserted inside the function
186         at key_index.
187         """
188         self.delete_where(table_name, key, target)
189         for row in rows:
190             values = tuple(row[:key_index] + [target] + row[key_index:])
191             self.exec_on_vals(f'INSERT INTO {table_name} VALUES', values)
192
193     def row_where(self, table_name: str, key: str,
194                   target: int | str) -> list[Row]:
195         """Return list of Rows at table where key == target."""
196         return list(self.exec(f'SELECT * FROM {table_name} WHERE {key} = ?',
197                               (target,)))
198
199     # def column_where_pattern(self,
200     #                          table_name: str,
201     #                          column: str,
202     #                          pattern: str,
203     #                          keys: list[str]) -> list[Any]:
204     #     """Return column of rows where one of keys matches pattern."""
205     #     targets = tuple([f'%{pattern}%'] * len(keys))
206     #     haystack = ' OR '.join([f'{k} LIKE ?' for k in keys])
207     #     sql = f'SELECT {column} FROM {table_name} WHERE {haystack}'
208     #     return [row[0] for row in self.exec(sql, targets)]
209
210     def column_where(self, table_name: str, column: str, key: str,
211                      target: int | str) -> list[Any]:
212         """Return column of table where key == target."""
213         return [row[0] for row in
214                 self.exec(f'SELECT {column} FROM {table_name} '
215                           f'WHERE {key} = ?', (target,))]
216
217     def column_all(self, table_name: str, column: str) -> list[Any]:
218         """Return complete column of table."""
219         return [row[0] for row in
220                 self.exec(f'SELECT {column} FROM {table_name}')]
221
222     def delete_where(self, table_name: str, key: str,
223                      target: int | str) -> None:
224         """Delete from table where key == target."""
225         self.exec(f'DELETE FROM {table_name} WHERE {key} = ?', (target,))
226
227
228 BaseModelId = TypeVar('BaseModelId', int, str)
229 BaseModelInstance = TypeVar('BaseModelInstance', bound='BaseModel[Any]')
230
231
232 class BaseModel(Generic[BaseModelId]):
233     """Template for most of the models we use/derive from the DB."""
234     table_name = ''
235     to_save: list[str] = []
236     to_save_versioned: list[str] = []
237     to_save_relations: list[tuple[str, str, str, int]] = []
238     id_: None | BaseModelId
239     cache_: dict[BaseModelId, Self]
240     to_search: list[str] = []
241     can_create_by_id = False
242     _exists = True
243
244     def __init__(self, id_: BaseModelId | None) -> None:
245         if isinstance(id_, int) and id_ < 1:
246             msg = f'illegal {self.__class__.__name__} ID, must be >=1: {id_}'
247             raise HandledException(msg)
248         if isinstance(id_, str) and "" == id_:
249             msg = f'illegal {self.__class__.__name__} ID, must be non-empty'
250             raise HandledException(msg)
251         self.id_ = id_
252
253     def __hash__(self) -> int:
254         hashable = [self.id_] + [getattr(self, name) for name in self.to_save]
255         for definition in self.to_save_relations:
256             attr = getattr(self, definition[2])
257             hashable += [tuple(rel.id_ for rel in attr)]
258         for name in self.to_save_versioned:
259             hashable += [hash(getattr(self, name))]
260         return hash(tuple(hashable))
261
262     def __eq__(self, other: object) -> bool:
263         if not isinstance(other, self.__class__):
264             return False
265         return hash(self) == hash(other)
266
267     def __lt__(self, other: Any) -> bool:
268         if not isinstance(other, self.__class__):
269             msg = 'cannot compare to object of different class'
270             raise HandledException(msg)
271         assert isinstance(self.id_, int)
272         assert isinstance(other.id_, int)
273         return self.id_ < other.id_
274
275     @property
276     def as_dict(self) -> dict[str, object]:
277         """Return self as (json.dumps-compatible) dict."""
278         library: dict[str, dict[str | int, object]] = {}
279         d: dict[str, object] = {'id': self.id_, '_library': library}
280         for to_save in self.to_save:
281             attr = getattr(self, to_save)
282             if hasattr(attr, 'as_dict_into_reference'):
283                 d[to_save] = attr.as_dict_into_reference(library)
284             else:
285                 d[to_save] = attr
286         if len(self.to_save_versioned) > 0:
287             d['_versioned'] = {}
288         for k in self.to_save_versioned:
289             attr = getattr(self, k)
290             assert isinstance(d['_versioned'], dict)
291             d['_versioned'][k] = attr.history
292         for r in self.to_save_relations:
293             attr_name = r[2]
294             l: list[int | str] = []
295             for rel in getattr(self, attr_name):
296                 l += [rel.as_dict_into_reference(library)]
297             d[attr_name] = l
298         return d
299
300     def as_dict_into_reference(self,
301                                library: dict[str, dict[str | int, object]]
302                                ) -> int | str:
303         """Return self.id_ while writing .as_dict into library."""
304         def into_library(library: dict[str, dict[str | int, object]],
305                          cls_name: str,
306                          id_: str | int,
307                          d: dict[str, object]
308                          ) -> None:
309             if cls_name not in library:
310                 library[cls_name] = {}
311             if id_ in library[cls_name]:
312                 if library[cls_name][id_] != d:
313                     msg = 'Unexpected inequality of entries for ' +\
314                             f'_library at: {cls_name}/{id_}'
315                     raise HandledException(msg)
316             else:
317                 library[cls_name][id_] = d
318         as_dict = self.as_dict
319         assert isinstance(as_dict['_library'], dict)
320         for cls_name, dict_of_objs in as_dict['_library'].items():
321             for id_, obj in dict_of_objs.items():
322                 into_library(library, cls_name, id_, obj)
323         del as_dict['_library']
324         assert self.id_ is not None
325         into_library(library, self.__class__.__name__, self.id_, as_dict)
326         assert isinstance(as_dict['id'], (int, str))
327         return as_dict['id']
328
329     # cache management
330     # (we primarily use the cache to ensure we work on the same object in
331     # memory no matter where and how we retrieve it, e.g. we don't want
332     # .by_id() calls to create a new object each time, but rather a pointer
333     # to the one already instantiated)
334
335     def __getattribute__(self, name: str) -> Any:
336         """Ensure fail if ._disappear() was called, except to check ._exists"""
337         if name != '_exists' and not super().__getattribute__('_exists'):
338             raise HandledException('Object does not exist.')
339         return super().__getattribute__(name)
340
341     def _disappear(self) -> None:
342         """Invalidate object, make future use raise exceptions."""
343         assert self.id_ is not None
344         if self._get_cached(self.id_):
345             self._uncache()
346         to_kill = list(self.__dict__.keys())
347         for attr in to_kill:
348             delattr(self, attr)
349         self._exists = False
350
351     @classmethod
352     def empty_cache(cls) -> None:
353         """Empty class's cache, and disappear all former inhabitants."""
354         # pylint: disable=protected-access
355         # (cause we remain within the class)
356         if hasattr(cls, 'cache_'):
357             to_disappear = list(cls.cache_.values())
358             for item in to_disappear:
359                 item._disappear()
360         cls.cache_ = {}
361
362     @classmethod
363     def get_cache(cls: type[BaseModelInstance]) -> dict[Any, BaseModel[Any]]:
364         """Get cache dictionary, create it if not yet existing."""
365         if not hasattr(cls, 'cache_'):
366             d: dict[Any, BaseModel[Any]] = {}
367             cls.cache_ = d
368         return cls.cache_
369
370     @classmethod
371     def _get_cached(cls: type[BaseModelInstance],
372                     id_: BaseModelId) -> BaseModelInstance | None:
373         """Get object of id_ from class's cache, or None if not found."""
374         cache = cls.get_cache()
375         if id_ in cache:
376             obj = cache[id_]
377             assert isinstance(obj, cls)
378             return obj
379         return None
380
381     def cache(self) -> None:
382         """Update object in class's cache.
383
384         Also calls ._disappear if cache holds older reference to object of same
385         ID, but different memory address, to avoid doing anything with
386         dangling leftovers.
387         """
388         if self.id_ is None:
389             raise HandledException('Cannot cache object without ID.')
390         cache = self.get_cache()
391         old_cached = self._get_cached(self.id_)
392         if old_cached and id(old_cached) != id(self):
393             # pylint: disable=protected-access
394             # (cause we remain within the class)
395             old_cached._disappear()
396         cache[self.id_] = self
397
398     def _uncache(self) -> None:
399         """Remove self from cache."""
400         if self.id_ is None:
401             raise HandledException('Cannot un-cache object without ID.')
402         cache = self.get_cache()
403         del cache[self.id_]
404
405     # object retrieval and generation
406
407     @classmethod
408     def from_table_row(cls: type[BaseModelInstance],
409                        # pylint: disable=unused-argument
410                        db_conn: DatabaseConnection,
411                        row: Row | list[Any]) -> BaseModelInstance:
412         """Make from DB row (sans relations), update DB cache with it."""
413         obj = cls(*row)
414         assert obj.id_ is not None
415         for attr_name in cls.to_save_versioned:
416             attr = getattr(obj, attr_name)
417             table_name = attr.table_name
418             for row_ in db_conn.row_where(table_name, 'parent', obj.id_):
419                 attr.history_from_row(row_)
420         obj.cache()
421         return obj
422
423     @classmethod
424     def by_id(cls, db_conn: DatabaseConnection, id_: BaseModelId) -> Self:
425         """Retrieve by id_, on failure throw NotFoundException.
426
427         First try to get from cls.cache_, only then check DB; if found,
428         put into cache.
429         """
430         obj = None
431         if id_ is not None:
432             obj = cls._get_cached(id_)
433             if not obj:
434                 for row in db_conn.row_where(cls.table_name, 'id', id_):
435                     obj = cls.from_table_row(db_conn, row)
436                     break
437         if obj:
438             return obj
439         raise NotFoundException(f'found no object of ID {id_}')
440
441     @classmethod
442     def by_id_or_create(cls, db_conn: DatabaseConnection,
443                         id_: BaseModelId | None
444                         ) -> Self:
445         """Wrapper around .by_id, creating (not caching/saving) if not find."""
446         if not cls.can_create_by_id:
447             raise HandledException('Class cannot .by_id_or_create.')
448         if id_ is None:
449             return cls(None)
450         try:
451             return cls.by_id(db_conn, id_)
452         except NotFoundException:
453             return cls(id_)
454
455     @classmethod
456     def all(cls: type[BaseModelInstance],
457             db_conn: DatabaseConnection) -> list[BaseModelInstance]:
458         """Collect all objects of class into list.
459
460         Note that this primarily returns the contents of the cache, and only
461         _expands_ that by additional findings in the DB. This assumes the
462         cache is always instantly cleaned of any items that would be removed
463         from the DB.
464         """
465         items: dict[BaseModelId, BaseModelInstance] = {}
466         for k, v in cls.get_cache().items():
467             assert isinstance(v, cls)
468             items[k] = v
469         already_recorded = items.keys()
470         for id_ in db_conn.column_all(cls.table_name, 'id'):
471             if id_ not in already_recorded:
472                 item = cls.by_id(db_conn, id_)
473                 assert item.id_ is not None
474                 items[item.id_] = item
475         return list(items.values())
476
477     @classmethod
478     def by_date_range_with_limits(cls: type[BaseModelInstance],
479                                   db_conn: DatabaseConnection,
480                                   date_range: tuple[str, str],
481                                   date_col: str = 'day'
482                                   ) -> tuple[list[BaseModelInstance], str,
483                                              str]:
484         """Return list of items in database within (open) date_range interval.
485
486         If no range values provided, defaults them to 'yesterday' and
487         'tomorrow'. Knows to properly interpret these and 'today' as value.
488         """
489         start_str = date_range[0] if date_range[0] else 'yesterday'
490         end_str = date_range[1] if date_range[1] else 'tomorrow'
491         start_date = valid_date(start_str)
492         end_date = valid_date(end_str)
493         items = []
494         sql = f'SELECT id FROM {cls.table_name} '
495         sql += f'WHERE {date_col} >= ? AND {date_col} <= ?'
496         for row in db_conn.exec(sql, (start_date, end_date)):
497             items += [cls.by_id(db_conn, row[0])]
498         return items, start_date, end_date
499
500     @classmethod
501     def matching(cls: type[BaseModelInstance], db_conn: DatabaseConnection,
502                  pattern: str) -> list[BaseModelInstance]:
503         """Return all objects whose .to_search match pattern."""
504         items = cls.all(db_conn)
505         if pattern:
506             filtered = []
507             for item in items:
508                 for attr_name in cls.to_search:
509                     toks = attr_name.split('.')
510                     parent = item
511                     for tok in toks:
512                         attr = getattr(parent, tok)
513                         parent = attr
514                     if pattern in attr:
515                         filtered += [item]
516                         break
517             return filtered
518         return items
519
520     # database writing
521
522     def save(self, db_conn: DatabaseConnection) -> None:
523         """Write self to DB and cache and ensure .id_.
524
525         Write both to DB, and to cache. To DB, write .id_ and attributes
526         listed in cls.to_save[_versioned|_relations].
527
528         Ensure self.id_ by setting it to what the DB command returns as the
529         last saved row's ID (cursor.lastrowid), EXCEPT if self.id_ already
530         exists as a 'str', which implies we do our own ID creation (so far
531         only the case with the Day class, where it's to be a date string.
532         """
533         values = tuple([self.id_] + [getattr(self, key)
534                                      for key in self.to_save])
535         table_name = self.table_name
536         cursor = db_conn.exec_on_vals(f'REPLACE INTO {table_name} VALUES',
537                                       values)
538         if not isinstance(self.id_, str):
539             self.id_ = cursor.lastrowid  # type: ignore[assignment]
540         self.cache()
541         for attr_name in self.to_save_versioned:
542             getattr(self, attr_name).save(db_conn)
543         for table, column, attr_name, key_index in self.to_save_relations:
544             assert isinstance(self.id_, (int, str))
545             db_conn.rewrite_relations(table, column, self.id_,
546                                       [[i.id_] for i
547                                        in getattr(self, attr_name)], key_index)
548
549     def remove(self, db_conn: DatabaseConnection) -> None:
550         """Remove from DB and cache, including dependencies."""
551         if self.id_ is None or self._get_cached(self.id_) is None:
552             raise HandledException('cannot remove unsaved item')
553         for attr_name in self.to_save_versioned:
554             getattr(self, attr_name).remove(db_conn)
555         for table, column, attr_name, _ in self.to_save_relations:
556             db_conn.delete_where(table, column, self.id_)
557         self._uncache()
558         db_conn.delete_where(self.table_name, 'id', self.id_)
559         self._disappear()