From adaf5a5349eab2c1f217d38e110f2c4d98c64116 Mon Sep 17 00:00:00 2001 From: Christian Heller Date: Sat, 11 Jan 2025 05:47:53 +0100 Subject: [PATCH] Simplify BaseModel now that .id_ cannot be str anymore. --- plomtask/days.py | 4 ++-- plomtask/db.py | 31 ++++++++++++------------------- plomtask/http.py | 1 - 3 files changed, 14 insertions(+), 22 deletions(-) diff --git a/plomtask/days.py b/plomtask/days.py index aac59bb..ebd16f1 100644 --- a/plomtask/days.py +++ b/plomtask/days.py @@ -3,7 +3,7 @@ from __future__ import annotations from typing import Any, Self from sqlite3 import Row from datetime import date as dt_date, timedelta -from plomtask.db import DatabaseConnection, BaseModel, BaseModelId +from plomtask.db import DatabaseConnection, BaseModel from plomtask.todos import Todo from plomtask.dating import dt_date_from_days_n, days_n_from_dt_date @@ -29,7 +29,7 @@ class Day(BaseModel): return day @classmethod - def by_id(cls, db_conn: DatabaseConnection, id_: BaseModelId) -> Self: + def by_id(cls, db_conn: DatabaseConnection, id_: int) -> Self: """Checks Todo.days_to_update if we need to a retrieved Day's .todos""" day = super().by_id(db_conn, id_) assert isinstance(day.id_, int) diff --git a/plomtask/db.py b/plomtask/db.py index 2b2c18e..2ce7a61 100644 --- a/plomtask/db.py +++ b/plomtask/db.py @@ -258,9 +258,6 @@ class DatabaseConnection: self.exec(f'DELETE FROM {table_name} WHERE {key} =', (target,)) -BaseModelId = int | str - - class BaseModel: """Template for most of the models we use/derive from the DB.""" table_name = '' @@ -268,20 +265,17 @@ class BaseModel: to_save_relations: list[tuple[str, str, str, int]] = [] versioned_defaults: dict[str, str | float] = {} add_to_dict: list[str] = [] - id_: None | BaseModelId - cache_: dict[BaseModelId, Self] + id_: None | int + cache_: dict[int, Self] to_search: list[str] = [] can_create_by_id = False _exists = True sorters: dict[str, Callable[..., Any]] = {} - def __init__(self, id_: BaseModelId | None) -> None: + def __init__(self, id_: int | None) -> None: if isinstance(id_, int) and id_ < 1: msg = f'illegal {self.__class__.__name__} ID, must be >=1: {id_}' raise BadFormatException(msg) - if isinstance(id_, str) and "" == id_: - msg = f'illegal {self.__class__.__name__} ID, must be non-empty' - raise BadFormatException(msg) self.id_ = id_ def __hash__(self) -> int: @@ -313,10 +307,10 @@ class BaseModel: return list(cls.versioned_defaults.keys()) @property - def as_dict_and_refs(self) -> tuple[dict[str, object], list[BaseModel]]: + def as_dict_and_refs(self) -> tuple[dict[str, object], list[Self]]: """Return self as json.dumps-ready dict, list of referenced objects.""" d: dict[str, object] = {'id': self.id_} - refs: list[BaseModel] = [] + refs: list[Self] = [] for to_save in self.to_save_simples: d[to_save] = getattr(self, to_save) if len(self.to_save_versioned()) > 0: @@ -397,15 +391,15 @@ class BaseModel: cls.cache_ = {} @classmethod - def get_cache(cls) -> dict[BaseModelId, Self]: + def get_cache(cls) -> dict[int, Self]: """Get cache dictionary, create it if not yet existing.""" if not hasattr(cls, 'cache_'): - d: dict[BaseModelId, BaseModel] = {} + d: dict[int, Self] = {} cls.cache_ = d return cls.cache_ @classmethod - def _get_cached(cls, id_: BaseModelId) -> Self | None: + def _get_cached(cls, id_: int) -> Self | None: """Get object of id_ from class's cache, or None if not found.""" cache = cls.get_cache() if id_ in cache: @@ -455,7 +449,7 @@ class BaseModel: return obj @classmethod - def by_id(cls, db_conn: DatabaseConnection, id_: BaseModelId) -> Self: + def by_id(cls, db_conn: DatabaseConnection, id_: int) -> Self: """Retrieve by id_, on failure throw NotFoundException. First try to get from cls.cache_, only then check DB; if found, @@ -475,8 +469,7 @@ class BaseModel: raise NotFoundException(f'found no object of ID {id_}') @classmethod - def by_id_or_create(cls, db_conn: DatabaseConnection, - id_: BaseModelId | None + def by_id_or_create(cls, db_conn: DatabaseConnection, id_: int | None ) -> Self: """Wrapper around .by_id, creating (not caching/saving) if no find.""" if not cls.can_create_by_id: @@ -497,7 +490,7 @@ class BaseModel: cache is always instantly cleaned of any items that would be removed from the DB. """ - items: dict[BaseModelId, Self] = {} + items: dict[int, Self] = {} for k, v in cls.get_cache().items(): items[k] = v already_recorded = items.keys() @@ -550,7 +543,7 @@ class BaseModel: for attr_name in self.to_save_versioned(): getattr(self, attr_name).save(db_conn) for table, column, attr_name, key_index in self.to_save_relations: - assert isinstance(self.id_, (int, str)) + assert isinstance(self.id_, int) db_conn.rewrite_relations(table, column, self.id_, [[i.id_] for i in getattr(self, attr_name)], key_index) diff --git a/plomtask/http.py b/plomtask/http.py index 348feb0..b6c6845 100644 --- a/plomtask/http.py +++ b/plomtask/http.py @@ -525,7 +525,6 @@ class TaskHandler(BaseHTTPRequestHandler): for process_id in owned_ids: Process.by_id(self._conn, process_id) # to ensure ID exists preset_top_step = process_id - assert not isinstance(process.id_, str) return {'process': process, 'is_new': not exists, 'preset_top_step': preset_top_step, -- 2.30.2