home · contact · privacy
Minor template improvements.
[plomtask] / plomtask / db.py
1 """Database management."""
2 from __future__ import annotations
3 from os import listdir
4 from os.path import isfile
5 from difflib import Differ
6 from sqlite3 import connect as sql_connect, Cursor, Row
7 from typing import Any, Self, TypeVar, Generic
8 from plomtask.exceptions import HandledException, NotFoundException
9 from plomtask.dating import valid_date
10
11 EXPECTED_DB_VERSION = 4
12 MIGRATIONS_DIR = 'migrations'
13 FILENAME_DB_SCHEMA = f'init_{EXPECTED_DB_VERSION}.sql'
14 PATH_DB_SCHEMA = f'{MIGRATIONS_DIR}/{FILENAME_DB_SCHEMA}'
15
16
17 class UnmigratedDbException(HandledException):
18     """To identify case of unmigrated DB file."""
19
20
21 class DatabaseFile:  # pylint: disable=too-few-public-methods
22     """Represents the sqlite3 database's file."""
23
24     def __init__(self, path: str) -> None:
25         self.path = path
26         self._check()
27
28     @classmethod
29     def create_at(cls, path: str) -> DatabaseFile:
30         """Make new DB file at path."""
31         with sql_connect(path) as conn:
32             with open(PATH_DB_SCHEMA, 'r', encoding='utf-8') as f:
33                 conn.executescript(f.read())
34             conn.execute(f'PRAGMA user_version = {EXPECTED_DB_VERSION}')
35         return cls(path)
36
37     @classmethod
38     def migrate(cls, path: str) -> DatabaseFile:
39         """Apply migrations from_version to EXPECTED_DB_VERSION."""
40         migrations = cls._available_migrations()
41         from_version = cls.get_version_of_db(path)
42         migrations_todo = migrations[from_version+1:]
43         for j, filename in enumerate(migrations_todo):
44             with sql_connect(path) as conn:
45                 with open(f'{MIGRATIONS_DIR}/{filename}', 'r',
46                           encoding='utf-8') as f:
47                     conn.executescript(f.read())
48             user_version = from_version + j + 1
49             with sql_connect(path) as conn:
50                 conn.execute(f'PRAGMA user_version = {user_version}')
51         return cls(path)
52
53     def _check(self) -> None:
54         """Check file exists, and is of proper DB version and schema."""
55         if not isfile(self.path):
56             raise NotFoundException
57         if self.user_version != EXPECTED_DB_VERSION:
58             raise UnmigratedDbException()
59         self._validate_schema()
60
61     @staticmethod
62     def _available_migrations() -> list[str]:
63         """Validate migrations directory and return sorted entries."""
64         msg_too_big = 'Migration directory points beyond expected DB version.'
65         msg_bad_entry = 'Migration directory contains unexpected entry: '
66         msg_missing = 'Migration directory misses migration of number: '
67         migrations = {}
68         for entry in listdir(MIGRATIONS_DIR):
69             if entry == FILENAME_DB_SCHEMA:
70                 continue
71             toks = entry.split('_', 1)
72             if len(toks) < 2:
73                 raise HandledException(msg_bad_entry + entry)
74             try:
75                 i = int(toks[0])
76             except ValueError as e:
77                 raise HandledException(msg_bad_entry + entry) from e
78             if i > EXPECTED_DB_VERSION:
79                 raise HandledException(msg_too_big)
80             migrations[i] = toks[1]
81         migrations_list = []
82         for i in range(EXPECTED_DB_VERSION + 1):
83             if i not in migrations:
84                 raise HandledException(msg_missing + str(i))
85             migrations_list += [f'{i}_{migrations[i]}']
86         return migrations_list
87
88     @staticmethod
89     def get_version_of_db(path: str) -> int:
90         """Get DB user_version, fail if outside expected range."""
91         sql_for_db_version = 'PRAGMA user_version'
92         with sql_connect(path) as conn:
93             db_version = list(conn.execute(sql_for_db_version))[0][0]
94         if db_version > EXPECTED_DB_VERSION:
95             msg = f'Wrong DB version, expected '\
96                     f'{EXPECTED_DB_VERSION}, got unknown {db_version}.'
97             raise HandledException(msg)
98         assert isinstance(db_version, int)
99         return db_version
100
101     @property
102     def user_version(self) -> int:
103         """Get DB user_version."""
104         return self.__class__.get_version_of_db(self.path)
105
106     def _validate_schema(self) -> None:
107         """Compare found schema with what's stored at PATH_DB_SCHEMA."""
108
109         def reformat_rows(rows: list[str]) -> list[str]:
110             new_rows = []
111             for row in rows:
112                 new_row = []
113                 for subrow in row.split('\n'):
114                     subrow = subrow.rstrip()
115                     in_parentheses = 0
116                     split_at = []
117                     for i, c in enumerate(subrow):
118                         if '(' == c:
119                             in_parentheses += 1
120                         elif ')' == c:
121                             in_parentheses -= 1
122                         elif ',' == c and 0 == in_parentheses:
123                             split_at += [i + 1]
124                     prev_split = 0
125                     for i in split_at:
126                         segment = subrow[prev_split:i].strip()
127                         if len(segment) > 0:
128                             new_row += [f'    {segment}']
129                         prev_split = i
130                     segment = subrow[prev_split:].strip()
131                     if len(segment) > 0:
132                         new_row += [f'    {segment}']
133                 new_row[0] = new_row[0].lstrip()
134                 new_row[-1] = new_row[-1].lstrip()
135                 if new_row[-1] != ')' and new_row[-3][-1] != ',':
136                     new_row[-3] = new_row[-3] + ','
137                     new_row[-2:] = ['    ' + new_row[-1][:-1]] + [')']
138                 new_rows += ['\n'.join(new_row)]
139             return new_rows
140
141         sql_for_schema = 'SELECT sql FROM sqlite_master ORDER BY sql'
142         msg_err = 'Database has wrong tables schema. Diff:\n'
143         with sql_connect(self.path) as conn:
144             schema_rows = [r[0] for r in conn.execute(sql_for_schema) if r[0]]
145         schema_rows = reformat_rows(schema_rows)
146         retrieved_schema = ';\n'.join(schema_rows) + ';'
147         with open(PATH_DB_SCHEMA, 'r', encoding='utf-8') as f:
148             stored_schema = f.read().rstrip()
149         if stored_schema != retrieved_schema:
150             diff_msg = Differ().compare(retrieved_schema.splitlines(),
151                                         stored_schema.splitlines())
152             raise HandledException(msg_err + '\n'.join(diff_msg))
153
154
155 class DatabaseConnection:
156     """A single connection to the database."""
157
158     def __init__(self, db_file: DatabaseFile) -> None:
159         self.file = db_file
160         self.conn = sql_connect(self.file.path)
161
162     def commit(self) -> None:
163         """Commit SQL transaction."""
164         self.conn.commit()
165
166     def exec(self, code: str, inputs: tuple[Any, ...] = tuple()) -> Cursor:
167         """Add commands to SQL transaction."""
168         return self.conn.execute(code, inputs)
169
170     def close(self) -> None:
171         """Close DB connection."""
172         self.conn.close()
173
174     def rewrite_relations(self, table_name: str, key: str, target: int | str,
175                           rows: list[list[Any]], key_index: int = 0) -> None:
176         # pylint: disable=too-many-arguments
177         """Rewrite relations in table_name to target, with rows values.
178
179         Note that single rows are expected without the column and value
180         identified by key and target, which are inserted inside the function
181         at key_index.
182         """
183         self.delete_where(table_name, key, target)
184         for row in rows:
185             values = tuple(row[:key_index] + [target] + row[key_index:])
186             q_marks = self.__class__.q_marks_from_values(values)
187             self.exec(f'INSERT INTO {table_name} VALUES {q_marks}', values)
188
189     def row_where(self, table_name: str, key: str,
190                   target: int | str) -> list[Row]:
191         """Return list of Rows at table where key == target."""
192         return list(self.exec(f'SELECT * FROM {table_name} WHERE {key} = ?',
193                               (target,)))
194
195     # def column_where_pattern(self,
196     #                          table_name: str,
197     #                          column: str,
198     #                          pattern: str,
199     #                          keys: list[str]) -> list[Any]:
200     #     """Return column of rows where one of keys matches pattern."""
201     #     targets = tuple([f'%{pattern}%'] * len(keys))
202     #     haystack = ' OR '.join([f'{k} LIKE ?' for k in keys])
203     #     sql = f'SELECT {column} FROM {table_name} WHERE {haystack}'
204     #     return [row[0] for row in self.exec(sql, targets)]
205
206     def column_where(self, table_name: str, column: str, key: str,
207                      target: int | str) -> list[Any]:
208         """Return column of table where key == target."""
209         return [row[0] for row in
210                 self.exec(f'SELECT {column} FROM {table_name} '
211                           f'WHERE {key} = ?', (target,))]
212
213     def column_all(self, table_name: str, column: str) -> list[Any]:
214         """Return complete column of table."""
215         return [row[0] for row in
216                 self.exec(f'SELECT {column} FROM {table_name}')]
217
218     def delete_where(self, table_name: str, key: str,
219                      target: int | str) -> None:
220         """Delete from table where key == target."""
221         self.exec(f'DELETE FROM {table_name} WHERE {key} = ?', (target,))
222
223     @staticmethod
224     def q_marks_from_values(values: tuple[Any]) -> str:
225         """Return placeholder to insert values into SQL code."""
226         return '(' + ','.join(['?'] * len(values)) + ')'
227
228
229 BaseModelId = TypeVar('BaseModelId', int, str)
230 BaseModelInstance = TypeVar('BaseModelInstance', bound='BaseModel[Any]')
231
232
233 class BaseModel(Generic[BaseModelId]):
234     """Template for most of the models we use/derive from the DB."""
235     table_name = ''
236     to_save: list[str] = []
237     to_save_versioned: list[str] = []
238     to_save_relations: list[tuple[str, str, str, int]] = []
239     id_: None | BaseModelId
240     cache_: dict[BaseModelId, Self]
241     to_search: list[str] = []
242
243     def __init__(self, id_: BaseModelId | None) -> None:
244         if isinstance(id_, int) and id_ < 1:
245             msg = f'illegal {self.__class__.__name__} ID, must be >=1: {id_}'
246             raise HandledException(msg)
247         self.id_ = id_
248
249     def __eq__(self, other: object) -> bool:
250         if not isinstance(other, self.__class__):
251             return False
252         to_hash_me = tuple([self.id_] +
253                            [getattr(self, name) for name in self.to_save])
254         to_hash_other = tuple([other.id_] +
255                               [getattr(other, name) for name in other.to_save])
256         return hash(to_hash_me) == hash(to_hash_other)
257
258     def __lt__(self, other: Any) -> bool:
259         if not isinstance(other, self.__class__):
260             msg = 'cannot compare to object of different class'
261             raise HandledException(msg)
262         assert isinstance(self.id_, int)
263         assert isinstance(other.id_, int)
264         return self.id_ < other.id_
265
266     @classmethod
267     def get_cached(cls: type[BaseModelInstance],
268                    id_: BaseModelId) -> BaseModelInstance | None:
269         """Get object of id_ from class's cache, or None if not found."""
270         # pylint: disable=consider-iterating-dictionary
271         cache = cls.get_cache()
272         if id_ in cache.keys():
273             obj = cache[id_]
274             assert isinstance(obj, cls)
275             return obj
276         return None
277
278     @classmethod
279     def empty_cache(cls) -> None:
280         """Empty class's cache."""
281         cls.cache_ = {}
282
283     @classmethod
284     def get_cache(cls: type[BaseModelInstance]) -> dict[Any, BaseModel[Any]]:
285         """Get cache dictionary, create it if not yet existing."""
286         if not hasattr(cls, 'cache_'):
287             d: dict[Any, BaseModel[Any]] = {}
288             cls.cache_ = d
289         return cls.cache_
290
291     def cache(self) -> None:
292         """Update object in class's cache."""
293         if self.id_ is None:
294             raise HandledException('Cannot cache object without ID.')
295         cache = self.__class__.get_cache()
296         cache[self.id_] = self
297
298     def uncache(self) -> None:
299         """Remove self from cache."""
300         if self.id_ is None:
301             raise HandledException('Cannot un-cache object without ID.')
302         cache = self.__class__.get_cache()
303         del cache[self.id_]
304
305     @classmethod
306     def from_table_row(cls: type[BaseModelInstance],
307                        # pylint: disable=unused-argument
308                        db_conn: DatabaseConnection,
309                        row: Row | list[Any]) -> BaseModelInstance:
310         """Make from DB row, write to DB cache."""
311         obj = cls(*row)
312         obj.cache()
313         return obj
314
315     @classmethod
316     def by_id(cls, db_conn: DatabaseConnection,
317               id_: BaseModelId | None,
318               # pylint: disable=unused-argument
319               create: bool = False) -> Self:
320         """Retrieve by id_, on failure throw NotFoundException.
321
322         First try to get from cls.cache_, only then check DB; if found,
323         put into cache.
324
325         If create=True, make anew (but do not cache yet).
326         """
327         obj = None
328         if id_ is not None:
329             obj = cls.get_cached(id_)
330             if not obj:
331                 for row in db_conn.row_where(cls.table_name, 'id', id_):
332                     obj = cls.from_table_row(db_conn, row)
333                     obj.cache()
334                     break
335         if obj:
336             return obj
337         if create:
338             obj = cls(id_)
339             return obj
340         raise NotFoundException(f'found no object of ID {id_}')
341
342     @classmethod
343     def all(cls: type[BaseModelInstance],
344             db_conn: DatabaseConnection) -> list[BaseModelInstance]:
345         """Collect all objects of class into list.
346
347         Note that this primarily returns the contents of the cache, and only
348         _expands_ that by additional findings in the DB. This assumes the
349         cache is always instantly cleaned of any items that would be removed
350         from the DB.
351         """
352         items: dict[BaseModelId, BaseModelInstance] = {}
353         for k, v in cls.get_cache().items():
354             assert isinstance(v, cls)
355             items[k] = v
356         already_recorded = items.keys()
357         for id_ in db_conn.column_all(cls.table_name, 'id'):
358             if id_ not in already_recorded:
359                 item = cls.by_id(db_conn, id_)
360                 assert item.id_ is not None
361                 items[item.id_] = item
362         return list(items.values())
363
364     @classmethod
365     def by_date_range_with_limits(cls: type[BaseModelInstance],
366                                   db_conn: DatabaseConnection,
367                                   date_range: tuple[str, str],
368                                   date_col: str = 'day'
369                                   ) -> tuple[list[BaseModelInstance], str,
370                                              str]:
371         """Return list of Days in database within (open) date_range interval.
372
373         If no range values provided, defaults them to 'yesterday' and
374         'tomorrow'. Knows to properly interpret these and 'today' as value.
375         """
376         start_str = date_range[0] if date_range[0] else 'yesterday'
377         end_str = date_range[1] if date_range[1] else 'tomorrow'
378         start_date = valid_date(start_str)
379         end_date = valid_date(end_str)
380         items = []
381         sql = f'SELECT id FROM {cls.table_name} '
382         sql += f'WHERE {date_col} >= ? AND {date_col} <= ?'
383         for row in db_conn.exec(sql, (start_date, end_date)):
384             items += [cls.by_id(db_conn, row[0])]
385         return items, start_date, end_date
386
387     @classmethod
388     def matching(cls: type[BaseModelInstance], db_conn: DatabaseConnection,
389                  pattern: str) -> list[BaseModelInstance]:
390         """Return all objects whose .to_search match pattern."""
391         items = cls.all(db_conn)
392         if pattern:
393             filtered = []
394             for item in items:
395                 for attr_name in cls.to_search:
396                     toks = attr_name.split('.')
397                     parent = item
398                     for tok in toks:
399                         attr = getattr(parent, tok)
400                         parent = attr
401                     if pattern in attr:
402                         filtered += [item]
403                         break
404             return filtered
405         return items
406
407     def save(self, db_conn: DatabaseConnection) -> None:
408         """Write self to DB and cache and ensure .id_.
409
410         Write both to DB, and to cache. To DB, write .id_ and attributes
411         listed in cls.to_save[_versioned|_relations].
412
413         Ensure self.id_ by setting it to what the DB command returns as the
414         last saved row's ID (cursor.lastrowid), EXCEPT if self.id_ already
415         exists as a 'str', which implies we do our own ID creation (so far
416         only the case with the Day class, where it's to be a date string.
417         """
418         values = tuple([self.id_] + [getattr(self, key)
419                                      for key in self.to_save])
420         q_marks = DatabaseConnection.q_marks_from_values(values)
421         table_name = self.table_name
422         cursor = db_conn.exec(f'REPLACE INTO {table_name} VALUES {q_marks}',
423                               values)
424         if not isinstance(self.id_, str):
425             self.id_ = cursor.lastrowid  # type: ignore[assignment]
426         self.cache()
427         for attr_name in self.to_save_versioned:
428             getattr(self, attr_name).save(db_conn)
429         for table, column, attr_name, key_index in self.to_save_relations:
430             assert isinstance(self.id_, (int, str))
431             db_conn.rewrite_relations(table, column, self.id_,
432                                       [[i.id_] for i
433                                        in getattr(self, attr_name)], key_index)
434
435     def remove(self, db_conn: DatabaseConnection) -> None:
436         """Remove from DB and cache, including dependencies."""
437         if self.id_ is None or self.__class__.get_cached(self.id_) is None:
438             raise HandledException('cannot remove unsaved item')
439         for attr_name in self.to_save_versioned:
440             getattr(self, attr_name).remove(db_conn)
441         for table, column, attr_name, _ in self.to_save_relations:
442             db_conn.delete_where(table, column, self.id_)
443         self.uncache()
444         db_conn.delete_where(self.table_name, 'id', self.id_)