home · contact · privacy
b2f2142c9c6957c19e90674270a1635082050f59
[plomtask] / plomtask / db.py
1 """Database management."""
2 from __future__ import annotations
3 from os import listdir
4 from os.path import isfile
5 from difflib import Differ
6 from sqlite3 import connect as sql_connect, Cursor, Row
7 from typing import Any, Self, TypeVar, Generic
8 from plomtask.exceptions import HandledException, NotFoundException
9 from plomtask.dating import valid_date
10
11 EXPECTED_DB_VERSION = 5
12 MIGRATIONS_DIR = 'migrations'
13 FILENAME_DB_SCHEMA = f'init_{EXPECTED_DB_VERSION}.sql'
14 PATH_DB_SCHEMA = f'{MIGRATIONS_DIR}/{FILENAME_DB_SCHEMA}'
15
16
17 class UnmigratedDbException(HandledException):
18     """To identify case of unmigrated DB file."""
19
20
21 class DatabaseFile:
22     """Represents the sqlite3 database's file."""
23     # pylint: disable=too-few-public-methods
24
25     def __init__(self, path: str) -> None:
26         self.path = path
27         self._check()
28
29     @classmethod
30     def create_at(cls, path: str) -> DatabaseFile:
31         """Make new DB file at path."""
32         with sql_connect(path) as conn:
33             with open(PATH_DB_SCHEMA, 'r', encoding='utf-8') as f:
34                 conn.executescript(f.read())
35             conn.execute(f'PRAGMA user_version = {EXPECTED_DB_VERSION}')
36         return cls(path)
37
38     @classmethod
39     def migrate(cls, path: str) -> DatabaseFile:
40         """Apply migrations from_version to EXPECTED_DB_VERSION."""
41         migrations = cls._available_migrations()
42         from_version = cls._get_version_of_db(path)
43         migrations_todo = migrations[from_version+1:]
44         for j, filename in enumerate(migrations_todo):
45             with sql_connect(path) as conn:
46                 with open(f'{MIGRATIONS_DIR}/{filename}', 'r',
47                           encoding='utf-8') as f:
48                     conn.executescript(f.read())
49             user_version = from_version + j + 1
50             with sql_connect(path) as conn:
51                 conn.execute(f'PRAGMA user_version = {user_version}')
52         return cls(path)
53
54     def _check(self) -> None:
55         """Check file exists, and is of proper DB version and schema."""
56         if not isfile(self.path):
57             raise NotFoundException
58         if self._user_version != EXPECTED_DB_VERSION:
59             raise UnmigratedDbException()
60         self._validate_schema()
61
62     @staticmethod
63     def _available_migrations() -> list[str]:
64         """Validate migrations directory and return sorted entries."""
65         msg_too_big = 'Migration directory points beyond expected DB version.'
66         msg_bad_entry = 'Migration directory contains unexpected entry: '
67         msg_missing = 'Migration directory misses migration of number: '
68         migrations = {}
69         for entry in listdir(MIGRATIONS_DIR):
70             if entry == FILENAME_DB_SCHEMA:
71                 continue
72             toks = entry.split('_', 1)
73             if len(toks) < 2:
74                 raise HandledException(msg_bad_entry + entry)
75             try:
76                 i = int(toks[0])
77             except ValueError as e:
78                 raise HandledException(msg_bad_entry + entry) from e
79             if i > EXPECTED_DB_VERSION:
80                 raise HandledException(msg_too_big)
81             migrations[i] = toks[1]
82         migrations_list = []
83         for i in range(EXPECTED_DB_VERSION + 1):
84             if i not in migrations:
85                 raise HandledException(msg_missing + str(i))
86             migrations_list += [f'{i}_{migrations[i]}']
87         return migrations_list
88
89     @staticmethod
90     def _get_version_of_db(path: str) -> int:
91         """Get DB user_version, fail if outside expected range."""
92         sql_for_db_version = 'PRAGMA user_version'
93         with sql_connect(path) as conn:
94             db_version = list(conn.execute(sql_for_db_version))[0][0]
95         if db_version > EXPECTED_DB_VERSION:
96             msg = f'Wrong DB version, expected '\
97                     f'{EXPECTED_DB_VERSION}, got unknown {db_version}.'
98             raise HandledException(msg)
99         assert isinstance(db_version, int)
100         return db_version
101
102     @property
103     def _user_version(self) -> int:
104         """Get DB user_version."""
105         # pylint: disable=protected-access
106         # (since we remain within class)
107         return self.__class__._get_version_of_db(self.path)
108
109     def _validate_schema(self) -> None:
110         """Compare found schema with what's stored at PATH_DB_SCHEMA."""
111
112         def reformat_rows(rows: list[str]) -> list[str]:
113             new_rows = []
114             for row in rows:
115                 new_row = []
116                 for subrow in row.split('\n'):
117                     subrow = subrow.rstrip()
118                     in_parentheses = 0
119                     split_at = []
120                     for i, c in enumerate(subrow):
121                         if '(' == c:
122                             in_parentheses += 1
123                         elif ')' == c:
124                             in_parentheses -= 1
125                         elif ',' == c and 0 == in_parentheses:
126                             split_at += [i + 1]
127                     prev_split = 0
128                     for i in split_at:
129                         segment = subrow[prev_split:i].strip()
130                         if len(segment) > 0:
131                             new_row += [f'    {segment}']
132                         prev_split = i
133                     segment = subrow[prev_split:].strip()
134                     if len(segment) > 0:
135                         new_row += [f'    {segment}']
136                 new_row[0] = new_row[0].lstrip()
137                 new_row[-1] = new_row[-1].lstrip()
138                 if new_row[-1] != ')' and new_row[-3][-1] != ',':
139                     new_row[-3] = new_row[-3] + ','
140                     new_row[-2:] = ['    ' + new_row[-1][:-1]] + [')']
141                 new_rows += ['\n'.join(new_row)]
142             return new_rows
143
144         sql_for_schema = 'SELECT sql FROM sqlite_master ORDER BY sql'
145         msg_err = 'Database has wrong tables schema. Diff:\n'
146         with sql_connect(self.path) as conn:
147             schema_rows = [r[0] for r in conn.execute(sql_for_schema) if r[0]]
148         schema_rows = reformat_rows(schema_rows)
149         retrieved_schema = ';\n'.join(schema_rows) + ';'
150         with open(PATH_DB_SCHEMA, 'r', encoding='utf-8') as f:
151             stored_schema = f.read().rstrip()
152         if stored_schema != retrieved_schema:
153             diff_msg = Differ().compare(retrieved_schema.splitlines(),
154                                         stored_schema.splitlines())
155             raise HandledException(msg_err + '\n'.join(diff_msg))
156
157
158 class DatabaseConnection:
159     """A single connection to the database."""
160
161     def __init__(self, db_file: DatabaseFile) -> None:
162         self.conn = sql_connect(db_file.path)
163
164     def commit(self) -> None:
165         """Commit SQL transaction."""
166         self.conn.commit()
167
168     def exec(self, code: str, inputs: tuple[Any, ...] = tuple()) -> Cursor:
169         """Add commands to SQL transaction."""
170         return self.conn.execute(code, inputs)
171
172     def exec_on_vals(self, code: str, inputs: tuple[Any, ...]) -> Cursor:
173         """Wrapper around .exec appending adequate " (?, …)" to code."""
174         q_marks_from_values = '(' + ','.join(['?'] * len(inputs)) + ')'
175         return self.exec(f'{code} {q_marks_from_values}', inputs)
176
177     def close(self) -> None:
178         """Close DB connection."""
179         self.conn.close()
180
181     def rewrite_relations(self, table_name: str, key: str, target: int | str,
182                           rows: list[list[Any]], key_index: int = 0) -> None:
183         # pylint: disable=too-many-arguments
184         """Rewrite relations in table_name to target, with rows values.
185
186         Note that single rows are expected without the column and value
187         identified by key and target, which are inserted inside the function
188         at key_index.
189         """
190         self.delete_where(table_name, key, target)
191         for row in rows:
192             values = tuple(row[:key_index] + [target] + row[key_index:])
193             self.exec_on_vals(f'INSERT INTO {table_name} VALUES', values)
194
195     def row_where(self, table_name: str, key: str,
196                   target: int | str) -> list[Row]:
197         """Return list of Rows at table where key == target."""
198         return list(self.exec(f'SELECT * FROM {table_name} WHERE {key} = ?',
199                               (target,)))
200
201     # def column_where_pattern(self,
202     #                          table_name: str,
203     #                          column: str,
204     #                          pattern: str,
205     #                          keys: list[str]) -> list[Any]:
206     #     """Return column of rows where one of keys matches pattern."""
207     #     targets = tuple([f'%{pattern}%'] * len(keys))
208     #     haystack = ' OR '.join([f'{k} LIKE ?' for k in keys])
209     #     sql = f'SELECT {column} FROM {table_name} WHERE {haystack}'
210     #     return [row[0] for row in self.exec(sql, targets)]
211
212     def column_where(self, table_name: str, column: str, key: str,
213                      target: int | str) -> list[Any]:
214         """Return column of table where key == target."""
215         return [row[0] for row in
216                 self.exec(f'SELECT {column} FROM {table_name} '
217                           f'WHERE {key} = ?', (target,))]
218
219     def column_all(self, table_name: str, column: str) -> list[Any]:
220         """Return complete column of table."""
221         return [row[0] for row in
222                 self.exec(f'SELECT {column} FROM {table_name}')]
223
224     def delete_where(self, table_name: str, key: str,
225                      target: int | str) -> None:
226         """Delete from table where key == target."""
227         self.exec(f'DELETE FROM {table_name} WHERE {key} = ?', (target,))
228
229
230 BaseModelId = TypeVar('BaseModelId', int, str)
231 BaseModelInstance = TypeVar('BaseModelInstance', bound='BaseModel[Any]')
232
233
234 class BaseModel(Generic[BaseModelId]):
235     """Template for most of the models we use/derive from the DB."""
236     table_name = ''
237     to_save: list[str] = []
238     to_save_versioned: list[str] = []
239     to_save_relations: list[tuple[str, str, str, int]] = []
240     id_: None | BaseModelId
241     cache_: dict[BaseModelId, Self]
242     to_search: list[str] = []
243
244     def __init__(self, id_: BaseModelId | None) -> None:
245         if isinstance(id_, int) and id_ < 1:
246             msg = f'illegal {self.__class__.__name__} ID, must be >=1: {id_}'
247             raise HandledException(msg)
248         if isinstance(id_, str) and "" == id_:
249             msg = f'illegal {self.__class__.__name__} ID, must be non-empty'
250             raise HandledException(msg)
251         self.id_ = id_
252
253     def __eq__(self, other: object) -> bool:
254         if not isinstance(other, self.__class__):
255             return False
256         to_hash_me = tuple([self.id_] +
257                            [getattr(self, name) for name in self.to_save])
258         to_hash_other = tuple([other.id_] +
259                               [getattr(other, name) for name in other.to_save])
260         return hash(to_hash_me) == hash(to_hash_other)
261
262     def __lt__(self, other: Any) -> bool:
263         if not isinstance(other, self.__class__):
264             msg = 'cannot compare to object of different class'
265             raise HandledException(msg)
266         assert isinstance(self.id_, int)
267         assert isinstance(other.id_, int)
268         return self.id_ < other.id_
269
270     @classmethod
271     def get_cached(cls: type[BaseModelInstance],
272                    id_: BaseModelId) -> BaseModelInstance | None:
273         """Get object of id_ from class's cache, or None if not found."""
274         # pylint: disable=consider-iterating-dictionary
275         cache = cls.get_cache()
276         if id_ in cache.keys():
277             obj = cache[id_]
278             assert isinstance(obj, cls)
279             return obj
280         return None
281
282     @classmethod
283     def empty_cache(cls) -> None:
284         """Empty class's cache."""
285         cls.cache_ = {}
286
287     @classmethod
288     def get_cache(cls: type[BaseModelInstance]) -> dict[Any, BaseModel[Any]]:
289         """Get cache dictionary, create it if not yet existing."""
290         if not hasattr(cls, 'cache_'):
291             d: dict[Any, BaseModel[Any]] = {}
292             cls.cache_ = d
293         return cls.cache_
294
295     def cache(self) -> None:
296         """Update object in class's cache."""
297         if self.id_ is None:
298             raise HandledException('Cannot cache object without ID.')
299         cache = self.__class__.get_cache()
300         cache[self.id_] = self
301
302     def uncache(self) -> None:
303         """Remove self from cache."""
304         if self.id_ is None:
305             raise HandledException('Cannot un-cache object without ID.')
306         cache = self.__class__.get_cache()
307         del cache[self.id_]
308
309     @classmethod
310     def from_table_row(cls: type[BaseModelInstance],
311                        # pylint: disable=unused-argument
312                        db_conn: DatabaseConnection,
313                        row: Row | list[Any]) -> BaseModelInstance:
314         """Make from DB row, write to DB cache."""
315         obj = cls(*row)
316         obj.cache()
317         return obj
318
319     @classmethod
320     def by_id(cls, db_conn: DatabaseConnection,
321               id_: BaseModelId | None,
322               # pylint: disable=unused-argument
323               create: bool = False) -> Self:
324         """Retrieve by id_, on failure throw NotFoundException.
325
326         First try to get from cls.cache_, only then check DB; if found,
327         put into cache.
328
329         If create=True, make anew (but do not cache yet).
330         """
331         obj = None
332         if id_ is not None:
333             obj = cls.get_cached(id_)
334             if not obj:
335                 for row in db_conn.row_where(cls.table_name, 'id', id_):
336                     obj = cls.from_table_row(db_conn, row)
337                     obj.cache()
338                     break
339         if obj:
340             return obj
341         if create:
342             obj = cls(id_)
343             return obj
344         raise NotFoundException(f'found no object of ID {id_}')
345
346     @classmethod
347     def all(cls: type[BaseModelInstance],
348             db_conn: DatabaseConnection) -> list[BaseModelInstance]:
349         """Collect all objects of class into list.
350
351         Note that this primarily returns the contents of the cache, and only
352         _expands_ that by additional findings in the DB. This assumes the
353         cache is always instantly cleaned of any items that would be removed
354         from the DB.
355         """
356         items: dict[BaseModelId, BaseModelInstance] = {}
357         for k, v in cls.get_cache().items():
358             assert isinstance(v, cls)
359             items[k] = v
360         already_recorded = items.keys()
361         for id_ in db_conn.column_all(cls.table_name, 'id'):
362             if id_ not in already_recorded:
363                 item = cls.by_id(db_conn, id_)
364                 assert item.id_ is not None
365                 items[item.id_] = item
366         return list(items.values())
367
368     @classmethod
369     def by_date_range_with_limits(cls: type[BaseModelInstance],
370                                   db_conn: DatabaseConnection,
371                                   date_range: tuple[str, str],
372                                   date_col: str = 'day'
373                                   ) -> tuple[list[BaseModelInstance], str,
374                                              str]:
375         """Return list of items in database within (open) date_range interval.
376
377         If no range values provided, defaults them to 'yesterday' and
378         'tomorrow'. Knows to properly interpret these and 'today' as value.
379         """
380         start_str = date_range[0] if date_range[0] else 'yesterday'
381         end_str = date_range[1] if date_range[1] else 'tomorrow'
382         start_date = valid_date(start_str)
383         end_date = valid_date(end_str)
384         items = []
385         sql = f'SELECT id FROM {cls.table_name} '
386         sql += f'WHERE {date_col} >= ? AND {date_col} <= ?'
387         for row in db_conn.exec(sql, (start_date, end_date)):
388             items += [cls.by_id(db_conn, row[0])]
389         return items, start_date, end_date
390
391     @classmethod
392     def matching(cls: type[BaseModelInstance], db_conn: DatabaseConnection,
393                  pattern: str) -> list[BaseModelInstance]:
394         """Return all objects whose .to_search match pattern."""
395         items = cls.all(db_conn)
396         if pattern:
397             filtered = []
398             for item in items:
399                 for attr_name in cls.to_search:
400                     toks = attr_name.split('.')
401                     parent = item
402                     for tok in toks:
403                         attr = getattr(parent, tok)
404                         parent = attr
405                     if pattern in attr:
406                         filtered += [item]
407                         break
408             return filtered
409         return items
410
411     def save(self, db_conn: DatabaseConnection) -> None:
412         """Write self to DB and cache and ensure .id_.
413
414         Write both to DB, and to cache. To DB, write .id_ and attributes
415         listed in cls.to_save[_versioned|_relations].
416
417         Ensure self.id_ by setting it to what the DB command returns as the
418         last saved row's ID (cursor.lastrowid), EXCEPT if self.id_ already
419         exists as a 'str', which implies we do our own ID creation (so far
420         only the case with the Day class, where it's to be a date string.
421         """
422         values = tuple([self.id_] + [getattr(self, key)
423                                      for key in self.to_save])
424         table_name = self.table_name
425         cursor = db_conn.exec_on_vals(f'REPLACE INTO {table_name} VALUES',
426                                       values)
427         if not isinstance(self.id_, str):
428             self.id_ = cursor.lastrowid  # type: ignore[assignment]
429         self.cache()
430         for attr_name in self.to_save_versioned:
431             getattr(self, attr_name).save(db_conn)
432         for table, column, attr_name, key_index in self.to_save_relations:
433             assert isinstance(self.id_, (int, str))
434             db_conn.rewrite_relations(table, column, self.id_,
435                                       [[i.id_] for i
436                                        in getattr(self, attr_name)], key_index)
437
438     def remove(self, db_conn: DatabaseConnection) -> None:
439         """Remove from DB and cache, including dependencies."""
440         if self.id_ is None or self.__class__.get_cached(self.id_) is None:
441             raise HandledException('cannot remove unsaved item')
442         for attr_name in self.to_save_versioned:
443             getattr(self, attr_name).remove(db_conn)
444         for table, column, attr_name, _ in self.to_save_relations:
445             db_conn.delete_where(table, column, self.id_)
446         self.uncache()
447         db_conn.delete_where(self.table_name, 'id', self.id_)