home · contact · privacy
3a661d37758f579a9cac09dc90b86f07f0572e04
[plomtask] / plomtask / db.py
1 """Database management."""
2 from __future__ import annotations
3 from os.path import isfile
4 from difflib import Differ
5 from sqlite3 import connect as sql_connect, Cursor, Row
6 from typing import Any, Self, TypeVar, Generic
7 from plomtask.exceptions import HandledException, NotFoundException
8
9 PATH_DB_SCHEMA = 'scripts/init.sql'
10 EXPECTED_DB_VERSION = 0
11
12
13 class DatabaseFile:  # pylint: disable=too-few-public-methods
14     """Represents the sqlite3 database's file."""
15
16     def __init__(self, path: str) -> None:
17         self.path = path
18         self._check()
19
20     def remake(self) -> None:
21         """Create tables in self.path file as per PATH_DB_SCHEMA sql file."""
22         with sql_connect(self.path) as conn:
23             with open(PATH_DB_SCHEMA, 'r', encoding='utf-8') as f:
24                 conn.executescript(f.read())
25         self._check()
26
27     def _check(self) -> None:
28         """Check file exists, and is of proper DB version and schema."""
29         self.exists = isfile(self.path)
30         if self.exists:
31             self._validate_user_version()
32             self._validate_schema()
33
34     def _validate_user_version(self) -> None:
35         """Compare DB user_version with EXPECTED_DB_VERSION."""
36         sql_for_db_version = 'PRAGMA user_version'
37         with sql_connect(self.path) as conn:
38             db_version = list(conn.execute(sql_for_db_version))[0][0]
39             if db_version != EXPECTED_DB_VERSION:
40                 msg = f'Wrong DB version, expected '\
41                         f'{EXPECTED_DB_VERSION}, got {db_version}.'
42                 raise HandledException(msg)
43
44     def _validate_schema(self) -> None:
45         """Compare found schema with what's stored at PATH_DB_SCHEMA."""
46         sql_for_schema = 'SELECT sql FROM sqlite_master ORDER BY sql'
47         msg_err = 'Database has wrong tables schema. Diff:\n'
48         with sql_connect(self.path) as conn:
49             schema_rows = [r[0] for r in conn.execute(sql_for_schema) if r[0]]
50             retrieved_schema = ';\n'.join(schema_rows) + ';'
51             with open(PATH_DB_SCHEMA, 'r', encoding='utf-8') as f:
52                 stored_schema = f.read().rstrip()
53                 if stored_schema != retrieved_schema:
54                     diff_msg = Differ().compare(retrieved_schema.splitlines(),
55                                                 stored_schema.splitlines())
56                     raise HandledException(msg_err + '\n'.join(diff_msg))
57
58
59 class DatabaseConnection:
60     """A single connection to the database."""
61
62     def __init__(self, db_file: DatabaseFile) -> None:
63         self.file = db_file
64         self.conn = sql_connect(self.file.path)
65
66     def commit(self) -> None:
67         """Commit SQL transaction."""
68         self.conn.commit()
69
70     def exec(self, code: str, inputs: tuple[Any, ...] = tuple()) -> Cursor:
71         """Add commands to SQL transaction."""
72         return self.conn.execute(code, inputs)
73
74     def close(self) -> None:
75         """Close DB connection."""
76         self.conn.close()
77
78     def rewrite_relations(self, table_name: str, key: str, target: int,
79                           rows: list[list[Any]]) -> None:
80         """Rewrite relations in table_name to target, with rows values."""
81         self.delete_where(table_name, key, target)
82         for row in rows:
83             values = tuple([target] + row)
84             q_marks = self.__class__.q_marks_from_values(values)
85             self.exec(f'INSERT INTO {table_name} VALUES {q_marks}', values)
86
87     def row_where(self, table_name: str, key: str,
88                   target: int | str) -> list[Row]:
89         """Return list of Rows at table where key == target."""
90         return list(self.exec(f'SELECT * FROM {table_name} WHERE {key} = ?',
91                               (target,)))
92
93     def column_where(self, table_name: str, column: str, key: str,
94                      target: int | str) -> list[Any]:
95         """Return column of table where key == target."""
96         return [row[0] for row in
97                 self.exec(f'SELECT {column} FROM {table_name} '
98                           f'WHERE {key} = ?', (target,))]
99
100     def column_all(self, table_name: str, column: str) -> list[Any]:
101         """Return complete column of table."""
102         return [row[0] for row in
103                 self.exec(f'SELECT {column} FROM {table_name}')]
104
105     def delete_where(self, table_name: str, key: str,
106                      target: int | str) -> None:
107         """Delete from table where key == target."""
108         self.exec(f'DELETE FROM {table_name} WHERE {key} = ?', (target,))
109
110     @staticmethod
111     def q_marks_from_values(values: tuple[Any]) -> str:
112         """Return placeholder to insert values into SQL code."""
113         return '(' + ','.join(['?'] * len(values)) + ')'
114
115
116 BaseModelId = TypeVar('BaseModelId', int, str)
117 BaseModelInstance = TypeVar('BaseModelInstance', bound='BaseModel[Any]')
118
119
120 class BaseModel(Generic[BaseModelId]):
121     """Template for most of the models we use/derive from the DB."""
122     table_name = ''
123     to_save: list[str] = []
124     id_: None | BaseModelId
125     cache_: dict[BaseModelId, Self]
126
127     def __init__(self, id_: BaseModelId | None) -> None:
128         if isinstance(id_, int) and id_ < 1:
129             msg = f'illegal {self.__class__.__name__} ID, must be >=1: {id_}'
130             raise HandledException(msg)
131         self.id_ = id_
132
133     @classmethod
134     def get_cached(cls: type[BaseModelInstance],
135                    id_: BaseModelId) -> BaseModelInstance | None:
136         """Get object of id_ from class's cache, or None if not found."""
137         # pylint: disable=consider-iterating-dictionary
138         cache = cls.get_cache()
139         if id_ in cache.keys():
140             obj = cache[id_]
141             assert isinstance(obj, cls)
142             return obj
143         return None
144
145     @classmethod
146     def empty_cache(cls) -> None:
147         """Empty class's cache."""
148         cls.cache_ = {}
149
150     @classmethod
151     def get_cache(cls: type[BaseModelInstance]) -> dict[Any, BaseModel[Any]]:
152         """Get cache dictionary, create it if not yet existing."""
153         if not hasattr(cls, 'cache_'):
154             d: dict[Any, BaseModel[Any]] = {}
155             cls.cache_ = d
156         return cls.cache_
157
158     def cache(self) -> None:
159         """Update object in class's cache."""
160         if self.id_ is None:
161             raise HandledException('Cannot cache object without ID.')
162         cache = self.__class__.get_cache()
163         cache[self.id_] = self
164
165     def uncache(self) -> None:
166         """Remove self from cache."""
167         if self.id_ is None:
168             raise HandledException('Cannot un-cache object without ID.')
169         cache = self.__class__.get_cache()
170         del cache[self.id_]
171
172     @classmethod
173     def from_table_row(cls: type[BaseModelInstance],
174                        # pylint: disable=unused-argument
175                        db_conn: DatabaseConnection,
176                        row: Row | list[Any]) -> BaseModelInstance:
177         """Make from DB row, write to DB cache."""
178         obj = cls(*row)
179         obj.cache()
180         return obj
181
182     @classmethod
183     def by_id(cls, db_conn: DatabaseConnection,
184               id_: BaseModelId | None,
185               # pylint: disable=unused-argument
186               create: bool = False) -> Self:
187         """Retrieve by id_, on failure throw NotFoundException.
188
189         First try to get from cls.cache_, only then check DB; if found,
190         put into cache.
191
192         If create=True, make anew (but do not cache yet).
193         """
194         obj = None
195         if id_ is not None:
196             obj = cls.get_cached(id_)
197             if not obj:
198                 for row in db_conn.row_where(cls.table_name, 'id', id_):
199                     obj = cls.from_table_row(db_conn, row)
200                     obj.cache()
201                     break
202         if obj:
203             return obj
204         if create:
205             obj = cls(id_)
206             return obj
207         raise NotFoundException(f'found no object of ID {id_}')
208
209     @classmethod
210     def all(cls: type[BaseModelInstance],
211             db_conn: DatabaseConnection) -> list[BaseModelInstance]:
212         """Collect all objects of class into list.
213
214         Note that this primarily returns the contents of the cache, and only
215         _expands_ that by additional findings in the DB. This assumes the
216         cache is always instantly cleaned of any items that would be removed
217         from the DB.
218         """
219         items: dict[BaseModelId, BaseModelInstance] = {}
220         for k, v in cls.get_cache().items():
221             assert isinstance(v, cls)
222             items[k] = v
223         already_recorded = items.keys()
224         for id_ in db_conn.column_all(cls.table_name, 'id'):
225             if id_ not in already_recorded:
226                 item = cls.by_id(db_conn, id_)
227                 assert item.id_ is not None
228                 items[item.id_] = item
229         return list(items.values())
230
231     def __eq__(self, other: object) -> bool:
232         if not isinstance(other, self.__class__):
233             msg = 'cannot compare to object of different class'
234             raise HandledException(msg)
235         to_hash_me = tuple([self.id_] +
236                            [getattr(self, name) for name in self.to_save])
237         to_hash_other = tuple([other.id_] +
238                               [getattr(other, name) for name in other.to_save])
239         return hash(to_hash_me) == hash(to_hash_other)
240
241     def save_core(self, db_conn: DatabaseConnection) -> None:
242         """Write bare-bones self (sans connected items), ensuring self.id_.
243
244         Write both to DB, and to cache. To DB, write .id_ and attributes
245         listed in cls.to_save.
246
247         Ensure self.id_ by setting it to what the DB command returns as the
248         last saved row's ID (cursor.lastrowid), EXCEPT if self.id_ already
249         exists as a 'str', which implies we do our own ID creation (so far
250         only the case with the Day class, where it's to be a date string.
251         """
252         values = tuple([self.id_] + [getattr(self, key)
253                                      for key in self.to_save])
254         q_marks = DatabaseConnection.q_marks_from_values(values)
255         table_name = self.table_name
256         cursor = db_conn.exec(f'REPLACE INTO {table_name} VALUES {q_marks}',
257                               values)
258         if not isinstance(self.id_, str):
259             self.id_ = cursor.lastrowid  # type: ignore[assignment]
260         self.cache()
261
262     def remove(self, db_conn: DatabaseConnection) -> None:
263         """Remove from DB and cache."""
264         assert isinstance(self.id_, int | str)
265         self.uncache()
266         db_conn.delete_where(self.table_name, 'id', self.id_)