1 """Database management."""
2 from __future__ import annotations
4 from os.path import isfile
5 from difflib import Differ
6 from sqlite3 import connect as sql_connect, Cursor, Row
7 from typing import Any, Self, TypeVar, Generic, Callable
8 from plomtask.exceptions import HandledException, NotFoundException
9 from plomtask.dating import valid_date
11 EXPECTED_DB_VERSION = 5
12 MIGRATIONS_DIR = 'migrations'
13 FILENAME_DB_SCHEMA = f'init_{EXPECTED_DB_VERSION}.sql'
14 PATH_DB_SCHEMA = f'{MIGRATIONS_DIR}/{FILENAME_DB_SCHEMA}'
17 class UnmigratedDbException(HandledException):
18 """To identify case of unmigrated DB file."""
22 """Collects references for future library building."""
23 # pylint: disable=too-few-public-methods
25 def __init__(self, d: dict[str, list[int | str]]) -> None:
26 # NB: For tighter mypy testing, we might prefer the library argument
27 # to be of type dict[str, list[int] | list[str] instead. But my
28 # current coding knowledge only manages to make that work by turning
29 # the code much more complex, so let's leave it at
33 def update(self, other: CtxReferences) -> bool:
34 """Updates other with entries in self."""
36 for cls_name, id_list in self.d.items():
37 if cls_name not in other.d:
38 other.d[cls_name] = []
40 if id_ not in other.d[cls_name]:
41 other.d[cls_name] += [id_]
46 """Represents the sqlite3 database's file."""
47 # pylint: disable=too-few-public-methods
49 def __init__(self, path: str) -> None:
54 def create_at(cls, path: str) -> DatabaseFile:
55 """Make new DB file at path."""
56 with sql_connect(path) as conn:
57 with open(PATH_DB_SCHEMA, 'r', encoding='utf-8') as f:
58 conn.executescript(f.read())
59 conn.execute(f'PRAGMA user_version = {EXPECTED_DB_VERSION}')
63 def migrate(cls, path: str) -> DatabaseFile:
64 """Apply migrations from_version to EXPECTED_DB_VERSION."""
65 migrations = cls._available_migrations()
66 from_version = cls._get_version_of_db(path)
67 migrations_todo = migrations[from_version+1:]
68 for j, filename in enumerate(migrations_todo):
69 with sql_connect(path) as conn:
70 with open(f'{MIGRATIONS_DIR}/{filename}', 'r',
71 encoding='utf-8') as f:
72 conn.executescript(f.read())
73 user_version = from_version + j + 1
74 with sql_connect(path) as conn:
75 conn.execute(f'PRAGMA user_version = {user_version}')
78 def _check(self) -> None:
79 """Check file exists, and is of proper DB version and schema."""
80 if not isfile(self.path):
81 raise NotFoundException
82 if self._user_version != EXPECTED_DB_VERSION:
83 raise UnmigratedDbException()
84 self._validate_schema()
87 def _available_migrations() -> list[str]:
88 """Validate migrations directory and return sorted entries."""
89 msg_too_big = 'Migration directory points beyond expected DB version.'
90 msg_bad_entry = 'Migration directory contains unexpected entry: '
91 msg_missing = 'Migration directory misses migration of number: '
93 for entry in listdir(MIGRATIONS_DIR):
94 if entry == FILENAME_DB_SCHEMA:
96 toks = entry.split('_', 1)
98 raise HandledException(msg_bad_entry + entry)
101 except ValueError as e:
102 raise HandledException(msg_bad_entry + entry) from e
103 if i > EXPECTED_DB_VERSION:
104 raise HandledException(msg_too_big)
105 migrations[i] = toks[1]
107 for i in range(EXPECTED_DB_VERSION + 1):
108 if i not in migrations:
109 raise HandledException(msg_missing + str(i))
110 migrations_list += [f'{i}_{migrations[i]}']
111 return migrations_list
114 def _get_version_of_db(path: str) -> int:
115 """Get DB user_version, fail if outside expected range."""
116 sql_for_db_version = 'PRAGMA user_version'
117 with sql_connect(path) as conn:
118 db_version = list(conn.execute(sql_for_db_version))[0][0]
119 if db_version > EXPECTED_DB_VERSION:
120 msg = f'Wrong DB version, expected '\
121 f'{EXPECTED_DB_VERSION}, got unknown {db_version}.'
122 raise HandledException(msg)
123 assert isinstance(db_version, int)
127 def _user_version(self) -> int:
128 """Get DB user_version."""
129 return self._get_version_of_db(self.path)
131 def _validate_schema(self) -> None:
132 """Compare found schema with what's stored at PATH_DB_SCHEMA."""
134 def reformat_rows(rows: list[str]) -> list[str]:
138 for subrow in row.split('\n'):
139 subrow = subrow.rstrip()
142 for i, c in enumerate(subrow):
147 elif ',' == c and 0 == in_parentheses:
151 segment = subrow[prev_split:i].strip()
153 new_row += [f' {segment}']
155 segment = subrow[prev_split:].strip()
157 new_row += [f' {segment}']
158 new_row[0] = new_row[0].lstrip()
159 new_row[-1] = new_row[-1].lstrip()
160 if new_row[-1] != ')' and new_row[-3][-1] != ',':
161 new_row[-3] = new_row[-3] + ','
162 new_row[-2:] = [' ' + new_row[-1][:-1]] + [')']
163 new_rows += ['\n'.join(new_row)]
166 sql_for_schema = 'SELECT sql FROM sqlite_master ORDER BY sql'
167 msg_err = 'Database has wrong tables schema. Diff:\n'
168 with sql_connect(self.path) as conn:
169 schema_rows = [r[0] for r in conn.execute(sql_for_schema) if r[0]]
170 schema_rows = reformat_rows(schema_rows)
171 retrieved_schema = ';\n'.join(schema_rows) + ';'
172 with open(PATH_DB_SCHEMA, 'r', encoding='utf-8') as f:
173 stored_schema = f.read().rstrip()
174 if stored_schema != retrieved_schema:
175 diff_msg = Differ().compare(retrieved_schema.splitlines(),
176 stored_schema.splitlines())
177 raise HandledException(msg_err + '\n'.join(diff_msg))
180 class DatabaseConnection:
181 """A single connection to the database."""
183 def __init__(self, db_file: DatabaseFile) -> None:
184 self.conn = sql_connect(db_file.path)
186 def commit(self) -> None:
187 """Commit SQL transaction."""
190 def exec(self, code: str, inputs: tuple[Any, ...] = tuple()) -> Cursor:
191 """Add commands to SQL transaction."""
192 return self.conn.execute(code, inputs)
194 def exec_on_vals(self, code: str, inputs: tuple[Any, ...]) -> Cursor:
195 """Wrapper around .exec appending adequate " (?, …)" to code."""
196 q_marks_from_values = '(' + ','.join(['?'] * len(inputs)) + ')'
197 return self.exec(f'{code} {q_marks_from_values}', inputs)
199 def close(self) -> None:
200 """Close DB connection."""
203 def rewrite_relations(self, table_name: str, key: str, target: int | str,
204 rows: list[list[Any]], key_index: int = 0) -> None:
205 # pylint: disable=too-many-arguments
206 """Rewrite relations in table_name to target, with rows values.
208 Note that single rows are expected without the column and value
209 identified by key and target, which are inserted inside the function
212 self.delete_where(table_name, key, target)
214 values = tuple(row[:key_index] + [target] + row[key_index:])
215 self.exec_on_vals(f'INSERT INTO {table_name} VALUES', values)
217 def row_where(self, table_name: str, key: str,
218 target: int | str) -> list[Row]:
219 """Return list of Rows at table where key == target."""
220 return list(self.exec(f'SELECT * FROM {table_name} WHERE {key} = ?',
223 # def column_where_pattern(self,
227 # keys: list[str]) -> list[Any]:
228 # """Return column of rows where one of keys matches pattern."""
229 # targets = tuple([f'%{pattern}%'] * len(keys))
230 # haystack = ' OR '.join([f'{k} LIKE ?' for k in keys])
231 # sql = f'SELECT {column} FROM {table_name} WHERE {haystack}'
232 # return [row[0] for row in self.exec(sql, targets)]
234 def column_where(self, table_name: str, column: str, key: str,
235 target: int | str) -> list[Any]:
236 """Return column of table where key == target."""
237 return [row[0] for row in
238 self.exec(f'SELECT {column} FROM {table_name} '
239 f'WHERE {key} = ?', (target,))]
241 def column_all(self, table_name: str, column: str) -> list[Any]:
242 """Return complete column of table."""
243 return [row[0] for row in
244 self.exec(f'SELECT {column} FROM {table_name}')]
246 def delete_where(self, table_name: str, key: str,
247 target: int | str) -> None:
248 """Delete from table where key == target."""
249 self.exec(f'DELETE FROM {table_name} WHERE {key} = ?', (target,))
252 BaseModelId = TypeVar('BaseModelId', int, str)
253 BaseModelInstance = TypeVar('BaseModelInstance', bound='BaseModel[Any]')
256 class BaseModel(Generic[BaseModelId]):
257 """Template for most of the models we use/derive from the DB."""
259 to_save_simples: list[str] = []
260 to_save_relations: list[tuple[str, str, str, int]] = []
261 versioned_defaults: dict[str, str | float] = {}
262 add_to_dict: list[str] = []
263 id_: None | BaseModelId
264 cache_: dict[BaseModelId, Self]
265 to_search: list[str] = []
266 can_create_by_id = False
268 sorters: dict[str, Callable[..., Any]] = {}
270 def __init__(self, id_: BaseModelId | None) -> None:
271 if isinstance(id_, int) and id_ < 1:
272 msg = f'illegal {self.__class__.__name__} ID, must be >=1: {id_}'
273 raise HandledException(msg)
274 if isinstance(id_, str) and "" == id_:
275 msg = f'illegal {self.__class__.__name__} ID, must be non-empty'
276 raise HandledException(msg)
279 def __hash__(self) -> int:
280 hashable = [self.id_] + [getattr(self, name)
281 for name in self.to_save_simples]
282 for definition in self.to_save_relations:
283 attr = getattr(self, definition[2])
284 hashable += [tuple(rel.id_ for rel in attr)]
285 for name in self.to_save_versioned():
286 hashable += [hash(getattr(self, name))]
287 return hash(tuple(hashable))
289 def __eq__(self, other: object) -> bool:
290 if not isinstance(other, self.__class__):
292 return hash(self) == hash(other)
294 def __lt__(self, other: Any) -> bool:
295 if not isinstance(other, self.__class__):
296 msg = 'cannot compare to object of different class'
297 raise HandledException(msg)
298 assert isinstance(self.id_, int)
299 assert isinstance(other.id_, int)
300 return self.id_ < other.id_
303 def to_save_versioned(cls) -> list[str]:
304 """Return keys of cls.versioned_defaults assuming we wanna save 'em."""
305 return list(cls.versioned_defaults.keys())
308 def as_dict(self) -> dict[str, object]:
309 """Return self as (json.dumps-compatible) dict."""
310 references = CtxReferences({})
311 d: dict[str, object] = {'id': self.id_, '_references': references}
312 for to_save in self.to_save_simples:
313 d[to_save] = getattr(self, to_save)
314 if len(self.to_save_versioned()) > 0:
316 for k in self.to_save_versioned():
317 attr = getattr(self, k)
318 assert isinstance(d['_versioned'], dict)
319 d['_versioned'][k] = attr.history
320 for r in self.to_save_relations:
322 l: list[int | str] = []
323 for rel in getattr(self, attr_name):
324 cls_name = rel.__class__.__name__
325 if cls_name not in references.d:
326 references.d[cls_name] = []
328 references.d[cls_name] += [rel.id_]
330 for k in self.add_to_dict:
331 d[k] = [x.into_reference(references)
332 for x in getattr(self, k)]
335 def into_reference(self, references: CtxReferences) -> int | str:
336 """Return self.id_ and write into references for class.."""
337 cls_name = self.__class__.__name__
338 if cls_name not in references.d:
339 references.d[cls_name] = []
340 assert self.id_ is not None
341 references.d[cls_name] += [self.id_]
342 own_refs = self.as_dict['_references']
343 assert isinstance(own_refs, CtxReferences)
344 own_refs.update(references)
348 def name_lowercase(cls) -> str:
349 """Convenience method to return cls' name in lowercase."""
350 return cls.__name__.lower()
353 def sort_by(cls, seq: list[Any], sort_key: str, default: str = 'title'
355 """Sort cls list by cls.sorters[sort_key] (reverse if '-'-prefixed).
357 Before cls.sorters[sort_key] is applied, seq is sorted by .id_, to
358 ensure predictability where parts of seq are of same sort value.
361 if len(sort_key) > 1 and '-' == sort_key[0]:
362 sort_key = sort_key[1:]
364 if sort_key not in cls.sorters:
366 seq.sort(key=lambda x: x.id_, reverse=reverse)
367 sorter: Callable[..., Any] = cls.sorters[sort_key]
368 seq.sort(key=sorter, reverse=reverse)
370 sort_key = f'-{sort_key}'
374 # (we primarily use the cache to ensure we work on the same object in
375 # memory no matter where and how we retrieve it, e.g. we don't want
376 # .by_id() calls to create a new object each time, but rather a pointer
377 # to the one already instantiated)
379 def __getattribute__(self, name: str) -> Any:
380 """Ensure fail if ._disappear() was called, except to check ._exists"""
381 if name != '_exists' and not super().__getattribute__('_exists'):
382 raise HandledException('Object does not exist.')
383 return super().__getattribute__(name)
385 def _disappear(self) -> None:
386 """Invalidate object, make future use raise exceptions."""
387 assert self.id_ is not None
388 if self._get_cached(self.id_):
390 to_kill = list(self.__dict__.keys())
396 def empty_cache(cls) -> None:
397 """Empty class's cache, and disappear all former inhabitants."""
398 # pylint: disable=protected-access
399 # (cause we remain within the class)
400 if hasattr(cls, 'cache_'):
401 to_disappear = list(cls.cache_.values())
402 for item in to_disappear:
407 def get_cache(cls: type[BaseModelInstance]) -> dict[Any, BaseModel[Any]]:
408 """Get cache dictionary, create it if not yet existing."""
409 if not hasattr(cls, 'cache_'):
410 d: dict[Any, BaseModel[Any]] = {}
415 def _get_cached(cls: type[BaseModelInstance],
416 id_: BaseModelId) -> BaseModelInstance | None:
417 """Get object of id_ from class's cache, or None if not found."""
418 cache = cls.get_cache()
421 assert isinstance(obj, cls)
425 def cache(self) -> None:
426 """Update object in class's cache.
428 Also calls ._disappear if cache holds older reference to object of same
429 ID, but different memory address, to avoid doing anything with
433 raise HandledException('Cannot cache object without ID.')
434 cache = self.get_cache()
435 old_cached = self._get_cached(self.id_)
436 if old_cached and id(old_cached) != id(self):
437 # pylint: disable=protected-access
438 # (cause we remain within the class)
439 old_cached._disappear()
440 cache[self.id_] = self
442 def _uncache(self) -> None:
443 """Remove self from cache."""
445 raise HandledException('Cannot un-cache object without ID.')
446 cache = self.get_cache()
449 # object retrieval and generation
452 def from_table_row(cls: type[BaseModelInstance],
453 # pylint: disable=unused-argument
454 db_conn: DatabaseConnection,
455 row: Row | list[Any]) -> BaseModelInstance:
456 """Make from DB row (sans relations), update DB cache with it."""
458 assert obj.id_ is not None
459 for attr_name in cls.to_save_versioned():
460 attr = getattr(obj, attr_name)
461 table_name = attr.table_name
462 for row_ in db_conn.row_where(table_name, 'parent', obj.id_):
463 attr.history_from_row(row_)
468 def by_id(cls, db_conn: DatabaseConnection, id_: BaseModelId) -> Self:
469 """Retrieve by id_, on failure throw NotFoundException.
471 First try to get from cls.cache_, only then check DB; if found,
476 obj = cls._get_cached(id_)
478 for row in db_conn.row_where(cls.table_name, 'id', id_):
479 obj = cls.from_table_row(db_conn, row)
483 raise NotFoundException(f'found no object of ID {id_}')
486 def by_id_or_create(cls, db_conn: DatabaseConnection,
487 id_: BaseModelId | None
489 """Wrapper around .by_id, creating (not caching/saving) if not find."""
490 if not cls.can_create_by_id:
491 raise HandledException('Class cannot .by_id_or_create.')
495 return cls.by_id(db_conn, id_)
496 except NotFoundException:
500 def all(cls: type[BaseModelInstance],
501 db_conn: DatabaseConnection) -> list[BaseModelInstance]:
502 """Collect all objects of class into list.
504 Note that this primarily returns the contents of the cache, and only
505 _expands_ that by additional findings in the DB. This assumes the
506 cache is always instantly cleaned of any items that would be removed
509 items: dict[BaseModelId, BaseModelInstance] = {}
510 for k, v in cls.get_cache().items():
511 assert isinstance(v, cls)
513 already_recorded = items.keys()
514 for id_ in db_conn.column_all(cls.table_name, 'id'):
515 if id_ not in already_recorded:
516 item = cls.by_id(db_conn, id_)
517 assert item.id_ is not None
518 items[item.id_] = item
519 return list(items.values())
522 def by_date_range_with_limits(cls: type[BaseModelInstance],
523 db_conn: DatabaseConnection,
524 date_range: tuple[str, str],
525 date_col: str = 'day'
526 ) -> tuple[list[BaseModelInstance], str,
528 """Return list of items in DB within (closed) date_range interval.
530 If no range values provided, defaults them to 'yesterday' and
531 'tomorrow'. Knows to properly interpret these and 'today' as value.
533 start_str = date_range[0] if date_range[0] else 'yesterday'
534 end_str = date_range[1] if date_range[1] else 'tomorrow'
535 start_date = valid_date(start_str)
536 end_date = valid_date(end_str)
538 sql = f'SELECT id FROM {cls.table_name} '
539 sql += f'WHERE {date_col} >= ? AND {date_col} <= ?'
540 for row in db_conn.exec(sql, (start_date, end_date)):
541 items += [cls.by_id(db_conn, row[0])]
542 return items, start_date, end_date
545 def matching(cls: type[BaseModelInstance], db_conn: DatabaseConnection,
546 pattern: str) -> list[BaseModelInstance]:
547 """Return all objects whose .to_search match pattern."""
548 items = cls.all(db_conn)
552 for attr_name in cls.to_search:
553 toks = attr_name.split('.')
556 attr = getattr(parent, tok)
566 def save(self, db_conn: DatabaseConnection) -> None:
567 """Write self to DB and cache and ensure .id_.
569 Write both to DB, and to cache. To DB, write .id_ and attributes
570 listed in cls.to_save_[simples|versioned|_relations].
572 Ensure self.id_ by setting it to what the DB command returns as the
573 last saved row's ID (cursor.lastrowid), EXCEPT if self.id_ already
574 exists as a 'str', which implies we do our own ID creation (so far
575 only the case with the Day class, where it's to be a date string.
577 values = tuple([self.id_] + [getattr(self, key)
578 for key in self.to_save_simples])
579 table_name = self.table_name
580 cursor = db_conn.exec_on_vals(f'REPLACE INTO {table_name} VALUES',
582 if not isinstance(self.id_, str):
583 self.id_ = cursor.lastrowid # type: ignore[assignment]
585 for attr_name in self.to_save_versioned():
586 getattr(self, attr_name).save(db_conn)
587 for table, column, attr_name, key_index in self.to_save_relations:
588 assert isinstance(self.id_, (int, str))
589 db_conn.rewrite_relations(table, column, self.id_,
591 in getattr(self, attr_name)], key_index)
593 def remove(self, db_conn: DatabaseConnection) -> None:
594 """Remove from DB and cache, including dependencies."""
595 if self.id_ is None or self._get_cached(self.id_) is None:
596 raise HandledException('cannot remove unsaved item')
597 for attr_name in self.to_save_versioned():
598 getattr(self, attr_name).remove(db_conn)
599 for table, column, attr_name, _ in self.to_save_relations:
600 db_conn.delete_where(table, column, self.id_)
602 db_conn.delete_where(self.table_name, 'id', self.id_)