home · contact · privacy
Refactor request handler delete or retrieving items on POST.
[plomtask] / plomtask / db.py
1 """Database management."""
2 from __future__ import annotations
3 from os import listdir
4 from os.path import isfile
5 from difflib import Differ
6 from sqlite3 import connect as sql_connect, Cursor, Row
7 from typing import Any, Self, TypeVar, Generic
8 from plomtask.exceptions import HandledException, NotFoundException
9 from plomtask.dating import valid_date
10
11 EXPECTED_DB_VERSION = 5
12 MIGRATIONS_DIR = 'migrations'
13 FILENAME_DB_SCHEMA = f'init_{EXPECTED_DB_VERSION}.sql'
14 PATH_DB_SCHEMA = f'{MIGRATIONS_DIR}/{FILENAME_DB_SCHEMA}'
15
16
17 class UnmigratedDbException(HandledException):
18     """To identify case of unmigrated DB file."""
19
20
21 class DatabaseFile:
22     """Represents the sqlite3 database's file."""
23     # pylint: disable=too-few-public-methods
24
25     def __init__(self, path: str) -> None:
26         self.path = path
27         self._check()
28
29     @classmethod
30     def create_at(cls, path: str) -> DatabaseFile:
31         """Make new DB file at path."""
32         with sql_connect(path) as conn:
33             with open(PATH_DB_SCHEMA, 'r', encoding='utf-8') as f:
34                 conn.executescript(f.read())
35             conn.execute(f'PRAGMA user_version = {EXPECTED_DB_VERSION}')
36         return cls(path)
37
38     @classmethod
39     def migrate(cls, path: str) -> DatabaseFile:
40         """Apply migrations from_version to EXPECTED_DB_VERSION."""
41         migrations = cls._available_migrations()
42         from_version = cls._get_version_of_db(path)
43         migrations_todo = migrations[from_version+1:]
44         for j, filename in enumerate(migrations_todo):
45             with sql_connect(path) as conn:
46                 with open(f'{MIGRATIONS_DIR}/{filename}', 'r',
47                           encoding='utf-8') as f:
48                     conn.executescript(f.read())
49             user_version = from_version + j + 1
50             with sql_connect(path) as conn:
51                 conn.execute(f'PRAGMA user_version = {user_version}')
52         return cls(path)
53
54     def _check(self) -> None:
55         """Check file exists, and is of proper DB version and schema."""
56         if not isfile(self.path):
57             raise NotFoundException
58         if self._user_version != EXPECTED_DB_VERSION:
59             raise UnmigratedDbException()
60         self._validate_schema()
61
62     @staticmethod
63     def _available_migrations() -> list[str]:
64         """Validate migrations directory and return sorted entries."""
65         msg_too_big = 'Migration directory points beyond expected DB version.'
66         msg_bad_entry = 'Migration directory contains unexpected entry: '
67         msg_missing = 'Migration directory misses migration of number: '
68         migrations = {}
69         for entry in listdir(MIGRATIONS_DIR):
70             if entry == FILENAME_DB_SCHEMA:
71                 continue
72             toks = entry.split('_', 1)
73             if len(toks) < 2:
74                 raise HandledException(msg_bad_entry + entry)
75             try:
76                 i = int(toks[0])
77             except ValueError as e:
78                 raise HandledException(msg_bad_entry + entry) from e
79             if i > EXPECTED_DB_VERSION:
80                 raise HandledException(msg_too_big)
81             migrations[i] = toks[1]
82         migrations_list = []
83         for i in range(EXPECTED_DB_VERSION + 1):
84             if i not in migrations:
85                 raise HandledException(msg_missing + str(i))
86             migrations_list += [f'{i}_{migrations[i]}']
87         return migrations_list
88
89     @staticmethod
90     def _get_version_of_db(path: str) -> int:
91         """Get DB user_version, fail if outside expected range."""
92         sql_for_db_version = 'PRAGMA user_version'
93         with sql_connect(path) as conn:
94             db_version = list(conn.execute(sql_for_db_version))[0][0]
95         if db_version > EXPECTED_DB_VERSION:
96             msg = f'Wrong DB version, expected '\
97                     f'{EXPECTED_DB_VERSION}, got unknown {db_version}.'
98             raise HandledException(msg)
99         assert isinstance(db_version, int)
100         return db_version
101
102     @property
103     def _user_version(self) -> int:
104         """Get DB user_version."""
105         return self._get_version_of_db(self.path)
106
107     def _validate_schema(self) -> None:
108         """Compare found schema with what's stored at PATH_DB_SCHEMA."""
109
110         def reformat_rows(rows: list[str]) -> list[str]:
111             new_rows = []
112             for row in rows:
113                 new_row = []
114                 for subrow in row.split('\n'):
115                     subrow = subrow.rstrip()
116                     in_parentheses = 0
117                     split_at = []
118                     for i, c in enumerate(subrow):
119                         if '(' == c:
120                             in_parentheses += 1
121                         elif ')' == c:
122                             in_parentheses -= 1
123                         elif ',' == c and 0 == in_parentheses:
124                             split_at += [i + 1]
125                     prev_split = 0
126                     for i in split_at:
127                         segment = subrow[prev_split:i].strip()
128                         if len(segment) > 0:
129                             new_row += [f'    {segment}']
130                         prev_split = i
131                     segment = subrow[prev_split:].strip()
132                     if len(segment) > 0:
133                         new_row += [f'    {segment}']
134                 new_row[0] = new_row[0].lstrip()
135                 new_row[-1] = new_row[-1].lstrip()
136                 if new_row[-1] != ')' and new_row[-3][-1] != ',':
137                     new_row[-3] = new_row[-3] + ','
138                     new_row[-2:] = ['    ' + new_row[-1][:-1]] + [')']
139                 new_rows += ['\n'.join(new_row)]
140             return new_rows
141
142         sql_for_schema = 'SELECT sql FROM sqlite_master ORDER BY sql'
143         msg_err = 'Database has wrong tables schema. Diff:\n'
144         with sql_connect(self.path) as conn:
145             schema_rows = [r[0] for r in conn.execute(sql_for_schema) if r[0]]
146         schema_rows = reformat_rows(schema_rows)
147         retrieved_schema = ';\n'.join(schema_rows) + ';'
148         with open(PATH_DB_SCHEMA, 'r', encoding='utf-8') as f:
149             stored_schema = f.read().rstrip()
150         if stored_schema != retrieved_schema:
151             diff_msg = Differ().compare(retrieved_schema.splitlines(),
152                                         stored_schema.splitlines())
153             raise HandledException(msg_err + '\n'.join(diff_msg))
154
155
156 class DatabaseConnection:
157     """A single connection to the database."""
158
159     def __init__(self, db_file: DatabaseFile) -> None:
160         self.conn = sql_connect(db_file.path)
161
162     def commit(self) -> None:
163         """Commit SQL transaction."""
164         self.conn.commit()
165
166     def exec(self, code: str, inputs: tuple[Any, ...] = tuple()) -> Cursor:
167         """Add commands to SQL transaction."""
168         return self.conn.execute(code, inputs)
169
170     def exec_on_vals(self, code: str, inputs: tuple[Any, ...]) -> Cursor:
171         """Wrapper around .exec appending adequate " (?, …)" to code."""
172         q_marks_from_values = '(' + ','.join(['?'] * len(inputs)) + ')'
173         return self.exec(f'{code} {q_marks_from_values}', inputs)
174
175     def close(self) -> None:
176         """Close DB connection."""
177         self.conn.close()
178
179     def rewrite_relations(self, table_name: str, key: str, target: int | str,
180                           rows: list[list[Any]], key_index: int = 0) -> None:
181         # pylint: disable=too-many-arguments
182         """Rewrite relations in table_name to target, with rows values.
183
184         Note that single rows are expected without the column and value
185         identified by key and target, which are inserted inside the function
186         at key_index.
187         """
188         self.delete_where(table_name, key, target)
189         for row in rows:
190             values = tuple(row[:key_index] + [target] + row[key_index:])
191             self.exec_on_vals(f'INSERT INTO {table_name} VALUES', values)
192
193     def row_where(self, table_name: str, key: str,
194                   target: int | str) -> list[Row]:
195         """Return list of Rows at table where key == target."""
196         return list(self.exec(f'SELECT * FROM {table_name} WHERE {key} = ?',
197                               (target,)))
198
199     # def column_where_pattern(self,
200     #                          table_name: str,
201     #                          column: str,
202     #                          pattern: str,
203     #                          keys: list[str]) -> list[Any]:
204     #     """Return column of rows where one of keys matches pattern."""
205     #     targets = tuple([f'%{pattern}%'] * len(keys))
206     #     haystack = ' OR '.join([f'{k} LIKE ?' for k in keys])
207     #     sql = f'SELECT {column} FROM {table_name} WHERE {haystack}'
208     #     return [row[0] for row in self.exec(sql, targets)]
209
210     def column_where(self, table_name: str, column: str, key: str,
211                      target: int | str) -> list[Any]:
212         """Return column of table where key == target."""
213         return [row[0] for row in
214                 self.exec(f'SELECT {column} FROM {table_name} '
215                           f'WHERE {key} = ?', (target,))]
216
217     def column_all(self, table_name: str, column: str) -> list[Any]:
218         """Return complete column of table."""
219         return [row[0] for row in
220                 self.exec(f'SELECT {column} FROM {table_name}')]
221
222     def delete_where(self, table_name: str, key: str,
223                      target: int | str) -> None:
224         """Delete from table where key == target."""
225         self.exec(f'DELETE FROM {table_name} WHERE {key} = ?', (target,))
226
227
228 BaseModelId = TypeVar('BaseModelId', int, str)
229 BaseModelInstance = TypeVar('BaseModelInstance', bound='BaseModel[Any]')
230
231
232 class BaseModel(Generic[BaseModelId]):
233     """Template for most of the models we use/derive from the DB."""
234     table_name = ''
235     to_save: list[str] = []
236     to_save_versioned: list[str] = []
237     to_save_relations: list[tuple[str, str, str, int]] = []
238     add_to_dict: list[str] = []
239     id_: None | BaseModelId
240     cache_: dict[BaseModelId, Self]
241     to_search: list[str] = []
242     can_create_by_id = False
243     _exists = True
244
245     def __init__(self, id_: BaseModelId | None) -> None:
246         if isinstance(id_, int) and id_ < 1:
247             msg = f'illegal {self.__class__.__name__} ID, must be >=1: {id_}'
248             raise HandledException(msg)
249         if isinstance(id_, str) and "" == id_:
250             msg = f'illegal {self.__class__.__name__} ID, must be non-empty'
251             raise HandledException(msg)
252         self.id_ = id_
253
254     def __hash__(self) -> int:
255         hashable = [self.id_] + [getattr(self, name) for name in self.to_save]
256         for definition in self.to_save_relations:
257             attr = getattr(self, definition[2])
258             hashable += [tuple(rel.id_ for rel in attr)]
259         for name in self.to_save_versioned:
260             hashable += [hash(getattr(self, name))]
261         return hash(tuple(hashable))
262
263     def __eq__(self, other: object) -> bool:
264         if not isinstance(other, self.__class__):
265             return False
266         return hash(self) == hash(other)
267
268     def __lt__(self, other: Any) -> bool:
269         if not isinstance(other, self.__class__):
270             msg = 'cannot compare to object of different class'
271             raise HandledException(msg)
272         assert isinstance(self.id_, int)
273         assert isinstance(other.id_, int)
274         return self.id_ < other.id_
275
276     @property
277     def as_dict(self) -> dict[str, object]:
278         """Return self as (json.dumps-compatible) dict."""
279         library: dict[str, dict[str | int, object]] = {}
280         d: dict[str, object] = {'id': self.id_, '_library': library}
281         for to_save in self.to_save:
282             attr = getattr(self, to_save)
283             if hasattr(attr, 'as_dict_into_reference'):
284                 d[to_save] = attr.as_dict_into_reference(library)
285             else:
286                 d[to_save] = attr
287         if len(self.to_save_versioned) > 0:
288             d['_versioned'] = {}
289         for k in self.to_save_versioned:
290             attr = getattr(self, k)
291             assert isinstance(d['_versioned'], dict)
292             d['_versioned'][k] = attr.history
293         for r in self.to_save_relations:
294             attr_name = r[2]
295             l: list[int | str] = []
296             for rel in getattr(self, attr_name):
297                 l += [rel.as_dict_into_reference(library)]
298             d[attr_name] = l
299         for k in self.add_to_dict:
300             d[k] = [x.as_dict_into_reference(library)
301                     for x in getattr(self, k)]
302         return d
303
304     def as_dict_into_reference(self,
305                                library: dict[str, dict[str | int, object]]
306                                ) -> int | str:
307         """Return self.id_ while writing .as_dict into library."""
308         def into_library(library: dict[str, dict[str | int, object]],
309                          cls_name: str,
310                          id_: str | int,
311                          d: dict[str, object]
312                          ) -> None:
313             if cls_name not in library:
314                 library[cls_name] = {}
315             if id_ in library[cls_name]:
316                 if library[cls_name][id_] != d:
317                     msg = 'Unexpected inequality of entries for ' +\
318                             f'_library at: {cls_name}/{id_}'
319                     raise HandledException(msg)
320             else:
321                 library[cls_name][id_] = d
322         as_dict = self.as_dict
323         assert isinstance(as_dict['_library'], dict)
324         for cls_name, dict_of_objs in as_dict['_library'].items():
325             for id_, obj in dict_of_objs.items():
326                 into_library(library, cls_name, id_, obj)
327         del as_dict['_library']
328         assert self.id_ is not None
329         into_library(library, self.__class__.__name__, self.id_, as_dict)
330         assert isinstance(as_dict['id'], (int, str))
331         return as_dict['id']
332
333     @classmethod
334     def name_lowercase(cls) -> str:
335         """Convenience method to return cls' name in lowercase."""
336         return cls.__name__.lower()
337
338     # cache management
339     # (we primarily use the cache to ensure we work on the same object in
340     # memory no matter where and how we retrieve it, e.g. we don't want
341     # .by_id() calls to create a new object each time, but rather a pointer
342     # to the one already instantiated)
343
344     def __getattribute__(self, name: str) -> Any:
345         """Ensure fail if ._disappear() was called, except to check ._exists"""
346         if name != '_exists' and not super().__getattribute__('_exists'):
347             raise HandledException('Object does not exist.')
348         return super().__getattribute__(name)
349
350     def _disappear(self) -> None:
351         """Invalidate object, make future use raise exceptions."""
352         assert self.id_ is not None
353         if self._get_cached(self.id_):
354             self._uncache()
355         to_kill = list(self.__dict__.keys())
356         for attr in to_kill:
357             delattr(self, attr)
358         self._exists = False
359
360     @classmethod
361     def empty_cache(cls) -> None:
362         """Empty class's cache, and disappear all former inhabitants."""
363         # pylint: disable=protected-access
364         # (cause we remain within the class)
365         if hasattr(cls, 'cache_'):
366             to_disappear = list(cls.cache_.values())
367             for item in to_disappear:
368                 item._disappear()
369         cls.cache_ = {}
370
371     @classmethod
372     def get_cache(cls: type[BaseModelInstance]) -> dict[Any, BaseModel[Any]]:
373         """Get cache dictionary, create it if not yet existing."""
374         if not hasattr(cls, 'cache_'):
375             d: dict[Any, BaseModel[Any]] = {}
376             cls.cache_ = d
377         return cls.cache_
378
379     @classmethod
380     def _get_cached(cls: type[BaseModelInstance],
381                     id_: BaseModelId) -> BaseModelInstance | None:
382         """Get object of id_ from class's cache, or None if not found."""
383         cache = cls.get_cache()
384         if id_ in cache:
385             obj = cache[id_]
386             assert isinstance(obj, cls)
387             return obj
388         return None
389
390     def cache(self) -> None:
391         """Update object in class's cache.
392
393         Also calls ._disappear if cache holds older reference to object of same
394         ID, but different memory address, to avoid doing anything with
395         dangling leftovers.
396         """
397         if self.id_ is None:
398             raise HandledException('Cannot cache object without ID.')
399         cache = self.get_cache()
400         old_cached = self._get_cached(self.id_)
401         if old_cached and id(old_cached) != id(self):
402             # pylint: disable=protected-access
403             # (cause we remain within the class)
404             old_cached._disappear()
405         cache[self.id_] = self
406
407     def _uncache(self) -> None:
408         """Remove self from cache."""
409         if self.id_ is None:
410             raise HandledException('Cannot un-cache object without ID.')
411         cache = self.get_cache()
412         del cache[self.id_]
413
414     # object retrieval and generation
415
416     @classmethod
417     def from_table_row(cls: type[BaseModelInstance],
418                        # pylint: disable=unused-argument
419                        db_conn: DatabaseConnection,
420                        row: Row | list[Any]) -> BaseModelInstance:
421         """Make from DB row (sans relations), update DB cache with it."""
422         obj = cls(*row)
423         assert obj.id_ is not None
424         for attr_name in cls.to_save_versioned:
425             attr = getattr(obj, attr_name)
426             table_name = attr.table_name
427             for row_ in db_conn.row_where(table_name, 'parent', obj.id_):
428                 attr.history_from_row(row_)
429         obj.cache()
430         return obj
431
432     @classmethod
433     def by_id(cls, db_conn: DatabaseConnection, id_: BaseModelId) -> Self:
434         """Retrieve by id_, on failure throw NotFoundException.
435
436         First try to get from cls.cache_, only then check DB; if found,
437         put into cache.
438         """
439         obj = None
440         if id_ is not None:
441             obj = cls._get_cached(id_)
442             if not obj:
443                 for row in db_conn.row_where(cls.table_name, 'id', id_):
444                     obj = cls.from_table_row(db_conn, row)
445                     break
446         if obj:
447             return obj
448         raise NotFoundException(f'found no object of ID {id_}')
449
450     @classmethod
451     def by_id_or_create(cls, db_conn: DatabaseConnection,
452                         id_: BaseModelId | None
453                         ) -> Self:
454         """Wrapper around .by_id, creating (not caching/saving) if not find."""
455         if not cls.can_create_by_id:
456             raise HandledException('Class cannot .by_id_or_create.')
457         if id_ is None:
458             return cls(None)
459         try:
460             return cls.by_id(db_conn, id_)
461         except NotFoundException:
462             return cls(id_)
463
464     @classmethod
465     def all(cls: type[BaseModelInstance],
466             db_conn: DatabaseConnection) -> list[BaseModelInstance]:
467         """Collect all objects of class into list.
468
469         Note that this primarily returns the contents of the cache, and only
470         _expands_ that by additional findings in the DB. This assumes the
471         cache is always instantly cleaned of any items that would be removed
472         from the DB.
473         """
474         items: dict[BaseModelId, BaseModelInstance] = {}
475         for k, v in cls.get_cache().items():
476             assert isinstance(v, cls)
477             items[k] = v
478         already_recorded = items.keys()
479         for id_ in db_conn.column_all(cls.table_name, 'id'):
480             if id_ not in already_recorded:
481                 item = cls.by_id(db_conn, id_)
482                 assert item.id_ is not None
483                 items[item.id_] = item
484         return list(items.values())
485
486     @classmethod
487     def by_date_range_with_limits(cls: type[BaseModelInstance],
488                                   db_conn: DatabaseConnection,
489                                   date_range: tuple[str, str],
490                                   date_col: str = 'day'
491                                   ) -> tuple[list[BaseModelInstance], str,
492                                              str]:
493         """Return list of items in database within (open) date_range interval.
494
495         If no range values provided, defaults them to 'yesterday' and
496         'tomorrow'. Knows to properly interpret these and 'today' as value.
497         """
498         start_str = date_range[0] if date_range[0] else 'yesterday'
499         end_str = date_range[1] if date_range[1] else 'tomorrow'
500         start_date = valid_date(start_str)
501         end_date = valid_date(end_str)
502         items = []
503         sql = f'SELECT id FROM {cls.table_name} '
504         sql += f'WHERE {date_col} >= ? AND {date_col} <= ?'
505         for row in db_conn.exec(sql, (start_date, end_date)):
506             items += [cls.by_id(db_conn, row[0])]
507         return items, start_date, end_date
508
509     @classmethod
510     def matching(cls: type[BaseModelInstance], db_conn: DatabaseConnection,
511                  pattern: str) -> list[BaseModelInstance]:
512         """Return all objects whose .to_search match pattern."""
513         items = cls.all(db_conn)
514         if pattern:
515             filtered = []
516             for item in items:
517                 for attr_name in cls.to_search:
518                     toks = attr_name.split('.')
519                     parent = item
520                     for tok in toks:
521                         attr = getattr(parent, tok)
522                         parent = attr
523                     if pattern in attr:
524                         filtered += [item]
525                         break
526             return filtered
527         return items
528
529     # database writing
530
531     def save(self, db_conn: DatabaseConnection) -> None:
532         """Write self to DB and cache and ensure .id_.
533
534         Write both to DB, and to cache. To DB, write .id_ and attributes
535         listed in cls.to_save[_versioned|_relations].
536
537         Ensure self.id_ by setting it to what the DB command returns as the
538         last saved row's ID (cursor.lastrowid), EXCEPT if self.id_ already
539         exists as a 'str', which implies we do our own ID creation (so far
540         only the case with the Day class, where it's to be a date string.
541         """
542         values = tuple([self.id_] + [getattr(self, key)
543                                      for key in self.to_save])
544         table_name = self.table_name
545         cursor = db_conn.exec_on_vals(f'REPLACE INTO {table_name} VALUES',
546                                       values)
547         if not isinstance(self.id_, str):
548             self.id_ = cursor.lastrowid  # type: ignore[assignment]
549         self.cache()
550         for attr_name in self.to_save_versioned:
551             getattr(self, attr_name).save(db_conn)
552         for table, column, attr_name, key_index in self.to_save_relations:
553             assert isinstance(self.id_, (int, str))
554             db_conn.rewrite_relations(table, column, self.id_,
555                                       [[i.id_] for i
556                                        in getattr(self, attr_name)], key_index)
557
558     def remove(self, db_conn: DatabaseConnection) -> None:
559         """Remove from DB and cache, including dependencies."""
560         if self.id_ is None or self._get_cached(self.id_) is None:
561             raise HandledException('cannot remove unsaved item')
562         for attr_name in self.to_save_versioned:
563             getattr(self, attr_name).remove(db_conn)
564         for table, column, attr_name, _ in self.to_save_relations:
565             db_conn.delete_where(table, column, self.id_)
566         self._uncache()
567         db_conn.delete_where(self.table_name, 'id', self.id_)
568         self._disappear()