1 """Database management."""
2 from __future__ import annotations
4 from os.path import isfile
5 from difflib import Differ
6 from sqlite3 import connect as sql_connect, Cursor, Row
7 from typing import Any, Self, TypeVar, Generic
8 from plomtask.exceptions import HandledException, NotFoundException
9 from plomtask.dating import valid_date
11 EXPECTED_DB_VERSION = 5
12 MIGRATIONS_DIR = 'migrations'
13 FILENAME_DB_SCHEMA = f'init_{EXPECTED_DB_VERSION}.sql'
14 PATH_DB_SCHEMA = f'{MIGRATIONS_DIR}/{FILENAME_DB_SCHEMA}'
17 class UnmigratedDbException(HandledException):
18 """To identify case of unmigrated DB file."""
22 """Represents the sqlite3 database's file."""
23 # pylint: disable=too-few-public-methods
25 def __init__(self, path: str) -> None:
30 def create_at(cls, path: str) -> DatabaseFile:
31 """Make new DB file at path."""
32 with sql_connect(path) as conn:
33 with open(PATH_DB_SCHEMA, 'r', encoding='utf-8') as f:
34 conn.executescript(f.read())
35 conn.execute(f'PRAGMA user_version = {EXPECTED_DB_VERSION}')
39 def migrate(cls, path: str) -> DatabaseFile:
40 """Apply migrations from_version to EXPECTED_DB_VERSION."""
41 migrations = cls._available_migrations()
42 from_version = cls._get_version_of_db(path)
43 migrations_todo = migrations[from_version+1:]
44 for j, filename in enumerate(migrations_todo):
45 with sql_connect(path) as conn:
46 with open(f'{MIGRATIONS_DIR}/{filename}', 'r',
47 encoding='utf-8') as f:
48 conn.executescript(f.read())
49 user_version = from_version + j + 1
50 with sql_connect(path) as conn:
51 conn.execute(f'PRAGMA user_version = {user_version}')
54 def _check(self) -> None:
55 """Check file exists, and is of proper DB version and schema."""
56 if not isfile(self.path):
57 raise NotFoundException
58 if self._user_version != EXPECTED_DB_VERSION:
59 raise UnmigratedDbException()
60 self._validate_schema()
63 def _available_migrations() -> list[str]:
64 """Validate migrations directory and return sorted entries."""
65 msg_too_big = 'Migration directory points beyond expected DB version.'
66 msg_bad_entry = 'Migration directory contains unexpected entry: '
67 msg_missing = 'Migration directory misses migration of number: '
69 for entry in listdir(MIGRATIONS_DIR):
70 if entry == FILENAME_DB_SCHEMA:
72 toks = entry.split('_', 1)
74 raise HandledException(msg_bad_entry + entry)
77 except ValueError as e:
78 raise HandledException(msg_bad_entry + entry) from e
79 if i > EXPECTED_DB_VERSION:
80 raise HandledException(msg_too_big)
81 migrations[i] = toks[1]
83 for i in range(EXPECTED_DB_VERSION + 1):
84 if i not in migrations:
85 raise HandledException(msg_missing + str(i))
86 migrations_list += [f'{i}_{migrations[i]}']
87 return migrations_list
90 def _get_version_of_db(path: str) -> int:
91 """Get DB user_version, fail if outside expected range."""
92 sql_for_db_version = 'PRAGMA user_version'
93 with sql_connect(path) as conn:
94 db_version = list(conn.execute(sql_for_db_version))[0][0]
95 if db_version > EXPECTED_DB_VERSION:
96 msg = f'Wrong DB version, expected '\
97 f'{EXPECTED_DB_VERSION}, got unknown {db_version}.'
98 raise HandledException(msg)
99 assert isinstance(db_version, int)
103 def _user_version(self) -> int:
104 """Get DB user_version."""
105 return self._get_version_of_db(self.path)
107 def _validate_schema(self) -> None:
108 """Compare found schema with what's stored at PATH_DB_SCHEMA."""
110 def reformat_rows(rows: list[str]) -> list[str]:
114 for subrow in row.split('\n'):
115 subrow = subrow.rstrip()
118 for i, c in enumerate(subrow):
123 elif ',' == c and 0 == in_parentheses:
127 segment = subrow[prev_split:i].strip()
129 new_row += [f' {segment}']
131 segment = subrow[prev_split:].strip()
133 new_row += [f' {segment}']
134 new_row[0] = new_row[0].lstrip()
135 new_row[-1] = new_row[-1].lstrip()
136 if new_row[-1] != ')' and new_row[-3][-1] != ',':
137 new_row[-3] = new_row[-3] + ','
138 new_row[-2:] = [' ' + new_row[-1][:-1]] + [')']
139 new_rows += ['\n'.join(new_row)]
142 sql_for_schema = 'SELECT sql FROM sqlite_master ORDER BY sql'
143 msg_err = 'Database has wrong tables schema. Diff:\n'
144 with sql_connect(self.path) as conn:
145 schema_rows = [r[0] for r in conn.execute(sql_for_schema) if r[0]]
146 schema_rows = reformat_rows(schema_rows)
147 retrieved_schema = ';\n'.join(schema_rows) + ';'
148 with open(PATH_DB_SCHEMA, 'r', encoding='utf-8') as f:
149 stored_schema = f.read().rstrip()
150 if stored_schema != retrieved_schema:
151 diff_msg = Differ().compare(retrieved_schema.splitlines(),
152 stored_schema.splitlines())
153 raise HandledException(msg_err + '\n'.join(diff_msg))
156 class DatabaseConnection:
157 """A single connection to the database."""
159 def __init__(self, db_file: DatabaseFile) -> None:
160 self.conn = sql_connect(db_file.path)
162 def commit(self) -> None:
163 """Commit SQL transaction."""
166 def exec(self, code: str, inputs: tuple[Any, ...] = tuple()) -> Cursor:
167 """Add commands to SQL transaction."""
168 return self.conn.execute(code, inputs)
170 def exec_on_vals(self, code: str, inputs: tuple[Any, ...]) -> Cursor:
171 """Wrapper around .exec appending adequate " (?, …)" to code."""
172 q_marks_from_values = '(' + ','.join(['?'] * len(inputs)) + ')'
173 return self.exec(f'{code} {q_marks_from_values}', inputs)
175 def close(self) -> None:
176 """Close DB connection."""
179 def rewrite_relations(self, table_name: str, key: str, target: int | str,
180 rows: list[list[Any]], key_index: int = 0) -> None:
181 # pylint: disable=too-many-arguments
182 """Rewrite relations in table_name to target, with rows values.
184 Note that single rows are expected without the column and value
185 identified by key and target, which are inserted inside the function
188 self.delete_where(table_name, key, target)
190 values = tuple(row[:key_index] + [target] + row[key_index:])
191 self.exec_on_vals(f'INSERT INTO {table_name} VALUES', values)
193 def row_where(self, table_name: str, key: str,
194 target: int | str) -> list[Row]:
195 """Return list of Rows at table where key == target."""
196 return list(self.exec(f'SELECT * FROM {table_name} WHERE {key} = ?',
199 # def column_where_pattern(self,
203 # keys: list[str]) -> list[Any]:
204 # """Return column of rows where one of keys matches pattern."""
205 # targets = tuple([f'%{pattern}%'] * len(keys))
206 # haystack = ' OR '.join([f'{k} LIKE ?' for k in keys])
207 # sql = f'SELECT {column} FROM {table_name} WHERE {haystack}'
208 # return [row[0] for row in self.exec(sql, targets)]
210 def column_where(self, table_name: str, column: str, key: str,
211 target: int | str) -> list[Any]:
212 """Return column of table where key == target."""
213 return [row[0] for row in
214 self.exec(f'SELECT {column} FROM {table_name} '
215 f'WHERE {key} = ?', (target,))]
217 def column_all(self, table_name: str, column: str) -> list[Any]:
218 """Return complete column of table."""
219 return [row[0] for row in
220 self.exec(f'SELECT {column} FROM {table_name}')]
222 def delete_where(self, table_name: str, key: str,
223 target: int | str) -> None:
224 """Delete from table where key == target."""
225 self.exec(f'DELETE FROM {table_name} WHERE {key} = ?', (target,))
228 BaseModelId = TypeVar('BaseModelId', int, str)
229 BaseModelInstance = TypeVar('BaseModelInstance', bound='BaseModel[Any]')
232 class BaseModel(Generic[BaseModelId]):
233 """Template for most of the models we use/derive from the DB."""
235 to_save: list[str] = []
236 to_save_versioned: list[str] = []
237 to_save_relations: list[tuple[str, str, str, int]] = []
238 id_: None | BaseModelId
239 cache_: dict[BaseModelId, Self]
240 to_search: list[str] = []
243 def __init__(self, id_: BaseModelId | None) -> None:
244 if isinstance(id_, int) and id_ < 1:
245 msg = f'illegal {self.__class__.__name__} ID, must be >=1: {id_}'
246 raise HandledException(msg)
247 if isinstance(id_, str) and "" == id_:
248 msg = f'illegal {self.__class__.__name__} ID, must be non-empty'
249 raise HandledException(msg)
252 def __hash__(self) -> int:
253 hashable = [self.id_] + [getattr(self, name) for name in self.to_save]
254 for definition in self.to_save_relations:
255 attr = getattr(self, definition[2])
256 hashable += [tuple(rel.id_ for rel in attr)]
257 for name in self.to_save_versioned:
258 hashable += [hash(getattr(self, name))]
259 return hash(tuple(hashable))
261 def __eq__(self, other: object) -> bool:
262 if not isinstance(other, self.__class__):
264 return hash(self) == hash(other)
266 def __lt__(self, other: Any) -> bool:
267 if not isinstance(other, self.__class__):
268 msg = 'cannot compare to object of different class'
269 raise HandledException(msg)
270 assert isinstance(self.id_, int)
271 assert isinstance(other.id_, int)
272 return self.id_ < other.id_
275 def as_dict(self) -> dict[str, object]:
276 """Return self as (json.dumps-coompatible) dict."""
277 d: dict[str, object] = {'id': self.id_}
278 if len(self.to_save_versioned) > 0:
280 for k in self.to_save:
281 attr = getattr(self, k)
282 if hasattr(attr, 'as_dict'):
285 for k in self.to_save_versioned:
286 attr = getattr(self, k)
287 assert isinstance(d['_versioned'], dict)
288 d['_versioned'][k] = attr.history
289 for r in self.to_save_relations:
291 d[attr_name] = [x.as_dict for x in getattr(self, attr_name)]
295 # (we primarily use the cache to ensure we work on the same object in
296 # memory no matter where and how we retrieve it, e.g. we don't want
297 # .by_id() calls to create a new object each time, but rather a pointer
298 # to the one already instantiated)
300 def __getattribute__(self, name: str) -> Any:
301 """Ensure fail if ._disappear() was called, except to check ._exists"""
302 if name != '_exists' and not super().__getattribute__('_exists'):
303 raise HandledException('Object does not exist.')
304 return super().__getattribute__(name)
306 def _disappear(self) -> None:
307 """Invalidate object, make future use raise exceptions."""
308 assert self.id_ is not None
309 if self._get_cached(self.id_):
311 to_kill = list(self.__dict__.keys())
317 def empty_cache(cls) -> None:
318 """Empty class's cache."""
322 def get_cache(cls: type[BaseModelInstance]) -> dict[Any, BaseModel[Any]]:
323 """Get cache dictionary, create it if not yet existing."""
324 if not hasattr(cls, 'cache_'):
325 d: dict[Any, BaseModel[Any]] = {}
330 def _get_cached(cls: type[BaseModelInstance],
331 id_: BaseModelId) -> BaseModelInstance | None:
332 """Get object of id_ from class's cache, or None if not found."""
333 # pylint: disable=consider-iterating-dictionary
334 cache = cls.get_cache()
335 if id_ in cache.keys():
337 assert isinstance(obj, cls)
341 def _cache(self) -> None:
342 """Update object in class's cache.
344 Also calls ._disappear if cache holds older reference to object of same
345 ID, but different memory address, to avoid doing anything with
349 raise HandledException('Cannot cache object without ID.')
350 cache = self.get_cache()
351 old_cached = self._get_cached(self.id_)
352 if old_cached and id(old_cached) != id(self):
353 # pylint: disable=protected-access
354 # (cause we remain within the class)
355 old_cached._disappear()
356 cache[self.id_] = self
358 def _uncache(self) -> None:
359 """Remove self from cache."""
361 raise HandledException('Cannot un-cache object without ID.')
362 cache = self.get_cache()
365 # object retrieval and generation
368 def from_table_row(cls: type[BaseModelInstance],
369 # pylint: disable=unused-argument
370 db_conn: DatabaseConnection,
371 row: Row | list[Any]) -> BaseModelInstance:
372 """Make from DB row (sans relations), update DB cache with it."""
374 assert obj.id_ is not None
375 for attr_name in cls.to_save_versioned:
376 attr = getattr(obj, attr_name)
377 table_name = attr.table_name
378 for row_ in db_conn.row_where(table_name, 'parent', obj.id_):
379 attr.history_from_row(row_)
384 def by_id(cls, db_conn: DatabaseConnection,
385 id_: BaseModelId | None,
386 # pylint: disable=unused-argument
387 create: bool = False) -> Self:
388 """Retrieve by id_, on failure throw NotFoundException.
390 First try to get from cls.cache_, only then check DB; if found,
393 If create=True, make anew (but do not cache yet).
397 obj = cls._get_cached(id_)
399 for row in db_conn.row_where(cls.table_name, 'id', id_):
400 obj = cls.from_table_row(db_conn, row)
407 raise NotFoundException(f'found no object of ID {id_}')
410 def all(cls: type[BaseModelInstance],
411 db_conn: DatabaseConnection) -> list[BaseModelInstance]:
412 """Collect all objects of class into list.
414 Note that this primarily returns the contents of the cache, and only
415 _expands_ that by additional findings in the DB. This assumes the
416 cache is always instantly cleaned of any items that would be removed
419 items: dict[BaseModelId, BaseModelInstance] = {}
420 for k, v in cls.get_cache().items():
421 assert isinstance(v, cls)
423 already_recorded = items.keys()
424 for id_ in db_conn.column_all(cls.table_name, 'id'):
425 if id_ not in already_recorded:
426 item = cls.by_id(db_conn, id_)
427 assert item.id_ is not None
428 items[item.id_] = item
429 return list(items.values())
432 def by_date_range_with_limits(cls: type[BaseModelInstance],
433 db_conn: DatabaseConnection,
434 date_range: tuple[str, str],
435 date_col: str = 'day'
436 ) -> tuple[list[BaseModelInstance], str,
438 """Return list of items in database within (open) date_range interval.
440 If no range values provided, defaults them to 'yesterday' and
441 'tomorrow'. Knows to properly interpret these and 'today' as value.
443 start_str = date_range[0] if date_range[0] else 'yesterday'
444 end_str = date_range[1] if date_range[1] else 'tomorrow'
445 start_date = valid_date(start_str)
446 end_date = valid_date(end_str)
448 sql = f'SELECT id FROM {cls.table_name} '
449 sql += f'WHERE {date_col} >= ? AND {date_col} <= ?'
450 for row in db_conn.exec(sql, (start_date, end_date)):
451 items += [cls.by_id(db_conn, row[0])]
452 return items, start_date, end_date
455 def matching(cls: type[BaseModelInstance], db_conn: DatabaseConnection,
456 pattern: str) -> list[BaseModelInstance]:
457 """Return all objects whose .to_search match pattern."""
458 items = cls.all(db_conn)
462 for attr_name in cls.to_search:
463 toks = attr_name.split('.')
466 attr = getattr(parent, tok)
476 def save(self, db_conn: DatabaseConnection) -> None:
477 """Write self to DB and cache and ensure .id_.
479 Write both to DB, and to cache. To DB, write .id_ and attributes
480 listed in cls.to_save[_versioned|_relations].
482 Ensure self.id_ by setting it to what the DB command returns as the
483 last saved row's ID (cursor.lastrowid), EXCEPT if self.id_ already
484 exists as a 'str', which implies we do our own ID creation (so far
485 only the case with the Day class, where it's to be a date string.
487 values = tuple([self.id_] + [getattr(self, key)
488 for key in self.to_save])
489 table_name = self.table_name
490 cursor = db_conn.exec_on_vals(f'REPLACE INTO {table_name} VALUES',
492 if not isinstance(self.id_, str):
493 self.id_ = cursor.lastrowid # type: ignore[assignment]
495 for attr_name in self.to_save_versioned:
496 getattr(self, attr_name).save(db_conn)
497 for table, column, attr_name, key_index in self.to_save_relations:
498 assert isinstance(self.id_, (int, str))
499 db_conn.rewrite_relations(table, column, self.id_,
501 in getattr(self, attr_name)], key_index)
503 def remove(self, db_conn: DatabaseConnection) -> None:
504 """Remove from DB and cache, including dependencies."""
505 if self.id_ is None or self._get_cached(self.id_) is None:
506 raise HandledException('cannot remove unsaved item')
507 for attr_name in self.to_save_versioned:
508 getattr(self, attr_name).remove(db_conn)
509 for table, column, attr_name, _ in self.to_save_relations:
510 db_conn.delete_where(table, column, self.id_)
512 db_conn.delete_where(self.table_name, 'id', self.id_)