home · contact · privacy
Refactor database management code a little bit.
[plomtask] / plomtask / db.py
1 """Database management."""
2 from __future__ import annotations
3 from os import listdir
4 from os.path import isfile
5 from difflib import Differ
6 from sqlite3 import connect as sql_connect, Cursor, Row
7 from typing import Any, Self, TypeVar, Generic
8 from plomtask.exceptions import HandledException, NotFoundException
9 from plomtask.dating import valid_date
10
11 EXPECTED_DB_VERSION = 5
12 MIGRATIONS_DIR = 'migrations'
13 FILENAME_DB_SCHEMA = f'init_{EXPECTED_DB_VERSION}.sql'
14 PATH_DB_SCHEMA = f'{MIGRATIONS_DIR}/{FILENAME_DB_SCHEMA}'
15
16
17 class UnmigratedDbException(HandledException):
18     """To identify case of unmigrated DB file."""
19
20
21 class DatabaseFile:
22     """Represents the sqlite3 database's file."""
23     # pylint: disable=too-few-public-methods
24
25     def __init__(self, path: str) -> None:
26         self.path = path
27         self._check()
28
29     @classmethod
30     def create_at(cls, path: str) -> DatabaseFile:
31         """Make new DB file at path."""
32         with sql_connect(path) as conn:
33             with open(PATH_DB_SCHEMA, 'r', encoding='utf-8') as f:
34                 conn.executescript(f.read())
35             conn.execute(f'PRAGMA user_version = {EXPECTED_DB_VERSION}')
36         return cls(path)
37
38     @classmethod
39     def migrate(cls, path: str) -> DatabaseFile:
40         """Apply migrations from_version to EXPECTED_DB_VERSION."""
41         migrations = cls._available_migrations()
42         from_version = cls._get_version_of_db(path)
43         migrations_todo = migrations[from_version+1:]
44         for j, filename in enumerate(migrations_todo):
45             with sql_connect(path) as conn:
46                 with open(f'{MIGRATIONS_DIR}/{filename}', 'r',
47                           encoding='utf-8') as f:
48                     conn.executescript(f.read())
49             user_version = from_version + j + 1
50             with sql_connect(path) as conn:
51                 conn.execute(f'PRAGMA user_version = {user_version}')
52         return cls(path)
53
54     def _check(self) -> None:
55         """Check file exists, and is of proper DB version and schema."""
56         if not isfile(self.path):
57             raise NotFoundException
58         if self._user_version != EXPECTED_DB_VERSION:
59             raise UnmigratedDbException()
60         self._validate_schema()
61
62     @staticmethod
63     def _available_migrations() -> list[str]:
64         """Validate migrations directory and return sorted entries."""
65         msg_too_big = 'Migration directory points beyond expected DB version.'
66         msg_bad_entry = 'Migration directory contains unexpected entry: '
67         msg_missing = 'Migration directory misses migration of number: '
68         migrations = {}
69         for entry in listdir(MIGRATIONS_DIR):
70             if entry == FILENAME_DB_SCHEMA:
71                 continue
72             toks = entry.split('_', 1)
73             if len(toks) < 2:
74                 raise HandledException(msg_bad_entry + entry)
75             try:
76                 i = int(toks[0])
77             except ValueError as e:
78                 raise HandledException(msg_bad_entry + entry) from e
79             if i > EXPECTED_DB_VERSION:
80                 raise HandledException(msg_too_big)
81             migrations[i] = toks[1]
82         migrations_list = []
83         for i in range(EXPECTED_DB_VERSION + 1):
84             if i not in migrations:
85                 raise HandledException(msg_missing + str(i))
86             migrations_list += [f'{i}_{migrations[i]}']
87         return migrations_list
88
89     @staticmethod
90     def _get_version_of_db(path: str) -> int:
91         """Get DB user_version, fail if outside expected range."""
92         sql_for_db_version = 'PRAGMA user_version'
93         with sql_connect(path) as conn:
94             db_version = list(conn.execute(sql_for_db_version))[0][0]
95         if db_version > EXPECTED_DB_VERSION:
96             msg = f'Wrong DB version, expected '\
97                     f'{EXPECTED_DB_VERSION}, got unknown {db_version}.'
98             raise HandledException(msg)
99         assert isinstance(db_version, int)
100         return db_version
101
102     @property
103     def _user_version(self) -> int:
104         """Get DB user_version."""
105         # pylint: disable=protected-access
106         # (since we remain within class)
107         return self.__class__._get_version_of_db(self.path)
108
109     def _validate_schema(self) -> None:
110         """Compare found schema with what's stored at PATH_DB_SCHEMA."""
111
112         def reformat_rows(rows: list[str]) -> list[str]:
113             new_rows = []
114             for row in rows:
115                 new_row = []
116                 for subrow in row.split('\n'):
117                     subrow = subrow.rstrip()
118                     in_parentheses = 0
119                     split_at = []
120                     for i, c in enumerate(subrow):
121                         if '(' == c:
122                             in_parentheses += 1
123                         elif ')' == c:
124                             in_parentheses -= 1
125                         elif ',' == c and 0 == in_parentheses:
126                             split_at += [i + 1]
127                     prev_split = 0
128                     for i in split_at:
129                         segment = subrow[prev_split:i].strip()
130                         if len(segment) > 0:
131                             new_row += [f'    {segment}']
132                         prev_split = i
133                     segment = subrow[prev_split:].strip()
134                     if len(segment) > 0:
135                         new_row += [f'    {segment}']
136                 new_row[0] = new_row[0].lstrip()
137                 new_row[-1] = new_row[-1].lstrip()
138                 if new_row[-1] != ')' and new_row[-3][-1] != ',':
139                     new_row[-3] = new_row[-3] + ','
140                     new_row[-2:] = ['    ' + new_row[-1][:-1]] + [')']
141                 new_rows += ['\n'.join(new_row)]
142             return new_rows
143
144         sql_for_schema = 'SELECT sql FROM sqlite_master ORDER BY sql'
145         msg_err = 'Database has wrong tables schema. Diff:\n'
146         with sql_connect(self.path) as conn:
147             schema_rows = [r[0] for r in conn.execute(sql_for_schema) if r[0]]
148         schema_rows = reformat_rows(schema_rows)
149         retrieved_schema = ';\n'.join(schema_rows) + ';'
150         with open(PATH_DB_SCHEMA, 'r', encoding='utf-8') as f:
151             stored_schema = f.read().rstrip()
152         if stored_schema != retrieved_schema:
153             diff_msg = Differ().compare(retrieved_schema.splitlines(),
154                                         stored_schema.splitlines())
155             raise HandledException(msg_err + '\n'.join(diff_msg))
156
157
158 class DatabaseConnection:
159     """A single connection to the database."""
160
161     def __init__(self, db_file: DatabaseFile) -> None:
162         self.conn = sql_connect(db_file.path)
163
164     def commit(self) -> None:
165         """Commit SQL transaction."""
166         self.conn.commit()
167
168     def exec(self, code: str, inputs: tuple[Any, ...] = tuple()) -> Cursor:
169         """Add commands to SQL transaction."""
170         return self.conn.execute(code, inputs)
171
172     def exec_on_vals(self, code: str, inputs: tuple[Any, ...]) -> Cursor:
173         """Wrapper around .exec appending adequate " (?, …)" to code."""
174         q_marks_from_values = '(' + ','.join(['?'] * len(inputs)) + ')'
175         return self.exec(f'{code} {q_marks_from_values}', inputs)
176
177     def close(self) -> None:
178         """Close DB connection."""
179         self.conn.close()
180
181     def rewrite_relations(self, table_name: str, key: str, target: int | str,
182                           rows: list[list[Any]], key_index: int = 0) -> None:
183         # pylint: disable=too-many-arguments
184         """Rewrite relations in table_name to target, with rows values.
185
186         Note that single rows are expected without the column and value
187         identified by key and target, which are inserted inside the function
188         at key_index.
189         """
190         self.delete_where(table_name, key, target)
191         for row in rows:
192             values = tuple(row[:key_index] + [target] + row[key_index:])
193             self.exec_on_vals(f'INSERT INTO {table_name} VALUES', values)
194
195     def row_where(self, table_name: str, key: str,
196                   target: int | str) -> list[Row]:
197         """Return list of Rows at table where key == target."""
198         return list(self.exec(f'SELECT * FROM {table_name} WHERE {key} = ?',
199                               (target,)))
200
201     # def column_where_pattern(self,
202     #                          table_name: str,
203     #                          column: str,
204     #                          pattern: str,
205     #                          keys: list[str]) -> list[Any]:
206     #     """Return column of rows where one of keys matches pattern."""
207     #     targets = tuple([f'%{pattern}%'] * len(keys))
208     #     haystack = ' OR '.join([f'{k} LIKE ?' for k in keys])
209     #     sql = f'SELECT {column} FROM {table_name} WHERE {haystack}'
210     #     return [row[0] for row in self.exec(sql, targets)]
211
212     def column_where(self, table_name: str, column: str, key: str,
213                      target: int | str) -> list[Any]:
214         """Return column of table where key == target."""
215         return [row[0] for row in
216                 self.exec(f'SELECT {column} FROM {table_name} '
217                           f'WHERE {key} = ?', (target,))]
218
219     def column_all(self, table_name: str, column: str) -> list[Any]:
220         """Return complete column of table."""
221         return [row[0] for row in
222                 self.exec(f'SELECT {column} FROM {table_name}')]
223
224     def delete_where(self, table_name: str, key: str,
225                      target: int | str) -> None:
226         """Delete from table where key == target."""
227         self.exec(f'DELETE FROM {table_name} WHERE {key} = ?', (target,))
228
229
230 BaseModelId = TypeVar('BaseModelId', int, str)
231 BaseModelInstance = TypeVar('BaseModelInstance', bound='BaseModel[Any]')
232
233
234 class BaseModel(Generic[BaseModelId]):
235     """Template for most of the models we use/derive from the DB."""
236     table_name = ''
237     to_save: list[str] = []
238     to_save_versioned: list[str] = []
239     to_save_relations: list[tuple[str, str, str, int]] = []
240     id_: None | BaseModelId
241     cache_: dict[BaseModelId, Self]
242     to_search: list[str] = []
243
244     def __init__(self, id_: BaseModelId | None) -> None:
245         if isinstance(id_, int) and id_ < 1:
246             msg = f'illegal {self.__class__.__name__} ID, must be >=1: {id_}'
247             raise HandledException(msg)
248         self.id_ = id_
249
250     def __eq__(self, other: object) -> bool:
251         if not isinstance(other, self.__class__):
252             return False
253         to_hash_me = tuple([self.id_] +
254                            [getattr(self, name) for name in self.to_save])
255         to_hash_other = tuple([other.id_] +
256                               [getattr(other, name) for name in other.to_save])
257         return hash(to_hash_me) == hash(to_hash_other)
258
259     def __lt__(self, other: Any) -> bool:
260         if not isinstance(other, self.__class__):
261             msg = 'cannot compare to object of different class'
262             raise HandledException(msg)
263         assert isinstance(self.id_, int)
264         assert isinstance(other.id_, int)
265         return self.id_ < other.id_
266
267     @classmethod
268     def get_cached(cls: type[BaseModelInstance],
269                    id_: BaseModelId) -> BaseModelInstance | None:
270         """Get object of id_ from class's cache, or None if not found."""
271         # pylint: disable=consider-iterating-dictionary
272         cache = cls.get_cache()
273         if id_ in cache.keys():
274             obj = cache[id_]
275             assert isinstance(obj, cls)
276             return obj
277         return None
278
279     @classmethod
280     def empty_cache(cls) -> None:
281         """Empty class's cache."""
282         cls.cache_ = {}
283
284     @classmethod
285     def get_cache(cls: type[BaseModelInstance]) -> dict[Any, BaseModel[Any]]:
286         """Get cache dictionary, create it if not yet existing."""
287         if not hasattr(cls, 'cache_'):
288             d: dict[Any, BaseModel[Any]] = {}
289             cls.cache_ = d
290         return cls.cache_
291
292     def cache(self) -> None:
293         """Update object in class's cache."""
294         if self.id_ is None:
295             raise HandledException('Cannot cache object without ID.')
296         cache = self.__class__.get_cache()
297         cache[self.id_] = self
298
299     def uncache(self) -> None:
300         """Remove self from cache."""
301         if self.id_ is None:
302             raise HandledException('Cannot un-cache object without ID.')
303         cache = self.__class__.get_cache()
304         del cache[self.id_]
305
306     @classmethod
307     def from_table_row(cls: type[BaseModelInstance],
308                        # pylint: disable=unused-argument
309                        db_conn: DatabaseConnection,
310                        row: Row | list[Any]) -> BaseModelInstance:
311         """Make from DB row, write to DB cache."""
312         obj = cls(*row)
313         obj.cache()
314         return obj
315
316     @classmethod
317     def by_id(cls, db_conn: DatabaseConnection,
318               id_: BaseModelId | None,
319               # pylint: disable=unused-argument
320               create: bool = False) -> Self:
321         """Retrieve by id_, on failure throw NotFoundException.
322
323         First try to get from cls.cache_, only then check DB; if found,
324         put into cache.
325
326         If create=True, make anew (but do not cache yet).
327         """
328         obj = None
329         if id_ is not None:
330             obj = cls.get_cached(id_)
331             if not obj:
332                 for row in db_conn.row_where(cls.table_name, 'id', id_):
333                     obj = cls.from_table_row(db_conn, row)
334                     obj.cache()
335                     break
336         if obj:
337             return obj
338         if create:
339             obj = cls(id_)
340             return obj
341         raise NotFoundException(f'found no object of ID {id_}')
342
343     @classmethod
344     def all(cls: type[BaseModelInstance],
345             db_conn: DatabaseConnection) -> list[BaseModelInstance]:
346         """Collect all objects of class into list.
347
348         Note that this primarily returns the contents of the cache, and only
349         _expands_ that by additional findings in the DB. This assumes the
350         cache is always instantly cleaned of any items that would be removed
351         from the DB.
352         """
353         items: dict[BaseModelId, BaseModelInstance] = {}
354         for k, v in cls.get_cache().items():
355             assert isinstance(v, cls)
356             items[k] = v
357         already_recorded = items.keys()
358         for id_ in db_conn.column_all(cls.table_name, 'id'):
359             if id_ not in already_recorded:
360                 item = cls.by_id(db_conn, id_)
361                 assert item.id_ is not None
362                 items[item.id_] = item
363         return list(items.values())
364
365     @classmethod
366     def by_date_range_with_limits(cls: type[BaseModelInstance],
367                                   db_conn: DatabaseConnection,
368                                   date_range: tuple[str, str],
369                                   date_col: str = 'day'
370                                   ) -> tuple[list[BaseModelInstance], str,
371                                              str]:
372         """Return list of items in database within (open) date_range interval.
373
374         If no range values provided, defaults them to 'yesterday' and
375         'tomorrow'. Knows to properly interpret these and 'today' as value.
376         """
377         start_str = date_range[0] if date_range[0] else 'yesterday'
378         end_str = date_range[1] if date_range[1] else 'tomorrow'
379         start_date = valid_date(start_str)
380         end_date = valid_date(end_str)
381         items = []
382         sql = f'SELECT id FROM {cls.table_name} '
383         sql += f'WHERE {date_col} >= ? AND {date_col} <= ?'
384         for row in db_conn.exec(sql, (start_date, end_date)):
385             items += [cls.by_id(db_conn, row[0])]
386         return items, start_date, end_date
387
388     @classmethod
389     def matching(cls: type[BaseModelInstance], db_conn: DatabaseConnection,
390                  pattern: str) -> list[BaseModelInstance]:
391         """Return all objects whose .to_search match pattern."""
392         items = cls.all(db_conn)
393         if pattern:
394             filtered = []
395             for item in items:
396                 for attr_name in cls.to_search:
397                     toks = attr_name.split('.')
398                     parent = item
399                     for tok in toks:
400                         attr = getattr(parent, tok)
401                         parent = attr
402                     if pattern in attr:
403                         filtered += [item]
404                         break
405             return filtered
406         return items
407
408     def save(self, db_conn: DatabaseConnection) -> None:
409         """Write self to DB and cache and ensure .id_.
410
411         Write both to DB, and to cache. To DB, write .id_ and attributes
412         listed in cls.to_save[_versioned|_relations].
413
414         Ensure self.id_ by setting it to what the DB command returns as the
415         last saved row's ID (cursor.lastrowid), EXCEPT if self.id_ already
416         exists as a 'str', which implies we do our own ID creation (so far
417         only the case with the Day class, where it's to be a date string.
418         """
419         values = tuple([self.id_] + [getattr(self, key)
420                                      for key in self.to_save])
421         table_name = self.table_name
422         cursor = db_conn.exec_on_vals(f'REPLACE INTO {table_name} VALUES',
423                                       values)
424         if not isinstance(self.id_, str):
425             self.id_ = cursor.lastrowid  # type: ignore[assignment]
426         self.cache()
427         for attr_name in self.to_save_versioned:
428             getattr(self, attr_name).save(db_conn)
429         for table, column, attr_name, key_index in self.to_save_relations:
430             assert isinstance(self.id_, (int, str))
431             db_conn.rewrite_relations(table, column, self.id_,
432                                       [[i.id_] for i
433                                        in getattr(self, attr_name)], key_index)
434
435     def remove(self, db_conn: DatabaseConnection) -> None:
436         """Remove from DB and cache, including dependencies."""
437         if self.id_ is None or self.__class__.get_cached(self.id_) is None:
438             raise HandledException('cannot remove unsaved item')
439         for attr_name in self.to_save_versioned:
440             getattr(self, attr_name).remove(db_conn)
441         for table, column, attr_name, _ in self.to_save_relations:
442             db_conn.delete_where(table, column, self.id_)
443         self.uncache()
444         db_conn.delete_where(self.table_name, 'id', self.id_)