home · contact · privacy
In Calendar view, highlight today's date.
[plomtask] / plomtask / db.py
1 """Database management."""
2 from __future__ import annotations
3 from os.path import isfile
4 from difflib import Differ
5 from sqlite3 import connect as sql_connect, Cursor, Row
6 from typing import Any, Self, TypeVar, Generic
7 from plomtask.exceptions import HandledException, NotFoundException
8
9 PATH_DB_SCHEMA = 'scripts/init.sql'
10 EXPECTED_DB_VERSION = 0
11
12
13 class DatabaseFile:  # pylint: disable=too-few-public-methods
14     """Represents the sqlite3 database's file."""
15
16     def __init__(self, path: str) -> None:
17         self.path = path
18         self._check()
19
20     def remake(self) -> None:
21         """Create tables in self.path file as per PATH_DB_SCHEMA sql file."""
22         with sql_connect(self.path) as conn:
23             with open(PATH_DB_SCHEMA, 'r', encoding='utf-8') as f:
24                 conn.executescript(f.read())
25         self._check()
26
27     def _check(self) -> None:
28         """Check file exists, and is of proper DB version and schema."""
29         self.exists = isfile(self.path)
30         if self.exists:
31             self._validate_user_version()
32             self._validate_schema()
33
34     def _validate_user_version(self) -> None:
35         """Compare DB user_version with EXPECTED_DB_VERSION."""
36         sql_for_db_version = 'PRAGMA user_version'
37         with sql_connect(self.path) as conn:
38             db_version = list(conn.execute(sql_for_db_version))[0][0]
39             if db_version != EXPECTED_DB_VERSION:
40                 msg = f'Wrong DB version, expected '\
41                         f'{EXPECTED_DB_VERSION}, got {db_version}.'
42                 raise HandledException(msg)
43
44     def _validate_schema(self) -> None:
45         """Compare found schema with what's stored at PATH_DB_SCHEMA."""
46         sql_for_schema = 'SELECT sql FROM sqlite_master ORDER BY sql'
47         msg_err = 'Database has wrong tables schema. Diff:\n'
48         with sql_connect(self.path) as conn:
49             schema_rows = [r[0] for r in conn.execute(sql_for_schema) if r[0]]
50             retrieved_schema = ';\n'.join(schema_rows) + ';'
51             with open(PATH_DB_SCHEMA, 'r', encoding='utf-8') as f:
52                 stored_schema = f.read().rstrip()
53                 if stored_schema != retrieved_schema:
54                     diff_msg = Differ().compare(retrieved_schema.splitlines(),
55                                                 stored_schema.splitlines())
56                     raise HandledException(msg_err + '\n'.join(diff_msg))
57
58
59 class DatabaseConnection:
60     """A single connection to the database."""
61
62     def __init__(self, db_file: DatabaseFile) -> None:
63         self.file = db_file
64         self.conn = sql_connect(self.file.path)
65
66     def commit(self) -> None:
67         """Commit SQL transaction."""
68         self.conn.commit()
69
70     def exec(self, code: str, inputs: tuple[Any, ...] = tuple()) -> Cursor:
71         """Add commands to SQL transaction."""
72         return self.conn.execute(code, inputs)
73
74     def close(self) -> None:
75         """Close DB connection."""
76         self.conn.close()
77
78     def rewrite_relations(self, table_name: str, key: str, target: int | str,
79                           rows: list[list[Any]]) -> None:
80         """Rewrite relations in table_name to target, with rows values."""
81         self.delete_where(table_name, key, target)
82         for row in rows:
83             values = tuple([target] + row)
84             q_marks = self.__class__.q_marks_from_values(values)
85             self.exec(f'INSERT INTO {table_name} VALUES {q_marks}', values)
86
87     def row_where(self, table_name: str, key: str,
88                   target: int | str) -> list[Row]:
89         """Return list of Rows at table where key == target."""
90         return list(self.exec(f'SELECT * FROM {table_name} WHERE {key} = ?',
91                               (target,)))
92
93     def column_where(self, table_name: str, column: str, key: str,
94                      target: int | str) -> list[Any]:
95         """Return column of table where key == target."""
96         return [row[0] for row in
97                 self.exec(f'SELECT {column} FROM {table_name} '
98                           f'WHERE {key} = ?', (target,))]
99
100     def column_all(self, table_name: str, column: str) -> list[Any]:
101         """Return complete column of table."""
102         return [row[0] for row in
103                 self.exec(f'SELECT {column} FROM {table_name}')]
104
105     def delete_where(self, table_name: str, key: str,
106                      target: int | str) -> None:
107         """Delete from table where key == target."""
108         self.exec(f'DELETE FROM {table_name} WHERE {key} = ?', (target,))
109
110     @staticmethod
111     def q_marks_from_values(values: tuple[Any]) -> str:
112         """Return placeholder to insert values into SQL code."""
113         return '(' + ','.join(['?'] * len(values)) + ')'
114
115
116 BaseModelId = TypeVar('BaseModelId', int, str)
117 BaseModelInstance = TypeVar('BaseModelInstance', bound='BaseModel[Any]')
118
119
120 class BaseModel(Generic[BaseModelId]):
121     """Template for most of the models we use/derive from the DB."""
122     table_name = ''
123     to_save: list[str] = []
124     to_save_versioned: list[str] = []
125     to_save_relations: list[tuple[str, str, str]] = []
126     id_: None | BaseModelId
127     cache_: dict[BaseModelId, Self]
128
129     def __init__(self, id_: BaseModelId | None) -> None:
130         if isinstance(id_, int) and id_ < 1:
131             msg = f'illegal {self.__class__.__name__} ID, must be >=1: {id_}'
132             raise HandledException(msg)
133         self.id_ = id_
134
135     def __eq__(self, other: object) -> bool:
136         if not isinstance(other, self.__class__):
137             return False
138         to_hash_me = tuple([self.id_] +
139                            [getattr(self, name) for name in self.to_save])
140         to_hash_other = tuple([other.id_] +
141                               [getattr(other, name) for name in other.to_save])
142         return hash(to_hash_me) == hash(to_hash_other)
143
144     def __lt__(self, other: Any) -> bool:
145         if not isinstance(other, self.__class__):
146             msg = 'cannot compare to object of different class'
147             raise HandledException(msg)
148         assert isinstance(self.id_, int)
149         assert isinstance(other.id_, int)
150         return self.id_ < other.id_
151
152     @classmethod
153     def get_cached(cls: type[BaseModelInstance],
154                    id_: BaseModelId) -> BaseModelInstance | None:
155         """Get object of id_ from class's cache, or None if not found."""
156         # pylint: disable=consider-iterating-dictionary
157         cache = cls.get_cache()
158         if id_ in cache.keys():
159             obj = cache[id_]
160             assert isinstance(obj, cls)
161             return obj
162         return None
163
164     @classmethod
165     def empty_cache(cls) -> None:
166         """Empty class's cache."""
167         cls.cache_ = {}
168
169     @classmethod
170     def get_cache(cls: type[BaseModelInstance]) -> dict[Any, BaseModel[Any]]:
171         """Get cache dictionary, create it if not yet existing."""
172         if not hasattr(cls, 'cache_'):
173             d: dict[Any, BaseModel[Any]] = {}
174             cls.cache_ = d
175         return cls.cache_
176
177     def cache(self) -> None:
178         """Update object in class's cache."""
179         if self.id_ is None:
180             raise HandledException('Cannot cache object without ID.')
181         cache = self.__class__.get_cache()
182         cache[self.id_] = self
183
184     def uncache(self) -> None:
185         """Remove self from cache."""
186         if self.id_ is None:
187             raise HandledException('Cannot un-cache object without ID.')
188         cache = self.__class__.get_cache()
189         del cache[self.id_]
190
191     @classmethod
192     def from_table_row(cls: type[BaseModelInstance],
193                        # pylint: disable=unused-argument
194                        db_conn: DatabaseConnection,
195                        row: Row | list[Any]) -> BaseModelInstance:
196         """Make from DB row, write to DB cache."""
197         obj = cls(*row)
198         obj.cache()
199         return obj
200
201     @classmethod
202     def by_id(cls, db_conn: DatabaseConnection,
203               id_: BaseModelId | None,
204               # pylint: disable=unused-argument
205               create: bool = False) -> Self:
206         """Retrieve by id_, on failure throw NotFoundException.
207
208         First try to get from cls.cache_, only then check DB; if found,
209         put into cache.
210
211         If create=True, make anew (but do not cache yet).
212         """
213         obj = None
214         if id_ is not None:
215             obj = cls.get_cached(id_)
216             if not obj:
217                 for row in db_conn.row_where(cls.table_name, 'id', id_):
218                     obj = cls.from_table_row(db_conn, row)
219                     obj.cache()
220                     break
221         if obj:
222             return obj
223         if create:
224             obj = cls(id_)
225             return obj
226         raise NotFoundException(f'found no object of ID {id_}')
227
228     @classmethod
229     def all(cls: type[BaseModelInstance],
230             db_conn: DatabaseConnection) -> list[BaseModelInstance]:
231         """Collect all objects of class into list.
232
233         Note that this primarily returns the contents of the cache, and only
234         _expands_ that by additional findings in the DB. This assumes the
235         cache is always instantly cleaned of any items that would be removed
236         from the DB.
237         """
238         items: dict[BaseModelId, BaseModelInstance] = {}
239         for k, v in cls.get_cache().items():
240             assert isinstance(v, cls)
241             items[k] = v
242         already_recorded = items.keys()
243         for id_ in db_conn.column_all(cls.table_name, 'id'):
244             if id_ not in already_recorded:
245                 item = cls.by_id(db_conn, id_)
246                 assert item.id_ is not None
247                 items[item.id_] = item
248         return list(items.values())
249
250     def save(self, db_conn: DatabaseConnection) -> None:
251         """Write self to DB and cache and ensure .id_.
252
253         Write both to DB, and to cache. To DB, write .id_ and attributes
254         listed in cls.to_save[_versioned|_relations].
255
256         Ensure self.id_ by setting it to what the DB command returns as the
257         last saved row's ID (cursor.lastrowid), EXCEPT if self.id_ already
258         exists as a 'str', which implies we do our own ID creation (so far
259         only the case with the Day class, where it's to be a date string.
260         """
261         values = tuple([self.id_] + [getattr(self, key)
262                                      for key in self.to_save])
263         q_marks = DatabaseConnection.q_marks_from_values(values)
264         table_name = self.table_name
265         cursor = db_conn.exec(f'REPLACE INTO {table_name} VALUES {q_marks}',
266                               values)
267         if not isinstance(self.id_, str):
268             self.id_ = cursor.lastrowid  # type: ignore[assignment]
269         self.cache()
270         for attr_name in self.to_save_versioned:
271             getattr(self, attr_name).save(db_conn)
272         for table, column, attr_name in self.to_save_relations:
273             assert isinstance(self.id_, (int, str))
274             db_conn.rewrite_relations(table, column, self.id_,
275                                       [[i.id_] for i
276                                        in getattr(self, attr_name)])
277
278     def remove(self, db_conn: DatabaseConnection) -> None:
279         """Remove from DB and cache, including dependencies."""
280         if self.id_ is None or self.__class__.get_cached(self.id_) is None:
281             raise HandledException('cannot remove unsaved item')
282         for attr_name in self.to_save_versioned:
283             getattr(self, attr_name).remove(db_conn)
284         for table, column, attr_name in self.to_save_relations:
285             db_conn.delete_where(table, column, self.id_)
286         self.uncache()
287         db_conn.delete_where(self.table_name, 'id', self.id_)