1 """Database management."""
2 from __future__ import annotations
4 from os.path import isfile
5 from difflib import Differ
6 from sqlite3 import connect as sql_connect, Cursor, Row
7 from typing import Any, Self, TypeVar, Generic
8 from plomtask.exceptions import HandledException, NotFoundException
9 from plomtask.dating import valid_date
11 EXPECTED_DB_VERSION = 5
12 MIGRATIONS_DIR = 'migrations'
13 FILENAME_DB_SCHEMA = f'init_{EXPECTED_DB_VERSION}.sql'
14 PATH_DB_SCHEMA = f'{MIGRATIONS_DIR}/{FILENAME_DB_SCHEMA}'
17 class UnmigratedDbException(HandledException):
18 """To identify case of unmigrated DB file."""
22 """Represents the sqlite3 database's file."""
23 # pylint: disable=too-few-public-methods
25 def __init__(self, path: str) -> None:
30 def create_at(cls, path: str) -> DatabaseFile:
31 """Make new DB file at path."""
32 with sql_connect(path) as conn:
33 with open(PATH_DB_SCHEMA, 'r', encoding='utf-8') as f:
34 conn.executescript(f.read())
35 conn.execute(f'PRAGMA user_version = {EXPECTED_DB_VERSION}')
39 def migrate(cls, path: str) -> DatabaseFile:
40 """Apply migrations from_version to EXPECTED_DB_VERSION."""
41 migrations = cls._available_migrations()
42 from_version = cls._get_version_of_db(path)
43 migrations_todo = migrations[from_version+1:]
44 for j, filename in enumerate(migrations_todo):
45 with sql_connect(path) as conn:
46 with open(f'{MIGRATIONS_DIR}/{filename}', 'r',
47 encoding='utf-8') as f:
48 conn.executescript(f.read())
49 user_version = from_version + j + 1
50 with sql_connect(path) as conn:
51 conn.execute(f'PRAGMA user_version = {user_version}')
54 def _check(self) -> None:
55 """Check file exists, and is of proper DB version and schema."""
56 if not isfile(self.path):
57 raise NotFoundException
58 if self._user_version != EXPECTED_DB_VERSION:
59 raise UnmigratedDbException()
60 self._validate_schema()
63 def _available_migrations() -> list[str]:
64 """Validate migrations directory and return sorted entries."""
65 msg_too_big = 'Migration directory points beyond expected DB version.'
66 msg_bad_entry = 'Migration directory contains unexpected entry: '
67 msg_missing = 'Migration directory misses migration of number: '
69 for entry in listdir(MIGRATIONS_DIR):
70 if entry == FILENAME_DB_SCHEMA:
72 toks = entry.split('_', 1)
74 raise HandledException(msg_bad_entry + entry)
77 except ValueError as e:
78 raise HandledException(msg_bad_entry + entry) from e
79 if i > EXPECTED_DB_VERSION:
80 raise HandledException(msg_too_big)
81 migrations[i] = toks[1]
83 for i in range(EXPECTED_DB_VERSION + 1):
84 if i not in migrations:
85 raise HandledException(msg_missing + str(i))
86 migrations_list += [f'{i}_{migrations[i]}']
87 return migrations_list
90 def _get_version_of_db(path: str) -> int:
91 """Get DB user_version, fail if outside expected range."""
92 sql_for_db_version = 'PRAGMA user_version'
93 with sql_connect(path) as conn:
94 db_version = list(conn.execute(sql_for_db_version))[0][0]
95 if db_version > EXPECTED_DB_VERSION:
96 msg = f'Wrong DB version, expected '\
97 f'{EXPECTED_DB_VERSION}, got unknown {db_version}.'
98 raise HandledException(msg)
99 assert isinstance(db_version, int)
103 def _user_version(self) -> int:
104 """Get DB user_version."""
105 return self._get_version_of_db(self.path)
107 def _validate_schema(self) -> None:
108 """Compare found schema with what's stored at PATH_DB_SCHEMA."""
110 def reformat_rows(rows: list[str]) -> list[str]:
114 for subrow in row.split('\n'):
115 subrow = subrow.rstrip()
118 for i, c in enumerate(subrow):
123 elif ',' == c and 0 == in_parentheses:
127 segment = subrow[prev_split:i].strip()
129 new_row += [f' {segment}']
131 segment = subrow[prev_split:].strip()
133 new_row += [f' {segment}']
134 new_row[0] = new_row[0].lstrip()
135 new_row[-1] = new_row[-1].lstrip()
136 if new_row[-1] != ')' and new_row[-3][-1] != ',':
137 new_row[-3] = new_row[-3] + ','
138 new_row[-2:] = [' ' + new_row[-1][:-1]] + [')']
139 new_rows += ['\n'.join(new_row)]
142 sql_for_schema = 'SELECT sql FROM sqlite_master ORDER BY sql'
143 msg_err = 'Database has wrong tables schema. Diff:\n'
144 with sql_connect(self.path) as conn:
145 schema_rows = [r[0] for r in conn.execute(sql_for_schema) if r[0]]
146 schema_rows = reformat_rows(schema_rows)
147 retrieved_schema = ';\n'.join(schema_rows) + ';'
148 with open(PATH_DB_SCHEMA, 'r', encoding='utf-8') as f:
149 stored_schema = f.read().rstrip()
150 if stored_schema != retrieved_schema:
151 diff_msg = Differ().compare(retrieved_schema.splitlines(),
152 stored_schema.splitlines())
153 raise HandledException(msg_err + '\n'.join(diff_msg))
156 class DatabaseConnection:
157 """A single connection to the database."""
159 def __init__(self, db_file: DatabaseFile) -> None:
160 self.conn = sql_connect(db_file.path)
162 def commit(self) -> None:
163 """Commit SQL transaction."""
166 def exec(self, code: str, inputs: tuple[Any, ...] = tuple()) -> Cursor:
167 """Add commands to SQL transaction."""
168 return self.conn.execute(code, inputs)
170 def exec_on_vals(self, code: str, inputs: tuple[Any, ...]) -> Cursor:
171 """Wrapper around .exec appending adequate " (?, …)" to code."""
172 q_marks_from_values = '(' + ','.join(['?'] * len(inputs)) + ')'
173 return self.exec(f'{code} {q_marks_from_values}', inputs)
175 def close(self) -> None:
176 """Close DB connection."""
179 def rewrite_relations(self, table_name: str, key: str, target: int | str,
180 rows: list[list[Any]], key_index: int = 0) -> None:
181 # pylint: disable=too-many-arguments
182 """Rewrite relations in table_name to target, with rows values.
184 Note that single rows are expected without the column and value
185 identified by key and target, which are inserted inside the function
188 self.delete_where(table_name, key, target)
190 values = tuple(row[:key_index] + [target] + row[key_index:])
191 self.exec_on_vals(f'INSERT INTO {table_name} VALUES', values)
193 def row_where(self, table_name: str, key: str,
194 target: int | str) -> list[Row]:
195 """Return list of Rows at table where key == target."""
196 return list(self.exec(f'SELECT * FROM {table_name} WHERE {key} = ?',
199 # def column_where_pattern(self,
203 # keys: list[str]) -> list[Any]:
204 # """Return column of rows where one of keys matches pattern."""
205 # targets = tuple([f'%{pattern}%'] * len(keys))
206 # haystack = ' OR '.join([f'{k} LIKE ?' for k in keys])
207 # sql = f'SELECT {column} FROM {table_name} WHERE {haystack}'
208 # return [row[0] for row in self.exec(sql, targets)]
210 def column_where(self, table_name: str, column: str, key: str,
211 target: int | str) -> list[Any]:
212 """Return column of table where key == target."""
213 return [row[0] for row in
214 self.exec(f'SELECT {column} FROM {table_name} '
215 f'WHERE {key} = ?', (target,))]
217 def column_all(self, table_name: str, column: str) -> list[Any]:
218 """Return complete column of table."""
219 return [row[0] for row in
220 self.exec(f'SELECT {column} FROM {table_name}')]
222 def delete_where(self, table_name: str, key: str,
223 target: int | str) -> None:
224 """Delete from table where key == target."""
225 self.exec(f'DELETE FROM {table_name} WHERE {key} = ?', (target,))
228 BaseModelId = TypeVar('BaseModelId', int, str)
229 BaseModelInstance = TypeVar('BaseModelInstance', bound='BaseModel[Any]')
232 class BaseModel(Generic[BaseModelId]):
233 """Template for most of the models we use/derive from the DB."""
235 to_save: list[str] = []
236 to_save_versioned: list[str] = []
237 to_save_relations: list[tuple[str, str, str, int]] = []
238 add_to_dict: list[str] = []
239 id_: None | BaseModelId
240 cache_: dict[BaseModelId, Self]
241 to_search: list[str] = []
242 can_create_by_id = False
245 def __init__(self, id_: BaseModelId | None) -> None:
246 if isinstance(id_, int) and id_ < 1:
247 msg = f'illegal {self.__class__.__name__} ID, must be >=1: {id_}'
248 raise HandledException(msg)
249 if isinstance(id_, str) and "" == id_:
250 msg = f'illegal {self.__class__.__name__} ID, must be non-empty'
251 raise HandledException(msg)
254 def __hash__(self) -> int:
255 hashable = [self.id_] + [getattr(self, name) for name in self.to_save]
256 for definition in self.to_save_relations:
257 attr = getattr(self, definition[2])
258 hashable += [tuple(rel.id_ for rel in attr)]
259 for name in self.to_save_versioned:
260 hashable += [hash(getattr(self, name))]
261 return hash(tuple(hashable))
263 def __eq__(self, other: object) -> bool:
264 if not isinstance(other, self.__class__):
266 return hash(self) == hash(other)
268 def __lt__(self, other: Any) -> bool:
269 if not isinstance(other, self.__class__):
270 msg = 'cannot compare to object of different class'
271 raise HandledException(msg)
272 assert isinstance(self.id_, int)
273 assert isinstance(other.id_, int)
274 return self.id_ < other.id_
277 def as_dict(self) -> dict[str, object]:
278 """Return self as (json.dumps-compatible) dict."""
279 library: dict[str, dict[str | int, object]] = {}
280 d: dict[str, object] = {'id': self.id_, '_library': library}
281 for to_save in self.to_save:
282 attr = getattr(self, to_save)
283 if hasattr(attr, 'as_dict_into_reference'):
284 d[to_save] = attr.as_dict_into_reference(library)
287 if len(self.to_save_versioned) > 0:
289 for k in self.to_save_versioned:
290 attr = getattr(self, k)
291 assert isinstance(d['_versioned'], dict)
292 d['_versioned'][k] = attr.history
293 for r in self.to_save_relations:
295 l: list[int | str] = []
296 for rel in getattr(self, attr_name):
297 l += [rel.as_dict_into_reference(library)]
299 for k in self.add_to_dict:
300 d[k] = [x.as_dict_into_reference(library)
301 for x in getattr(self, k)]
304 def as_dict_into_reference(self,
305 library: dict[str, dict[str | int, object]]
307 """Return self.id_ while writing .as_dict into library."""
308 def into_library(library: dict[str, dict[str | int, object]],
313 if cls_name not in library:
314 library[cls_name] = {}
315 if id_ in library[cls_name]:
316 if library[cls_name][id_] != d:
317 msg = 'Unexpected inequality of entries for ' +\
318 f'_library at: {cls_name}/{id_}'
319 raise HandledException(msg)
321 library[cls_name][id_] = d
322 as_dict = self.as_dict
323 assert isinstance(as_dict['_library'], dict)
324 for cls_name, dict_of_objs in as_dict['_library'].items():
325 for id_, obj in dict_of_objs.items():
326 into_library(library, cls_name, id_, obj)
327 del as_dict['_library']
328 assert self.id_ is not None
329 into_library(library, self.__class__.__name__, self.id_, as_dict)
330 assert isinstance(as_dict['id'], (int, str))
334 def name_lowercase(cls) -> str:
335 """Convenience method to return cls' name in lowercase."""
336 return cls.__name__.lower()
339 # (we primarily use the cache to ensure we work on the same object in
340 # memory no matter where and how we retrieve it, e.g. we don't want
341 # .by_id() calls to create a new object each time, but rather a pointer
342 # to the one already instantiated)
344 def __getattribute__(self, name: str) -> Any:
345 """Ensure fail if ._disappear() was called, except to check ._exists"""
346 if name != '_exists' and not super().__getattribute__('_exists'):
347 raise HandledException('Object does not exist.')
348 return super().__getattribute__(name)
350 def _disappear(self) -> None:
351 """Invalidate object, make future use raise exceptions."""
352 assert self.id_ is not None
353 if self._get_cached(self.id_):
355 to_kill = list(self.__dict__.keys())
361 def empty_cache(cls) -> None:
362 """Empty class's cache, and disappear all former inhabitants."""
363 # pylint: disable=protected-access
364 # (cause we remain within the class)
365 if hasattr(cls, 'cache_'):
366 to_disappear = list(cls.cache_.values())
367 for item in to_disappear:
372 def get_cache(cls: type[BaseModelInstance]) -> dict[Any, BaseModel[Any]]:
373 """Get cache dictionary, create it if not yet existing."""
374 if not hasattr(cls, 'cache_'):
375 d: dict[Any, BaseModel[Any]] = {}
380 def _get_cached(cls: type[BaseModelInstance],
381 id_: BaseModelId) -> BaseModelInstance | None:
382 """Get object of id_ from class's cache, or None if not found."""
383 cache = cls.get_cache()
386 assert isinstance(obj, cls)
390 def cache(self) -> None:
391 """Update object in class's cache.
393 Also calls ._disappear if cache holds older reference to object of same
394 ID, but different memory address, to avoid doing anything with
398 raise HandledException('Cannot cache object without ID.')
399 cache = self.get_cache()
400 old_cached = self._get_cached(self.id_)
401 if old_cached and id(old_cached) != id(self):
402 # pylint: disable=protected-access
403 # (cause we remain within the class)
404 old_cached._disappear()
405 cache[self.id_] = self
407 def _uncache(self) -> None:
408 """Remove self from cache."""
410 raise HandledException('Cannot un-cache object without ID.')
411 cache = self.get_cache()
414 # object retrieval and generation
417 def from_table_row(cls: type[BaseModelInstance],
418 # pylint: disable=unused-argument
419 db_conn: DatabaseConnection,
420 row: Row | list[Any]) -> BaseModelInstance:
421 """Make from DB row (sans relations), update DB cache with it."""
423 assert obj.id_ is not None
424 for attr_name in cls.to_save_versioned:
425 attr = getattr(obj, attr_name)
426 table_name = attr.table_name
427 for row_ in db_conn.row_where(table_name, 'parent', obj.id_):
428 attr.history_from_row(row_)
433 def by_id(cls, db_conn: DatabaseConnection, id_: BaseModelId) -> Self:
434 """Retrieve by id_, on failure throw NotFoundException.
436 First try to get from cls.cache_, only then check DB; if found,
441 obj = cls._get_cached(id_)
443 for row in db_conn.row_where(cls.table_name, 'id', id_):
444 obj = cls.from_table_row(db_conn, row)
448 raise NotFoundException(f'found no object of ID {id_}')
451 def by_id_or_create(cls, db_conn: DatabaseConnection,
452 id_: BaseModelId | None
454 """Wrapper around .by_id, creating (not caching/saving) if not find."""
455 if not cls.can_create_by_id:
456 raise HandledException('Class cannot .by_id_or_create.')
460 return cls.by_id(db_conn, id_)
461 except NotFoundException:
465 def all(cls: type[BaseModelInstance],
466 db_conn: DatabaseConnection) -> list[BaseModelInstance]:
467 """Collect all objects of class into list.
469 Note that this primarily returns the contents of the cache, and only
470 _expands_ that by additional findings in the DB. This assumes the
471 cache is always instantly cleaned of any items that would be removed
474 items: dict[BaseModelId, BaseModelInstance] = {}
475 for k, v in cls.get_cache().items():
476 assert isinstance(v, cls)
478 already_recorded = items.keys()
479 for id_ in db_conn.column_all(cls.table_name, 'id'):
480 if id_ not in already_recorded:
481 item = cls.by_id(db_conn, id_)
482 assert item.id_ is not None
483 items[item.id_] = item
484 return list(items.values())
487 def by_date_range_with_limits(cls: type[BaseModelInstance],
488 db_conn: DatabaseConnection,
489 date_range: tuple[str, str],
490 date_col: str = 'day'
491 ) -> tuple[list[BaseModelInstance], str,
493 """Return list of items in database within (open) date_range interval.
495 If no range values provided, defaults them to 'yesterday' and
496 'tomorrow'. Knows to properly interpret these and 'today' as value.
498 start_str = date_range[0] if date_range[0] else 'yesterday'
499 end_str = date_range[1] if date_range[1] else 'tomorrow'
500 start_date = valid_date(start_str)
501 end_date = valid_date(end_str)
503 sql = f'SELECT id FROM {cls.table_name} '
504 sql += f'WHERE {date_col} >= ? AND {date_col} <= ?'
505 for row in db_conn.exec(sql, (start_date, end_date)):
506 items += [cls.by_id(db_conn, row[0])]
507 return items, start_date, end_date
510 def matching(cls: type[BaseModelInstance], db_conn: DatabaseConnection,
511 pattern: str) -> list[BaseModelInstance]:
512 """Return all objects whose .to_search match pattern."""
513 items = cls.all(db_conn)
517 for attr_name in cls.to_search:
518 toks = attr_name.split('.')
521 attr = getattr(parent, tok)
531 def save(self, db_conn: DatabaseConnection) -> None:
532 """Write self to DB and cache and ensure .id_.
534 Write both to DB, and to cache. To DB, write .id_ and attributes
535 listed in cls.to_save[_versioned|_relations].
537 Ensure self.id_ by setting it to what the DB command returns as the
538 last saved row's ID (cursor.lastrowid), EXCEPT if self.id_ already
539 exists as a 'str', which implies we do our own ID creation (so far
540 only the case with the Day class, where it's to be a date string.
542 values = tuple([self.id_] + [getattr(self, key)
543 for key in self.to_save])
544 table_name = self.table_name
545 cursor = db_conn.exec_on_vals(f'REPLACE INTO {table_name} VALUES',
547 if not isinstance(self.id_, str):
548 self.id_ = cursor.lastrowid # type: ignore[assignment]
550 for attr_name in self.to_save_versioned:
551 getattr(self, attr_name).save(db_conn)
552 for table, column, attr_name, key_index in self.to_save_relations:
553 assert isinstance(self.id_, (int, str))
554 db_conn.rewrite_relations(table, column, self.id_,
556 in getattr(self, attr_name)], key_index)
558 def remove(self, db_conn: DatabaseConnection) -> None:
559 """Remove from DB and cache, including dependencies."""
560 if self.id_ is None or self._get_cached(self.id_) is None:
561 raise HandledException('cannot remove unsaved item')
562 for attr_name in self.to_save_versioned:
563 getattr(self, attr_name).remove(db_conn)
564 for table, column, attr_name, _ in self.to_save_relations:
565 db_conn.delete_where(table, column, self.id_)
567 db_conn.delete_where(self.table_name, 'id', self.id_)