home · contact · privacy
Fix bug of /day POSTS breaking on empty new_todo fields.
[plomtask] / plomtask / db.py
1 """Database management."""
2 from __future__ import annotations
3 from os import listdir
4 from os.path import isfile
5 from difflib import Differ
6 from sqlite3 import connect as sql_connect, Cursor, Row
7 from typing import Any, Self, TypeVar, Generic, Callable
8 from plomtask.exceptions import (HandledException, NotFoundException,
9                                  BadFormatException)
10 from plomtask.dating import valid_date
11
12 EXPECTED_DB_VERSION = 5
13 MIGRATIONS_DIR = 'migrations'
14 FILENAME_DB_SCHEMA = f'init_{EXPECTED_DB_VERSION}.sql'
15 PATH_DB_SCHEMA = f'{MIGRATIONS_DIR}/{FILENAME_DB_SCHEMA}'
16
17
18 class UnmigratedDbException(HandledException):
19     """To identify case of unmigrated DB file."""
20
21
22 class DatabaseFile:
23     """Represents the sqlite3 database's file."""
24     # pylint: disable=too-few-public-methods
25
26     def __init__(self, path: str) -> None:
27         self.path = path
28         self._check()
29
30     @classmethod
31     def create_at(cls, path: str) -> DatabaseFile:
32         """Make new DB file at path."""
33         with sql_connect(path) as conn:
34             with open(PATH_DB_SCHEMA, 'r', encoding='utf-8') as f:
35                 conn.executescript(f.read())
36             conn.execute(f'PRAGMA user_version = {EXPECTED_DB_VERSION}')
37         return cls(path)
38
39     @classmethod
40     def migrate(cls, path: str) -> DatabaseFile:
41         """Apply migrations from_version to EXPECTED_DB_VERSION."""
42         migrations = cls._available_migrations()
43         from_version = cls._get_version_of_db(path)
44         migrations_todo = migrations[from_version+1:]
45         for j, filename in enumerate(migrations_todo):
46             with sql_connect(path) as conn:
47                 with open(f'{MIGRATIONS_DIR}/{filename}', 'r',
48                           encoding='utf-8') as f:
49                     conn.executescript(f.read())
50             user_version = from_version + j + 1
51             with sql_connect(path) as conn:
52                 conn.execute(f'PRAGMA user_version = {user_version}')
53         return cls(path)
54
55     def _check(self) -> None:
56         """Check file exists, and is of proper DB version and schema."""
57         if not isfile(self.path):
58             raise NotFoundException
59         if self._user_version != EXPECTED_DB_VERSION:
60             raise UnmigratedDbException()
61         self._validate_schema()
62
63     @staticmethod
64     def _available_migrations() -> list[str]:
65         """Validate migrations directory and return sorted entries."""
66         msg_too_big = 'Migration directory points beyond expected DB version.'
67         msg_bad_entry = 'Migration directory contains unexpected entry: '
68         msg_missing = 'Migration directory misses migration of number: '
69         migrations = {}
70         for entry in listdir(MIGRATIONS_DIR):
71             if entry == FILENAME_DB_SCHEMA:
72                 continue
73             toks = entry.split('_', 1)
74             if len(toks) < 2:
75                 raise HandledException(msg_bad_entry + entry)
76             try:
77                 i = int(toks[0])
78             except ValueError as e:
79                 raise HandledException(msg_bad_entry + entry) from e
80             if i > EXPECTED_DB_VERSION:
81                 raise HandledException(msg_too_big)
82             migrations[i] = toks[1]
83         migrations_list = []
84         for i in range(EXPECTED_DB_VERSION + 1):
85             if i not in migrations:
86                 raise HandledException(msg_missing + str(i))
87             migrations_list += [f'{i}_{migrations[i]}']
88         return migrations_list
89
90     @staticmethod
91     def _get_version_of_db(path: str) -> int:
92         """Get DB user_version, fail if outside expected range."""
93         sql_for_db_version = 'PRAGMA user_version'
94         with sql_connect(path) as conn:
95             db_version = list(conn.execute(sql_for_db_version))[0][0]
96         if db_version > EXPECTED_DB_VERSION:
97             msg = f'Wrong DB version, expected '\
98                     f'{EXPECTED_DB_VERSION}, got unknown {db_version}.'
99             raise HandledException(msg)
100         assert isinstance(db_version, int)
101         return db_version
102
103     @property
104     def _user_version(self) -> int:
105         """Get DB user_version."""
106         return self._get_version_of_db(self.path)
107
108     def _validate_schema(self) -> None:
109         """Compare found schema with what's stored at PATH_DB_SCHEMA."""
110
111         def reformat_rows(rows: list[str]) -> list[str]:
112             new_rows = []
113             for row in rows:
114                 new_row = []
115                 for subrow in row.split('\n'):
116                     subrow = subrow.rstrip()
117                     in_parentheses = 0
118                     split_at = []
119                     for i, c in enumerate(subrow):
120                         if '(' == c:
121                             in_parentheses += 1
122                         elif ')' == c:
123                             in_parentheses -= 1
124                         elif ',' == c and 0 == in_parentheses:
125                             split_at += [i + 1]
126                     prev_split = 0
127                     for i in split_at:
128                         segment = subrow[prev_split:i].strip()
129                         if len(segment) > 0:
130                             new_row += [f'    {segment}']
131                         prev_split = i
132                     segment = subrow[prev_split:].strip()
133                     if len(segment) > 0:
134                         new_row += [f'    {segment}']
135                 new_row[0] = new_row[0].lstrip()
136                 new_row[-1] = new_row[-1].lstrip()
137                 if new_row[-1] != ')' and new_row[-3][-1] != ',':
138                     new_row[-3] = new_row[-3] + ','
139                     new_row[-2:] = ['    ' + new_row[-1][:-1]] + [')']
140                 new_rows += ['\n'.join(new_row)]
141             return new_rows
142
143         sql_for_schema = 'SELECT sql FROM sqlite_master ORDER BY sql'
144         msg_err = 'Database has wrong tables schema. Diff:\n'
145         with sql_connect(self.path) as conn:
146             schema_rows = [r[0] for r in conn.execute(sql_for_schema) if r[0]]
147         schema_rows = reformat_rows(schema_rows)
148         retrieved_schema = ';\n'.join(schema_rows) + ';'
149         with open(PATH_DB_SCHEMA, 'r', encoding='utf-8') as f:
150             stored_schema = f.read().rstrip()
151         if stored_schema != retrieved_schema:
152             diff_msg = Differ().compare(retrieved_schema.splitlines(),
153                                         stored_schema.splitlines())
154             raise HandledException(msg_err + '\n'.join(diff_msg))
155
156
157 class DatabaseConnection:
158     """A single connection to the database."""
159
160     def __init__(self, db_file: DatabaseFile) -> None:
161         self.conn = sql_connect(db_file.path)
162
163     def commit(self) -> None:
164         """Commit SQL transaction."""
165         self.conn.commit()
166
167     def exec(self, code: str, inputs: tuple[Any, ...] = tuple()) -> Cursor:
168         """Add commands to SQL transaction."""
169         return self.conn.execute(code, inputs)
170
171     def exec_on_vals(self, code: str, inputs: tuple[Any, ...]) -> Cursor:
172         """Wrapper around .exec appending adequate " (?, …)" to code."""
173         q_marks_from_values = '(' + ','.join(['?'] * len(inputs)) + ')'
174         return self.exec(f'{code} {q_marks_from_values}', inputs)
175
176     def close(self) -> None:
177         """Close DB connection."""
178         self.conn.close()
179
180     def rewrite_relations(self, table_name: str, key: str, target: int | str,
181                           rows: list[list[Any]], key_index: int = 0) -> None:
182         # pylint: disable=too-many-arguments
183         """Rewrite relations in table_name to target, with rows values.
184
185         Note that single rows are expected without the column and value
186         identified by key and target, which are inserted inside the function
187         at key_index.
188         """
189         self.delete_where(table_name, key, target)
190         for row in rows:
191             values = tuple(row[:key_index] + [target] + row[key_index:])
192             self.exec_on_vals(f'INSERT INTO {table_name} VALUES', values)
193
194     def row_where(self, table_name: str, key: str,
195                   target: int | str) -> list[Row]:
196         """Return list of Rows at table where key == target."""
197         return list(self.exec(f'SELECT * FROM {table_name} WHERE {key} = ?',
198                               (target,)))
199
200     # def column_where_pattern(self,
201     #                          table_name: str,
202     #                          column: str,
203     #                          pattern: str,
204     #                          keys: list[str]) -> list[Any]:
205     #     """Return column of rows where one of keys matches pattern."""
206     #     targets = tuple([f'%{pattern}%'] * len(keys))
207     #     haystack = ' OR '.join([f'{k} LIKE ?' for k in keys])
208     #     sql = f'SELECT {column} FROM {table_name} WHERE {haystack}'
209     #     return [row[0] for row in self.exec(sql, targets)]
210
211     def column_where(self, table_name: str, column: str, key: str,
212                      target: int | str) -> list[Any]:
213         """Return column of table where key == target."""
214         return [row[0] for row in
215                 self.exec(f'SELECT {column} FROM {table_name} '
216                           f'WHERE {key} = ?', (target,))]
217
218     def column_all(self, table_name: str, column: str) -> list[Any]:
219         """Return complete column of table."""
220         return [row[0] for row in
221                 self.exec(f'SELECT {column} FROM {table_name}')]
222
223     def delete_where(self, table_name: str, key: str,
224                      target: int | str) -> None:
225         """Delete from table where key == target."""
226         self.exec(f'DELETE FROM {table_name} WHERE {key} = ?', (target,))
227
228
229 BaseModelId = TypeVar('BaseModelId', int, str)
230 BaseModelInstance = TypeVar('BaseModelInstance', bound='BaseModel[Any]')
231
232
233 class BaseModel(Generic[BaseModelId]):
234     """Template for most of the models we use/derive from the DB."""
235     table_name = ''
236     to_save_simples: list[str] = []
237     to_save_relations: list[tuple[str, str, str, int]] = []
238     versioned_defaults: dict[str, str | float] = {}
239     add_to_dict: list[str] = []
240     id_: None | BaseModelId
241     cache_: dict[BaseModelId, Self]
242     to_search: list[str] = []
243     can_create_by_id = False
244     _exists = True
245     sorters: dict[str, Callable[..., Any]] = {}
246
247     def __init__(self, id_: BaseModelId | None) -> None:
248         if isinstance(id_, int) and id_ < 1:
249             msg = f'illegal {self.__class__.__name__} ID, must be >=1: {id_}'
250             raise BadFormatException(msg)
251         if isinstance(id_, str) and "" == id_:
252             msg = f'illegal {self.__class__.__name__} ID, must be non-empty'
253             raise BadFormatException(msg)
254         self.id_ = id_
255
256     def __hash__(self) -> int:
257         hashable = [self.id_] + [getattr(self, name)
258                                  for name in self.to_save_simples]
259         for definition in self.to_save_relations:
260             attr = getattr(self, definition[2])
261             hashable += [tuple(rel.id_ for rel in attr)]
262         for name in self.to_save_versioned():
263             hashable += [hash(getattr(self, name))]
264         return hash(tuple(hashable))
265
266     def __eq__(self, other: object) -> bool:
267         if not isinstance(other, self.__class__):
268             return False
269         return hash(self) == hash(other)
270
271     def __lt__(self, other: Any) -> bool:
272         if not isinstance(other, self.__class__):
273             msg = 'cannot compare to object of different class'
274             raise HandledException(msg)
275         assert isinstance(self.id_, int)
276         assert isinstance(other.id_, int)
277         return self.id_ < other.id_
278
279     @classmethod
280     def to_save_versioned(cls) -> list[str]:
281         """Return keys of cls.versioned_defaults assuming we wanna save 'em."""
282         return list(cls.versioned_defaults.keys())
283
284     @property
285     def as_dict_and_refs(self) -> tuple[dict[str, object],
286                                         list[BaseModel[int] | BaseModel[str]]]:
287         """Return self as json.dumps-ready dict, list of referenced objects."""
288         d: dict[str, object] = {'id': self.id_}
289         refs: list[BaseModel[int] | BaseModel[str]] = []
290         for to_save in self.to_save_simples:
291             d[to_save] = getattr(self, to_save)
292         if len(self.to_save_versioned()) > 0:
293             d['_versioned'] = {}
294         for k in self.to_save_versioned():
295             attr = getattr(self, k)
296             assert isinstance(d['_versioned'], dict)
297             d['_versioned'][k] = attr.history
298         rels_to_collect = [rel[2] for rel in self.to_save_relations]
299         rels_to_collect += self.add_to_dict
300         for attr_name in rels_to_collect:
301             rel_list = []
302             for item in getattr(self, attr_name):
303                 rel_list += [item.id_]
304                 if item not in refs:
305                     refs += [item]
306             d[attr_name] = rel_list
307         return d, refs
308
309     @classmethod
310     def name_lowercase(cls) -> str:
311         """Convenience method to return cls' name in lowercase."""
312         return cls.__name__.lower()
313
314     @classmethod
315     def sort_by(cls, seq: list[Any], sort_key: str, default: str = 'title'
316                 ) -> str:
317         """Sort cls list by cls.sorters[sort_key] (reverse if '-'-prefixed).
318
319         Before cls.sorters[sort_key] is applied, seq is sorted by .id_, to
320         ensure predictability where parts of seq are of same sort value.
321         """
322         reverse = False
323         if len(sort_key) > 1 and '-' == sort_key[0]:
324             sort_key = sort_key[1:]
325             reverse = True
326         if sort_key not in cls.sorters:
327             sort_key = default
328         seq.sort(key=lambda x: x.id_, reverse=reverse)
329         sorter: Callable[..., Any] = cls.sorters[sort_key]
330         seq.sort(key=sorter, reverse=reverse)
331         if reverse:
332             sort_key = f'-{sort_key}'
333         return sort_key
334
335     # cache management
336     # (we primarily use the cache to ensure we work on the same object in
337     # memory no matter where and how we retrieve it, e.g. we don't want
338     # .by_id() calls to create a new object each time, but rather a pointer
339     # to the one already instantiated)
340
341     def __getattribute__(self, name: str) -> Any:
342         """Ensure fail if ._disappear() was called, except to check ._exists"""
343         if name != '_exists' and not super().__getattribute__('_exists'):
344             msg = f'Object for attribute does not exist: {name}'
345             raise HandledException(msg)
346         return super().__getattribute__(name)
347
348     def _disappear(self) -> None:
349         """Invalidate object, make future use raise exceptions."""
350         assert self.id_ is not None
351         if self._get_cached(self.id_):
352             self._uncache()
353         to_kill = list(self.__dict__.keys())
354         for attr in to_kill:
355             delattr(self, attr)
356         self._exists = False
357
358     @classmethod
359     def empty_cache(cls) -> None:
360         """Empty class's cache, and disappear all former inhabitants."""
361         # pylint: disable=protected-access
362         # (cause we remain within the class)
363         if hasattr(cls, 'cache_'):
364             to_disappear = list(cls.cache_.values())
365             for item in to_disappear:
366                 item._disappear()
367         cls.cache_ = {}
368
369     @classmethod
370     def get_cache(cls: type[BaseModelInstance]
371                   ) -> dict[Any, BaseModelInstance]:
372         """Get cache dictionary, create it if not yet existing."""
373         if not hasattr(cls, 'cache_'):
374             d: dict[Any, BaseModelInstance] = {}
375             cls.cache_ = d
376         return cls.cache_
377
378     @classmethod
379     def _get_cached(cls: type[BaseModelInstance],
380                     id_: BaseModelId
381                     ) -> BaseModelInstance | None:
382         """Get object of id_ from class's cache, or None if not found."""
383         cache = cls.get_cache()
384         if id_ in cache:
385             obj = cache[id_]
386             assert isinstance(obj, cls)
387             return obj
388         return None
389
390     def cache(self) -> None:
391         """Update object in class's cache.
392
393         Also calls ._disappear if cache holds older reference to object of same
394         ID, but different memory address, to avoid doing anything with
395         dangling leftovers.
396         """
397         if self.id_ is None:
398             raise HandledException('Cannot cache object without ID.')
399         cache = self.get_cache()
400         old_cached = self._get_cached(self.id_)
401         if old_cached and id(old_cached) != id(self):
402             # pylint: disable=protected-access
403             # (cause we remain within the class)
404             old_cached._disappear()
405         cache[self.id_] = self
406
407     def _uncache(self) -> None:
408         """Remove self from cache."""
409         if self.id_ is None:
410             raise HandledException('Cannot un-cache object without ID.')
411         cache = self.get_cache()
412         del cache[self.id_]
413
414     # object retrieval and generation
415
416     @classmethod
417     def from_table_row(cls: type[BaseModelInstance],
418                        # pylint: disable=unused-argument
419                        db_conn: DatabaseConnection,
420                        row: Row | list[Any]) -> BaseModelInstance:
421         """Make from DB row (sans relations), update DB cache with it."""
422         obj = cls(*row)
423         assert obj.id_ is not None
424         for attr_name in cls.to_save_versioned():
425             attr = getattr(obj, attr_name)
426             table_name = attr.table_name
427             for row_ in db_conn.row_where(table_name, 'parent', obj.id_):
428                 attr.history_from_row(row_)
429         obj.cache()
430         return obj
431
432     @classmethod
433     def by_id(cls, db_conn: DatabaseConnection, id_: BaseModelId) -> Self:
434         """Retrieve by id_, on failure throw NotFoundException.
435
436         First try to get from cls.cache_, only then check DB; if found,
437         put into cache.
438         """
439         obj = None
440         if id_ is not None:
441             if isinstance(id_, int) and id_ == 0:
442                 raise BadFormatException('illegal ID of value 0')
443             obj = cls._get_cached(id_)
444             if not obj:
445                 for row in db_conn.row_where(cls.table_name, 'id', id_):
446                     obj = cls.from_table_row(db_conn, row)
447                     break
448         if obj:
449             return obj
450         raise NotFoundException(f'found no object of ID {id_}')
451
452     @classmethod
453     def by_id_or_create(cls, db_conn: DatabaseConnection,
454                         id_: BaseModelId | None
455                         ) -> Self:
456         """Wrapper around .by_id, creating (not caching/saving) if no find."""
457         if not cls.can_create_by_id:
458             raise HandledException('Class cannot .by_id_or_create.')
459         if id_ is None:
460             return cls(None)
461         try:
462             return cls.by_id(db_conn, id_)
463         except NotFoundException:
464             return cls(id_)
465
466     @classmethod
467     def all(cls: type[BaseModelInstance],
468             db_conn: DatabaseConnection) -> list[BaseModelInstance]:
469         """Collect all objects of class into list.
470
471         Note that this primarily returns the contents of the cache, and only
472         _expands_ that by additional findings in the DB. This assumes the
473         cache is always instantly cleaned of any items that would be removed
474         from the DB.
475         """
476         items: dict[BaseModelId, BaseModelInstance] = {}
477         for k, v in cls.get_cache().items():
478             assert isinstance(v, cls)
479             items[k] = v
480         already_recorded = items.keys()
481         for id_ in db_conn.column_all(cls.table_name, 'id'):
482             if id_ not in already_recorded:
483                 item = cls.by_id(db_conn, id_)
484                 assert item.id_ is not None
485                 items[item.id_] = item
486         return sorted(list(items.values()))
487
488     @classmethod
489     def by_date_range_with_limits(cls: type[BaseModelInstance],
490                                   db_conn: DatabaseConnection,
491                                   date_range: tuple[str, str],
492                                   date_col: str = 'day'
493                                   ) -> tuple[list[BaseModelInstance], str,
494                                              str]:
495         """Return list of items in DB within (closed) date_range interval.
496
497         If no range values provided, defaults them to 'yesterday' and
498         'tomorrow'. Knows to properly interpret these and 'today' as value.
499         """
500         start_str = date_range[0] if date_range[0] else 'yesterday'
501         end_str = date_range[1] if date_range[1] else 'tomorrow'
502         start_date = valid_date(start_str)
503         end_date = valid_date(end_str)
504         items = []
505         sql = f'SELECT id FROM {cls.table_name} '
506         sql += f'WHERE {date_col} >= ? AND {date_col} <= ?'
507         for row in db_conn.exec(sql, (start_date, end_date)):
508             items += [cls.by_id(db_conn, row[0])]
509         return items, start_date, end_date
510
511     @classmethod
512     def matching(cls: type[BaseModelInstance], db_conn: DatabaseConnection,
513                  pattern: str) -> list[BaseModelInstance]:
514         """Return all objects whose .to_search match pattern."""
515         items = cls.all(db_conn)
516         if pattern:
517             filtered = []
518             for item in items:
519                 for attr_name in cls.to_search:
520                     toks = attr_name.split('.')
521                     parent = item
522                     for tok in toks:
523                         attr = getattr(parent, tok)
524                         parent = attr
525                     if pattern in attr:
526                         filtered += [item]
527                         break
528             return filtered
529         return items
530
531     # database writing
532
533     def save(self, db_conn: DatabaseConnection) -> None:
534         """Write self to DB and cache and ensure .id_.
535
536         Write both to DB, and to cache. To DB, write .id_ and attributes
537         listed in cls.to_save_[simples|versioned|_relations].
538
539         Ensure self.id_ by setting it to what the DB command returns as the
540         last saved row's ID (cursor.lastrowid), EXCEPT if self.id_ already
541         exists as a 'str', which implies we do our own ID creation (so far
542         only the case with the Day class, where it's to be a date string.
543         """
544         values = tuple([self.id_] + [getattr(self, key)
545                                      for key in self.to_save_simples])
546         table_name = self.table_name
547         cursor = db_conn.exec_on_vals(f'REPLACE INTO {table_name} VALUES',
548                                       values)
549         if not isinstance(self.id_, str):
550             self.id_ = cursor.lastrowid  # type: ignore[assignment]
551         self.cache()
552         for attr_name in self.to_save_versioned():
553             getattr(self, attr_name).save(db_conn)
554         for table, column, attr_name, key_index in self.to_save_relations:
555             assert isinstance(self.id_, (int, str))
556             db_conn.rewrite_relations(table, column, self.id_,
557                                       [[i.id_] for i
558                                        in getattr(self, attr_name)], key_index)
559
560     def remove(self, db_conn: DatabaseConnection) -> None:
561         """Remove from DB and cache, including dependencies."""
562         if self.id_ is None or self._get_cached(self.id_) is None:
563             raise HandledException('cannot remove unsaved item')
564         for attr_name in self.to_save_versioned():
565             getattr(self, attr_name).remove(db_conn)
566         for table, column, attr_name, _ in self.to_save_relations:
567             db_conn.delete_where(table, column, self.id_)
568         self._uncache()
569         db_conn.delete_where(self.table_name, 'id', self.id_)
570         self._disappear()