1 """Database management."""
2 from __future__ import annotations
4 from os.path import isfile
5 from difflib import Differ
6 from sqlite3 import connect as sql_connect, Cursor, Row
7 from typing import Any, Self, TypeVar, Generic, Callable
8 from plomtask.exceptions import HandledException, NotFoundException
9 from plomtask.dating import valid_date
11 EXPECTED_DB_VERSION = 5
12 MIGRATIONS_DIR = 'migrations'
13 FILENAME_DB_SCHEMA = f'init_{EXPECTED_DB_VERSION}.sql'
14 PATH_DB_SCHEMA = f'{MIGRATIONS_DIR}/{FILENAME_DB_SCHEMA}'
17 class UnmigratedDbException(HandledException):
18 """To identify case of unmigrated DB file."""
22 """Represents the sqlite3 database's file."""
23 # pylint: disable=too-few-public-methods
25 def __init__(self, path: str) -> None:
30 def create_at(cls, path: str) -> DatabaseFile:
31 """Make new DB file at path."""
32 with sql_connect(path) as conn:
33 with open(PATH_DB_SCHEMA, 'r', encoding='utf-8') as f:
34 conn.executescript(f.read())
35 conn.execute(f'PRAGMA user_version = {EXPECTED_DB_VERSION}')
39 def migrate(cls, path: str) -> DatabaseFile:
40 """Apply migrations from_version to EXPECTED_DB_VERSION."""
41 migrations = cls._available_migrations()
42 from_version = cls._get_version_of_db(path)
43 migrations_todo = migrations[from_version+1:]
44 for j, filename in enumerate(migrations_todo):
45 with sql_connect(path) as conn:
46 with open(f'{MIGRATIONS_DIR}/{filename}', 'r',
47 encoding='utf-8') as f:
48 conn.executescript(f.read())
49 user_version = from_version + j + 1
50 with sql_connect(path) as conn:
51 conn.execute(f'PRAGMA user_version = {user_version}')
54 def _check(self) -> None:
55 """Check file exists, and is of proper DB version and schema."""
56 if not isfile(self.path):
57 raise NotFoundException
58 if self._user_version != EXPECTED_DB_VERSION:
59 raise UnmigratedDbException()
60 self._validate_schema()
63 def _available_migrations() -> list[str]:
64 """Validate migrations directory and return sorted entries."""
65 msg_too_big = 'Migration directory points beyond expected DB version.'
66 msg_bad_entry = 'Migration directory contains unexpected entry: '
67 msg_missing = 'Migration directory misses migration of number: '
69 for entry in listdir(MIGRATIONS_DIR):
70 if entry == FILENAME_DB_SCHEMA:
72 toks = entry.split('_', 1)
74 raise HandledException(msg_bad_entry + entry)
77 except ValueError as e:
78 raise HandledException(msg_bad_entry + entry) from e
79 if i > EXPECTED_DB_VERSION:
80 raise HandledException(msg_too_big)
81 migrations[i] = toks[1]
83 for i in range(EXPECTED_DB_VERSION + 1):
84 if i not in migrations:
85 raise HandledException(msg_missing + str(i))
86 migrations_list += [f'{i}_{migrations[i]}']
87 return migrations_list
90 def _get_version_of_db(path: str) -> int:
91 """Get DB user_version, fail if outside expected range."""
92 sql_for_db_version = 'PRAGMA user_version'
93 with sql_connect(path) as conn:
94 db_version = list(conn.execute(sql_for_db_version))[0][0]
95 if db_version > EXPECTED_DB_VERSION:
96 msg = f'Wrong DB version, expected '\
97 f'{EXPECTED_DB_VERSION}, got unknown {db_version}.'
98 raise HandledException(msg)
99 assert isinstance(db_version, int)
103 def _user_version(self) -> int:
104 """Get DB user_version."""
105 return self._get_version_of_db(self.path)
107 def _validate_schema(self) -> None:
108 """Compare found schema with what's stored at PATH_DB_SCHEMA."""
110 def reformat_rows(rows: list[str]) -> list[str]:
114 for subrow in row.split('\n'):
115 subrow = subrow.rstrip()
118 for i, c in enumerate(subrow):
123 elif ',' == c and 0 == in_parentheses:
127 segment = subrow[prev_split:i].strip()
129 new_row += [f' {segment}']
131 segment = subrow[prev_split:].strip()
133 new_row += [f' {segment}']
134 new_row[0] = new_row[0].lstrip()
135 new_row[-1] = new_row[-1].lstrip()
136 if new_row[-1] != ')' and new_row[-3][-1] != ',':
137 new_row[-3] = new_row[-3] + ','
138 new_row[-2:] = [' ' + new_row[-1][:-1]] + [')']
139 new_rows += ['\n'.join(new_row)]
142 sql_for_schema = 'SELECT sql FROM sqlite_master ORDER BY sql'
143 msg_err = 'Database has wrong tables schema. Diff:\n'
144 with sql_connect(self.path) as conn:
145 schema_rows = [r[0] for r in conn.execute(sql_for_schema) if r[0]]
146 schema_rows = reformat_rows(schema_rows)
147 retrieved_schema = ';\n'.join(schema_rows) + ';'
148 with open(PATH_DB_SCHEMA, 'r', encoding='utf-8') as f:
149 stored_schema = f.read().rstrip()
150 if stored_schema != retrieved_schema:
151 diff_msg = Differ().compare(retrieved_schema.splitlines(),
152 stored_schema.splitlines())
153 raise HandledException(msg_err + '\n'.join(diff_msg))
156 class DatabaseConnection:
157 """A single connection to the database."""
159 def __init__(self, db_file: DatabaseFile) -> None:
160 self.conn = sql_connect(db_file.path)
162 def commit(self) -> None:
163 """Commit SQL transaction."""
166 def exec(self, code: str, inputs: tuple[Any, ...] = tuple()) -> Cursor:
167 """Add commands to SQL transaction."""
168 return self.conn.execute(code, inputs)
170 def exec_on_vals(self, code: str, inputs: tuple[Any, ...]) -> Cursor:
171 """Wrapper around .exec appending adequate " (?, …)" to code."""
172 q_marks_from_values = '(' + ','.join(['?'] * len(inputs)) + ')'
173 return self.exec(f'{code} {q_marks_from_values}', inputs)
175 def close(self) -> None:
176 """Close DB connection."""
179 def rewrite_relations(self, table_name: str, key: str, target: int | str,
180 rows: list[list[Any]], key_index: int = 0) -> None:
181 # pylint: disable=too-many-arguments
182 """Rewrite relations in table_name to target, with rows values.
184 Note that single rows are expected without the column and value
185 identified by key and target, which are inserted inside the function
188 self.delete_where(table_name, key, target)
190 values = tuple(row[:key_index] + [target] + row[key_index:])
191 self.exec_on_vals(f'INSERT INTO {table_name} VALUES', values)
193 def row_where(self, table_name: str, key: str,
194 target: int | str) -> list[Row]:
195 """Return list of Rows at table where key == target."""
196 return list(self.exec(f'SELECT * FROM {table_name} WHERE {key} = ?',
199 # def column_where_pattern(self,
203 # keys: list[str]) -> list[Any]:
204 # """Return column of rows where one of keys matches pattern."""
205 # targets = tuple([f'%{pattern}%'] * len(keys))
206 # haystack = ' OR '.join([f'{k} LIKE ?' for k in keys])
207 # sql = f'SELECT {column} FROM {table_name} WHERE {haystack}'
208 # return [row[0] for row in self.exec(sql, targets)]
210 def column_where(self, table_name: str, column: str, key: str,
211 target: int | str) -> list[Any]:
212 """Return column of table where key == target."""
213 return [row[0] for row in
214 self.exec(f'SELECT {column} FROM {table_name} '
215 f'WHERE {key} = ?', (target,))]
217 def column_all(self, table_name: str, column: str) -> list[Any]:
218 """Return complete column of table."""
219 return [row[0] for row in
220 self.exec(f'SELECT {column} FROM {table_name}')]
222 def delete_where(self, table_name: str, key: str,
223 target: int | str) -> None:
224 """Delete from table where key == target."""
225 self.exec(f'DELETE FROM {table_name} WHERE {key} = ?', (target,))
228 BaseModelId = TypeVar('BaseModelId', int, str)
229 BaseModelInstance = TypeVar('BaseModelInstance', bound='BaseModel[Any]')
232 class BaseModel(Generic[BaseModelId]):
233 """Template for most of the models we use/derive from the DB."""
235 to_save_simples: list[str] = []
236 to_save_relations: list[tuple[str, str, str, int]] = []
237 versioned_defaults: dict[str, str | float] = {}
238 add_to_dict: list[str] = []
239 id_: None | BaseModelId
240 cache_: dict[BaseModelId, Self]
241 to_search: list[str] = []
242 can_create_by_id = False
244 sorters: dict[str, Callable[..., Any]] = {}
246 def __init__(self, id_: BaseModelId | None) -> None:
247 if isinstance(id_, int) and id_ < 1:
248 msg = f'illegal {self.__class__.__name__} ID, must be >=1: {id_}'
249 raise HandledException(msg)
250 if isinstance(id_, str) and "" == id_:
251 msg = f'illegal {self.__class__.__name__} ID, must be non-empty'
252 raise HandledException(msg)
255 def __hash__(self) -> int:
256 hashable = [self.id_] + [getattr(self, name)
257 for name in self.to_save_simples]
258 for definition in self.to_save_relations:
259 attr = getattr(self, definition[2])
260 hashable += [tuple(rel.id_ for rel in attr)]
261 for name in self.to_save_versioned():
262 hashable += [hash(getattr(self, name))]
263 return hash(tuple(hashable))
265 def __eq__(self, other: object) -> bool:
266 if not isinstance(other, self.__class__):
268 return hash(self) == hash(other)
270 def __lt__(self, other: Any) -> bool:
271 if not isinstance(other, self.__class__):
272 msg = 'cannot compare to object of different class'
273 raise HandledException(msg)
274 assert isinstance(self.id_, int)
275 assert isinstance(other.id_, int)
276 return self.id_ < other.id_
279 def to_save_versioned(cls) -> list[str]:
280 """Return keys of cls.versioned_defaults assuming we wanna save 'em."""
281 return list(cls.versioned_defaults.keys())
284 def as_dict_and_refs(self) -> tuple[dict[str, object],
285 list[BaseModel[int] | BaseModel[str]]]:
286 """Return self as json.dumps-ready dict, list of referenced objects."""
287 d: dict[str, object] = {'id': self.id_}
288 refs: list[BaseModel[int] | BaseModel[str]] = []
289 for to_save in self.to_save_simples:
290 d[to_save] = getattr(self, to_save)
291 if len(self.to_save_versioned()) > 0:
293 for k in self.to_save_versioned():
294 attr = getattr(self, k)
295 assert isinstance(d['_versioned'], dict)
296 d['_versioned'][k] = attr.history
297 rels_to_collect = [rel[2] for rel in self.to_save_relations]
298 rels_to_collect += self.add_to_dict
299 for attr_name in rels_to_collect:
301 for item in getattr(self, attr_name):
302 rel_list += [item.id_]
305 d[attr_name] = rel_list
309 def name_lowercase(cls) -> str:
310 """Convenience method to return cls' name in lowercase."""
311 return cls.__name__.lower()
314 def sort_by(cls, seq: list[Any], sort_key: str, default: str = 'title'
316 """Sort cls list by cls.sorters[sort_key] (reverse if '-'-prefixed).
318 Before cls.sorters[sort_key] is applied, seq is sorted by .id_, to
319 ensure predictability where parts of seq are of same sort value.
322 if len(sort_key) > 1 and '-' == sort_key[0]:
323 sort_key = sort_key[1:]
325 if sort_key not in cls.sorters:
327 seq.sort(key=lambda x: x.id_, reverse=reverse)
328 sorter: Callable[..., Any] = cls.sorters[sort_key]
329 seq.sort(key=sorter, reverse=reverse)
331 sort_key = f'-{sort_key}'
335 # (we primarily use the cache to ensure we work on the same object in
336 # memory no matter where and how we retrieve it, e.g. we don't want
337 # .by_id() calls to create a new object each time, but rather a pointer
338 # to the one already instantiated)
340 def __getattribute__(self, name: str) -> Any:
341 """Ensure fail if ._disappear() was called, except to check ._exists"""
342 if name != '_exists' and not super().__getattribute__('_exists'):
343 msg = f'Object for attribute does not exist: {name}'
344 raise HandledException(msg)
345 return super().__getattribute__(name)
347 def _disappear(self) -> None:
348 """Invalidate object, make future use raise exceptions."""
349 assert self.id_ is not None
350 if self._get_cached(self.id_):
352 to_kill = list(self.__dict__.keys())
358 def empty_cache(cls) -> None:
359 """Empty class's cache, and disappear all former inhabitants."""
360 # pylint: disable=protected-access
361 # (cause we remain within the class)
362 if hasattr(cls, 'cache_'):
363 to_disappear = list(cls.cache_.values())
364 for item in to_disappear:
369 def get_cache(cls: type[BaseModelInstance]
370 ) -> dict[Any, BaseModelInstance]:
371 """Get cache dictionary, create it if not yet existing."""
372 if not hasattr(cls, 'cache_'):
373 d: dict[Any, BaseModelInstance] = {}
378 def _get_cached(cls: type[BaseModelInstance],
379 id_: BaseModelId) -> BaseModelInstance | None:
380 """Get object of id_ from class's cache, or None if not found."""
381 cache = cls.get_cache()
384 assert isinstance(obj, cls)
388 def cache(self) -> None:
389 """Update object in class's cache.
391 Also calls ._disappear if cache holds older reference to object of same
392 ID, but different memory address, to avoid doing anything with
396 raise HandledException('Cannot cache object without ID.')
397 cache = self.get_cache()
398 old_cached = self._get_cached(self.id_)
399 if old_cached and id(old_cached) != id(self):
400 # pylint: disable=protected-access
401 # (cause we remain within the class)
402 old_cached._disappear()
403 cache[self.id_] = self
405 def _uncache(self) -> None:
406 """Remove self from cache."""
408 raise HandledException('Cannot un-cache object without ID.')
409 cache = self.get_cache()
412 # object retrieval and generation
415 def from_table_row(cls: type[BaseModelInstance],
416 # pylint: disable=unused-argument
417 db_conn: DatabaseConnection,
418 row: Row | list[Any]) -> BaseModelInstance:
419 """Make from DB row (sans relations), update DB cache with it."""
421 assert obj.id_ is not None
422 for attr_name in cls.to_save_versioned():
423 attr = getattr(obj, attr_name)
424 table_name = attr.table_name
425 for row_ in db_conn.row_where(table_name, 'parent', obj.id_):
426 attr.history_from_row(row_)
431 def by_id(cls, db_conn: DatabaseConnection, id_: BaseModelId) -> Self:
432 """Retrieve by id_, on failure throw NotFoundException.
434 First try to get from cls.cache_, only then check DB; if found,
439 obj = cls._get_cached(id_)
441 for row in db_conn.row_where(cls.table_name, 'id', id_):
442 obj = cls.from_table_row(db_conn, row)
446 raise NotFoundException(f'found no object of ID {id_}')
449 def by_id_or_create(cls, db_conn: DatabaseConnection,
450 id_: BaseModelId | None
452 """Wrapper around .by_id, creating (not caching/saving) if not find."""
453 if not cls.can_create_by_id:
454 raise HandledException('Class cannot .by_id_or_create.')
458 return cls.by_id(db_conn, id_)
459 except NotFoundException:
463 def all(cls: type[BaseModelInstance],
464 db_conn: DatabaseConnection) -> list[BaseModelInstance]:
465 """Collect all objects of class into list.
467 Note that this primarily returns the contents of the cache, and only
468 _expands_ that by additional findings in the DB. This assumes the
469 cache is always instantly cleaned of any items that would be removed
472 items: dict[BaseModelId, BaseModelInstance] = {}
473 for k, v in cls.get_cache().items():
474 assert isinstance(v, cls)
476 already_recorded = items.keys()
477 for id_ in db_conn.column_all(cls.table_name, 'id'):
478 if id_ not in already_recorded:
479 item = cls.by_id(db_conn, id_)
480 assert item.id_ is not None
481 items[item.id_] = item
482 return sorted(list(items.values()))
485 def by_date_range_with_limits(cls: type[BaseModelInstance],
486 db_conn: DatabaseConnection,
487 date_range: tuple[str, str],
488 date_col: str = 'day'
489 ) -> tuple[list[BaseModelInstance], str,
491 """Return list of items in DB within (closed) date_range interval.
493 If no range values provided, defaults them to 'yesterday' and
494 'tomorrow'. Knows to properly interpret these and 'today' as value.
496 start_str = date_range[0] if date_range[0] else 'yesterday'
497 end_str = date_range[1] if date_range[1] else 'tomorrow'
498 start_date = valid_date(start_str)
499 end_date = valid_date(end_str)
501 sql = f'SELECT id FROM {cls.table_name} '
502 sql += f'WHERE {date_col} >= ? AND {date_col} <= ?'
503 for row in db_conn.exec(sql, (start_date, end_date)):
504 items += [cls.by_id(db_conn, row[0])]
505 return items, start_date, end_date
508 def matching(cls: type[BaseModelInstance], db_conn: DatabaseConnection,
509 pattern: str) -> list[BaseModelInstance]:
510 """Return all objects whose .to_search match pattern."""
511 items = cls.all(db_conn)
515 for attr_name in cls.to_search:
516 toks = attr_name.split('.')
519 attr = getattr(parent, tok)
529 def save(self, db_conn: DatabaseConnection) -> None:
530 """Write self to DB and cache and ensure .id_.
532 Write both to DB, and to cache. To DB, write .id_ and attributes
533 listed in cls.to_save_[simples|versioned|_relations].
535 Ensure self.id_ by setting it to what the DB command returns as the
536 last saved row's ID (cursor.lastrowid), EXCEPT if self.id_ already
537 exists as a 'str', which implies we do our own ID creation (so far
538 only the case with the Day class, where it's to be a date string.
540 values = tuple([self.id_] + [getattr(self, key)
541 for key in self.to_save_simples])
542 table_name = self.table_name
543 cursor = db_conn.exec_on_vals(f'REPLACE INTO {table_name} VALUES',
545 if not isinstance(self.id_, str):
546 self.id_ = cursor.lastrowid # type: ignore[assignment]
548 for attr_name in self.to_save_versioned():
549 getattr(self, attr_name).save(db_conn)
550 for table, column, attr_name, key_index in self.to_save_relations:
551 assert isinstance(self.id_, (int, str))
552 db_conn.rewrite_relations(table, column, self.id_,
554 in getattr(self, attr_name)], key_index)
556 def remove(self, db_conn: DatabaseConnection) -> None:
557 """Remove from DB and cache, including dependencies."""
558 if self.id_ is None or self._get_cached(self.id_) is None:
559 raise HandledException('cannot remove unsaved item')
560 for attr_name in self.to_save_versioned():
561 getattr(self, attr_name).remove(db_conn)
562 for table, column, attr_name, _ in self.to_save_relations:
563 db_conn.delete_where(table, column, self.id_)
565 db_conn.delete_where(self.table_name, 'id', self.id_)