1 """Database management."""
2 from __future__ import annotations
4 from os.path import isfile
5 from difflib import Differ
6 from sqlite3 import connect as sql_connect, Cursor, Row
7 from typing import Any, Self, TypeVar, Generic
8 from plomtask.exceptions import HandledException, NotFoundException
9 from plomtask.dating import valid_date
11 EXPECTED_DB_VERSION = 5
12 MIGRATIONS_DIR = 'migrations'
13 FILENAME_DB_SCHEMA = f'init_{EXPECTED_DB_VERSION}.sql'
14 PATH_DB_SCHEMA = f'{MIGRATIONS_DIR}/{FILENAME_DB_SCHEMA}'
17 class UnmigratedDbException(HandledException):
18 """To identify case of unmigrated DB file."""
22 """Represents the sqlite3 database's file."""
23 # pylint: disable=too-few-public-methods
25 def __init__(self, path: str) -> None:
30 def create_at(cls, path: str) -> DatabaseFile:
31 """Make new DB file at path."""
32 with sql_connect(path) as conn:
33 with open(PATH_DB_SCHEMA, 'r', encoding='utf-8') as f:
34 conn.executescript(f.read())
35 conn.execute(f'PRAGMA user_version = {EXPECTED_DB_VERSION}')
39 def migrate(cls, path: str) -> DatabaseFile:
40 """Apply migrations from_version to EXPECTED_DB_VERSION."""
41 migrations = cls._available_migrations()
42 from_version = cls._get_version_of_db(path)
43 migrations_todo = migrations[from_version+1:]
44 for j, filename in enumerate(migrations_todo):
45 with sql_connect(path) as conn:
46 with open(f'{MIGRATIONS_DIR}/{filename}', 'r',
47 encoding='utf-8') as f:
48 conn.executescript(f.read())
49 user_version = from_version + j + 1
50 with sql_connect(path) as conn:
51 conn.execute(f'PRAGMA user_version = {user_version}')
54 def _check(self) -> None:
55 """Check file exists, and is of proper DB version and schema."""
56 if not isfile(self.path):
57 raise NotFoundException
58 if self._user_version != EXPECTED_DB_VERSION:
59 raise UnmigratedDbException()
60 self._validate_schema()
63 def _available_migrations() -> list[str]:
64 """Validate migrations directory and return sorted entries."""
65 msg_too_big = 'Migration directory points beyond expected DB version.'
66 msg_bad_entry = 'Migration directory contains unexpected entry: '
67 msg_missing = 'Migration directory misses migration of number: '
69 for entry in listdir(MIGRATIONS_DIR):
70 if entry == FILENAME_DB_SCHEMA:
72 toks = entry.split('_', 1)
74 raise HandledException(msg_bad_entry + entry)
77 except ValueError as e:
78 raise HandledException(msg_bad_entry + entry) from e
79 if i > EXPECTED_DB_VERSION:
80 raise HandledException(msg_too_big)
81 migrations[i] = toks[1]
83 for i in range(EXPECTED_DB_VERSION + 1):
84 if i not in migrations:
85 raise HandledException(msg_missing + str(i))
86 migrations_list += [f'{i}_{migrations[i]}']
87 return migrations_list
90 def _get_version_of_db(path: str) -> int:
91 """Get DB user_version, fail if outside expected range."""
92 sql_for_db_version = 'PRAGMA user_version'
93 with sql_connect(path) as conn:
94 db_version = list(conn.execute(sql_for_db_version))[0][0]
95 if db_version > EXPECTED_DB_VERSION:
96 msg = f'Wrong DB version, expected '\
97 f'{EXPECTED_DB_VERSION}, got unknown {db_version}.'
98 raise HandledException(msg)
99 assert isinstance(db_version, int)
103 def _user_version(self) -> int:
104 """Get DB user_version."""
105 # pylint: disable=protected-access
106 # (since we remain within class)
107 return self.__class__._get_version_of_db(self.path)
109 def _validate_schema(self) -> None:
110 """Compare found schema with what's stored at PATH_DB_SCHEMA."""
112 def reformat_rows(rows: list[str]) -> list[str]:
116 for subrow in row.split('\n'):
117 subrow = subrow.rstrip()
120 for i, c in enumerate(subrow):
125 elif ',' == c and 0 == in_parentheses:
129 segment = subrow[prev_split:i].strip()
131 new_row += [f' {segment}']
133 segment = subrow[prev_split:].strip()
135 new_row += [f' {segment}']
136 new_row[0] = new_row[0].lstrip()
137 new_row[-1] = new_row[-1].lstrip()
138 if new_row[-1] != ')' and new_row[-3][-1] != ',':
139 new_row[-3] = new_row[-3] + ','
140 new_row[-2:] = [' ' + new_row[-1][:-1]] + [')']
141 new_rows += ['\n'.join(new_row)]
144 sql_for_schema = 'SELECT sql FROM sqlite_master ORDER BY sql'
145 msg_err = 'Database has wrong tables schema. Diff:\n'
146 with sql_connect(self.path) as conn:
147 schema_rows = [r[0] for r in conn.execute(sql_for_schema) if r[0]]
148 schema_rows = reformat_rows(schema_rows)
149 retrieved_schema = ';\n'.join(schema_rows) + ';'
150 with open(PATH_DB_SCHEMA, 'r', encoding='utf-8') as f:
151 stored_schema = f.read().rstrip()
152 if stored_schema != retrieved_schema:
153 diff_msg = Differ().compare(retrieved_schema.splitlines(),
154 stored_schema.splitlines())
155 raise HandledException(msg_err + '\n'.join(diff_msg))
158 class DatabaseConnection:
159 """A single connection to the database."""
161 def __init__(self, db_file: DatabaseFile) -> None:
162 self.conn = sql_connect(db_file.path)
164 def commit(self) -> None:
165 """Commit SQL transaction."""
168 def exec(self, code: str, inputs: tuple[Any, ...] = tuple()) -> Cursor:
169 """Add commands to SQL transaction."""
170 return self.conn.execute(code, inputs)
172 def exec_on_vals(self, code: str, inputs: tuple[Any, ...]) -> Cursor:
173 """Wrapper around .exec appending adequate " (?, …)" to code."""
174 q_marks_from_values = '(' + ','.join(['?'] * len(inputs)) + ')'
175 return self.exec(f'{code} {q_marks_from_values}', inputs)
177 def close(self) -> None:
178 """Close DB connection."""
181 def rewrite_relations(self, table_name: str, key: str, target: int | str,
182 rows: list[list[Any]], key_index: int = 0) -> None:
183 # pylint: disable=too-many-arguments
184 """Rewrite relations in table_name to target, with rows values.
186 Note that single rows are expected without the column and value
187 identified by key and target, which are inserted inside the function
190 self.delete_where(table_name, key, target)
192 values = tuple(row[:key_index] + [target] + row[key_index:])
193 self.exec_on_vals(f'INSERT INTO {table_name} VALUES', values)
195 def row_where(self, table_name: str, key: str,
196 target: int | str) -> list[Row]:
197 """Return list of Rows at table where key == target."""
198 return list(self.exec(f'SELECT * FROM {table_name} WHERE {key} = ?',
201 # def column_where_pattern(self,
205 # keys: list[str]) -> list[Any]:
206 # """Return column of rows where one of keys matches pattern."""
207 # targets = tuple([f'%{pattern}%'] * len(keys))
208 # haystack = ' OR '.join([f'{k} LIKE ?' for k in keys])
209 # sql = f'SELECT {column} FROM {table_name} WHERE {haystack}'
210 # return [row[0] for row in self.exec(sql, targets)]
212 def column_where(self, table_name: str, column: str, key: str,
213 target: int | str) -> list[Any]:
214 """Return column of table where key == target."""
215 return [row[0] for row in
216 self.exec(f'SELECT {column} FROM {table_name} '
217 f'WHERE {key} = ?', (target,))]
219 def column_all(self, table_name: str, column: str) -> list[Any]:
220 """Return complete column of table."""
221 return [row[0] for row in
222 self.exec(f'SELECT {column} FROM {table_name}')]
224 def delete_where(self, table_name: str, key: str,
225 target: int | str) -> None:
226 """Delete from table where key == target."""
227 self.exec(f'DELETE FROM {table_name} WHERE {key} = ?', (target,))
230 BaseModelId = TypeVar('BaseModelId', int, str)
231 BaseModelInstance = TypeVar('BaseModelInstance', bound='BaseModel[Any]')
234 class BaseModel(Generic[BaseModelId]):
235 """Template for most of the models we use/derive from the DB."""
237 to_save: list[str] = []
238 to_save_versioned: list[str] = []
239 to_save_relations: list[tuple[str, str, str, int]] = []
240 id_: None | BaseModelId
241 cache_: dict[BaseModelId, Self]
242 to_search: list[str] = []
244 def __init__(self, id_: BaseModelId | None) -> None:
245 if isinstance(id_, int) and id_ < 1:
246 msg = f'illegal {self.__class__.__name__} ID, must be >=1: {id_}'
247 raise HandledException(msg)
250 def __eq__(self, other: object) -> bool:
251 if not isinstance(other, self.__class__):
253 to_hash_me = tuple([self.id_] +
254 [getattr(self, name) for name in self.to_save])
255 to_hash_other = tuple([other.id_] +
256 [getattr(other, name) for name in other.to_save])
257 return hash(to_hash_me) == hash(to_hash_other)
259 def __lt__(self, other: Any) -> bool:
260 if not isinstance(other, self.__class__):
261 msg = 'cannot compare to object of different class'
262 raise HandledException(msg)
263 assert isinstance(self.id_, int)
264 assert isinstance(other.id_, int)
265 return self.id_ < other.id_
268 def get_cached(cls: type[BaseModelInstance],
269 id_: BaseModelId) -> BaseModelInstance | None:
270 """Get object of id_ from class's cache, or None if not found."""
271 # pylint: disable=consider-iterating-dictionary
272 cache = cls.get_cache()
273 if id_ in cache.keys():
275 assert isinstance(obj, cls)
280 def empty_cache(cls) -> None:
281 """Empty class's cache."""
285 def get_cache(cls: type[BaseModelInstance]) -> dict[Any, BaseModel[Any]]:
286 """Get cache dictionary, create it if not yet existing."""
287 if not hasattr(cls, 'cache_'):
288 d: dict[Any, BaseModel[Any]] = {}
292 def cache(self) -> None:
293 """Update object in class's cache."""
295 raise HandledException('Cannot cache object without ID.')
296 cache = self.__class__.get_cache()
297 cache[self.id_] = self
299 def uncache(self) -> None:
300 """Remove self from cache."""
302 raise HandledException('Cannot un-cache object without ID.')
303 cache = self.__class__.get_cache()
307 def from_table_row(cls: type[BaseModelInstance],
308 # pylint: disable=unused-argument
309 db_conn: DatabaseConnection,
310 row: Row | list[Any]) -> BaseModelInstance:
311 """Make from DB row, write to DB cache."""
317 def by_id(cls, db_conn: DatabaseConnection,
318 id_: BaseModelId | None,
319 # pylint: disable=unused-argument
320 create: bool = False) -> Self:
321 """Retrieve by id_, on failure throw NotFoundException.
323 First try to get from cls.cache_, only then check DB; if found,
326 If create=True, make anew (but do not cache yet).
330 obj = cls.get_cached(id_)
332 for row in db_conn.row_where(cls.table_name, 'id', id_):
333 obj = cls.from_table_row(db_conn, row)
341 raise NotFoundException(f'found no object of ID {id_}')
344 def all(cls: type[BaseModelInstance],
345 db_conn: DatabaseConnection) -> list[BaseModelInstance]:
346 """Collect all objects of class into list.
348 Note that this primarily returns the contents of the cache, and only
349 _expands_ that by additional findings in the DB. This assumes the
350 cache is always instantly cleaned of any items that would be removed
353 items: dict[BaseModelId, BaseModelInstance] = {}
354 for k, v in cls.get_cache().items():
355 assert isinstance(v, cls)
357 already_recorded = items.keys()
358 for id_ in db_conn.column_all(cls.table_name, 'id'):
359 if id_ not in already_recorded:
360 item = cls.by_id(db_conn, id_)
361 assert item.id_ is not None
362 items[item.id_] = item
363 return list(items.values())
366 def by_date_range_with_limits(cls: type[BaseModelInstance],
367 db_conn: DatabaseConnection,
368 date_range: tuple[str, str],
369 date_col: str = 'day'
370 ) -> tuple[list[BaseModelInstance], str,
372 """Return list of items in database within (open) date_range interval.
374 If no range values provided, defaults them to 'yesterday' and
375 'tomorrow'. Knows to properly interpret these and 'today' as value.
377 start_str = date_range[0] if date_range[0] else 'yesterday'
378 end_str = date_range[1] if date_range[1] else 'tomorrow'
379 start_date = valid_date(start_str)
380 end_date = valid_date(end_str)
382 sql = f'SELECT id FROM {cls.table_name} '
383 sql += f'WHERE {date_col} >= ? AND {date_col} <= ?'
384 for row in db_conn.exec(sql, (start_date, end_date)):
385 items += [cls.by_id(db_conn, row[0])]
386 return items, start_date, end_date
389 def matching(cls: type[BaseModelInstance], db_conn: DatabaseConnection,
390 pattern: str) -> list[BaseModelInstance]:
391 """Return all objects whose .to_search match pattern."""
392 items = cls.all(db_conn)
396 for attr_name in cls.to_search:
397 toks = attr_name.split('.')
400 attr = getattr(parent, tok)
408 def save(self, db_conn: DatabaseConnection) -> None:
409 """Write self to DB and cache and ensure .id_.
411 Write both to DB, and to cache. To DB, write .id_ and attributes
412 listed in cls.to_save[_versioned|_relations].
414 Ensure self.id_ by setting it to what the DB command returns as the
415 last saved row's ID (cursor.lastrowid), EXCEPT if self.id_ already
416 exists as a 'str', which implies we do our own ID creation (so far
417 only the case with the Day class, where it's to be a date string.
419 values = tuple([self.id_] + [getattr(self, key)
420 for key in self.to_save])
421 table_name = self.table_name
422 cursor = db_conn.exec_on_vals(f'REPLACE INTO {table_name} VALUES',
424 if not isinstance(self.id_, str):
425 self.id_ = cursor.lastrowid # type: ignore[assignment]
427 for attr_name in self.to_save_versioned:
428 getattr(self, attr_name).save(db_conn)
429 for table, column, attr_name, key_index in self.to_save_relations:
430 assert isinstance(self.id_, (int, str))
431 db_conn.rewrite_relations(table, column, self.id_,
433 in getattr(self, attr_name)], key_index)
435 def remove(self, db_conn: DatabaseConnection) -> None:
436 """Remove from DB and cache, including dependencies."""
437 if self.id_ is None or self.__class__.get_cached(self.id_) is None:
438 raise HandledException('cannot remove unsaved item')
439 for attr_name in self.to_save_versioned:
440 getattr(self, attr_name).remove(db_conn)
441 for table, column, attr_name, _ in self.to_save_relations:
442 db_conn.delete_where(table, column, self.id_)
444 db_conn.delete_where(self.table_name, 'id', self.id_)