home · contact · privacy
Add text-based search/filter for Conditions and Processes.
[plomtask] / plomtask / db.py
1 """Database management."""
2 from __future__ import annotations
3 from os import listdir
4 from os.path import isfile
5 from difflib import Differ
6 from sqlite3 import connect as sql_connect, Cursor, Row
7 from typing import Any, Self, TypeVar, Generic
8 from plomtask.exceptions import HandledException, NotFoundException
9
10 EXPECTED_DB_VERSION = 4
11 MIGRATIONS_DIR = 'migrations'
12 FILENAME_DB_SCHEMA = f'init_{EXPECTED_DB_VERSION}.sql'
13 PATH_DB_SCHEMA = f'{MIGRATIONS_DIR}/{FILENAME_DB_SCHEMA}'
14
15
16 class UnmigratedDbException(HandledException):
17     """To identify case of unmigrated DB file."""
18
19
20 class DatabaseFile:  # pylint: disable=too-few-public-methods
21     """Represents the sqlite3 database's file."""
22
23     def __init__(self, path: str) -> None:
24         self.path = path
25         self._check()
26
27     @classmethod
28     def create_at(cls, path: str) -> DatabaseFile:
29         """Make new DB file at path."""
30         with sql_connect(path) as conn:
31             with open(PATH_DB_SCHEMA, 'r', encoding='utf-8') as f:
32                 conn.executescript(f.read())
33             conn.execute(f'PRAGMA user_version = {EXPECTED_DB_VERSION}')
34         return cls(path)
35
36     @classmethod
37     def migrate(cls, path: str) -> DatabaseFile:
38         """Apply migrations from_version to EXPECTED_DB_VERSION."""
39         migrations = cls._available_migrations()
40         from_version = cls.get_version_of_db(path)
41         migrations_todo = migrations[from_version+1:]
42         for j, filename in enumerate(migrations_todo):
43             with sql_connect(path) as conn:
44                 with open(f'{MIGRATIONS_DIR}/{filename}', 'r',
45                           encoding='utf-8') as f:
46                     conn.executescript(f.read())
47             user_version = from_version + j + 1
48             with sql_connect(path) as conn:
49                 conn.execute(f'PRAGMA user_version = {user_version}')
50         return cls(path)
51
52     def _check(self) -> None:
53         """Check file exists, and is of proper DB version and schema."""
54         if not isfile(self.path):
55             raise NotFoundException
56         if self.user_version != EXPECTED_DB_VERSION:
57             raise UnmigratedDbException()
58         self._validate_schema()
59
60     @staticmethod
61     def _available_migrations() -> list[str]:
62         """Validate migrations directory and return sorted entries."""
63         msg_too_big = 'Migration directory points beyond expected DB version.'
64         msg_bad_entry = 'Migration directory contains unexpected entry: '
65         msg_missing = 'Migration directory misses migration of number: '
66         migrations = {}
67         for entry in listdir(MIGRATIONS_DIR):
68             if entry == FILENAME_DB_SCHEMA:
69                 continue
70             toks = entry.split('_', 1)
71             if len(toks) < 2:
72                 raise HandledException(msg_bad_entry + entry)
73             try:
74                 i = int(toks[0])
75             except ValueError as e:
76                 raise HandledException(msg_bad_entry + entry) from e
77             if i > EXPECTED_DB_VERSION:
78                 raise HandledException(msg_too_big)
79             migrations[i] = toks[1]
80         migrations_list = []
81         for i in range(EXPECTED_DB_VERSION + 1):
82             if i not in migrations:
83                 raise HandledException(msg_missing + str(i))
84             migrations_list += [f'{i}_{migrations[i]}']
85         return migrations_list
86
87     @staticmethod
88     def get_version_of_db(path: str) -> int:
89         """Get DB user_version, fail if outside expected range."""
90         sql_for_db_version = 'PRAGMA user_version'
91         with sql_connect(path) as conn:
92             db_version = list(conn.execute(sql_for_db_version))[0][0]
93         if db_version > EXPECTED_DB_VERSION:
94             msg = f'Wrong DB version, expected '\
95                     f'{EXPECTED_DB_VERSION}, got unknown {db_version}.'
96             raise HandledException(msg)
97         assert isinstance(db_version, int)
98         return db_version
99
100     @property
101     def user_version(self) -> int:
102         """Get DB user_version."""
103         return self.__class__.get_version_of_db(self.path)
104
105     def _validate_schema(self) -> None:
106         """Compare found schema with what's stored at PATH_DB_SCHEMA."""
107
108         def reformat_rows(rows: list[str]) -> list[str]:
109             new_rows = []
110             for row in rows:
111                 new_row = []
112                 for subrow in row.split('\n'):
113                     subrow = subrow.rstrip()
114                     in_parentheses = 0
115                     split_at = []
116                     for i, c in enumerate(subrow):
117                         if '(' == c:
118                             in_parentheses += 1
119                         elif ')' == c:
120                             in_parentheses -= 1
121                         elif ',' == c and 0 == in_parentheses:
122                             split_at += [i + 1]
123                     prev_split = 0
124                     for i in split_at:
125                         segment = subrow[prev_split:i].strip()
126                         if len(segment) > 0:
127                             new_row += [f'    {segment}']
128                         prev_split = i
129                     segment = subrow[prev_split:].strip()
130                     if len(segment) > 0:
131                         new_row += [f'    {segment}']
132                 new_row[0] = new_row[0].lstrip()
133                 new_row[-1] = new_row[-1].lstrip()
134                 if new_row[-1] != ')' and new_row[-3][-1] != ',':
135                     new_row[-3] = new_row[-3] + ','
136                     new_row[-2:] = ['    ' + new_row[-1][:-1]] + [')']
137                 new_rows += ['\n'.join(new_row)]
138             return new_rows
139
140         sql_for_schema = 'SELECT sql FROM sqlite_master ORDER BY sql'
141         msg_err = 'Database has wrong tables schema. Diff:\n'
142         with sql_connect(self.path) as conn:
143             schema_rows = [r[0] for r in conn.execute(sql_for_schema) if r[0]]
144         schema_rows = reformat_rows(schema_rows)
145         retrieved_schema = ';\n'.join(schema_rows) + ';'
146         with open(PATH_DB_SCHEMA, 'r', encoding='utf-8') as f:
147             stored_schema = f.read().rstrip()
148         if stored_schema != retrieved_schema:
149             diff_msg = Differ().compare(retrieved_schema.splitlines(),
150                                         stored_schema.splitlines())
151             raise HandledException(msg_err + '\n'.join(diff_msg))
152
153
154 class DatabaseConnection:
155     """A single connection to the database."""
156
157     def __init__(self, db_file: DatabaseFile) -> None:
158         self.file = db_file
159         self.conn = sql_connect(self.file.path)
160
161     def commit(self) -> None:
162         """Commit SQL transaction."""
163         self.conn.commit()
164
165     def exec(self, code: str, inputs: tuple[Any, ...] = tuple()) -> Cursor:
166         """Add commands to SQL transaction."""
167         return self.conn.execute(code, inputs)
168
169     def close(self) -> None:
170         """Close DB connection."""
171         self.conn.close()
172
173     def rewrite_relations(self, table_name: str, key: str, target: int | str,
174                           rows: list[list[Any]]) -> None:
175         """Rewrite relations in table_name to target, with rows values."""
176         self.delete_where(table_name, key, target)
177         for row in rows:
178             values = tuple([target] + row)
179             q_marks = self.__class__.q_marks_from_values(values)
180             self.exec(f'INSERT INTO {table_name} VALUES {q_marks}', values)
181
182     def row_where(self, table_name: str, key: str,
183                   target: int | str) -> list[Row]:
184         """Return list of Rows at table where key == target."""
185         return list(self.exec(f'SELECT * FROM {table_name} WHERE {key} = ?',
186                               (target,)))
187
188     # def column_where_pattern(self,
189     #                          table_name: str,
190     #                          column: str,
191     #                          pattern: str,
192     #                          keys: list[str]) -> list[Any]:
193     #     """Return column of rows where one of keys matches pattern."""
194     #     targets = tuple([f'%{pattern}%'] * len(keys))
195     #     haystack = ' OR '.join([f'{k} LIKE ?' for k in keys])
196     #     sql = f'SELECT {column} FROM {table_name} WHERE {haystack}'
197     #     return [row[0] for row in self.exec(sql, targets)]
198
199     def column_where(self, table_name: str, column: str, key: str,
200                      target: int | str) -> list[Any]:
201         """Return column of table where key == target."""
202         return [row[0] for row in
203                 self.exec(f'SELECT {column} FROM {table_name} '
204                           f'WHERE {key} = ?', (target,))]
205
206     def column_all(self, table_name: str, column: str) -> list[Any]:
207         """Return complete column of table."""
208         return [row[0] for row in
209                 self.exec(f'SELECT {column} FROM {table_name}')]
210
211     def delete_where(self, table_name: str, key: str,
212                      target: int | str) -> None:
213         """Delete from table where key == target."""
214         self.exec(f'DELETE FROM {table_name} WHERE {key} = ?', (target,))
215
216     @staticmethod
217     def q_marks_from_values(values: tuple[Any]) -> str:
218         """Return placeholder to insert values into SQL code."""
219         return '(' + ','.join(['?'] * len(values)) + ')'
220
221
222 BaseModelId = TypeVar('BaseModelId', int, str)
223 BaseModelInstance = TypeVar('BaseModelInstance', bound='BaseModel[Any]')
224
225
226 class BaseModel(Generic[BaseModelId]):
227     """Template for most of the models we use/derive from the DB."""
228     table_name = ''
229     to_save: list[str] = []
230     to_save_versioned: list[str] = []
231     to_save_relations: list[tuple[str, str, str]] = []
232     id_: None | BaseModelId
233     cache_: dict[BaseModelId, Self]
234     to_search: list[str] = []
235
236     def __init__(self, id_: BaseModelId | None) -> None:
237         if isinstance(id_, int) and id_ < 1:
238             msg = f'illegal {self.__class__.__name__} ID, must be >=1: {id_}'
239             raise HandledException(msg)
240         self.id_ = id_
241
242     def __eq__(self, other: object) -> bool:
243         if not isinstance(other, self.__class__):
244             return False
245         to_hash_me = tuple([self.id_] +
246                            [getattr(self, name) for name in self.to_save])
247         to_hash_other = tuple([other.id_] +
248                               [getattr(other, name) for name in other.to_save])
249         return hash(to_hash_me) == hash(to_hash_other)
250
251     def __lt__(self, other: Any) -> bool:
252         if not isinstance(other, self.__class__):
253             msg = 'cannot compare to object of different class'
254             raise HandledException(msg)
255         assert isinstance(self.id_, int)
256         assert isinstance(other.id_, int)
257         return self.id_ < other.id_
258
259     @classmethod
260     def get_cached(cls: type[BaseModelInstance],
261                    id_: BaseModelId) -> BaseModelInstance | None:
262         """Get object of id_ from class's cache, or None if not found."""
263         # pylint: disable=consider-iterating-dictionary
264         cache = cls.get_cache()
265         if id_ in cache.keys():
266             obj = cache[id_]
267             assert isinstance(obj, cls)
268             return obj
269         return None
270
271     @classmethod
272     def empty_cache(cls) -> None:
273         """Empty class's cache."""
274         cls.cache_ = {}
275
276     @classmethod
277     def get_cache(cls: type[BaseModelInstance]) -> dict[Any, BaseModel[Any]]:
278         """Get cache dictionary, create it if not yet existing."""
279         if not hasattr(cls, 'cache_'):
280             d: dict[Any, BaseModel[Any]] = {}
281             cls.cache_ = d
282         return cls.cache_
283
284     def cache(self) -> None:
285         """Update object in class's cache."""
286         if self.id_ is None:
287             raise HandledException('Cannot cache object without ID.')
288         cache = self.__class__.get_cache()
289         cache[self.id_] = self
290
291     def uncache(self) -> None:
292         """Remove self from cache."""
293         if self.id_ is None:
294             raise HandledException('Cannot un-cache object without ID.')
295         cache = self.__class__.get_cache()
296         del cache[self.id_]
297
298     @classmethod
299     def from_table_row(cls: type[BaseModelInstance],
300                        # pylint: disable=unused-argument
301                        db_conn: DatabaseConnection,
302                        row: Row | list[Any]) -> BaseModelInstance:
303         """Make from DB row, write to DB cache."""
304         obj = cls(*row)
305         obj.cache()
306         return obj
307
308     @classmethod
309     def by_id(cls, db_conn: DatabaseConnection,
310               id_: BaseModelId | None,
311               # pylint: disable=unused-argument
312               create: bool = False) -> Self:
313         """Retrieve by id_, on failure throw NotFoundException.
314
315         First try to get from cls.cache_, only then check DB; if found,
316         put into cache.
317
318         If create=True, make anew (but do not cache yet).
319         """
320         obj = None
321         if id_ is not None:
322             obj = cls.get_cached(id_)
323             if not obj:
324                 for row in db_conn.row_where(cls.table_name, 'id', id_):
325                     obj = cls.from_table_row(db_conn, row)
326                     obj.cache()
327                     break
328         if obj:
329             return obj
330         if create:
331             obj = cls(id_)
332             return obj
333         raise NotFoundException(f'found no object of ID {id_}')
334
335     @classmethod
336     def all(cls: type[BaseModelInstance],
337             db_conn: DatabaseConnection) -> list[BaseModelInstance]:
338         """Collect all objects of class into list.
339
340         Note that this primarily returns the contents of the cache, and only
341         _expands_ that by additional findings in the DB. This assumes the
342         cache is always instantly cleaned of any items that would be removed
343         from the DB.
344         """
345         items: dict[BaseModelId, BaseModelInstance] = {}
346         for k, v in cls.get_cache().items():
347             assert isinstance(v, cls)
348             items[k] = v
349         already_recorded = items.keys()
350         for id_ in db_conn.column_all(cls.table_name, 'id'):
351             if id_ not in already_recorded:
352                 item = cls.by_id(db_conn, id_)
353                 assert item.id_ is not None
354                 items[item.id_] = item
355         return list(items.values())
356
357     @classmethod
358     def matching(cls: type[BaseModelInstance], db_conn: DatabaseConnection,
359                  pattern: str) -> list[BaseModelInstance]:
360         """Return all objects whose .to_search match pattern."""
361         items = cls.all(db_conn)
362         if pattern:
363             filtered = []
364             for item in items:
365                 for attr_name in cls.to_search:
366                     toks = attr_name.split('.')
367                     parent = item
368                     for tok in toks:
369                         attr = getattr(parent, tok)
370                         parent = attr
371                     if pattern in attr:
372                         filtered += [item]
373                         break
374             return filtered
375         return items
376
377     def save(self, db_conn: DatabaseConnection) -> None:
378         """Write self to DB and cache and ensure .id_.
379
380         Write both to DB, and to cache. To DB, write .id_ and attributes
381         listed in cls.to_save[_versioned|_relations].
382
383         Ensure self.id_ by setting it to what the DB command returns as the
384         last saved row's ID (cursor.lastrowid), EXCEPT if self.id_ already
385         exists as a 'str', which implies we do our own ID creation (so far
386         only the case with the Day class, where it's to be a date string.
387         """
388         values = tuple([self.id_] + [getattr(self, key)
389                                      for key in self.to_save])
390         q_marks = DatabaseConnection.q_marks_from_values(values)
391         table_name = self.table_name
392         cursor = db_conn.exec(f'REPLACE INTO {table_name} VALUES {q_marks}',
393                               values)
394         if not isinstance(self.id_, str):
395             self.id_ = cursor.lastrowid  # type: ignore[assignment]
396         self.cache()
397         for attr_name in self.to_save_versioned:
398             getattr(self, attr_name).save(db_conn)
399         for table, column, attr_name in self.to_save_relations:
400             assert isinstance(self.id_, (int, str))
401             db_conn.rewrite_relations(table, column, self.id_,
402                                       [[i.id_] for i
403                                        in getattr(self, attr_name)])
404
405     def remove(self, db_conn: DatabaseConnection) -> None:
406         """Remove from DB and cache, including dependencies."""
407         if self.id_ is None or self.__class__.get_cached(self.id_) is None:
408             raise HandledException('cannot remove unsaved item')
409         for attr_name in self.to_save_versioned:
410             getattr(self, attr_name).remove(db_conn)
411         for table, column, attr_name in self.to_save_relations:
412             db_conn.delete_where(table, column, self.id_)
413         self.uncache()
414         db_conn.delete_where(self.table_name, 'id', self.id_)