home · contact · privacy
a47dff15917651940483afe83f41152399bedf7f
[plomtask] / plomtask / db.py
1 """Database management."""
2 from __future__ import annotations
3 from os import listdir
4 from os.path import isfile
5 from difflib import Differ
6 from sqlite3 import connect as sql_connect, Cursor, Row
7 from typing import Any, Self, TypeVar, Generic
8 from plomtask.exceptions import HandledException, NotFoundException
9 from plomtask.dating import valid_date
10
11 EXPECTED_DB_VERSION = 5
12 MIGRATIONS_DIR = 'migrations'
13 FILENAME_DB_SCHEMA = f'init_{EXPECTED_DB_VERSION}.sql'
14 PATH_DB_SCHEMA = f'{MIGRATIONS_DIR}/{FILENAME_DB_SCHEMA}'
15
16
17 class UnmigratedDbException(HandledException):
18     """To identify case of unmigrated DB file."""
19
20
21 class DatabaseFile:
22     """Represents the sqlite3 database's file."""
23     # pylint: disable=too-few-public-methods
24
25     def __init__(self, path: str) -> None:
26         self.path = path
27         self._check()
28
29     @classmethod
30     def create_at(cls, path: str) -> DatabaseFile:
31         """Make new DB file at path."""
32         with sql_connect(path) as conn:
33             with open(PATH_DB_SCHEMA, 'r', encoding='utf-8') as f:
34                 conn.executescript(f.read())
35             conn.execute(f'PRAGMA user_version = {EXPECTED_DB_VERSION}')
36         return cls(path)
37
38     @classmethod
39     def migrate(cls, path: str) -> DatabaseFile:
40         """Apply migrations from_version to EXPECTED_DB_VERSION."""
41         migrations = cls._available_migrations()
42         from_version = cls._get_version_of_db(path)
43         migrations_todo = migrations[from_version+1:]
44         for j, filename in enumerate(migrations_todo):
45             with sql_connect(path) as conn:
46                 with open(f'{MIGRATIONS_DIR}/{filename}', 'r',
47                           encoding='utf-8') as f:
48                     conn.executescript(f.read())
49             user_version = from_version + j + 1
50             with sql_connect(path) as conn:
51                 conn.execute(f'PRAGMA user_version = {user_version}')
52         return cls(path)
53
54     def _check(self) -> None:
55         """Check file exists, and is of proper DB version and schema."""
56         if not isfile(self.path):
57             raise NotFoundException
58         if self._user_version != EXPECTED_DB_VERSION:
59             raise UnmigratedDbException()
60         self._validate_schema()
61
62     @staticmethod
63     def _available_migrations() -> list[str]:
64         """Validate migrations directory and return sorted entries."""
65         msg_too_big = 'Migration directory points beyond expected DB version.'
66         msg_bad_entry = 'Migration directory contains unexpected entry: '
67         msg_missing = 'Migration directory misses migration of number: '
68         migrations = {}
69         for entry in listdir(MIGRATIONS_DIR):
70             if entry == FILENAME_DB_SCHEMA:
71                 continue
72             toks = entry.split('_', 1)
73             if len(toks) < 2:
74                 raise HandledException(msg_bad_entry + entry)
75             try:
76                 i = int(toks[0])
77             except ValueError as e:
78                 raise HandledException(msg_bad_entry + entry) from e
79             if i > EXPECTED_DB_VERSION:
80                 raise HandledException(msg_too_big)
81             migrations[i] = toks[1]
82         migrations_list = []
83         for i in range(EXPECTED_DB_VERSION + 1):
84             if i not in migrations:
85                 raise HandledException(msg_missing + str(i))
86             migrations_list += [f'{i}_{migrations[i]}']
87         return migrations_list
88
89     @staticmethod
90     def _get_version_of_db(path: str) -> int:
91         """Get DB user_version, fail if outside expected range."""
92         sql_for_db_version = 'PRAGMA user_version'
93         with sql_connect(path) as conn:
94             db_version = list(conn.execute(sql_for_db_version))[0][0]
95         if db_version > EXPECTED_DB_VERSION:
96             msg = f'Wrong DB version, expected '\
97                     f'{EXPECTED_DB_VERSION}, got unknown {db_version}.'
98             raise HandledException(msg)
99         assert isinstance(db_version, int)
100         return db_version
101
102     @property
103     def _user_version(self) -> int:
104         """Get DB user_version."""
105         # pylint: disable=protected-access
106         # (since we remain within class)
107         return self.__class__._get_version_of_db(self.path)
108
109     def _validate_schema(self) -> None:
110         """Compare found schema with what's stored at PATH_DB_SCHEMA."""
111
112         def reformat_rows(rows: list[str]) -> list[str]:
113             new_rows = []
114             for row in rows:
115                 new_row = []
116                 for subrow in row.split('\n'):
117                     subrow = subrow.rstrip()
118                     in_parentheses = 0
119                     split_at = []
120                     for i, c in enumerate(subrow):
121                         if '(' == c:
122                             in_parentheses += 1
123                         elif ')' == c:
124                             in_parentheses -= 1
125                         elif ',' == c and 0 == in_parentheses:
126                             split_at += [i + 1]
127                     prev_split = 0
128                     for i in split_at:
129                         segment = subrow[prev_split:i].strip()
130                         if len(segment) > 0:
131                             new_row += [f'    {segment}']
132                         prev_split = i
133                     segment = subrow[prev_split:].strip()
134                     if len(segment) > 0:
135                         new_row += [f'    {segment}']
136                 new_row[0] = new_row[0].lstrip()
137                 new_row[-1] = new_row[-1].lstrip()
138                 if new_row[-1] != ')' and new_row[-3][-1] != ',':
139                     new_row[-3] = new_row[-3] + ','
140                     new_row[-2:] = ['    ' + new_row[-1][:-1]] + [')']
141                 new_rows += ['\n'.join(new_row)]
142             return new_rows
143
144         sql_for_schema = 'SELECT sql FROM sqlite_master ORDER BY sql'
145         msg_err = 'Database has wrong tables schema. Diff:\n'
146         with sql_connect(self.path) as conn:
147             schema_rows = [r[0] for r in conn.execute(sql_for_schema) if r[0]]
148         schema_rows = reformat_rows(schema_rows)
149         retrieved_schema = ';\n'.join(schema_rows) + ';'
150         with open(PATH_DB_SCHEMA, 'r', encoding='utf-8') as f:
151             stored_schema = f.read().rstrip()
152         if stored_schema != retrieved_schema:
153             diff_msg = Differ().compare(retrieved_schema.splitlines(),
154                                         stored_schema.splitlines())
155             raise HandledException(msg_err + '\n'.join(diff_msg))
156
157
158 class DatabaseConnection:
159     """A single connection to the database."""
160
161     def __init__(self, db_file: DatabaseFile) -> None:
162         self.conn = sql_connect(db_file.path)
163
164     def commit(self) -> None:
165         """Commit SQL transaction."""
166         self.conn.commit()
167
168     def exec(self, code: str, inputs: tuple[Any, ...] = tuple()) -> Cursor:
169         """Add commands to SQL transaction."""
170         return self.conn.execute(code, inputs)
171
172     def exec_on_vals(self, code: str, inputs: tuple[Any, ...]) -> Cursor:
173         """Wrapper around .exec appending adequate " (?, …)" to code."""
174         q_marks_from_values = '(' + ','.join(['?'] * len(inputs)) + ')'
175         return self.exec(f'{code} {q_marks_from_values}', inputs)
176
177     def close(self) -> None:
178         """Close DB connection."""
179         self.conn.close()
180
181     def rewrite_relations(self, table_name: str, key: str, target: int | str,
182                           rows: list[list[Any]], key_index: int = 0) -> None:
183         # pylint: disable=too-many-arguments
184         """Rewrite relations in table_name to target, with rows values.
185
186         Note that single rows are expected without the column and value
187         identified by key and target, which are inserted inside the function
188         at key_index.
189         """
190         self.delete_where(table_name, key, target)
191         for row in rows:
192             values = tuple(row[:key_index] + [target] + row[key_index:])
193             self.exec_on_vals(f'INSERT INTO {table_name} VALUES', values)
194
195     def row_where(self, table_name: str, key: str,
196                   target: int | str) -> list[Row]:
197         """Return list of Rows at table where key == target."""
198         return list(self.exec(f'SELECT * FROM {table_name} WHERE {key} = ?',
199                               (target,)))
200
201     # def column_where_pattern(self,
202     #                          table_name: str,
203     #                          column: str,
204     #                          pattern: str,
205     #                          keys: list[str]) -> list[Any]:
206     #     """Return column of rows where one of keys matches pattern."""
207     #     targets = tuple([f'%{pattern}%'] * len(keys))
208     #     haystack = ' OR '.join([f'{k} LIKE ?' for k in keys])
209     #     sql = f'SELECT {column} FROM {table_name} WHERE {haystack}'
210     #     return [row[0] for row in self.exec(sql, targets)]
211
212     def column_where(self, table_name: str, column: str, key: str,
213                      target: int | str) -> list[Any]:
214         """Return column of table where key == target."""
215         return [row[0] for row in
216                 self.exec(f'SELECT {column} FROM {table_name} '
217                           f'WHERE {key} = ?', (target,))]
218
219     def column_all(self, table_name: str, column: str) -> list[Any]:
220         """Return complete column of table."""
221         return [row[0] for row in
222                 self.exec(f'SELECT {column} FROM {table_name}')]
223
224     def delete_where(self, table_name: str, key: str,
225                      target: int | str) -> None:
226         """Delete from table where key == target."""
227         self.exec(f'DELETE FROM {table_name} WHERE {key} = ?', (target,))
228
229
230 BaseModelId = TypeVar('BaseModelId', int, str)
231 BaseModelInstance = TypeVar('BaseModelInstance', bound='BaseModel[Any]')
232
233
234 class BaseModel(Generic[BaseModelId]):
235     """Template for most of the models we use/derive from the DB."""
236     table_name = ''
237     to_save: list[str] = []
238     to_save_versioned: list[str] = []
239     to_save_relations: list[tuple[str, str, str, int]] = []
240     id_: None | BaseModelId
241     cache_: dict[BaseModelId, Self]
242     to_search: list[str] = []
243
244     def __init__(self, id_: BaseModelId | None) -> None:
245         if isinstance(id_, int) and id_ < 1:
246             msg = f'illegal {self.__class__.__name__} ID, must be >=1: {id_}'
247             raise HandledException(msg)
248         if isinstance(id_, str) and "" == id_:
249             msg = f'illegal {self.__class__.__name__} ID, must be non-empty'
250             raise HandledException(msg)
251         self.id_ = id_
252
253     def __hash__(self) -> int:
254         hashable = [self.id_] + [getattr(self, name) for name in self.to_save]
255         for definition in self.to_save_relations:
256             attr = getattr(self, definition[2])
257             hashable += [tuple(rel.id_ for rel in attr)]
258         for name in self.to_save_versioned:
259             hashable += [hash(getattr(self, name))]
260         return hash(tuple(hashable))
261
262     def __eq__(self, other: object) -> bool:
263         if not isinstance(other, self.__class__):
264             return False
265         return hash(self) == hash(other)
266
267     def __lt__(self, other: Any) -> bool:
268         if not isinstance(other, self.__class__):
269             msg = 'cannot compare to object of different class'
270             raise HandledException(msg)
271         assert isinstance(self.id_, int)
272         assert isinstance(other.id_, int)
273         return self.id_ < other.id_
274
275     @classmethod
276     def get_cached(cls: type[BaseModelInstance],
277                    id_: BaseModelId) -> BaseModelInstance | None:
278         """Get object of id_ from class's cache, or None if not found."""
279         # pylint: disable=consider-iterating-dictionary
280         cache = cls.get_cache()
281         if id_ in cache.keys():
282             obj = cache[id_]
283             assert isinstance(obj, cls)
284             return obj
285         return None
286
287     @classmethod
288     def empty_cache(cls) -> None:
289         """Empty class's cache."""
290         cls.cache_ = {}
291
292     @classmethod
293     def get_cache(cls: type[BaseModelInstance]) -> dict[Any, BaseModel[Any]]:
294         """Get cache dictionary, create it if not yet existing."""
295         if not hasattr(cls, 'cache_'):
296             d: dict[Any, BaseModel[Any]] = {}
297             cls.cache_ = d
298         return cls.cache_
299
300     def cache(self) -> None:
301         """Update object in class's cache."""
302         if self.id_ is None:
303             raise HandledException('Cannot cache object without ID.')
304         cache = self.__class__.get_cache()
305         cache[self.id_] = self
306
307     def uncache(self) -> None:
308         """Remove self from cache."""
309         if self.id_ is None:
310             raise HandledException('Cannot un-cache object without ID.')
311         cache = self.__class__.get_cache()
312         del cache[self.id_]
313
314     @classmethod
315     def from_table_row(cls: type[BaseModelInstance],
316                        # pylint: disable=unused-argument
317                        db_conn: DatabaseConnection,
318                        row: Row | list[Any]) -> BaseModelInstance:
319         """Make from DB row, write to DB cache."""
320         obj = cls(*row)
321         obj.cache()
322         return obj
323
324     @classmethod
325     def by_id(cls, db_conn: DatabaseConnection,
326               id_: BaseModelId | None,
327               # pylint: disable=unused-argument
328               create: bool = False) -> Self:
329         """Retrieve by id_, on failure throw NotFoundException.
330
331         First try to get from cls.cache_, only then check DB; if found,
332         put into cache.
333
334         If create=True, make anew (but do not cache yet).
335         """
336         obj = None
337         if id_ is not None:
338             obj = cls.get_cached(id_)
339             if not obj:
340                 for row in db_conn.row_where(cls.table_name, 'id', id_):
341                     obj = cls.from_table_row(db_conn, row)
342                     obj.cache()
343                     break
344         if obj:
345             return obj
346         if create:
347             obj = cls(id_)
348             return obj
349         raise NotFoundException(f'found no object of ID {id_}')
350
351     @classmethod
352     def all(cls: type[BaseModelInstance],
353             db_conn: DatabaseConnection) -> list[BaseModelInstance]:
354         """Collect all objects of class into list.
355
356         Note that this primarily returns the contents of the cache, and only
357         _expands_ that by additional findings in the DB. This assumes the
358         cache is always instantly cleaned of any items that would be removed
359         from the DB.
360         """
361         items: dict[BaseModelId, BaseModelInstance] = {}
362         for k, v in cls.get_cache().items():
363             assert isinstance(v, cls)
364             items[k] = v
365         already_recorded = items.keys()
366         for id_ in db_conn.column_all(cls.table_name, 'id'):
367             if id_ not in already_recorded:
368                 item = cls.by_id(db_conn, id_)
369                 assert item.id_ is not None
370                 items[item.id_] = item
371         return list(items.values())
372
373     @classmethod
374     def by_date_range_with_limits(cls: type[BaseModelInstance],
375                                   db_conn: DatabaseConnection,
376                                   date_range: tuple[str, str],
377                                   date_col: str = 'day'
378                                   ) -> tuple[list[BaseModelInstance], str,
379                                              str]:
380         """Return list of items in database within (open) date_range interval.
381
382         If no range values provided, defaults them to 'yesterday' and
383         'tomorrow'. Knows to properly interpret these and 'today' as value.
384         """
385         start_str = date_range[0] if date_range[0] else 'yesterday'
386         end_str = date_range[1] if date_range[1] else 'tomorrow'
387         start_date = valid_date(start_str)
388         end_date = valid_date(end_str)
389         items = []
390         sql = f'SELECT id FROM {cls.table_name} '
391         sql += f'WHERE {date_col} >= ? AND {date_col} <= ?'
392         for row in db_conn.exec(sql, (start_date, end_date)):
393             items += [cls.by_id(db_conn, row[0])]
394         return items, start_date, end_date
395
396     @classmethod
397     def matching(cls: type[BaseModelInstance], db_conn: DatabaseConnection,
398                  pattern: str) -> list[BaseModelInstance]:
399         """Return all objects whose .to_search match pattern."""
400         items = cls.all(db_conn)
401         if pattern:
402             filtered = []
403             for item in items:
404                 for attr_name in cls.to_search:
405                     toks = attr_name.split('.')
406                     parent = item
407                     for tok in toks:
408                         attr = getattr(parent, tok)
409                         parent = attr
410                     if pattern in attr:
411                         filtered += [item]
412                         break
413             return filtered
414         return items
415
416     def save(self, db_conn: DatabaseConnection) -> None:
417         """Write self to DB and cache and ensure .id_.
418
419         Write both to DB, and to cache. To DB, write .id_ and attributes
420         listed in cls.to_save[_versioned|_relations].
421
422         Ensure self.id_ by setting it to what the DB command returns as the
423         last saved row's ID (cursor.lastrowid), EXCEPT if self.id_ already
424         exists as a 'str', which implies we do our own ID creation (so far
425         only the case with the Day class, where it's to be a date string.
426         """
427         values = tuple([self.id_] + [getattr(self, key)
428                                      for key in self.to_save])
429         table_name = self.table_name
430         cursor = db_conn.exec_on_vals(f'REPLACE INTO {table_name} VALUES',
431                                       values)
432         if not isinstance(self.id_, str):
433             self.id_ = cursor.lastrowid  # type: ignore[assignment]
434         self.cache()
435         for attr_name in self.to_save_versioned:
436             getattr(self, attr_name).save(db_conn)
437         for table, column, attr_name, key_index in self.to_save_relations:
438             assert isinstance(self.id_, (int, str))
439             db_conn.rewrite_relations(table, column, self.id_,
440                                       [[i.id_] for i
441                                        in getattr(self, attr_name)], key_index)
442
443     def remove(self, db_conn: DatabaseConnection) -> None:
444         """Remove from DB and cache, including dependencies."""
445         if self.id_ is None or self.__class__.get_cached(self.id_) is None:
446             raise HandledException('cannot remove unsaved item')
447         for attr_name in self.to_save_versioned:
448             getattr(self, attr_name).remove(db_conn)
449         for table, column, attr_name, _ in self.to_save_relations:
450             db_conn.delete_where(table, column, self.id_)
451         self.uncache()
452         db_conn.delete_where(self.table_name, 'id', self.id_)