From: Christian Heller <c.heller@plomlompom.de>
Date: Sat, 11 Jan 2025 04:47:53 +0000 (+0100)
Subject: Simplify BaseModel now that .id_ cannot be str anymore.
X-Git-Url: https://plomlompom.com/repos/%7B%7Bprefix%7D%7D/%7B%7Bdb.prefix%7D%7D/%7B%7B%20web_path%20%7D%7D/%7B%7Bprefix%7D%7D/edit?a=commitdiff_plain;h=adaf5a5349eab2c1f217d38e110f2c4d98c64116;p=plomtask

Simplify BaseModel now that .id_ cannot be str anymore.
---

diff --git a/plomtask/days.py b/plomtask/days.py
index aac59bb..ebd16f1 100644
--- a/plomtask/days.py
+++ b/plomtask/days.py
@@ -3,7 +3,7 @@ from __future__ import annotations
 from typing import Any, Self
 from sqlite3 import Row
 from datetime import date as dt_date, timedelta
-from plomtask.db import DatabaseConnection, BaseModel, BaseModelId
+from plomtask.db import DatabaseConnection, BaseModel
 from plomtask.todos import Todo
 from plomtask.dating import dt_date_from_days_n, days_n_from_dt_date
 
@@ -29,7 +29,7 @@ class Day(BaseModel):
         return day
 
     @classmethod
-    def by_id(cls, db_conn: DatabaseConnection, id_: BaseModelId) -> Self:
+    def by_id(cls, db_conn: DatabaseConnection, id_: int) -> Self:
         """Checks Todo.days_to_update if we need to a retrieved Day's .todos"""
         day = super().by_id(db_conn, id_)
         assert isinstance(day.id_, int)
diff --git a/plomtask/db.py b/plomtask/db.py
index 2b2c18e..2ce7a61 100644
--- a/plomtask/db.py
+++ b/plomtask/db.py
@@ -258,9 +258,6 @@ class DatabaseConnection:
         self.exec(f'DELETE FROM {table_name} WHERE {key} =', (target,))
 
 
-BaseModelId = int | str
-
-
 class BaseModel:
     """Template for most of the models we use/derive from the DB."""
     table_name = ''
@@ -268,20 +265,17 @@ class BaseModel:
     to_save_relations: list[tuple[str, str, str, int]] = []
     versioned_defaults: dict[str, str | float] = {}
     add_to_dict: list[str] = []
-    id_: None | BaseModelId
-    cache_: dict[BaseModelId, Self]
+    id_: None | int
+    cache_: dict[int, Self]
     to_search: list[str] = []
     can_create_by_id = False
     _exists = True
     sorters: dict[str, Callable[..., Any]] = {}
 
-    def __init__(self, id_: BaseModelId | None) -> None:
+    def __init__(self, id_: int | None) -> None:
         if isinstance(id_, int) and id_ < 1:
             msg = f'illegal {self.__class__.__name__} ID, must be >=1: {id_}'
             raise BadFormatException(msg)
-        if isinstance(id_, str) and "" == id_:
-            msg = f'illegal {self.__class__.__name__} ID, must be non-empty'
-            raise BadFormatException(msg)
         self.id_ = id_
 
     def __hash__(self) -> int:
@@ -313,10 +307,10 @@ class BaseModel:
         return list(cls.versioned_defaults.keys())
 
     @property
-    def as_dict_and_refs(self) -> tuple[dict[str, object], list[BaseModel]]:
+    def as_dict_and_refs(self) -> tuple[dict[str, object], list[Self]]:
         """Return self as json.dumps-ready dict, list of referenced objects."""
         d: dict[str, object] = {'id': self.id_}
-        refs: list[BaseModel] = []
+        refs: list[Self] = []
         for to_save in self.to_save_simples:
             d[to_save] = getattr(self, to_save)
         if len(self.to_save_versioned()) > 0:
@@ -397,15 +391,15 @@ class BaseModel:
         cls.cache_ = {}
 
     @classmethod
-    def get_cache(cls) -> dict[BaseModelId, Self]:
+    def get_cache(cls) -> dict[int, Self]:
         """Get cache dictionary, create it if not yet existing."""
         if not hasattr(cls, 'cache_'):
-            d: dict[BaseModelId, BaseModel] = {}
+            d: dict[int, Self] = {}
             cls.cache_ = d
         return cls.cache_
 
     @classmethod
-    def _get_cached(cls, id_: BaseModelId) -> Self | None:
+    def _get_cached(cls, id_: int) -> Self | None:
         """Get object of id_ from class's cache, or None if not found."""
         cache = cls.get_cache()
         if id_ in cache:
@@ -455,7 +449,7 @@ class BaseModel:
         return obj
 
     @classmethod
-    def by_id(cls, db_conn: DatabaseConnection, id_: BaseModelId) -> Self:
+    def by_id(cls, db_conn: DatabaseConnection, id_: int) -> Self:
         """Retrieve by id_, on failure throw NotFoundException.
 
         First try to get from cls.cache_, only then check DB; if found,
@@ -475,8 +469,7 @@ class BaseModel:
         raise NotFoundException(f'found no object of ID {id_}')
 
     @classmethod
-    def by_id_or_create(cls, db_conn: DatabaseConnection,
-                        id_: BaseModelId | None
+    def by_id_or_create(cls, db_conn: DatabaseConnection, id_: int | None
                         ) -> Self:
         """Wrapper around .by_id, creating (not caching/saving) if no find."""
         if not cls.can_create_by_id:
@@ -497,7 +490,7 @@ class BaseModel:
         cache is always instantly cleaned of any items that would be removed
         from the DB.
         """
-        items: dict[BaseModelId, Self] = {}
+        items: dict[int, Self] = {}
         for k, v in cls.get_cache().items():
             items[k] = v
         already_recorded = items.keys()
@@ -550,7 +543,7 @@ class BaseModel:
         for attr_name in self.to_save_versioned():
             getattr(self, attr_name).save(db_conn)
         for table, column, attr_name, key_index in self.to_save_relations:
-            assert isinstance(self.id_, (int, str))
+            assert isinstance(self.id_, int)
             db_conn.rewrite_relations(table, column, self.id_,
                                       [[i.id_] for i
                                        in getattr(self, attr_name)], key_index)
diff --git a/plomtask/http.py b/plomtask/http.py
index 348feb0..b6c6845 100644
--- a/plomtask/http.py
+++ b/plomtask/http.py
@@ -525,7 +525,6 @@ class TaskHandler(BaseHTTPRequestHandler):
         for process_id in owned_ids:
             Process.by_id(self._conn, process_id)  # to ensure ID exists
             preset_top_step = process_id
-        assert not isinstance(process.id_, str)
         return {'process': process,
                 'is_new': not exists,
                 'preset_top_step': preset_top_step,