home · contact · privacy
Minor BaseModel code re-organization.
[plomtask] / plomtask / db.py
1 """Database management."""
2 from __future__ import annotations
3 from os import listdir
4 from os.path import isfile
5 from difflib import Differ
6 from sqlite3 import connect as sql_connect, Cursor, Row
7 from typing import Any, Self, TypeVar, Generic
8 from plomtask.exceptions import HandledException, NotFoundException
9 from plomtask.dating import valid_date
10
11 EXPECTED_DB_VERSION = 5
12 MIGRATIONS_DIR = 'migrations'
13 FILENAME_DB_SCHEMA = f'init_{EXPECTED_DB_VERSION}.sql'
14 PATH_DB_SCHEMA = f'{MIGRATIONS_DIR}/{FILENAME_DB_SCHEMA}'
15
16
17 class UnmigratedDbException(HandledException):
18     """To identify case of unmigrated DB file."""
19
20
21 class DatabaseFile:
22     """Represents the sqlite3 database's file."""
23     # pylint: disable=too-few-public-methods
24
25     def __init__(self, path: str) -> None:
26         self.path = path
27         self._check()
28
29     @classmethod
30     def create_at(cls, path: str) -> DatabaseFile:
31         """Make new DB file at path."""
32         with sql_connect(path) as conn:
33             with open(PATH_DB_SCHEMA, 'r', encoding='utf-8') as f:
34                 conn.executescript(f.read())
35             conn.execute(f'PRAGMA user_version = {EXPECTED_DB_VERSION}')
36         return cls(path)
37
38     @classmethod
39     def migrate(cls, path: str) -> DatabaseFile:
40         """Apply migrations from_version to EXPECTED_DB_VERSION."""
41         migrations = cls._available_migrations()
42         from_version = cls._get_version_of_db(path)
43         migrations_todo = migrations[from_version+1:]
44         for j, filename in enumerate(migrations_todo):
45             with sql_connect(path) as conn:
46                 with open(f'{MIGRATIONS_DIR}/{filename}', 'r',
47                           encoding='utf-8') as f:
48                     conn.executescript(f.read())
49             user_version = from_version + j + 1
50             with sql_connect(path) as conn:
51                 conn.execute(f'PRAGMA user_version = {user_version}')
52         return cls(path)
53
54     def _check(self) -> None:
55         """Check file exists, and is of proper DB version and schema."""
56         if not isfile(self.path):
57             raise NotFoundException
58         if self._user_version != EXPECTED_DB_VERSION:
59             raise UnmigratedDbException()
60         self._validate_schema()
61
62     @staticmethod
63     def _available_migrations() -> list[str]:
64         """Validate migrations directory and return sorted entries."""
65         msg_too_big = 'Migration directory points beyond expected DB version.'
66         msg_bad_entry = 'Migration directory contains unexpected entry: '
67         msg_missing = 'Migration directory misses migration of number: '
68         migrations = {}
69         for entry in listdir(MIGRATIONS_DIR):
70             if entry == FILENAME_DB_SCHEMA:
71                 continue
72             toks = entry.split('_', 1)
73             if len(toks) < 2:
74                 raise HandledException(msg_bad_entry + entry)
75             try:
76                 i = int(toks[0])
77             except ValueError as e:
78                 raise HandledException(msg_bad_entry + entry) from e
79             if i > EXPECTED_DB_VERSION:
80                 raise HandledException(msg_too_big)
81             migrations[i] = toks[1]
82         migrations_list = []
83         for i in range(EXPECTED_DB_VERSION + 1):
84             if i not in migrations:
85                 raise HandledException(msg_missing + str(i))
86             migrations_list += [f'{i}_{migrations[i]}']
87         return migrations_list
88
89     @staticmethod
90     def _get_version_of_db(path: str) -> int:
91         """Get DB user_version, fail if outside expected range."""
92         sql_for_db_version = 'PRAGMA user_version'
93         with sql_connect(path) as conn:
94             db_version = list(conn.execute(sql_for_db_version))[0][0]
95         if db_version > EXPECTED_DB_VERSION:
96             msg = f'Wrong DB version, expected '\
97                     f'{EXPECTED_DB_VERSION}, got unknown {db_version}.'
98             raise HandledException(msg)
99         assert isinstance(db_version, int)
100         return db_version
101
102     @property
103     def _user_version(self) -> int:
104         """Get DB user_version."""
105         # pylint: disable=protected-access
106         # (since we remain within class)
107         return self.__class__._get_version_of_db(self.path)
108
109     def _validate_schema(self) -> None:
110         """Compare found schema with what's stored at PATH_DB_SCHEMA."""
111
112         def reformat_rows(rows: list[str]) -> list[str]:
113             new_rows = []
114             for row in rows:
115                 new_row = []
116                 for subrow in row.split('\n'):
117                     subrow = subrow.rstrip()
118                     in_parentheses = 0
119                     split_at = []
120                     for i, c in enumerate(subrow):
121                         if '(' == c:
122                             in_parentheses += 1
123                         elif ')' == c:
124                             in_parentheses -= 1
125                         elif ',' == c and 0 == in_parentheses:
126                             split_at += [i + 1]
127                     prev_split = 0
128                     for i in split_at:
129                         segment = subrow[prev_split:i].strip()
130                         if len(segment) > 0:
131                             new_row += [f'    {segment}']
132                         prev_split = i
133                     segment = subrow[prev_split:].strip()
134                     if len(segment) > 0:
135                         new_row += [f'    {segment}']
136                 new_row[0] = new_row[0].lstrip()
137                 new_row[-1] = new_row[-1].lstrip()
138                 if new_row[-1] != ')' and new_row[-3][-1] != ',':
139                     new_row[-3] = new_row[-3] + ','
140                     new_row[-2:] = ['    ' + new_row[-1][:-1]] + [')']
141                 new_rows += ['\n'.join(new_row)]
142             return new_rows
143
144         sql_for_schema = 'SELECT sql FROM sqlite_master ORDER BY sql'
145         msg_err = 'Database has wrong tables schema. Diff:\n'
146         with sql_connect(self.path) as conn:
147             schema_rows = [r[0] for r in conn.execute(sql_for_schema) if r[0]]
148         schema_rows = reformat_rows(schema_rows)
149         retrieved_schema = ';\n'.join(schema_rows) + ';'
150         with open(PATH_DB_SCHEMA, 'r', encoding='utf-8') as f:
151             stored_schema = f.read().rstrip()
152         if stored_schema != retrieved_schema:
153             diff_msg = Differ().compare(retrieved_schema.splitlines(),
154                                         stored_schema.splitlines())
155             raise HandledException(msg_err + '\n'.join(diff_msg))
156
157
158 class DatabaseConnection:
159     """A single connection to the database."""
160
161     def __init__(self, db_file: DatabaseFile) -> None:
162         self.conn = sql_connect(db_file.path)
163
164     def commit(self) -> None:
165         """Commit SQL transaction."""
166         self.conn.commit()
167
168     def exec(self, code: str, inputs: tuple[Any, ...] = tuple()) -> Cursor:
169         """Add commands to SQL transaction."""
170         return self.conn.execute(code, inputs)
171
172     def exec_on_vals(self, code: str, inputs: tuple[Any, ...]) -> Cursor:
173         """Wrapper around .exec appending adequate " (?, …)" to code."""
174         q_marks_from_values = '(' + ','.join(['?'] * len(inputs)) + ')'
175         return self.exec(f'{code} {q_marks_from_values}', inputs)
176
177     def close(self) -> None:
178         """Close DB connection."""
179         self.conn.close()
180
181     def rewrite_relations(self, table_name: str, key: str, target: int | str,
182                           rows: list[list[Any]], key_index: int = 0) -> None:
183         # pylint: disable=too-many-arguments
184         """Rewrite relations in table_name to target, with rows values.
185
186         Note that single rows are expected without the column and value
187         identified by key and target, which are inserted inside the function
188         at key_index.
189         """
190         self.delete_where(table_name, key, target)
191         for row in rows:
192             values = tuple(row[:key_index] + [target] + row[key_index:])
193             self.exec_on_vals(f'INSERT INTO {table_name} VALUES', values)
194
195     def row_where(self, table_name: str, key: str,
196                   target: int | str) -> list[Row]:
197         """Return list of Rows at table where key == target."""
198         return list(self.exec(f'SELECT * FROM {table_name} WHERE {key} = ?',
199                               (target,)))
200
201     # def column_where_pattern(self,
202     #                          table_name: str,
203     #                          column: str,
204     #                          pattern: str,
205     #                          keys: list[str]) -> list[Any]:
206     #     """Return column of rows where one of keys matches pattern."""
207     #     targets = tuple([f'%{pattern}%'] * len(keys))
208     #     haystack = ' OR '.join([f'{k} LIKE ?' for k in keys])
209     #     sql = f'SELECT {column} FROM {table_name} WHERE {haystack}'
210     #     return [row[0] for row in self.exec(sql, targets)]
211
212     def column_where(self, table_name: str, column: str, key: str,
213                      target: int | str) -> list[Any]:
214         """Return column of table where key == target."""
215         return [row[0] for row in
216                 self.exec(f'SELECT {column} FROM {table_name} '
217                           f'WHERE {key} = ?', (target,))]
218
219     def column_all(self, table_name: str, column: str) -> list[Any]:
220         """Return complete column of table."""
221         return [row[0] for row in
222                 self.exec(f'SELECT {column} FROM {table_name}')]
223
224     def delete_where(self, table_name: str, key: str,
225                      target: int | str) -> None:
226         """Delete from table where key == target."""
227         self.exec(f'DELETE FROM {table_name} WHERE {key} = ?', (target,))
228
229
230 BaseModelId = TypeVar('BaseModelId', int, str)
231 BaseModelInstance = TypeVar('BaseModelInstance', bound='BaseModel[Any]')
232
233
234 class BaseModel(Generic[BaseModelId]):
235     """Template for most of the models we use/derive from the DB."""
236     table_name = ''
237     to_save: list[str] = []
238     to_save_versioned: list[str] = []
239     to_save_relations: list[tuple[str, str, str, int]] = []
240     id_: None | BaseModelId
241     cache_: dict[BaseModelId, Self]
242     to_search: list[str] = []
243
244     def __init__(self, id_: BaseModelId | None) -> None:
245         if isinstance(id_, int) and id_ < 1:
246             msg = f'illegal {self.__class__.__name__} ID, must be >=1: {id_}'
247             raise HandledException(msg)
248         if isinstance(id_, str) and "" == id_:
249             msg = f'illegal {self.__class__.__name__} ID, must be non-empty'
250             raise HandledException(msg)
251         self.id_ = id_
252
253     def __hash__(self) -> int:
254         hashable = [self.id_] + [getattr(self, name) for name in self.to_save]
255         for definition in self.to_save_relations:
256             attr = getattr(self, definition[2])
257             hashable += [tuple(rel.id_ for rel in attr)]
258         for name in self.to_save_versioned:
259             hashable += [hash(getattr(self, name))]
260         return hash(tuple(hashable))
261
262     def __eq__(self, other: object) -> bool:
263         if not isinstance(other, self.__class__):
264             return False
265         return hash(self) == hash(other)
266
267     def __lt__(self, other: Any) -> bool:
268         if not isinstance(other, self.__class__):
269             msg = 'cannot compare to object of different class'
270             raise HandledException(msg)
271         assert isinstance(self.id_, int)
272         assert isinstance(other.id_, int)
273         return self.id_ < other.id_
274
275     # cache management
276
277     @classmethod
278     def _get_cached(cls: type[BaseModelInstance],
279                     id_: BaseModelId) -> BaseModelInstance | None:
280         """Get object of id_ from class's cache, or None if not found."""
281         # pylint: disable=consider-iterating-dictionary
282         cache = cls.get_cache()
283         if id_ in cache.keys():
284             obj = cache[id_]
285             assert isinstance(obj, cls)
286             return obj
287         return None
288
289     @classmethod
290     def empty_cache(cls) -> None:
291         """Empty class's cache."""
292         cls.cache_ = {}
293
294     @classmethod
295     def get_cache(cls: type[BaseModelInstance]) -> dict[Any, BaseModel[Any]]:
296         """Get cache dictionary, create it if not yet existing."""
297         if not hasattr(cls, 'cache_'):
298             d: dict[Any, BaseModel[Any]] = {}
299             cls.cache_ = d
300         return cls.cache_
301
302     def cache(self) -> None:
303         """Update object in class's cache."""
304         if self.id_ is None:
305             raise HandledException('Cannot cache object without ID.')
306         cache = self.__class__.get_cache()
307         cache[self.id_] = self
308
309     def uncache(self) -> None:
310         """Remove self from cache."""
311         if self.id_ is None:
312             raise HandledException('Cannot un-cache object without ID.')
313         cache = self.__class__.get_cache()
314         del cache[self.id_]
315
316     # object retrieval and generation
317
318     @classmethod
319     def from_table_row(cls: type[BaseModelInstance],
320                        # pylint: disable=unused-argument
321                        db_conn: DatabaseConnection,
322                        row: Row | list[Any]) -> BaseModelInstance:
323         """Make from DB row, write to DB cache."""
324         obj = cls(*row)
325         obj.cache()
326         return obj
327
328     @classmethod
329     def by_id(cls, db_conn: DatabaseConnection,
330               id_: BaseModelId | None,
331               # pylint: disable=unused-argument
332               create: bool = False) -> Self:
333         """Retrieve by id_, on failure throw NotFoundException.
334
335         First try to get from cls.cache_, only then check DB; if found,
336         put into cache.
337
338         If create=True, make anew (but do not cache yet).
339         """
340         obj = None
341         if id_ is not None:
342             obj = cls._get_cached(id_)
343             if not obj:
344                 for row in db_conn.row_where(cls.table_name, 'id', id_):
345                     obj = cls.from_table_row(db_conn, row)
346                     obj.cache()
347                     break
348         if obj:
349             return obj
350         if create:
351             obj = cls(id_)
352             return obj
353         raise NotFoundException(f'found no object of ID {id_}')
354
355     @classmethod
356     def all(cls: type[BaseModelInstance],
357             db_conn: DatabaseConnection) -> list[BaseModelInstance]:
358         """Collect all objects of class into list.
359
360         Note that this primarily returns the contents of the cache, and only
361         _expands_ that by additional findings in the DB. This assumes the
362         cache is always instantly cleaned of any items that would be removed
363         from the DB.
364         """
365         items: dict[BaseModelId, BaseModelInstance] = {}
366         for k, v in cls.get_cache().items():
367             assert isinstance(v, cls)
368             items[k] = v
369         already_recorded = items.keys()
370         for id_ in db_conn.column_all(cls.table_name, 'id'):
371             if id_ not in already_recorded:
372                 item = cls.by_id(db_conn, id_)
373                 assert item.id_ is not None
374                 items[item.id_] = item
375         return list(items.values())
376
377     @classmethod
378     def by_date_range_with_limits(cls: type[BaseModelInstance],
379                                   db_conn: DatabaseConnection,
380                                   date_range: tuple[str, str],
381                                   date_col: str = 'day'
382                                   ) -> tuple[list[BaseModelInstance], str,
383                                              str]:
384         """Return list of items in database within (open) date_range interval.
385
386         If no range values provided, defaults them to 'yesterday' and
387         'tomorrow'. Knows to properly interpret these and 'today' as value.
388         """
389         start_str = date_range[0] if date_range[0] else 'yesterday'
390         end_str = date_range[1] if date_range[1] else 'tomorrow'
391         start_date = valid_date(start_str)
392         end_date = valid_date(end_str)
393         items = []
394         sql = f'SELECT id FROM {cls.table_name} '
395         sql += f'WHERE {date_col} >= ? AND {date_col} <= ?'
396         for row in db_conn.exec(sql, (start_date, end_date)):
397             items += [cls.by_id(db_conn, row[0])]
398         return items, start_date, end_date
399
400     @classmethod
401     def matching(cls: type[BaseModelInstance], db_conn: DatabaseConnection,
402                  pattern: str) -> list[BaseModelInstance]:
403         """Return all objects whose .to_search match pattern."""
404         items = cls.all(db_conn)
405         if pattern:
406             filtered = []
407             for item in items:
408                 for attr_name in cls.to_search:
409                     toks = attr_name.split('.')
410                     parent = item
411                     for tok in toks:
412                         attr = getattr(parent, tok)
413                         parent = attr
414                     if pattern in attr:
415                         filtered += [item]
416                         break
417             return filtered
418         return items
419
420     # database writing
421
422     def save(self, db_conn: DatabaseConnection) -> None:
423         """Write self to DB and cache and ensure .id_.
424
425         Write both to DB, and to cache. To DB, write .id_ and attributes
426         listed in cls.to_save[_versioned|_relations].
427
428         Ensure self.id_ by setting it to what the DB command returns as the
429         last saved row's ID (cursor.lastrowid), EXCEPT if self.id_ already
430         exists as a 'str', which implies we do our own ID creation (so far
431         only the case with the Day class, where it's to be a date string.
432         """
433         values = tuple([self.id_] + [getattr(self, key)
434                                      for key in self.to_save])
435         table_name = self.table_name
436         cursor = db_conn.exec_on_vals(f'REPLACE INTO {table_name} VALUES',
437                                       values)
438         if not isinstance(self.id_, str):
439             self.id_ = cursor.lastrowid  # type: ignore[assignment]
440         self.cache()
441         for attr_name in self.to_save_versioned:
442             getattr(self, attr_name).save(db_conn)
443         for table, column, attr_name, key_index in self.to_save_relations:
444             assert isinstance(self.id_, (int, str))
445             db_conn.rewrite_relations(table, column, self.id_,
446                                       [[i.id_] for i
447                                        in getattr(self, attr_name)], key_index)
448
449     def remove(self, db_conn: DatabaseConnection) -> None:
450         """Remove from DB and cache, including dependencies."""
451         # pylint: disable=protected-access
452         # (since we remain within class)
453         if self.id_ is None or self.__class__._get_cached(self.id_) is None:
454             raise HandledException('cannot remove unsaved item')
455         for attr_name in self.to_save_versioned:
456             getattr(self, attr_name).remove(db_conn)
457         for table, column, attr_name, _ in self.to_save_relations:
458             db_conn.delete_where(table, column, self.id_)
459         self.uncache()
460         db_conn.delete_where(self.table_name, 'id', self.id_)