1 """Database management."""
2 from __future__ import annotations
4 from os.path import isfile
5 from difflib import Differ
6 from sqlite3 import connect as sql_connect, Cursor, Row
7 from typing import Any, Self, TypeVar, Generic, Callable
8 from plomtask.exceptions import (HandledException, NotFoundException,
10 from plomtask.dating import valid_date
12 EXPECTED_DB_VERSION = 5
13 MIGRATIONS_DIR = 'migrations'
14 FILENAME_DB_SCHEMA = f'init_{EXPECTED_DB_VERSION}.sql'
15 PATH_DB_SCHEMA = f'{MIGRATIONS_DIR}/{FILENAME_DB_SCHEMA}'
18 class UnmigratedDbException(HandledException):
19 """To identify case of unmigrated DB file."""
23 """Represents the sqlite3 database's file."""
24 # pylint: disable=too-few-public-methods
26 def __init__(self, path: str) -> None:
31 def create_at(cls, path: str) -> DatabaseFile:
32 """Make new DB file at path."""
33 with sql_connect(path) as conn:
34 with open(PATH_DB_SCHEMA, 'r', encoding='utf-8') as f:
35 conn.executescript(f.read())
36 conn.execute(f'PRAGMA user_version = {EXPECTED_DB_VERSION}')
40 def migrate(cls, path: str) -> DatabaseFile:
41 """Apply migrations from_version to EXPECTED_DB_VERSION."""
42 migrations = cls._available_migrations()
43 from_version = cls._get_version_of_db(path)
44 migrations_todo = migrations[from_version+1:]
45 for j, filename in enumerate(migrations_todo):
46 with sql_connect(path) as conn:
47 with open(f'{MIGRATIONS_DIR}/{filename}', 'r',
48 encoding='utf-8') as f:
49 conn.executescript(f.read())
50 user_version = from_version + j + 1
51 with sql_connect(path) as conn:
52 conn.execute(f'PRAGMA user_version = {user_version}')
55 def _check(self) -> None:
56 """Check file exists, and is of proper DB version and schema."""
57 if not isfile(self.path):
58 raise NotFoundException
59 if self._user_version != EXPECTED_DB_VERSION:
60 raise UnmigratedDbException()
61 self._validate_schema()
64 def _available_migrations() -> list[str]:
65 """Validate migrations directory and return sorted entries."""
66 msg_too_big = 'Migration directory points beyond expected DB version.'
67 msg_bad_entry = 'Migration directory contains unexpected entry: '
68 msg_missing = 'Migration directory misses migration of number: '
70 for entry in listdir(MIGRATIONS_DIR):
71 if entry == FILENAME_DB_SCHEMA:
73 toks = entry.split('_', 1)
75 raise HandledException(msg_bad_entry + entry)
78 except ValueError as e:
79 raise HandledException(msg_bad_entry + entry) from e
80 if i > EXPECTED_DB_VERSION:
81 raise HandledException(msg_too_big)
82 migrations[i] = toks[1]
84 for i in range(EXPECTED_DB_VERSION + 1):
85 if i not in migrations:
86 raise HandledException(msg_missing + str(i))
87 migrations_list += [f'{i}_{migrations[i]}']
88 return migrations_list
91 def _get_version_of_db(path: str) -> int:
92 """Get DB user_version, fail if outside expected range."""
93 sql_for_db_version = 'PRAGMA user_version'
94 with sql_connect(path) as conn:
95 db_version = list(conn.execute(sql_for_db_version))[0][0]
96 if db_version > EXPECTED_DB_VERSION:
97 msg = f'Wrong DB version, expected '\
98 f'{EXPECTED_DB_VERSION}, got unknown {db_version}.'
99 raise HandledException(msg)
100 assert isinstance(db_version, int)
104 def _user_version(self) -> int:
105 """Get DB user_version."""
106 return self._get_version_of_db(self.path)
108 def _validate_schema(self) -> None:
109 """Compare found schema with what's stored at PATH_DB_SCHEMA."""
111 def reformat_rows(rows: list[str]) -> list[str]:
115 for subrow in row.split('\n'):
116 subrow = subrow.rstrip()
119 for i, c in enumerate(subrow):
124 elif ',' == c and 0 == in_parentheses:
128 segment = subrow[prev_split:i].strip()
130 new_row += [f' {segment}']
132 segment = subrow[prev_split:].strip()
134 new_row += [f' {segment}']
135 new_row[0] = new_row[0].lstrip()
136 new_row[-1] = new_row[-1].lstrip()
137 if new_row[-1] != ')' and new_row[-3][-1] != ',':
138 new_row[-3] = new_row[-3] + ','
139 new_row[-2:] = [' ' + new_row[-1][:-1]] + [')']
140 new_rows += ['\n'.join(new_row)]
143 sql_for_schema = 'SELECT sql FROM sqlite_master ORDER BY sql'
144 msg_err = 'Database has wrong tables schema. Diff:\n'
145 with sql_connect(self.path) as conn:
146 schema_rows = [r[0] for r in conn.execute(sql_for_schema) if r[0]]
147 schema_rows = reformat_rows(schema_rows)
148 retrieved_schema = ';\n'.join(schema_rows) + ';'
149 with open(PATH_DB_SCHEMA, 'r', encoding='utf-8') as f:
150 stored_schema = f.read().rstrip()
151 if stored_schema != retrieved_schema:
152 diff_msg = Differ().compare(retrieved_schema.splitlines(),
153 stored_schema.splitlines())
154 raise HandledException(msg_err + '\n'.join(diff_msg))
157 class DatabaseConnection:
158 """A single connection to the database."""
160 def __init__(self, db_file: DatabaseFile) -> None:
161 self.conn = sql_connect(db_file.path)
163 def commit(self) -> None:
164 """Commit SQL transaction."""
167 def exec(self, code: str, inputs: tuple[Any, ...] = tuple()) -> Cursor:
168 """Add commands to SQL transaction."""
169 return self.conn.execute(code, inputs)
171 def exec_on_vals(self, code: str, inputs: tuple[Any, ...]) -> Cursor:
172 """Wrapper around .exec appending adequate " (?, …)" to code."""
173 q_marks_from_values = '(' + ','.join(['?'] * len(inputs)) + ')'
174 return self.exec(f'{code} {q_marks_from_values}', inputs)
176 def close(self) -> None:
177 """Close DB connection."""
180 def rewrite_relations(self, table_name: str, key: str, target: int | str,
181 rows: list[list[Any]], key_index: int = 0) -> None:
182 # pylint: disable=too-many-arguments
183 """Rewrite relations in table_name to target, with rows values.
185 Note that single rows are expected without the column and value
186 identified by key and target, which are inserted inside the function
189 self.delete_where(table_name, key, target)
191 values = tuple(row[:key_index] + [target] + row[key_index:])
192 self.exec_on_vals(f'INSERT INTO {table_name} VALUES', values)
194 def row_where(self, table_name: str, key: str,
195 target: int | str) -> list[Row]:
196 """Return list of Rows at table where key == target."""
197 return list(self.exec(f'SELECT * FROM {table_name} WHERE {key} = ?',
200 # def column_where_pattern(self,
204 # keys: list[str]) -> list[Any]:
205 # """Return column of rows where one of keys matches pattern."""
206 # targets = tuple([f'%{pattern}%'] * len(keys))
207 # haystack = ' OR '.join([f'{k} LIKE ?' for k in keys])
208 # sql = f'SELECT {column} FROM {table_name} WHERE {haystack}'
209 # return [row[0] for row in self.exec(sql, targets)]
211 def column_where(self, table_name: str, column: str, key: str,
212 target: int | str) -> list[Any]:
213 """Return column of table where key == target."""
214 return [row[0] for row in
215 self.exec(f'SELECT {column} FROM {table_name} '
216 f'WHERE {key} = ?', (target,))]
218 def column_all(self, table_name: str, column: str) -> list[Any]:
219 """Return complete column of table."""
220 return [row[0] for row in
221 self.exec(f'SELECT {column} FROM {table_name}')]
223 def delete_where(self, table_name: str, key: str,
224 target: int | str) -> None:
225 """Delete from table where key == target."""
226 self.exec(f'DELETE FROM {table_name} WHERE {key} = ?', (target,))
229 BaseModelId = TypeVar('BaseModelId', int, str)
230 BaseModelInstance = TypeVar('BaseModelInstance', bound='BaseModel[Any]')
233 class BaseModel(Generic[BaseModelId]):
234 """Template for most of the models we use/derive from the DB."""
236 to_save_simples: list[str] = []
237 to_save_relations: list[tuple[str, str, str, int]] = []
238 versioned_defaults: dict[str, str | float] = {}
239 add_to_dict: list[str] = []
240 id_: None | BaseModelId
241 cache_: dict[BaseModelId, Self]
242 to_search: list[str] = []
243 can_create_by_id = False
245 sorters: dict[str, Callable[..., Any]] = {}
247 def __init__(self, id_: BaseModelId | None) -> None:
248 if isinstance(id_, int) and id_ < 1:
249 msg = f'illegal {self.__class__.__name__} ID, must be >=1: {id_}'
250 raise BadFormatException(msg)
251 if isinstance(id_, str) and "" == id_:
252 msg = f'illegal {self.__class__.__name__} ID, must be non-empty'
253 raise BadFormatException(msg)
256 def __hash__(self) -> int:
257 hashable = [self.id_] + [getattr(self, name)
258 for name in self.to_save_simples]
259 for definition in self.to_save_relations:
260 attr = getattr(self, definition[2])
261 hashable += [tuple(rel.id_ for rel in attr)]
262 for name in self.to_save_versioned():
263 hashable += [hash(getattr(self, name))]
264 return hash(tuple(hashable))
266 def __eq__(self, other: object) -> bool:
267 if not isinstance(other, self.__class__):
269 return hash(self) == hash(other)
271 def __lt__(self, other: Any) -> bool:
272 if not isinstance(other, self.__class__):
273 msg = 'cannot compare to object of different class'
274 raise HandledException(msg)
275 assert isinstance(self.id_, int)
276 assert isinstance(other.id_, int)
277 return self.id_ < other.id_
280 def to_save_versioned(cls) -> list[str]:
281 """Return keys of cls.versioned_defaults assuming we wanna save 'em."""
282 return list(cls.versioned_defaults.keys())
285 def as_dict_and_refs(self) -> tuple[dict[str, object],
286 list[BaseModel[int] | BaseModel[str]]]:
287 """Return self as json.dumps-ready dict, list of referenced objects."""
288 d: dict[str, object] = {'id': self.id_}
289 refs: list[BaseModel[int] | BaseModel[str]] = []
290 for to_save in self.to_save_simples:
291 d[to_save] = getattr(self, to_save)
292 if len(self.to_save_versioned()) > 0:
294 for k in self.to_save_versioned():
295 attr = getattr(self, k)
296 assert isinstance(d['_versioned'], dict)
297 d['_versioned'][k] = attr.history
298 rels_to_collect = [rel[2] for rel in self.to_save_relations]
299 rels_to_collect += self.add_to_dict
300 for attr_name in rels_to_collect:
302 for item in getattr(self, attr_name):
303 rel_list += [item.id_]
306 d[attr_name] = rel_list
310 def name_lowercase(cls) -> str:
311 """Convenience method to return cls' name in lowercase."""
312 return cls.__name__.lower()
315 def sort_by(cls, seq: list[Any], sort_key: str, default: str = 'title'
317 """Sort cls list by cls.sorters[sort_key] (reverse if '-'-prefixed).
319 Before cls.sorters[sort_key] is applied, seq is sorted by .id_, to
320 ensure predictability where parts of seq are of same sort value.
323 if len(sort_key) > 1 and '-' == sort_key[0]:
324 sort_key = sort_key[1:]
326 if sort_key not in cls.sorters:
328 seq.sort(key=lambda x: x.id_, reverse=reverse)
329 sorter: Callable[..., Any] = cls.sorters[sort_key]
330 seq.sort(key=sorter, reverse=reverse)
332 sort_key = f'-{sort_key}'
336 # (we primarily use the cache to ensure we work on the same object in
337 # memory no matter where and how we retrieve it, e.g. we don't want
338 # .by_id() calls to create a new object each time, but rather a pointer
339 # to the one already instantiated)
341 def __getattribute__(self, name: str) -> Any:
342 """Ensure fail if ._disappear() was called, except to check ._exists"""
343 if name != '_exists' and not super().__getattribute__('_exists'):
344 msg = f'Object for attribute does not exist: {name}'
345 raise HandledException(msg)
346 return super().__getattribute__(name)
348 def _disappear(self) -> None:
349 """Invalidate object, make future use raise exceptions."""
350 assert self.id_ is not None
351 if self._get_cached(self.id_):
353 to_kill = list(self.__dict__.keys())
359 def empty_cache(cls) -> None:
360 """Empty class's cache, and disappear all former inhabitants."""
361 # pylint: disable=protected-access
362 # (cause we remain within the class)
363 if hasattr(cls, 'cache_'):
364 to_disappear = list(cls.cache_.values())
365 for item in to_disappear:
370 def get_cache(cls: type[BaseModelInstance]
371 ) -> dict[Any, BaseModelInstance]:
372 """Get cache dictionary, create it if not yet existing."""
373 if not hasattr(cls, 'cache_'):
374 d: dict[Any, BaseModelInstance] = {}
379 def _get_cached(cls: type[BaseModelInstance],
381 ) -> BaseModelInstance | None:
382 """Get object of id_ from class's cache, or None if not found."""
383 cache = cls.get_cache()
386 assert isinstance(obj, cls)
390 def cache(self) -> None:
391 """Update object in class's cache.
393 Also calls ._disappear if cache holds older reference to object of same
394 ID, but different memory address, to avoid doing anything with
398 raise HandledException('Cannot cache object without ID.')
399 cache = self.get_cache()
400 old_cached = self._get_cached(self.id_)
401 if old_cached and id(old_cached) != id(self):
402 # pylint: disable=protected-access
403 # (cause we remain within the class)
404 old_cached._disappear()
405 cache[self.id_] = self
407 def _uncache(self) -> None:
408 """Remove self from cache."""
410 raise HandledException('Cannot un-cache object without ID.')
411 cache = self.get_cache()
414 # object retrieval and generation
417 def from_table_row(cls: type[BaseModelInstance],
418 # pylint: disable=unused-argument
419 db_conn: DatabaseConnection,
420 row: Row | list[Any]) -> BaseModelInstance:
421 """Make from DB row (sans relations), update DB cache with it."""
423 assert obj.id_ is not None
424 for attr_name in cls.to_save_versioned():
425 attr = getattr(obj, attr_name)
426 table_name = attr.table_name
427 for row_ in db_conn.row_where(table_name, 'parent', obj.id_):
428 attr.history_from_row(row_)
433 def by_id(cls, db_conn: DatabaseConnection, id_: BaseModelId) -> Self:
434 """Retrieve by id_, on failure throw NotFoundException.
436 First try to get from cls.cache_, only then check DB; if found,
441 if isinstance(id_, int) and id_ == 0:
442 raise BadFormatException('illegal ID of value 0')
443 obj = cls._get_cached(id_)
445 for row in db_conn.row_where(cls.table_name, 'id', id_):
446 obj = cls.from_table_row(db_conn, row)
450 raise NotFoundException(f'found no object of ID {id_}')
453 def by_id_or_create(cls, db_conn: DatabaseConnection,
454 id_: BaseModelId | None
456 """Wrapper around .by_id, creating (not caching/saving) if no find."""
457 if not cls.can_create_by_id:
458 raise HandledException('Class cannot .by_id_or_create.')
462 return cls.by_id(db_conn, id_)
463 except NotFoundException:
467 def all(cls: type[BaseModelInstance],
468 db_conn: DatabaseConnection) -> list[BaseModelInstance]:
469 """Collect all objects of class into list.
471 Note that this primarily returns the contents of the cache, and only
472 _expands_ that by additional findings in the DB. This assumes the
473 cache is always instantly cleaned of any items that would be removed
476 items: dict[BaseModelId, BaseModelInstance] = {}
477 for k, v in cls.get_cache().items():
478 assert isinstance(v, cls)
480 already_recorded = items.keys()
481 for id_ in db_conn.column_all(cls.table_name, 'id'):
482 if id_ not in already_recorded:
483 item = cls.by_id(db_conn, id_)
484 assert item.id_ is not None
485 items[item.id_] = item
486 return sorted(list(items.values()))
489 def by_date_range_with_limits(cls: type[BaseModelInstance],
490 db_conn: DatabaseConnection,
491 date_range: tuple[str, str],
492 date_col: str = 'day'
493 ) -> tuple[list[BaseModelInstance], str,
495 """Return list of items in DB within (closed) date_range interval.
497 If no range values provided, defaults them to 'yesterday' and
498 'tomorrow'. Knows to properly interpret these and 'today' as value.
500 start_str = date_range[0] if date_range[0] else 'yesterday'
501 end_str = date_range[1] if date_range[1] else 'tomorrow'
502 start_date = valid_date(start_str)
503 end_date = valid_date(end_str)
505 sql = f'SELECT id FROM {cls.table_name} '
506 sql += f'WHERE {date_col} >= ? AND {date_col} <= ?'
507 for row in db_conn.exec(sql, (start_date, end_date)):
508 items += [cls.by_id(db_conn, row[0])]
509 return items, start_date, end_date
512 def matching(cls: type[BaseModelInstance], db_conn: DatabaseConnection,
513 pattern: str) -> list[BaseModelInstance]:
514 """Return all objects whose .to_search match pattern."""
515 items = cls.all(db_conn)
519 for attr_name in cls.to_search:
520 toks = attr_name.split('.')
523 attr = getattr(parent, tok)
533 def save(self, db_conn: DatabaseConnection) -> None:
534 """Write self to DB and cache and ensure .id_.
536 Write both to DB, and to cache. To DB, write .id_ and attributes
537 listed in cls.to_save_[simples|versioned|_relations].
539 Ensure self.id_ by setting it to what the DB command returns as the
540 last saved row's ID (cursor.lastrowid), EXCEPT if self.id_ already
541 exists as a 'str', which implies we do our own ID creation (so far
542 only the case with the Day class, where it's to be a date string.
544 values = tuple([self.id_] + [getattr(self, key)
545 for key in self.to_save_simples])
546 table_name = self.table_name
547 cursor = db_conn.exec_on_vals(f'REPLACE INTO {table_name} VALUES',
549 if not isinstance(self.id_, str):
550 self.id_ = cursor.lastrowid # type: ignore[assignment]
552 for attr_name in self.to_save_versioned():
553 getattr(self, attr_name).save(db_conn)
554 for table, column, attr_name, key_index in self.to_save_relations:
555 assert isinstance(self.id_, (int, str))
556 db_conn.rewrite_relations(table, column, self.id_,
558 in getattr(self, attr_name)], key_index)
560 def remove(self, db_conn: DatabaseConnection) -> None:
561 """Remove from DB and cache, including dependencies."""
562 if self.id_ is None or self._get_cached(self.id_) is None:
563 raise HandledException('cannot remove unsaved item')
564 for attr_name in self.to_save_versioned():
565 getattr(self, attr_name).remove(db_conn)
566 for table, column, attr_name, _ in self.to_save_relations:
567 db_conn.delete_where(table, column, self.id_)
569 db_conn.delete_where(self.table_name, 'id', self.id_)