1 """Database management."""
2 from __future__ import annotations
4 from os.path import isfile
5 from difflib import Differ
6 from sqlite3 import connect as sql_connect, Cursor, Row
7 from typing import Any, Self, TypeVar, Generic
8 from plomtask.exceptions import HandledException, NotFoundException
9 from plomtask.dating import valid_date
11 EXPECTED_DB_VERSION = 5
12 MIGRATIONS_DIR = 'migrations'
13 FILENAME_DB_SCHEMA = f'init_{EXPECTED_DB_VERSION}.sql'
14 PATH_DB_SCHEMA = f'{MIGRATIONS_DIR}/{FILENAME_DB_SCHEMA}'
17 class UnmigratedDbException(HandledException):
18 """To identify case of unmigrated DB file."""
22 """Represents the sqlite3 database's file."""
23 # pylint: disable=too-few-public-methods
25 def __init__(self, path: str) -> None:
30 def create_at(cls, path: str) -> DatabaseFile:
31 """Make new DB file at path."""
32 with sql_connect(path) as conn:
33 with open(PATH_DB_SCHEMA, 'r', encoding='utf-8') as f:
34 conn.executescript(f.read())
35 conn.execute(f'PRAGMA user_version = {EXPECTED_DB_VERSION}')
39 def migrate(cls, path: str) -> DatabaseFile:
40 """Apply migrations from_version to EXPECTED_DB_VERSION."""
41 migrations = cls._available_migrations()
42 from_version = cls._get_version_of_db(path)
43 migrations_todo = migrations[from_version+1:]
44 for j, filename in enumerate(migrations_todo):
45 with sql_connect(path) as conn:
46 with open(f'{MIGRATIONS_DIR}/{filename}', 'r',
47 encoding='utf-8') as f:
48 conn.executescript(f.read())
49 user_version = from_version + j + 1
50 with sql_connect(path) as conn:
51 conn.execute(f'PRAGMA user_version = {user_version}')
54 def _check(self) -> None:
55 """Check file exists, and is of proper DB version and schema."""
56 if not isfile(self.path):
57 raise NotFoundException
58 if self._user_version != EXPECTED_DB_VERSION:
59 raise UnmigratedDbException()
60 self._validate_schema()
63 def _available_migrations() -> list[str]:
64 """Validate migrations directory and return sorted entries."""
65 msg_too_big = 'Migration directory points beyond expected DB version.'
66 msg_bad_entry = 'Migration directory contains unexpected entry: '
67 msg_missing = 'Migration directory misses migration of number: '
69 for entry in listdir(MIGRATIONS_DIR):
70 if entry == FILENAME_DB_SCHEMA:
72 toks = entry.split('_', 1)
74 raise HandledException(msg_bad_entry + entry)
77 except ValueError as e:
78 raise HandledException(msg_bad_entry + entry) from e
79 if i > EXPECTED_DB_VERSION:
80 raise HandledException(msg_too_big)
81 migrations[i] = toks[1]
83 for i in range(EXPECTED_DB_VERSION + 1):
84 if i not in migrations:
85 raise HandledException(msg_missing + str(i))
86 migrations_list += [f'{i}_{migrations[i]}']
87 return migrations_list
90 def _get_version_of_db(path: str) -> int:
91 """Get DB user_version, fail if outside expected range."""
92 sql_for_db_version = 'PRAGMA user_version'
93 with sql_connect(path) as conn:
94 db_version = list(conn.execute(sql_for_db_version))[0][0]
95 if db_version > EXPECTED_DB_VERSION:
96 msg = f'Wrong DB version, expected '\
97 f'{EXPECTED_DB_VERSION}, got unknown {db_version}.'
98 raise HandledException(msg)
99 assert isinstance(db_version, int)
103 def _user_version(self) -> int:
104 """Get DB user_version."""
105 # pylint: disable=protected-access
106 # (since we remain within class)
107 return self.__class__._get_version_of_db(self.path)
109 def _validate_schema(self) -> None:
110 """Compare found schema with what's stored at PATH_DB_SCHEMA."""
112 def reformat_rows(rows: list[str]) -> list[str]:
116 for subrow in row.split('\n'):
117 subrow = subrow.rstrip()
120 for i, c in enumerate(subrow):
125 elif ',' == c and 0 == in_parentheses:
129 segment = subrow[prev_split:i].strip()
131 new_row += [f' {segment}']
133 segment = subrow[prev_split:].strip()
135 new_row += [f' {segment}']
136 new_row[0] = new_row[0].lstrip()
137 new_row[-1] = new_row[-1].lstrip()
138 if new_row[-1] != ')' and new_row[-3][-1] != ',':
139 new_row[-3] = new_row[-3] + ','
140 new_row[-2:] = [' ' + new_row[-1][:-1]] + [')']
141 new_rows += ['\n'.join(new_row)]
144 sql_for_schema = 'SELECT sql FROM sqlite_master ORDER BY sql'
145 msg_err = 'Database has wrong tables schema. Diff:\n'
146 with sql_connect(self.path) as conn:
147 schema_rows = [r[0] for r in conn.execute(sql_for_schema) if r[0]]
148 schema_rows = reformat_rows(schema_rows)
149 retrieved_schema = ';\n'.join(schema_rows) + ';'
150 with open(PATH_DB_SCHEMA, 'r', encoding='utf-8') as f:
151 stored_schema = f.read().rstrip()
152 if stored_schema != retrieved_schema:
153 diff_msg = Differ().compare(retrieved_schema.splitlines(),
154 stored_schema.splitlines())
155 raise HandledException(msg_err + '\n'.join(diff_msg))
158 class DatabaseConnection:
159 """A single connection to the database."""
161 def __init__(self, db_file: DatabaseFile) -> None:
162 self.conn = sql_connect(db_file.path)
164 def commit(self) -> None:
165 """Commit SQL transaction."""
168 def exec(self, code: str, inputs: tuple[Any, ...] = tuple()) -> Cursor:
169 """Add commands to SQL transaction."""
170 return self.conn.execute(code, inputs)
172 def exec_on_vals(self, code: str, inputs: tuple[Any, ...]) -> Cursor:
173 """Wrapper around .exec appending adequate " (?, …)" to code."""
174 q_marks_from_values = '(' + ','.join(['?'] * len(inputs)) + ')'
175 return self.exec(f'{code} {q_marks_from_values}', inputs)
177 def close(self) -> None:
178 """Close DB connection."""
181 def rewrite_relations(self, table_name: str, key: str, target: int | str,
182 rows: list[list[Any]], key_index: int = 0) -> None:
183 # pylint: disable=too-many-arguments
184 """Rewrite relations in table_name to target, with rows values.
186 Note that single rows are expected without the column and value
187 identified by key and target, which are inserted inside the function
190 self.delete_where(table_name, key, target)
192 values = tuple(row[:key_index] + [target] + row[key_index:])
193 self.exec_on_vals(f'INSERT INTO {table_name} VALUES', values)
195 def row_where(self, table_name: str, key: str,
196 target: int | str) -> list[Row]:
197 """Return list of Rows at table where key == target."""
198 return list(self.exec(f'SELECT * FROM {table_name} WHERE {key} = ?',
201 # def column_where_pattern(self,
205 # keys: list[str]) -> list[Any]:
206 # """Return column of rows where one of keys matches pattern."""
207 # targets = tuple([f'%{pattern}%'] * len(keys))
208 # haystack = ' OR '.join([f'{k} LIKE ?' for k in keys])
209 # sql = f'SELECT {column} FROM {table_name} WHERE {haystack}'
210 # return [row[0] for row in self.exec(sql, targets)]
212 def column_where(self, table_name: str, column: str, key: str,
213 target: int | str) -> list[Any]:
214 """Return column of table where key == target."""
215 return [row[0] for row in
216 self.exec(f'SELECT {column} FROM {table_name} '
217 f'WHERE {key} = ?', (target,))]
219 def column_all(self, table_name: str, column: str) -> list[Any]:
220 """Return complete column of table."""
221 return [row[0] for row in
222 self.exec(f'SELECT {column} FROM {table_name}')]
224 def delete_where(self, table_name: str, key: str,
225 target: int | str) -> None:
226 """Delete from table where key == target."""
227 self.exec(f'DELETE FROM {table_name} WHERE {key} = ?', (target,))
230 BaseModelId = TypeVar('BaseModelId', int, str)
231 BaseModelInstance = TypeVar('BaseModelInstance', bound='BaseModel[Any]')
234 class BaseModel(Generic[BaseModelId]):
235 """Template for most of the models we use/derive from the DB."""
237 to_save: list[str] = []
238 to_save_versioned: list[str] = []
239 to_save_relations: list[tuple[str, str, str, int]] = []
240 id_: None | BaseModelId
241 cache_: dict[BaseModelId, Self]
242 to_search: list[str] = []
244 def __init__(self, id_: BaseModelId | None) -> None:
245 if isinstance(id_, int) and id_ < 1:
246 msg = f'illegal {self.__class__.__name__} ID, must be >=1: {id_}'
247 raise HandledException(msg)
248 if isinstance(id_, str) and "" == id_:
249 msg = f'illegal {self.__class__.__name__} ID, must be non-empty'
250 raise HandledException(msg)
253 def __hash__(self) -> int:
254 hashable = [self.id_] + [getattr(self, name) for name in self.to_save]
255 for definition in self.to_save_relations:
256 attr = getattr(self, definition[2])
257 hashable += [tuple(rel.id_ for rel in attr)]
258 for name in self.to_save_versioned:
259 hashable += [hash(getattr(self, name))]
260 return hash(tuple(hashable))
262 def __eq__(self, other: object) -> bool:
263 if not isinstance(other, self.__class__):
265 return hash(self) == hash(other)
267 def __lt__(self, other: Any) -> bool:
268 if not isinstance(other, self.__class__):
269 msg = 'cannot compare to object of different class'
270 raise HandledException(msg)
271 assert isinstance(self.id_, int)
272 assert isinstance(other.id_, int)
273 return self.id_ < other.id_
278 def _get_cached(cls: type[BaseModelInstance],
279 id_: BaseModelId) -> BaseModelInstance | None:
280 """Get object of id_ from class's cache, or None if not found."""
281 # pylint: disable=consider-iterating-dictionary
282 cache = cls.get_cache()
283 if id_ in cache.keys():
285 assert isinstance(obj, cls)
290 def empty_cache(cls) -> None:
291 """Empty class's cache."""
295 def get_cache(cls: type[BaseModelInstance]) -> dict[Any, BaseModel[Any]]:
296 """Get cache dictionary, create it if not yet existing."""
297 if not hasattr(cls, 'cache_'):
298 d: dict[Any, BaseModel[Any]] = {}
302 def cache(self) -> None:
303 """Update object in class's cache."""
305 raise HandledException('Cannot cache object without ID.')
306 cache = self.__class__.get_cache()
307 cache[self.id_] = self
309 def uncache(self) -> None:
310 """Remove self from cache."""
312 raise HandledException('Cannot un-cache object without ID.')
313 cache = self.__class__.get_cache()
316 # object retrieval and generation
319 def from_table_row(cls: type[BaseModelInstance],
320 # pylint: disable=unused-argument
321 db_conn: DatabaseConnection,
322 row: Row | list[Any]) -> BaseModelInstance:
323 """Make from DB row, write to DB cache."""
329 def by_id(cls, db_conn: DatabaseConnection,
330 id_: BaseModelId | None,
331 # pylint: disable=unused-argument
332 create: bool = False) -> Self:
333 """Retrieve by id_, on failure throw NotFoundException.
335 First try to get from cls.cache_, only then check DB; if found,
338 If create=True, make anew (but do not cache yet).
342 obj = cls._get_cached(id_)
344 for row in db_conn.row_where(cls.table_name, 'id', id_):
345 obj = cls.from_table_row(db_conn, row)
353 raise NotFoundException(f'found no object of ID {id_}')
356 def all(cls: type[BaseModelInstance],
357 db_conn: DatabaseConnection) -> list[BaseModelInstance]:
358 """Collect all objects of class into list.
360 Note that this primarily returns the contents of the cache, and only
361 _expands_ that by additional findings in the DB. This assumes the
362 cache is always instantly cleaned of any items that would be removed
365 items: dict[BaseModelId, BaseModelInstance] = {}
366 for k, v in cls.get_cache().items():
367 assert isinstance(v, cls)
369 already_recorded = items.keys()
370 for id_ in db_conn.column_all(cls.table_name, 'id'):
371 if id_ not in already_recorded:
372 item = cls.by_id(db_conn, id_)
373 assert item.id_ is not None
374 items[item.id_] = item
375 return list(items.values())
378 def by_date_range_with_limits(cls: type[BaseModelInstance],
379 db_conn: DatabaseConnection,
380 date_range: tuple[str, str],
381 date_col: str = 'day'
382 ) -> tuple[list[BaseModelInstance], str,
384 """Return list of items in database within (open) date_range interval.
386 If no range values provided, defaults them to 'yesterday' and
387 'tomorrow'. Knows to properly interpret these and 'today' as value.
389 start_str = date_range[0] if date_range[0] else 'yesterday'
390 end_str = date_range[1] if date_range[1] else 'tomorrow'
391 start_date = valid_date(start_str)
392 end_date = valid_date(end_str)
394 sql = f'SELECT id FROM {cls.table_name} '
395 sql += f'WHERE {date_col} >= ? AND {date_col} <= ?'
396 for row in db_conn.exec(sql, (start_date, end_date)):
397 items += [cls.by_id(db_conn, row[0])]
398 return items, start_date, end_date
401 def matching(cls: type[BaseModelInstance], db_conn: DatabaseConnection,
402 pattern: str) -> list[BaseModelInstance]:
403 """Return all objects whose .to_search match pattern."""
404 items = cls.all(db_conn)
408 for attr_name in cls.to_search:
409 toks = attr_name.split('.')
412 attr = getattr(parent, tok)
422 def save(self, db_conn: DatabaseConnection) -> None:
423 """Write self to DB and cache and ensure .id_.
425 Write both to DB, and to cache. To DB, write .id_ and attributes
426 listed in cls.to_save[_versioned|_relations].
428 Ensure self.id_ by setting it to what the DB command returns as the
429 last saved row's ID (cursor.lastrowid), EXCEPT if self.id_ already
430 exists as a 'str', which implies we do our own ID creation (so far
431 only the case with the Day class, where it's to be a date string.
433 values = tuple([self.id_] + [getattr(self, key)
434 for key in self.to_save])
435 table_name = self.table_name
436 cursor = db_conn.exec_on_vals(f'REPLACE INTO {table_name} VALUES',
438 if not isinstance(self.id_, str):
439 self.id_ = cursor.lastrowid # type: ignore[assignment]
441 for attr_name in self.to_save_versioned:
442 getattr(self, attr_name).save(db_conn)
443 for table, column, attr_name, key_index in self.to_save_relations:
444 assert isinstance(self.id_, (int, str))
445 db_conn.rewrite_relations(table, column, self.id_,
447 in getattr(self, attr_name)], key_index)
449 def remove(self, db_conn: DatabaseConnection) -> None:
450 """Remove from DB and cache, including dependencies."""
451 # pylint: disable=protected-access
452 # (since we remain within class)
453 if self.id_ is None or self.__class__._get_cached(self.id_) is None:
454 raise HandledException('cannot remove unsaved item')
455 for attr_name in self.to_save_versioned:
456 getattr(self, attr_name).remove(db_conn)
457 for table, column, attr_name, _ in self.to_save_relations:
458 db_conn.delete_where(table, column, self.id_)
460 db_conn.delete_where(self.table_name, 'id', self.id_)