home · contact · privacy
Extend tests.
[plomtask] / plomtask / db.py
1 """Database management."""
2 from __future__ import annotations
3 from os.path import isfile
4 from difflib import Differ
5 from sqlite3 import connect as sql_connect, Cursor, Row
6 from typing import Any, Self, TypeVar, Generic
7 from plomtask.exceptions import HandledException, NotFoundException
8
9 PATH_DB_SCHEMA = 'scripts/init.sql'
10 EXPECTED_DB_VERSION = 0
11
12
13 class DatabaseFile:  # pylint: disable=too-few-public-methods
14     """Represents the sqlite3 database's file."""
15
16     def __init__(self, path: str) -> None:
17         self.path = path
18         self._check()
19
20     def remake(self) -> None:
21         """Create tables in self.path file as per PATH_DB_SCHEMA sql file."""
22         with sql_connect(self.path) as conn:
23             with open(PATH_DB_SCHEMA, 'r', encoding='utf-8') as f:
24                 conn.executescript(f.read())
25         self._check()
26
27     def _check(self) -> None:
28         """Check file exists, and is of proper DB version and schema."""
29         self.exists = isfile(self.path)
30         if self.exists:
31             self._validate_user_version()
32             self._validate_schema()
33
34     def _validate_user_version(self) -> None:
35         """Compare DB user_version with EXPECTED_DB_VERSION."""
36         sql_for_db_version = 'PRAGMA user_version'
37         with sql_connect(self.path) as conn:
38             db_version = list(conn.execute(sql_for_db_version))[0][0]
39             if db_version != EXPECTED_DB_VERSION:
40                 msg = f'Wrong DB version, expected '\
41                         f'{EXPECTED_DB_VERSION}, got {db_version}.'
42                 raise HandledException(msg)
43
44     def _validate_schema(self) -> None:
45         """Compare found schema with what's stored at PATH_DB_SCHEMA."""
46         sql_for_schema = 'SELECT sql FROM sqlite_master ORDER BY sql'
47         msg_err = 'Database has wrong tables schema. Diff:\n'
48         with sql_connect(self.path) as conn:
49             schema_rows = [r[0] for r in conn.execute(sql_for_schema) if r[0]]
50             retrieved_schema = ';\n'.join(schema_rows) + ';'
51             with open(PATH_DB_SCHEMA, 'r', encoding='utf-8') as f:
52                 stored_schema = f.read().rstrip()
53                 if stored_schema != retrieved_schema:
54                     diff_msg = Differ().compare(retrieved_schema.splitlines(),
55                                                 stored_schema.splitlines())
56                     raise HandledException(msg_err + '\n'.join(diff_msg))
57
58
59 class DatabaseConnection:
60     """A single connection to the database."""
61
62     def __init__(self, db_file: DatabaseFile) -> None:
63         self.file = db_file
64         self.conn = sql_connect(self.file.path)
65
66     def commit(self) -> None:
67         """Commit SQL transaction."""
68         self.conn.commit()
69
70     def exec(self, code: str, inputs: tuple[Any, ...] = tuple()) -> Cursor:
71         """Add commands to SQL transaction."""
72         return self.conn.execute(code, inputs)
73
74     def close(self) -> None:
75         """Close DB connection."""
76         self.conn.close()
77
78     def rewrite_relations(self, table_name: str, key: str, target: int,
79                           rows: list[list[Any]]) -> None:
80         """Rewrite relations in table_name to target, with rows values."""
81         self.delete_where(table_name, key, target)
82         for row in rows:
83             values = tuple([target] + row)
84             q_marks = self.__class__.q_marks_from_values(values)
85             self.exec(f'INSERT INTO {table_name} VALUES {q_marks}', values)
86
87     def row_where(self, table_name: str, key: str,
88                   target: int | str) -> list[Row]:
89         """Return list of Rows at table where key == target."""
90         return list(self.exec(f'SELECT * FROM {table_name} WHERE {key} = ?',
91                               (target,)))
92
93     def column_where(self, table_name: str, column: str, key: str,
94                      target: int | str) -> list[Any]:
95         """Return column of table where key == target."""
96         return [row[0] for row in
97                 self.exec(f'SELECT {column} FROM {table_name} '
98                           f'WHERE {key} = ?', (target,))]
99
100     def column_all(self, table_name: str, column: str) -> list[Any]:
101         """Return complete column of table."""
102         return [row[0] for row in
103                 self.exec(f'SELECT {column} FROM {table_name}')]
104
105     def delete_where(self, table_name: str, key: str,
106                      target: int | str) -> None:
107         """Delete from table where key == target."""
108         self.exec(f'DELETE FROM {table_name} WHERE {key} = ?', (target,))
109
110     @staticmethod
111     def q_marks_from_values(values: tuple[Any]) -> str:
112         """Return placeholder to insert values into SQL code."""
113         return '(' + ','.join(['?'] * len(values)) + ')'
114
115
116 BaseModelId = TypeVar('BaseModelId', int, str)
117 BaseModelInstance = TypeVar('BaseModelInstance', bound='BaseModel[Any]')
118
119
120 class BaseModel(Generic[BaseModelId]):
121     """Template for most of the models we use/derive from the DB."""
122     table_name = ''
123     to_save: list[str] = []
124     id_: None | BaseModelId
125     cache_: dict[BaseModelId, Self]
126
127     def __init__(self, id_: BaseModelId | None) -> None:
128         if isinstance(id_, int) and id_ < 1:
129             msg = f'illegal {self.__class__.__name__} ID, must be >=1: {id_}'
130             raise HandledException(msg)
131         self.id_ = id_
132
133     def __eq__(self, other: object) -> bool:
134         if not isinstance(other, self.__class__):
135             return False
136         to_hash_me = tuple([self.id_] +
137                            [getattr(self, name) for name in self.to_save])
138         to_hash_other = tuple([other.id_] +
139                               [getattr(other, name) for name in other.to_save])
140         return hash(to_hash_me) == hash(to_hash_other)
141
142     def __lt__(self, other: Any) -> bool:
143         if not isinstance(other, self.__class__):
144             msg = 'cannot compare to object of different class'
145             raise HandledException(msg)
146         assert isinstance(self.id_, int)
147         assert isinstance(other.id_, int)
148         return self.id_ < other.id_
149
150     @classmethod
151     def get_cached(cls: type[BaseModelInstance],
152                    id_: BaseModelId) -> BaseModelInstance | None:
153         """Get object of id_ from class's cache, or None if not found."""
154         # pylint: disable=consider-iterating-dictionary
155         cache = cls.get_cache()
156         if id_ in cache.keys():
157             obj = cache[id_]
158             assert isinstance(obj, cls)
159             return obj
160         return None
161
162     @classmethod
163     def empty_cache(cls) -> None:
164         """Empty class's cache."""
165         cls.cache_ = {}
166
167     @classmethod
168     def get_cache(cls: type[BaseModelInstance]) -> dict[Any, BaseModel[Any]]:
169         """Get cache dictionary, create it if not yet existing."""
170         if not hasattr(cls, 'cache_'):
171             d: dict[Any, BaseModel[Any]] = {}
172             cls.cache_ = d
173         return cls.cache_
174
175     def cache(self) -> None:
176         """Update object in class's cache."""
177         if self.id_ is None:
178             raise HandledException('Cannot cache object without ID.')
179         cache = self.__class__.get_cache()
180         cache[self.id_] = self
181
182     def uncache(self) -> None:
183         """Remove self from cache."""
184         if self.id_ is None:
185             raise HandledException('Cannot un-cache object without ID.')
186         cache = self.__class__.get_cache()
187         del cache[self.id_]
188
189     @classmethod
190     def from_table_row(cls: type[BaseModelInstance],
191                        # pylint: disable=unused-argument
192                        db_conn: DatabaseConnection,
193                        row: Row | list[Any]) -> BaseModelInstance:
194         """Make from DB row, write to DB cache."""
195         obj = cls(*row)
196         obj.cache()
197         return obj
198
199     @classmethod
200     def by_id(cls, db_conn: DatabaseConnection,
201               id_: BaseModelId | None,
202               # pylint: disable=unused-argument
203               create: bool = False) -> Self:
204         """Retrieve by id_, on failure throw NotFoundException.
205
206         First try to get from cls.cache_, only then check DB; if found,
207         put into cache.
208
209         If create=True, make anew (but do not cache yet).
210         """
211         obj = None
212         if id_ is not None:
213             obj = cls.get_cached(id_)
214             if not obj:
215                 for row in db_conn.row_where(cls.table_name, 'id', id_):
216                     obj = cls.from_table_row(db_conn, row)
217                     obj.cache()
218                     break
219         if obj:
220             return obj
221         if create:
222             obj = cls(id_)
223             return obj
224         raise NotFoundException(f'found no object of ID {id_}')
225
226     @classmethod
227     def all(cls: type[BaseModelInstance],
228             db_conn: DatabaseConnection) -> list[BaseModelInstance]:
229         """Collect all objects of class into list.
230
231         Note that this primarily returns the contents of the cache, and only
232         _expands_ that by additional findings in the DB. This assumes the
233         cache is always instantly cleaned of any items that would be removed
234         from the DB.
235         """
236         items: dict[BaseModelId, BaseModelInstance] = {}
237         for k, v in cls.get_cache().items():
238             assert isinstance(v, cls)
239             items[k] = v
240         already_recorded = items.keys()
241         for id_ in db_conn.column_all(cls.table_name, 'id'):
242             if id_ not in already_recorded:
243                 item = cls.by_id(db_conn, id_)
244                 assert item.id_ is not None
245                 items[item.id_] = item
246         return list(items.values())
247
248     def save_core(self, db_conn: DatabaseConnection) -> None:
249         """Write bare-bones self (sans connected items), ensuring self.id_.
250
251         Write both to DB, and to cache. To DB, write .id_ and attributes
252         listed in cls.to_save.
253
254         Ensure self.id_ by setting it to what the DB command returns as the
255         last saved row's ID (cursor.lastrowid), EXCEPT if self.id_ already
256         exists as a 'str', which implies we do our own ID creation (so far
257         only the case with the Day class, where it's to be a date string.
258         """
259         values = tuple([self.id_] + [getattr(self, key)
260                                      for key in self.to_save])
261         q_marks = DatabaseConnection.q_marks_from_values(values)
262         table_name = self.table_name
263         cursor = db_conn.exec(f'REPLACE INTO {table_name} VALUES {q_marks}',
264                               values)
265         if not isinstance(self.id_, str):
266             self.id_ = cursor.lastrowid  # type: ignore[assignment]
267         self.cache()
268
269     def remove(self, db_conn: DatabaseConnection) -> None:
270         """Remove from DB and cache."""
271         if self.id_ is None or self.__class__.get_cached(self.id_) is None:
272             raise HandledException('cannot remove unsaved item')
273         self.uncache()
274         db_conn.delete_where(self.table_name, 'id', self.id_)