1 """Database management."""
2 from __future__ import annotations
4 from os.path import isfile
5 from difflib import Differ
6 from sqlite3 import connect as sql_connect, Cursor, Row
7 from typing import Any, Self, TypeVar, Generic
8 from plomtask.exceptions import HandledException, NotFoundException
9 from plomtask.dating import valid_date
11 EXPECTED_DB_VERSION = 5
12 MIGRATIONS_DIR = 'migrations'
13 FILENAME_DB_SCHEMA = f'init_{EXPECTED_DB_VERSION}.sql'
14 PATH_DB_SCHEMA = f'{MIGRATIONS_DIR}/{FILENAME_DB_SCHEMA}'
17 class UnmigratedDbException(HandledException):
18 """To identify case of unmigrated DB file."""
22 """Represents the sqlite3 database's file."""
23 # pylint: disable=too-few-public-methods
25 def __init__(self, path: str) -> None:
30 def create_at(cls, path: str) -> DatabaseFile:
31 """Make new DB file at path."""
32 with sql_connect(path) as conn:
33 with open(PATH_DB_SCHEMA, 'r', encoding='utf-8') as f:
34 conn.executescript(f.read())
35 conn.execute(f'PRAGMA user_version = {EXPECTED_DB_VERSION}')
39 def migrate(cls, path: str) -> DatabaseFile:
40 """Apply migrations from_version to EXPECTED_DB_VERSION."""
41 migrations = cls._available_migrations()
42 from_version = cls._get_version_of_db(path)
43 migrations_todo = migrations[from_version+1:]
44 for j, filename in enumerate(migrations_todo):
45 with sql_connect(path) as conn:
46 with open(f'{MIGRATIONS_DIR}/{filename}', 'r',
47 encoding='utf-8') as f:
48 conn.executescript(f.read())
49 user_version = from_version + j + 1
50 with sql_connect(path) as conn:
51 conn.execute(f'PRAGMA user_version = {user_version}')
54 def _check(self) -> None:
55 """Check file exists, and is of proper DB version and schema."""
56 if not isfile(self.path):
57 raise NotFoundException
58 if self._user_version != EXPECTED_DB_VERSION:
59 raise UnmigratedDbException()
60 self._validate_schema()
63 def _available_migrations() -> list[str]:
64 """Validate migrations directory and return sorted entries."""
65 msg_too_big = 'Migration directory points beyond expected DB version.'
66 msg_bad_entry = 'Migration directory contains unexpected entry: '
67 msg_missing = 'Migration directory misses migration of number: '
69 for entry in listdir(MIGRATIONS_DIR):
70 if entry == FILENAME_DB_SCHEMA:
72 toks = entry.split('_', 1)
74 raise HandledException(msg_bad_entry + entry)
77 except ValueError as e:
78 raise HandledException(msg_bad_entry + entry) from e
79 if i > EXPECTED_DB_VERSION:
80 raise HandledException(msg_too_big)
81 migrations[i] = toks[1]
83 for i in range(EXPECTED_DB_VERSION + 1):
84 if i not in migrations:
85 raise HandledException(msg_missing + str(i))
86 migrations_list += [f'{i}_{migrations[i]}']
87 return migrations_list
90 def _get_version_of_db(path: str) -> int:
91 """Get DB user_version, fail if outside expected range."""
92 sql_for_db_version = 'PRAGMA user_version'
93 with sql_connect(path) as conn:
94 db_version = list(conn.execute(sql_for_db_version))[0][0]
95 if db_version > EXPECTED_DB_VERSION:
96 msg = f'Wrong DB version, expected '\
97 f'{EXPECTED_DB_VERSION}, got unknown {db_version}.'
98 raise HandledException(msg)
99 assert isinstance(db_version, int)
103 def _user_version(self) -> int:
104 """Get DB user_version."""
105 # pylint: disable=protected-access
106 # (since we remain within class)
107 return self.__class__._get_version_of_db(self.path)
109 def _validate_schema(self) -> None:
110 """Compare found schema with what's stored at PATH_DB_SCHEMA."""
112 def reformat_rows(rows: list[str]) -> list[str]:
116 for subrow in row.split('\n'):
117 subrow = subrow.rstrip()
120 for i, c in enumerate(subrow):
125 elif ',' == c and 0 == in_parentheses:
129 segment = subrow[prev_split:i].strip()
131 new_row += [f' {segment}']
133 segment = subrow[prev_split:].strip()
135 new_row += [f' {segment}']
136 new_row[0] = new_row[0].lstrip()
137 new_row[-1] = new_row[-1].lstrip()
138 if new_row[-1] != ')' and new_row[-3][-1] != ',':
139 new_row[-3] = new_row[-3] + ','
140 new_row[-2:] = [' ' + new_row[-1][:-1]] + [')']
141 new_rows += ['\n'.join(new_row)]
144 sql_for_schema = 'SELECT sql FROM sqlite_master ORDER BY sql'
145 msg_err = 'Database has wrong tables schema. Diff:\n'
146 with sql_connect(self.path) as conn:
147 schema_rows = [r[0] for r in conn.execute(sql_for_schema) if r[0]]
148 schema_rows = reformat_rows(schema_rows)
149 retrieved_schema = ';\n'.join(schema_rows) + ';'
150 with open(PATH_DB_SCHEMA, 'r', encoding='utf-8') as f:
151 stored_schema = f.read().rstrip()
152 if stored_schema != retrieved_schema:
153 diff_msg = Differ().compare(retrieved_schema.splitlines(),
154 stored_schema.splitlines())
155 raise HandledException(msg_err + '\n'.join(diff_msg))
158 class DatabaseConnection:
159 """A single connection to the database."""
161 def __init__(self, db_file: DatabaseFile) -> None:
162 self.conn = sql_connect(db_file.path)
164 def commit(self) -> None:
165 """Commit SQL transaction."""
168 def exec(self, code: str, inputs: tuple[Any, ...] = tuple()) -> Cursor:
169 """Add commands to SQL transaction."""
170 return self.conn.execute(code, inputs)
172 def exec_on_vals(self, code: str, inputs: tuple[Any, ...]) -> Cursor:
173 """Wrapper around .exec appending adequate " (?, …)" to code."""
174 q_marks_from_values = '(' + ','.join(['?'] * len(inputs)) + ')'
175 return self.exec(f'{code} {q_marks_from_values}', inputs)
177 def close(self) -> None:
178 """Close DB connection."""
181 def rewrite_relations(self, table_name: str, key: str, target: int | str,
182 rows: list[list[Any]], key_index: int = 0) -> None:
183 # pylint: disable=too-many-arguments
184 """Rewrite relations in table_name to target, with rows values.
186 Note that single rows are expected without the column and value
187 identified by key and target, which are inserted inside the function
190 self.delete_where(table_name, key, target)
192 values = tuple(row[:key_index] + [target] + row[key_index:])
193 self.exec_on_vals(f'INSERT INTO {table_name} VALUES', values)
195 def row_where(self, table_name: str, key: str,
196 target: int | str) -> list[Row]:
197 """Return list of Rows at table where key == target."""
198 return list(self.exec(f'SELECT * FROM {table_name} WHERE {key} = ?',
201 # def column_where_pattern(self,
205 # keys: list[str]) -> list[Any]:
206 # """Return column of rows where one of keys matches pattern."""
207 # targets = tuple([f'%{pattern}%'] * len(keys))
208 # haystack = ' OR '.join([f'{k} LIKE ?' for k in keys])
209 # sql = f'SELECT {column} FROM {table_name} WHERE {haystack}'
210 # return [row[0] for row in self.exec(sql, targets)]
212 def column_where(self, table_name: str, column: str, key: str,
213 target: int | str) -> list[Any]:
214 """Return column of table where key == target."""
215 return [row[0] for row in
216 self.exec(f'SELECT {column} FROM {table_name} '
217 f'WHERE {key} = ?', (target,))]
219 def column_all(self, table_name: str, column: str) -> list[Any]:
220 """Return complete column of table."""
221 return [row[0] for row in
222 self.exec(f'SELECT {column} FROM {table_name}')]
224 def delete_where(self, table_name: str, key: str,
225 target: int | str) -> None:
226 """Delete from table where key == target."""
227 self.exec(f'DELETE FROM {table_name} WHERE {key} = ?', (target,))
230 BaseModelId = TypeVar('BaseModelId', int, str)
231 BaseModelInstance = TypeVar('BaseModelInstance', bound='BaseModel[Any]')
234 class BaseModel(Generic[BaseModelId]):
235 """Template for most of the models we use/derive from the DB."""
237 to_save: list[str] = []
238 to_save_versioned: list[str] = []
239 to_save_relations: list[tuple[str, str, str, int]] = []
240 id_: None | BaseModelId
241 cache_: dict[BaseModelId, Self]
242 to_search: list[str] = []
244 def __init__(self, id_: BaseModelId | None) -> None:
245 if isinstance(id_, int) and id_ < 1:
246 msg = f'illegal {self.__class__.__name__} ID, must be >=1: {id_}'
247 raise HandledException(msg)
248 if isinstance(id_, str) and "" == id_:
249 msg = f'illegal {self.__class__.__name__} ID, must be non-empty'
250 raise HandledException(msg)
253 def __eq__(self, other: object) -> bool:
254 if not isinstance(other, self.__class__):
256 to_hash_me = tuple([self.id_] +
257 [getattr(self, name) for name in self.to_save])
258 to_hash_other = tuple([other.id_] +
259 [getattr(other, name) for name in other.to_save])
260 return hash(to_hash_me) == hash(to_hash_other)
262 def __lt__(self, other: Any) -> bool:
263 if not isinstance(other, self.__class__):
264 msg = 'cannot compare to object of different class'
265 raise HandledException(msg)
266 assert isinstance(self.id_, int)
267 assert isinstance(other.id_, int)
268 return self.id_ < other.id_
271 def get_cached(cls: type[BaseModelInstance],
272 id_: BaseModelId) -> BaseModelInstance | None:
273 """Get object of id_ from class's cache, or None if not found."""
274 # pylint: disable=consider-iterating-dictionary
275 cache = cls.get_cache()
276 if id_ in cache.keys():
278 assert isinstance(obj, cls)
283 def empty_cache(cls) -> None:
284 """Empty class's cache."""
288 def get_cache(cls: type[BaseModelInstance]) -> dict[Any, BaseModel[Any]]:
289 """Get cache dictionary, create it if not yet existing."""
290 if not hasattr(cls, 'cache_'):
291 d: dict[Any, BaseModel[Any]] = {}
295 def cache(self) -> None:
296 """Update object in class's cache."""
298 raise HandledException('Cannot cache object without ID.')
299 cache = self.__class__.get_cache()
300 cache[self.id_] = self
302 def uncache(self) -> None:
303 """Remove self from cache."""
305 raise HandledException('Cannot un-cache object without ID.')
306 cache = self.__class__.get_cache()
310 def from_table_row(cls: type[BaseModelInstance],
311 # pylint: disable=unused-argument
312 db_conn: DatabaseConnection,
313 row: Row | list[Any]) -> BaseModelInstance:
314 """Make from DB row, write to DB cache."""
320 def by_id(cls, db_conn: DatabaseConnection,
321 id_: BaseModelId | None,
322 # pylint: disable=unused-argument
323 create: bool = False) -> Self:
324 """Retrieve by id_, on failure throw NotFoundException.
326 First try to get from cls.cache_, only then check DB; if found,
329 If create=True, make anew (but do not cache yet).
333 obj = cls.get_cached(id_)
335 for row in db_conn.row_where(cls.table_name, 'id', id_):
336 obj = cls.from_table_row(db_conn, row)
344 raise NotFoundException(f'found no object of ID {id_}')
347 def all(cls: type[BaseModelInstance],
348 db_conn: DatabaseConnection) -> list[BaseModelInstance]:
349 """Collect all objects of class into list.
351 Note that this primarily returns the contents of the cache, and only
352 _expands_ that by additional findings in the DB. This assumes the
353 cache is always instantly cleaned of any items that would be removed
356 items: dict[BaseModelId, BaseModelInstance] = {}
357 for k, v in cls.get_cache().items():
358 assert isinstance(v, cls)
360 already_recorded = items.keys()
361 for id_ in db_conn.column_all(cls.table_name, 'id'):
362 if id_ not in already_recorded:
363 item = cls.by_id(db_conn, id_)
364 assert item.id_ is not None
365 items[item.id_] = item
366 return list(items.values())
369 def by_date_range_with_limits(cls: type[BaseModelInstance],
370 db_conn: DatabaseConnection,
371 date_range: tuple[str, str],
372 date_col: str = 'day'
373 ) -> tuple[list[BaseModelInstance], str,
375 """Return list of items in database within (open) date_range interval.
377 If no range values provided, defaults them to 'yesterday' and
378 'tomorrow'. Knows to properly interpret these and 'today' as value.
380 start_str = date_range[0] if date_range[0] else 'yesterday'
381 end_str = date_range[1] if date_range[1] else 'tomorrow'
382 start_date = valid_date(start_str)
383 end_date = valid_date(end_str)
385 sql = f'SELECT id FROM {cls.table_name} '
386 sql += f'WHERE {date_col} >= ? AND {date_col} <= ?'
387 for row in db_conn.exec(sql, (start_date, end_date)):
388 items += [cls.by_id(db_conn, row[0])]
389 return items, start_date, end_date
392 def matching(cls: type[BaseModelInstance], db_conn: DatabaseConnection,
393 pattern: str) -> list[BaseModelInstance]:
394 """Return all objects whose .to_search match pattern."""
395 items = cls.all(db_conn)
399 for attr_name in cls.to_search:
400 toks = attr_name.split('.')
403 attr = getattr(parent, tok)
411 def save(self, db_conn: DatabaseConnection) -> None:
412 """Write self to DB and cache and ensure .id_.
414 Write both to DB, and to cache. To DB, write .id_ and attributes
415 listed in cls.to_save[_versioned|_relations].
417 Ensure self.id_ by setting it to what the DB command returns as the
418 last saved row's ID (cursor.lastrowid), EXCEPT if self.id_ already
419 exists as a 'str', which implies we do our own ID creation (so far
420 only the case with the Day class, where it's to be a date string.
422 values = tuple([self.id_] + [getattr(self, key)
423 for key in self.to_save])
424 table_name = self.table_name
425 cursor = db_conn.exec_on_vals(f'REPLACE INTO {table_name} VALUES',
427 if not isinstance(self.id_, str):
428 self.id_ = cursor.lastrowid # type: ignore[assignment]
430 for attr_name in self.to_save_versioned:
431 getattr(self, attr_name).save(db_conn)
432 for table, column, attr_name, key_index in self.to_save_relations:
433 assert isinstance(self.id_, (int, str))
434 db_conn.rewrite_relations(table, column, self.id_,
436 in getattr(self, attr_name)], key_index)
438 def remove(self, db_conn: DatabaseConnection) -> None:
439 """Remove from DB and cache, including dependencies."""
440 if self.id_ is None or self.__class__.get_cached(self.id_) is None:
441 raise HandledException('cannot remove unsaved item')
442 for attr_name in self.to_save_versioned:
443 getattr(self, attr_name).remove(db_conn)
444 for table, column, attr_name, _ in self.to_save_relations:
445 db_conn.delete_where(table, column, self.id_)
447 db_conn.delete_where(self.table_name, 'id', self.id_)