1 """Database management."""
2 from __future__ import annotations
4 from os.path import isfile
5 from difflib import Differ
6 from sqlite3 import connect as sql_connect, Cursor, Row
7 from typing import Any, Self, TypeVar, Generic
8 from plomtask.exceptions import HandledException, NotFoundException
9 from plomtask.dating import valid_date
11 EXPECTED_DB_VERSION = 5
12 MIGRATIONS_DIR = 'migrations'
13 FILENAME_DB_SCHEMA = f'init_{EXPECTED_DB_VERSION}.sql'
14 PATH_DB_SCHEMA = f'{MIGRATIONS_DIR}/{FILENAME_DB_SCHEMA}'
17 class UnmigratedDbException(HandledException):
18 """To identify case of unmigrated DB file."""
22 """Represents the sqlite3 database's file."""
23 # pylint: disable=too-few-public-methods
25 def __init__(self, path: str) -> None:
30 def create_at(cls, path: str) -> DatabaseFile:
31 """Make new DB file at path."""
32 with sql_connect(path) as conn:
33 with open(PATH_DB_SCHEMA, 'r', encoding='utf-8') as f:
34 conn.executescript(f.read())
35 conn.execute(f'PRAGMA user_version = {EXPECTED_DB_VERSION}')
39 def migrate(cls, path: str) -> DatabaseFile:
40 """Apply migrations from_version to EXPECTED_DB_VERSION."""
41 migrations = cls._available_migrations()
42 from_version = cls._get_version_of_db(path)
43 migrations_todo = migrations[from_version+1:]
44 for j, filename in enumerate(migrations_todo):
45 with sql_connect(path) as conn:
46 with open(f'{MIGRATIONS_DIR}/{filename}', 'r',
47 encoding='utf-8') as f:
48 conn.executescript(f.read())
49 user_version = from_version + j + 1
50 with sql_connect(path) as conn:
51 conn.execute(f'PRAGMA user_version = {user_version}')
54 def _check(self) -> None:
55 """Check file exists, and is of proper DB version and schema."""
56 if not isfile(self.path):
57 raise NotFoundException
58 if self._user_version != EXPECTED_DB_VERSION:
59 raise UnmigratedDbException()
60 self._validate_schema()
63 def _available_migrations() -> list[str]:
64 """Validate migrations directory and return sorted entries."""
65 msg_too_big = 'Migration directory points beyond expected DB version.'
66 msg_bad_entry = 'Migration directory contains unexpected entry: '
67 msg_missing = 'Migration directory misses migration of number: '
69 for entry in listdir(MIGRATIONS_DIR):
70 if entry == FILENAME_DB_SCHEMA:
72 toks = entry.split('_', 1)
74 raise HandledException(msg_bad_entry + entry)
77 except ValueError as e:
78 raise HandledException(msg_bad_entry + entry) from e
79 if i > EXPECTED_DB_VERSION:
80 raise HandledException(msg_too_big)
81 migrations[i] = toks[1]
83 for i in range(EXPECTED_DB_VERSION + 1):
84 if i not in migrations:
85 raise HandledException(msg_missing + str(i))
86 migrations_list += [f'{i}_{migrations[i]}']
87 return migrations_list
90 def _get_version_of_db(path: str) -> int:
91 """Get DB user_version, fail if outside expected range."""
92 sql_for_db_version = 'PRAGMA user_version'
93 with sql_connect(path) as conn:
94 db_version = list(conn.execute(sql_for_db_version))[0][0]
95 if db_version > EXPECTED_DB_VERSION:
96 msg = f'Wrong DB version, expected '\
97 f'{EXPECTED_DB_VERSION}, got unknown {db_version}.'
98 raise HandledException(msg)
99 assert isinstance(db_version, int)
103 def _user_version(self) -> int:
104 """Get DB user_version."""
105 # pylint: disable=protected-access
106 # (since we remain within class)
107 return self.__class__._get_version_of_db(self.path)
109 def _validate_schema(self) -> None:
110 """Compare found schema with what's stored at PATH_DB_SCHEMA."""
112 def reformat_rows(rows: list[str]) -> list[str]:
116 for subrow in row.split('\n'):
117 subrow = subrow.rstrip()
120 for i, c in enumerate(subrow):
125 elif ',' == c and 0 == in_parentheses:
129 segment = subrow[prev_split:i].strip()
131 new_row += [f' {segment}']
133 segment = subrow[prev_split:].strip()
135 new_row += [f' {segment}']
136 new_row[0] = new_row[0].lstrip()
137 new_row[-1] = new_row[-1].lstrip()
138 if new_row[-1] != ')' and new_row[-3][-1] != ',':
139 new_row[-3] = new_row[-3] + ','
140 new_row[-2:] = [' ' + new_row[-1][:-1]] + [')']
141 new_rows += ['\n'.join(new_row)]
144 sql_for_schema = 'SELECT sql FROM sqlite_master ORDER BY sql'
145 msg_err = 'Database has wrong tables schema. Diff:\n'
146 with sql_connect(self.path) as conn:
147 schema_rows = [r[0] for r in conn.execute(sql_for_schema) if r[0]]
148 schema_rows = reformat_rows(schema_rows)
149 retrieved_schema = ';\n'.join(schema_rows) + ';'
150 with open(PATH_DB_SCHEMA, 'r', encoding='utf-8') as f:
151 stored_schema = f.read().rstrip()
152 if stored_schema != retrieved_schema:
153 diff_msg = Differ().compare(retrieved_schema.splitlines(),
154 stored_schema.splitlines())
155 raise HandledException(msg_err + '\n'.join(diff_msg))
158 class DatabaseConnection:
159 """A single connection to the database."""
161 def __init__(self, db_file: DatabaseFile) -> None:
162 self.conn = sql_connect(db_file.path)
164 def commit(self) -> None:
165 """Commit SQL transaction."""
168 def exec(self, code: str, inputs: tuple[Any, ...] = tuple()) -> Cursor:
169 """Add commands to SQL transaction."""
170 return self.conn.execute(code, inputs)
172 def exec_on_vals(self, code: str, inputs: tuple[Any, ...]) -> Cursor:
173 """Wrapper around .exec appending adequate " (?, …)" to code."""
174 q_marks_from_values = '(' + ','.join(['?'] * len(inputs)) + ')'
175 return self.exec(f'{code} {q_marks_from_values}', inputs)
177 def close(self) -> None:
178 """Close DB connection."""
181 def rewrite_relations(self, table_name: str, key: str, target: int | str,
182 rows: list[list[Any]], key_index: int = 0) -> None:
183 # pylint: disable=too-many-arguments
184 """Rewrite relations in table_name to target, with rows values.
186 Note that single rows are expected without the column and value
187 identified by key and target, which are inserted inside the function
190 self.delete_where(table_name, key, target)
192 values = tuple(row[:key_index] + [target] + row[key_index:])
193 self.exec_on_vals(f'INSERT INTO {table_name} VALUES', values)
195 def row_where(self, table_name: str, key: str,
196 target: int | str) -> list[Row]:
197 """Return list of Rows at table where key == target."""
198 return list(self.exec(f'SELECT * FROM {table_name} WHERE {key} = ?',
201 # def column_where_pattern(self,
205 # keys: list[str]) -> list[Any]:
206 # """Return column of rows where one of keys matches pattern."""
207 # targets = tuple([f'%{pattern}%'] * len(keys))
208 # haystack = ' OR '.join([f'{k} LIKE ?' for k in keys])
209 # sql = f'SELECT {column} FROM {table_name} WHERE {haystack}'
210 # return [row[0] for row in self.exec(sql, targets)]
212 def column_where(self, table_name: str, column: str, key: str,
213 target: int | str) -> list[Any]:
214 """Return column of table where key == target."""
215 return [row[0] for row in
216 self.exec(f'SELECT {column} FROM {table_name} '
217 f'WHERE {key} = ?', (target,))]
219 def column_all(self, table_name: str, column: str) -> list[Any]:
220 """Return complete column of table."""
221 return [row[0] for row in
222 self.exec(f'SELECT {column} FROM {table_name}')]
224 def delete_where(self, table_name: str, key: str,
225 target: int | str) -> None:
226 """Delete from table where key == target."""
227 self.exec(f'DELETE FROM {table_name} WHERE {key} = ?', (target,))
230 BaseModelId = TypeVar('BaseModelId', int, str)
231 BaseModelInstance = TypeVar('BaseModelInstance', bound='BaseModel[Any]')
234 class BaseModel(Generic[BaseModelId]):
235 """Template for most of the models we use/derive from the DB."""
237 to_save: list[str] = []
238 to_save_versioned: list[str] = []
239 to_save_relations: list[tuple[str, str, str, int]] = []
240 id_: None | BaseModelId
241 cache_: dict[BaseModelId, Self]
242 to_search: list[str] = []
244 def __init__(self, id_: BaseModelId | None) -> None:
245 if isinstance(id_, int) and id_ < 1:
246 msg = f'illegal {self.__class__.__name__} ID, must be >=1: {id_}'
247 raise HandledException(msg)
248 if isinstance(id_, str) and "" == id_:
249 msg = f'illegal {self.__class__.__name__} ID, must be non-empty'
250 raise HandledException(msg)
253 def __hash__(self) -> int:
254 hashable = [self.id_] + [getattr(self, name) for name in self.to_save]
255 for definition in self.to_save_relations:
256 attr = getattr(self, definition[2])
257 hashable += [tuple(rel.id_ for rel in attr)]
258 for name in self.to_save_versioned:
259 hashable += [hash(getattr(self, name))]
260 return hash(tuple(hashable))
262 def __eq__(self, other: object) -> bool:
263 if not isinstance(other, self.__class__):
265 return hash(self) == hash(other)
267 def __lt__(self, other: Any) -> bool:
268 if not isinstance(other, self.__class__):
269 msg = 'cannot compare to object of different class'
270 raise HandledException(msg)
271 assert isinstance(self.id_, int)
272 assert isinstance(other.id_, int)
273 return self.id_ < other.id_
276 def get_cached(cls: type[BaseModelInstance],
277 id_: BaseModelId) -> BaseModelInstance | None:
278 """Get object of id_ from class's cache, or None if not found."""
279 # pylint: disable=consider-iterating-dictionary
280 cache = cls.get_cache()
281 if id_ in cache.keys():
283 assert isinstance(obj, cls)
288 def empty_cache(cls) -> None:
289 """Empty class's cache."""
293 def get_cache(cls: type[BaseModelInstance]) -> dict[Any, BaseModel[Any]]:
294 """Get cache dictionary, create it if not yet existing."""
295 if not hasattr(cls, 'cache_'):
296 d: dict[Any, BaseModel[Any]] = {}
300 def cache(self) -> None:
301 """Update object in class's cache."""
303 raise HandledException('Cannot cache object without ID.')
304 cache = self.__class__.get_cache()
305 cache[self.id_] = self
307 def uncache(self) -> None:
308 """Remove self from cache."""
310 raise HandledException('Cannot un-cache object without ID.')
311 cache = self.__class__.get_cache()
315 def from_table_row(cls: type[BaseModelInstance],
316 # pylint: disable=unused-argument
317 db_conn: DatabaseConnection,
318 row: Row | list[Any]) -> BaseModelInstance:
319 """Make from DB row, write to DB cache."""
325 def by_id(cls, db_conn: DatabaseConnection,
326 id_: BaseModelId | None,
327 # pylint: disable=unused-argument
328 create: bool = False) -> Self:
329 """Retrieve by id_, on failure throw NotFoundException.
331 First try to get from cls.cache_, only then check DB; if found,
334 If create=True, make anew (but do not cache yet).
338 obj = cls.get_cached(id_)
340 for row in db_conn.row_where(cls.table_name, 'id', id_):
341 obj = cls.from_table_row(db_conn, row)
349 raise NotFoundException(f'found no object of ID {id_}')
352 def all(cls: type[BaseModelInstance],
353 db_conn: DatabaseConnection) -> list[BaseModelInstance]:
354 """Collect all objects of class into list.
356 Note that this primarily returns the contents of the cache, and only
357 _expands_ that by additional findings in the DB. This assumes the
358 cache is always instantly cleaned of any items that would be removed
361 items: dict[BaseModelId, BaseModelInstance] = {}
362 for k, v in cls.get_cache().items():
363 assert isinstance(v, cls)
365 already_recorded = items.keys()
366 for id_ in db_conn.column_all(cls.table_name, 'id'):
367 if id_ not in already_recorded:
368 item = cls.by_id(db_conn, id_)
369 assert item.id_ is not None
370 items[item.id_] = item
371 return list(items.values())
374 def by_date_range_with_limits(cls: type[BaseModelInstance],
375 db_conn: DatabaseConnection,
376 date_range: tuple[str, str],
377 date_col: str = 'day'
378 ) -> tuple[list[BaseModelInstance], str,
380 """Return list of items in database within (open) date_range interval.
382 If no range values provided, defaults them to 'yesterday' and
383 'tomorrow'. Knows to properly interpret these and 'today' as value.
385 start_str = date_range[0] if date_range[0] else 'yesterday'
386 end_str = date_range[1] if date_range[1] else 'tomorrow'
387 start_date = valid_date(start_str)
388 end_date = valid_date(end_str)
390 sql = f'SELECT id FROM {cls.table_name} '
391 sql += f'WHERE {date_col} >= ? AND {date_col} <= ?'
392 for row in db_conn.exec(sql, (start_date, end_date)):
393 items += [cls.by_id(db_conn, row[0])]
394 return items, start_date, end_date
397 def matching(cls: type[BaseModelInstance], db_conn: DatabaseConnection,
398 pattern: str) -> list[BaseModelInstance]:
399 """Return all objects whose .to_search match pattern."""
400 items = cls.all(db_conn)
404 for attr_name in cls.to_search:
405 toks = attr_name.split('.')
408 attr = getattr(parent, tok)
416 def save(self, db_conn: DatabaseConnection) -> None:
417 """Write self to DB and cache and ensure .id_.
419 Write both to DB, and to cache. To DB, write .id_ and attributes
420 listed in cls.to_save[_versioned|_relations].
422 Ensure self.id_ by setting it to what the DB command returns as the
423 last saved row's ID (cursor.lastrowid), EXCEPT if self.id_ already
424 exists as a 'str', which implies we do our own ID creation (so far
425 only the case with the Day class, where it's to be a date string.
427 values = tuple([self.id_] + [getattr(self, key)
428 for key in self.to_save])
429 table_name = self.table_name
430 cursor = db_conn.exec_on_vals(f'REPLACE INTO {table_name} VALUES',
432 if not isinstance(self.id_, str):
433 self.id_ = cursor.lastrowid # type: ignore[assignment]
435 for attr_name in self.to_save_versioned:
436 getattr(self, attr_name).save(db_conn)
437 for table, column, attr_name, key_index in self.to_save_relations:
438 assert isinstance(self.id_, (int, str))
439 db_conn.rewrite_relations(table, column, self.id_,
441 in getattr(self, attr_name)], key_index)
443 def remove(self, db_conn: DatabaseConnection) -> None:
444 """Remove from DB and cache, including dependencies."""
445 if self.id_ is None or self.__class__.get_cached(self.id_) is None:
446 raise HandledException('cannot remove unsaved item')
447 for attr_name in self.to_save_versioned:
448 getattr(self, attr_name).remove(db_conn)
449 for table, column, attr_name, _ in self.to_save_relations:
450 db_conn.delete_where(table, column, self.id_)
452 db_conn.delete_where(self.table_name, 'id', self.id_)