home · contact · privacy
b5461a507e9612e2593643e6d6e198779e2fc456
[plomtask] / plomtask / db.py
1 """Database management."""
2 from __future__ import annotations
3 from os import listdir
4 from os.path import isfile
5 from difflib import Differ
6 from sqlite3 import connect as sql_connect, Cursor, Row
7 from typing import Any, Self, TypeVar, Generic
8 from plomtask.exceptions import HandledException, NotFoundException
9 from plomtask.dating import valid_date
10
11 EXPECTED_DB_VERSION = 4
12 MIGRATIONS_DIR = 'migrations'
13 FILENAME_DB_SCHEMA = f'init_{EXPECTED_DB_VERSION}.sql'
14 PATH_DB_SCHEMA = f'{MIGRATIONS_DIR}/{FILENAME_DB_SCHEMA}'
15
16
17 class UnmigratedDbException(HandledException):
18     """To identify case of unmigrated DB file."""
19
20
21 class DatabaseFile:  # pylint: disable=too-few-public-methods
22     """Represents the sqlite3 database's file."""
23
24     def __init__(self, path: str) -> None:
25         self.path = path
26         self._check()
27
28     @classmethod
29     def create_at(cls, path: str) -> DatabaseFile:
30         """Make new DB file at path."""
31         with sql_connect(path) as conn:
32             with open(PATH_DB_SCHEMA, 'r', encoding='utf-8') as f:
33                 conn.executescript(f.read())
34             conn.execute(f'PRAGMA user_version = {EXPECTED_DB_VERSION}')
35         return cls(path)
36
37     @classmethod
38     def migrate(cls, path: str) -> DatabaseFile:
39         """Apply migrations from_version to EXPECTED_DB_VERSION."""
40         migrations = cls._available_migrations()
41         from_version = cls.get_version_of_db(path)
42         migrations_todo = migrations[from_version+1:]
43         for j, filename in enumerate(migrations_todo):
44             with sql_connect(path) as conn:
45                 with open(f'{MIGRATIONS_DIR}/{filename}', 'r',
46                           encoding='utf-8') as f:
47                     conn.executescript(f.read())
48             user_version = from_version + j + 1
49             with sql_connect(path) as conn:
50                 conn.execute(f'PRAGMA user_version = {user_version}')
51         return cls(path)
52
53     def _check(self) -> None:
54         """Check file exists, and is of proper DB version and schema."""
55         if not isfile(self.path):
56             raise NotFoundException
57         if self.user_version != EXPECTED_DB_VERSION:
58             raise UnmigratedDbException()
59         self._validate_schema()
60
61     @staticmethod
62     def _available_migrations() -> list[str]:
63         """Validate migrations directory and return sorted entries."""
64         msg_too_big = 'Migration directory points beyond expected DB version.'
65         msg_bad_entry = 'Migration directory contains unexpected entry: '
66         msg_missing = 'Migration directory misses migration of number: '
67         migrations = {}
68         for entry in listdir(MIGRATIONS_DIR):
69             if entry == FILENAME_DB_SCHEMA:
70                 continue
71             toks = entry.split('_', 1)
72             if len(toks) < 2:
73                 raise HandledException(msg_bad_entry + entry)
74             try:
75                 i = int(toks[0])
76             except ValueError as e:
77                 raise HandledException(msg_bad_entry + entry) from e
78             if i > EXPECTED_DB_VERSION:
79                 raise HandledException(msg_too_big)
80             migrations[i] = toks[1]
81         migrations_list = []
82         for i in range(EXPECTED_DB_VERSION + 1):
83             if i not in migrations:
84                 raise HandledException(msg_missing + str(i))
85             migrations_list += [f'{i}_{migrations[i]}']
86         return migrations_list
87
88     @staticmethod
89     def get_version_of_db(path: str) -> int:
90         """Get DB user_version, fail if outside expected range."""
91         sql_for_db_version = 'PRAGMA user_version'
92         with sql_connect(path) as conn:
93             db_version = list(conn.execute(sql_for_db_version))[0][0]
94         if db_version > EXPECTED_DB_VERSION:
95             msg = f'Wrong DB version, expected '\
96                     f'{EXPECTED_DB_VERSION}, got unknown {db_version}.'
97             raise HandledException(msg)
98         assert isinstance(db_version, int)
99         return db_version
100
101     @property
102     def user_version(self) -> int:
103         """Get DB user_version."""
104         return self.__class__.get_version_of_db(self.path)
105
106     def _validate_schema(self) -> None:
107         """Compare found schema with what's stored at PATH_DB_SCHEMA."""
108
109         def reformat_rows(rows: list[str]) -> list[str]:
110             new_rows = []
111             for row in rows:
112                 new_row = []
113                 for subrow in row.split('\n'):
114                     subrow = subrow.rstrip()
115                     in_parentheses = 0
116                     split_at = []
117                     for i, c in enumerate(subrow):
118                         if '(' == c:
119                             in_parentheses += 1
120                         elif ')' == c:
121                             in_parentheses -= 1
122                         elif ',' == c and 0 == in_parentheses:
123                             split_at += [i + 1]
124                     prev_split = 0
125                     for i in split_at:
126                         segment = subrow[prev_split:i].strip()
127                         if len(segment) > 0:
128                             new_row += [f'    {segment}']
129                         prev_split = i
130                     segment = subrow[prev_split:].strip()
131                     if len(segment) > 0:
132                         new_row += [f'    {segment}']
133                 new_row[0] = new_row[0].lstrip()
134                 new_row[-1] = new_row[-1].lstrip()
135                 if new_row[-1] != ')' and new_row[-3][-1] != ',':
136                     new_row[-3] = new_row[-3] + ','
137                     new_row[-2:] = ['    ' + new_row[-1][:-1]] + [')']
138                 new_rows += ['\n'.join(new_row)]
139             return new_rows
140
141         sql_for_schema = 'SELECT sql FROM sqlite_master ORDER BY sql'
142         msg_err = 'Database has wrong tables schema. Diff:\n'
143         with sql_connect(self.path) as conn:
144             schema_rows = [r[0] for r in conn.execute(sql_for_schema) if r[0]]
145         schema_rows = reformat_rows(schema_rows)
146         retrieved_schema = ';\n'.join(schema_rows) + ';'
147         with open(PATH_DB_SCHEMA, 'r', encoding='utf-8') as f:
148             stored_schema = f.read().rstrip()
149         if stored_schema != retrieved_schema:
150             diff_msg = Differ().compare(retrieved_schema.splitlines(),
151                                         stored_schema.splitlines())
152             raise HandledException(msg_err + '\n'.join(diff_msg))
153
154
155 class DatabaseConnection:
156     """A single connection to the database."""
157
158     def __init__(self, db_file: DatabaseFile) -> None:
159         self.file = db_file
160         self.conn = sql_connect(self.file.path)
161
162     def commit(self) -> None:
163         """Commit SQL transaction."""
164         self.conn.commit()
165
166     def exec(self, code: str, inputs: tuple[Any, ...] = tuple()) -> Cursor:
167         """Add commands to SQL transaction."""
168         return self.conn.execute(code, inputs)
169
170     def close(self) -> None:
171         """Close DB connection."""
172         self.conn.close()
173
174     def rewrite_relations(self, table_name: str, key: str, target: int | str,
175                           rows: list[list[Any]]) -> None:
176         """Rewrite relations in table_name to target, with rows values."""
177         self.delete_where(table_name, key, target)
178         for row in rows:
179             values = tuple([target] + row)
180             q_marks = self.__class__.q_marks_from_values(values)
181             self.exec(f'INSERT INTO {table_name} VALUES {q_marks}', values)
182
183     def row_where(self, table_name: str, key: str,
184                   target: int | str) -> list[Row]:
185         """Return list of Rows at table where key == target."""
186         return list(self.exec(f'SELECT * FROM {table_name} WHERE {key} = ?',
187                               (target,)))
188
189     # def column_where_pattern(self,
190     #                          table_name: str,
191     #                          column: str,
192     #                          pattern: str,
193     #                          keys: list[str]) -> list[Any]:
194     #     """Return column of rows where one of keys matches pattern."""
195     #     targets = tuple([f'%{pattern}%'] * len(keys))
196     #     haystack = ' OR '.join([f'{k} LIKE ?' for k in keys])
197     #     sql = f'SELECT {column} FROM {table_name} WHERE {haystack}'
198     #     return [row[0] for row in self.exec(sql, targets)]
199
200     def column_where(self, table_name: str, column: str, key: str,
201                      target: int | str) -> list[Any]:
202         """Return column of table where key == target."""
203         return [row[0] for row in
204                 self.exec(f'SELECT {column} FROM {table_name} '
205                           f'WHERE {key} = ?', (target,))]
206
207     def column_all(self, table_name: str, column: str) -> list[Any]:
208         """Return complete column of table."""
209         return [row[0] for row in
210                 self.exec(f'SELECT {column} FROM {table_name}')]
211
212     def delete_where(self, table_name: str, key: str,
213                      target: int | str) -> None:
214         """Delete from table where key == target."""
215         self.exec(f'DELETE FROM {table_name} WHERE {key} = ?', (target,))
216
217     @staticmethod
218     def q_marks_from_values(values: tuple[Any]) -> str:
219         """Return placeholder to insert values into SQL code."""
220         return '(' + ','.join(['?'] * len(values)) + ')'
221
222
223 BaseModelId = TypeVar('BaseModelId', int, str)
224 BaseModelInstance = TypeVar('BaseModelInstance', bound='BaseModel[Any]')
225
226
227 class BaseModel(Generic[BaseModelId]):
228     """Template for most of the models we use/derive from the DB."""
229     table_name = ''
230     to_save: list[str] = []
231     to_save_versioned: list[str] = []
232     to_save_relations: list[tuple[str, str, str]] = []
233     id_: None | BaseModelId
234     cache_: dict[BaseModelId, Self]
235     to_search: list[str] = []
236
237     def __init__(self, id_: BaseModelId | None) -> None:
238         if isinstance(id_, int) and id_ < 1:
239             msg = f'illegal {self.__class__.__name__} ID, must be >=1: {id_}'
240             raise HandledException(msg)
241         self.id_ = id_
242
243     def __eq__(self, other: object) -> bool:
244         if not isinstance(other, self.__class__):
245             return False
246         to_hash_me = tuple([self.id_] +
247                            [getattr(self, name) for name in self.to_save])
248         to_hash_other = tuple([other.id_] +
249                               [getattr(other, name) for name in other.to_save])
250         return hash(to_hash_me) == hash(to_hash_other)
251
252     def __lt__(self, other: Any) -> bool:
253         if not isinstance(other, self.__class__):
254             msg = 'cannot compare to object of different class'
255             raise HandledException(msg)
256         assert isinstance(self.id_, int)
257         assert isinstance(other.id_, int)
258         return self.id_ < other.id_
259
260     @classmethod
261     def get_cached(cls: type[BaseModelInstance],
262                    id_: BaseModelId) -> BaseModelInstance | None:
263         """Get object of id_ from class's cache, or None if not found."""
264         # pylint: disable=consider-iterating-dictionary
265         cache = cls.get_cache()
266         if id_ in cache.keys():
267             obj = cache[id_]
268             assert isinstance(obj, cls)
269             return obj
270         return None
271
272     @classmethod
273     def empty_cache(cls) -> None:
274         """Empty class's cache."""
275         cls.cache_ = {}
276
277     @classmethod
278     def get_cache(cls: type[BaseModelInstance]) -> dict[Any, BaseModel[Any]]:
279         """Get cache dictionary, create it if not yet existing."""
280         if not hasattr(cls, 'cache_'):
281             d: dict[Any, BaseModel[Any]] = {}
282             cls.cache_ = d
283         return cls.cache_
284
285     def cache(self) -> None:
286         """Update object in class's cache."""
287         if self.id_ is None:
288             raise HandledException('Cannot cache object without ID.')
289         cache = self.__class__.get_cache()
290         cache[self.id_] = self
291
292     def uncache(self) -> None:
293         """Remove self from cache."""
294         if self.id_ is None:
295             raise HandledException('Cannot un-cache object without ID.')
296         cache = self.__class__.get_cache()
297         del cache[self.id_]
298
299     @classmethod
300     def from_table_row(cls: type[BaseModelInstance],
301                        # pylint: disable=unused-argument
302                        db_conn: DatabaseConnection,
303                        row: Row | list[Any]) -> BaseModelInstance:
304         """Make from DB row, write to DB cache."""
305         obj = cls(*row)
306         obj.cache()
307         return obj
308
309     @classmethod
310     def by_id(cls, db_conn: DatabaseConnection,
311               id_: BaseModelId | None,
312               # pylint: disable=unused-argument
313               create: bool = False) -> Self:
314         """Retrieve by id_, on failure throw NotFoundException.
315
316         First try to get from cls.cache_, only then check DB; if found,
317         put into cache.
318
319         If create=True, make anew (but do not cache yet).
320         """
321         obj = None
322         if id_ is not None:
323             obj = cls.get_cached(id_)
324             if not obj:
325                 for row in db_conn.row_where(cls.table_name, 'id', id_):
326                     obj = cls.from_table_row(db_conn, row)
327                     obj.cache()
328                     break
329         if obj:
330             return obj
331         if create:
332             obj = cls(id_)
333             return obj
334         raise NotFoundException(f'found no object of ID {id_}')
335
336     @classmethod
337     def all(cls: type[BaseModelInstance],
338             db_conn: DatabaseConnection) -> list[BaseModelInstance]:
339         """Collect all objects of class into list.
340
341         Note that this primarily returns the contents of the cache, and only
342         _expands_ that by additional findings in the DB. This assumes the
343         cache is always instantly cleaned of any items that would be removed
344         from the DB.
345         """
346         items: dict[BaseModelId, BaseModelInstance] = {}
347         for k, v in cls.get_cache().items():
348             assert isinstance(v, cls)
349             items[k] = v
350         already_recorded = items.keys()
351         for id_ in db_conn.column_all(cls.table_name, 'id'):
352             if id_ not in already_recorded:
353                 item = cls.by_id(db_conn, id_)
354                 assert item.id_ is not None
355                 items[item.id_] = item
356         return list(items.values())
357
358     @classmethod
359     def by_date_range_with_limits(cls: type[BaseModelInstance],
360                                   db_conn: DatabaseConnection,
361                                   date_range: tuple[str, str],
362                                   date_col: str = 'day'
363                                   ) -> tuple[list[BaseModelInstance], str,
364                                              str]:
365         """Return list of Days in database within (open) date_range interval.
366
367         If no range values provided, defaults them to 'yesterday' and
368         'tomorrow'. Knows to properly interpret these and 'today' as value.
369         """
370         start_str = date_range[0] if date_range[0] else 'yesterday'
371         end_str = date_range[1] if date_range[1] else 'tomorrow'
372         start_date = valid_date(start_str)
373         end_date = valid_date(end_str)
374         items = []
375         sql = f'SELECT id FROM {cls.table_name} '
376         sql += f'WHERE {date_col} >= ? AND {date_col} <= ?'
377         for row in db_conn.exec(sql, (start_date, end_date)):
378             items += [cls.by_id(db_conn, row[0])]
379         return items, start_date, end_date
380
381     @classmethod
382     def matching(cls: type[BaseModelInstance], db_conn: DatabaseConnection,
383                  pattern: str) -> list[BaseModelInstance]:
384         """Return all objects whose .to_search match pattern."""
385         items = cls.all(db_conn)
386         if pattern:
387             filtered = []
388             for item in items:
389                 for attr_name in cls.to_search:
390                     toks = attr_name.split('.')
391                     parent = item
392                     for tok in toks:
393                         attr = getattr(parent, tok)
394                         parent = attr
395                     if pattern in attr:
396                         filtered += [item]
397                         break
398             return filtered
399         return items
400
401     def save(self, db_conn: DatabaseConnection) -> None:
402         """Write self to DB and cache and ensure .id_.
403
404         Write both to DB, and to cache. To DB, write .id_ and attributes
405         listed in cls.to_save[_versioned|_relations].
406
407         Ensure self.id_ by setting it to what the DB command returns as the
408         last saved row's ID (cursor.lastrowid), EXCEPT if self.id_ already
409         exists as a 'str', which implies we do our own ID creation (so far
410         only the case with the Day class, where it's to be a date string.
411         """
412         values = tuple([self.id_] + [getattr(self, key)
413                                      for key in self.to_save])
414         q_marks = DatabaseConnection.q_marks_from_values(values)
415         table_name = self.table_name
416         cursor = db_conn.exec(f'REPLACE INTO {table_name} VALUES {q_marks}',
417                               values)
418         if not isinstance(self.id_, str):
419             self.id_ = cursor.lastrowid  # type: ignore[assignment]
420         self.cache()
421         for attr_name in self.to_save_versioned:
422             getattr(self, attr_name).save(db_conn)
423         for table, column, attr_name in self.to_save_relations:
424             assert isinstance(self.id_, (int, str))
425             db_conn.rewrite_relations(table, column, self.id_,
426                                       [[i.id_] for i
427                                        in getattr(self, attr_name)])
428
429     def remove(self, db_conn: DatabaseConnection) -> None:
430         """Remove from DB and cache, including dependencies."""
431         if self.id_ is None or self.__class__.get_cached(self.id_) is None:
432             raise HandledException('cannot remove unsaved item')
433         for attr_name in self.to_save_versioned:
434             getattr(self, attr_name).remove(db_conn)
435         for table, column, attr_name in self.to_save_relations:
436             db_conn.delete_where(table, column, self.id_)
437         self.uncache()
438         db_conn.delete_where(self.table_name, 'id', self.id_)