1 """Database management."""
2 from __future__ import annotations
4 from os.path import isfile
5 from difflib import Differ
6 from sqlite3 import connect as sql_connect, Cursor, Row
7 from typing import Any, Self, TypeVar, Generic
8 from plomtask.exceptions import HandledException, NotFoundException
9 from plomtask.dating import valid_date
11 EXPECTED_DB_VERSION = 5
12 MIGRATIONS_DIR = 'migrations'
13 FILENAME_DB_SCHEMA = f'init_{EXPECTED_DB_VERSION}.sql'
14 PATH_DB_SCHEMA = f'{MIGRATIONS_DIR}/{FILENAME_DB_SCHEMA}'
17 class UnmigratedDbException(HandledException):
18 """To identify case of unmigrated DB file."""
21 class DatabaseFile: # pylint: disable=too-few-public-methods
22 """Represents the sqlite3 database's file."""
24 def __init__(self, path: str) -> None:
29 def create_at(cls, path: str) -> DatabaseFile:
30 """Make new DB file at path."""
31 with sql_connect(path) as conn:
32 with open(PATH_DB_SCHEMA, 'r', encoding='utf-8') as f:
33 conn.executescript(f.read())
34 conn.execute(f'PRAGMA user_version = {EXPECTED_DB_VERSION}')
38 def migrate(cls, path: str) -> DatabaseFile:
39 """Apply migrations from_version to EXPECTED_DB_VERSION."""
40 migrations = cls._available_migrations()
41 from_version = cls.get_version_of_db(path)
42 migrations_todo = migrations[from_version+1:]
43 for j, filename in enumerate(migrations_todo):
44 with sql_connect(path) as conn:
45 with open(f'{MIGRATIONS_DIR}/{filename}', 'r',
46 encoding='utf-8') as f:
47 conn.executescript(f.read())
48 user_version = from_version + j + 1
49 with sql_connect(path) as conn:
50 conn.execute(f'PRAGMA user_version = {user_version}')
53 def _check(self) -> None:
54 """Check file exists, and is of proper DB version and schema."""
55 if not isfile(self.path):
56 raise NotFoundException
57 if self.user_version != EXPECTED_DB_VERSION:
58 raise UnmigratedDbException()
59 self._validate_schema()
62 def _available_migrations() -> list[str]:
63 """Validate migrations directory and return sorted entries."""
64 msg_too_big = 'Migration directory points beyond expected DB version.'
65 msg_bad_entry = 'Migration directory contains unexpected entry: '
66 msg_missing = 'Migration directory misses migration of number: '
68 for entry in listdir(MIGRATIONS_DIR):
69 if entry == FILENAME_DB_SCHEMA:
71 toks = entry.split('_', 1)
73 raise HandledException(msg_bad_entry + entry)
76 except ValueError as e:
77 raise HandledException(msg_bad_entry + entry) from e
78 if i > EXPECTED_DB_VERSION:
79 raise HandledException(msg_too_big)
80 migrations[i] = toks[1]
82 for i in range(EXPECTED_DB_VERSION + 1):
83 if i not in migrations:
84 raise HandledException(msg_missing + str(i))
85 migrations_list += [f'{i}_{migrations[i]}']
86 return migrations_list
89 def get_version_of_db(path: str) -> int:
90 """Get DB user_version, fail if outside expected range."""
91 sql_for_db_version = 'PRAGMA user_version'
92 with sql_connect(path) as conn:
93 db_version = list(conn.execute(sql_for_db_version))[0][0]
94 if db_version > EXPECTED_DB_VERSION:
95 msg = f'Wrong DB version, expected '\
96 f'{EXPECTED_DB_VERSION}, got unknown {db_version}.'
97 raise HandledException(msg)
98 assert isinstance(db_version, int)
102 def user_version(self) -> int:
103 """Get DB user_version."""
104 return self.__class__.get_version_of_db(self.path)
106 def _validate_schema(self) -> None:
107 """Compare found schema with what's stored at PATH_DB_SCHEMA."""
109 def reformat_rows(rows: list[str]) -> list[str]:
113 for subrow in row.split('\n'):
114 subrow = subrow.rstrip()
117 for i, c in enumerate(subrow):
122 elif ',' == c and 0 == in_parentheses:
126 segment = subrow[prev_split:i].strip()
128 new_row += [f' {segment}']
130 segment = subrow[prev_split:].strip()
132 new_row += [f' {segment}']
133 new_row[0] = new_row[0].lstrip()
134 new_row[-1] = new_row[-1].lstrip()
135 if new_row[-1] != ')' and new_row[-3][-1] != ',':
136 new_row[-3] = new_row[-3] + ','
137 new_row[-2:] = [' ' + new_row[-1][:-1]] + [')']
138 new_rows += ['\n'.join(new_row)]
141 sql_for_schema = 'SELECT sql FROM sqlite_master ORDER BY sql'
142 msg_err = 'Database has wrong tables schema. Diff:\n'
143 with sql_connect(self.path) as conn:
144 schema_rows = [r[0] for r in conn.execute(sql_for_schema) if r[0]]
145 schema_rows = reformat_rows(schema_rows)
146 retrieved_schema = ';\n'.join(schema_rows) + ';'
147 with open(PATH_DB_SCHEMA, 'r', encoding='utf-8') as f:
148 stored_schema = f.read().rstrip()
149 if stored_schema != retrieved_schema:
150 diff_msg = Differ().compare(retrieved_schema.splitlines(),
151 stored_schema.splitlines())
152 raise HandledException(msg_err + '\n'.join(diff_msg))
155 class DatabaseConnection:
156 """A single connection to the database."""
158 def __init__(self, db_file: DatabaseFile) -> None:
160 self.conn = sql_connect(self.file.path)
162 def commit(self) -> None:
163 """Commit SQL transaction."""
166 def exec(self, code: str, inputs: tuple[Any, ...] = tuple()) -> Cursor:
167 """Add commands to SQL transaction."""
168 return self.conn.execute(code, inputs)
170 def close(self) -> None:
171 """Close DB connection."""
174 def rewrite_relations(self, table_name: str, key: str, target: int | str,
175 rows: list[list[Any]], key_index: int = 0) -> None:
176 # pylint: disable=too-many-arguments
177 """Rewrite relations in table_name to target, with rows values.
179 Note that single rows are expected without the column and value
180 identified by key and target, which are inserted inside the function
183 self.delete_where(table_name, key, target)
185 values = tuple(row[:key_index] + [target] + row[key_index:])
186 q_marks = self.__class__.q_marks_from_values(values)
187 self.exec(f'INSERT INTO {table_name} VALUES {q_marks}', values)
189 def row_where(self, table_name: str, key: str,
190 target: int | str) -> list[Row]:
191 """Return list of Rows at table where key == target."""
192 return list(self.exec(f'SELECT * FROM {table_name} WHERE {key} = ?',
195 # def column_where_pattern(self,
199 # keys: list[str]) -> list[Any]:
200 # """Return column of rows where one of keys matches pattern."""
201 # targets = tuple([f'%{pattern}%'] * len(keys))
202 # haystack = ' OR '.join([f'{k} LIKE ?' for k in keys])
203 # sql = f'SELECT {column} FROM {table_name} WHERE {haystack}'
204 # return [row[0] for row in self.exec(sql, targets)]
206 def column_where(self, table_name: str, column: str, key: str,
207 target: int | str) -> list[Any]:
208 """Return column of table where key == target."""
209 return [row[0] for row in
210 self.exec(f'SELECT {column} FROM {table_name} '
211 f'WHERE {key} = ?', (target,))]
213 def column_all(self, table_name: str, column: str) -> list[Any]:
214 """Return complete column of table."""
215 return [row[0] for row in
216 self.exec(f'SELECT {column} FROM {table_name}')]
218 def delete_where(self, table_name: str, key: str,
219 target: int | str) -> None:
220 """Delete from table where key == target."""
221 self.exec(f'DELETE FROM {table_name} WHERE {key} = ?', (target,))
224 def q_marks_from_values(values: tuple[Any]) -> str:
225 """Return placeholder to insert values into SQL code."""
226 return '(' + ','.join(['?'] * len(values)) + ')'
229 BaseModelId = TypeVar('BaseModelId', int, str)
230 BaseModelInstance = TypeVar('BaseModelInstance', bound='BaseModel[Any]')
233 class BaseModel(Generic[BaseModelId]):
234 """Template for most of the models we use/derive from the DB."""
236 to_save: list[str] = []
237 to_save_versioned: list[str] = []
238 to_save_relations: list[tuple[str, str, str, int]] = []
239 id_: None | BaseModelId
240 cache_: dict[BaseModelId, Self]
241 to_search: list[str] = []
243 def __init__(self, id_: BaseModelId | None) -> None:
244 if isinstance(id_, int) and id_ < 1:
245 msg = f'illegal {self.__class__.__name__} ID, must be >=1: {id_}'
246 raise HandledException(msg)
249 def __eq__(self, other: object) -> bool:
250 if not isinstance(other, self.__class__):
252 to_hash_me = tuple([self.id_] +
253 [getattr(self, name) for name in self.to_save])
254 to_hash_other = tuple([other.id_] +
255 [getattr(other, name) for name in other.to_save])
256 return hash(to_hash_me) == hash(to_hash_other)
258 def __lt__(self, other: Any) -> bool:
259 if not isinstance(other, self.__class__):
260 msg = 'cannot compare to object of different class'
261 raise HandledException(msg)
262 assert isinstance(self.id_, int)
263 assert isinstance(other.id_, int)
264 return self.id_ < other.id_
267 def get_cached(cls: type[BaseModelInstance],
268 id_: BaseModelId) -> BaseModelInstance | None:
269 """Get object of id_ from class's cache, or None if not found."""
270 # pylint: disable=consider-iterating-dictionary
271 cache = cls.get_cache()
272 if id_ in cache.keys():
274 assert isinstance(obj, cls)
279 def empty_cache(cls) -> None:
280 """Empty class's cache."""
284 def get_cache(cls: type[BaseModelInstance]) -> dict[Any, BaseModel[Any]]:
285 """Get cache dictionary, create it if not yet existing."""
286 if not hasattr(cls, 'cache_'):
287 d: dict[Any, BaseModel[Any]] = {}
291 def cache(self) -> None:
292 """Update object in class's cache."""
294 raise HandledException('Cannot cache object without ID.')
295 cache = self.__class__.get_cache()
296 cache[self.id_] = self
298 def uncache(self) -> None:
299 """Remove self from cache."""
301 raise HandledException('Cannot un-cache object without ID.')
302 cache = self.__class__.get_cache()
306 def from_table_row(cls: type[BaseModelInstance],
307 # pylint: disable=unused-argument
308 db_conn: DatabaseConnection,
309 row: Row | list[Any]) -> BaseModelInstance:
310 """Make from DB row, write to DB cache."""
316 def by_id(cls, db_conn: DatabaseConnection,
317 id_: BaseModelId | None,
318 # pylint: disable=unused-argument
319 create: bool = False) -> Self:
320 """Retrieve by id_, on failure throw NotFoundException.
322 First try to get from cls.cache_, only then check DB; if found,
325 If create=True, make anew (but do not cache yet).
329 obj = cls.get_cached(id_)
331 for row in db_conn.row_where(cls.table_name, 'id', id_):
332 obj = cls.from_table_row(db_conn, row)
340 raise NotFoundException(f'found no object of ID {id_}')
343 def all(cls: type[BaseModelInstance],
344 db_conn: DatabaseConnection) -> list[BaseModelInstance]:
345 """Collect all objects of class into list.
347 Note that this primarily returns the contents of the cache, and only
348 _expands_ that by additional findings in the DB. This assumes the
349 cache is always instantly cleaned of any items that would be removed
352 items: dict[BaseModelId, BaseModelInstance] = {}
353 for k, v in cls.get_cache().items():
354 assert isinstance(v, cls)
356 already_recorded = items.keys()
357 for id_ in db_conn.column_all(cls.table_name, 'id'):
358 if id_ not in already_recorded:
359 item = cls.by_id(db_conn, id_)
360 assert item.id_ is not None
361 items[item.id_] = item
362 return list(items.values())
365 def by_date_range_with_limits(cls: type[BaseModelInstance],
366 db_conn: DatabaseConnection,
367 date_range: tuple[str, str],
368 date_col: str = 'day'
369 ) -> tuple[list[BaseModelInstance], str,
371 """Return list of items in database within (open) date_range interval.
373 If no range values provided, defaults them to 'yesterday' and
374 'tomorrow'. Knows to properly interpret these and 'today' as value.
376 start_str = date_range[0] if date_range[0] else 'yesterday'
377 end_str = date_range[1] if date_range[1] else 'tomorrow'
378 start_date = valid_date(start_str)
379 end_date = valid_date(end_str)
381 sql = f'SELECT id FROM {cls.table_name} '
382 sql += f'WHERE {date_col} >= ? AND {date_col} <= ?'
383 for row in db_conn.exec(sql, (start_date, end_date)):
384 items += [cls.by_id(db_conn, row[0])]
385 return items, start_date, end_date
388 def matching(cls: type[BaseModelInstance], db_conn: DatabaseConnection,
389 pattern: str) -> list[BaseModelInstance]:
390 """Return all objects whose .to_search match pattern."""
391 items = cls.all(db_conn)
395 for attr_name in cls.to_search:
396 toks = attr_name.split('.')
399 attr = getattr(parent, tok)
407 def save(self, db_conn: DatabaseConnection) -> None:
408 """Write self to DB and cache and ensure .id_.
410 Write both to DB, and to cache. To DB, write .id_ and attributes
411 listed in cls.to_save[_versioned|_relations].
413 Ensure self.id_ by setting it to what the DB command returns as the
414 last saved row's ID (cursor.lastrowid), EXCEPT if self.id_ already
415 exists as a 'str', which implies we do our own ID creation (so far
416 only the case with the Day class, where it's to be a date string.
418 values = tuple([self.id_] + [getattr(self, key)
419 for key in self.to_save])
420 q_marks = DatabaseConnection.q_marks_from_values(values)
421 table_name = self.table_name
422 cursor = db_conn.exec(f'REPLACE INTO {table_name} VALUES {q_marks}',
424 if not isinstance(self.id_, str):
425 self.id_ = cursor.lastrowid # type: ignore[assignment]
427 for attr_name in self.to_save_versioned:
428 getattr(self, attr_name).save(db_conn)
429 for table, column, attr_name, key_index in self.to_save_relations:
430 assert isinstance(self.id_, (int, str))
431 db_conn.rewrite_relations(table, column, self.id_,
433 in getattr(self, attr_name)], key_index)
435 def remove(self, db_conn: DatabaseConnection) -> None:
436 """Remove from DB and cache, including dependencies."""
437 if self.id_ is None or self.__class__.get_cached(self.id_) is None:
438 raise HandledException('cannot remove unsaved item')
439 for attr_name in self.to_save_versioned:
440 getattr(self, attr_name).remove(db_conn)
441 for table, column, attr_name, _ in self.to_save_relations:
442 db_conn.delete_where(table, column, self.id_)
444 db_conn.delete_where(self.table_name, 'id', self.id_)