From: Christian Heller <c.heller@plomlompom.de>
Date: Tue, 18 Jun 2024 02:37:57 +0000 (+0200)
Subject: Split BaseModel.by_id into .by_id and by_id_or_create, refactor tests.
X-Git-Url: https://plomlompom.com/repos/%7B%7Bprefix%7D%7D/%7B%7B%20web_path%20%7D%7D/static/tasks?a=commitdiff_plain;h=25b71c6f0b10db05907128daf50c6e543e514c35;p=plomtask

Split BaseModel.by_id into .by_id and by_id_or_create, refactor tests.
---

diff --git a/plomtask/conditions.py b/plomtask/conditions.py
index 70365ce..b60d0af 100644
--- a/plomtask/conditions.py
+++ b/plomtask/conditions.py
@@ -11,6 +11,7 @@ class Condition(BaseModel[int]):
     to_save = ['is_active']
     to_save_versioned = ['title', 'description']
     to_search = ['title.newest', 'description.newest']
+    can_create_by_id = True
 
     def __init__(self, id_: int | None, is_active: bool = False) -> None:
         super().__init__(id_)
@@ -25,13 +26,14 @@ class Condition(BaseModel[int]):
         Checks for Todos and Processes that depend on Condition, prohibits
         deletion if found.
         """
-        if self.id_ is None:
-            raise HandledException('cannot remove unsaved item')
-        for item in ('process', 'todo'):
-            for attr in ('conditions', 'blockers', 'enables', 'disables'):
-                table_name = f'{item}_{attr}'
-                for _ in db_conn.row_where(table_name, 'condition', self.id_):
-                    raise HandledException('cannot remove Condition in use')
+        if self.id_ is not None:
+            for item in ('process', 'todo'):
+                for attr in ('conditions', 'blockers', 'enables', 'disables'):
+                    table_name = f'{item}_{attr}'
+                    for _ in db_conn.row_where(table_name, 'condition',
+                                               self.id_):
+                        msg = 'cannot remove Condition in use'
+                        raise HandledException(msg)
         super().remove(db_conn)
 
 
diff --git a/plomtask/days.py b/plomtask/days.py
index a924bbf..267156d 100644
--- a/plomtask/days.py
+++ b/plomtask/days.py
@@ -12,6 +12,7 @@ class Day(BaseModel[str]):
     """Individual days defined by their dates."""
     table_name = 'days'
     to_save = ['comment']
+    can_create_by_id = True
 
     def __init__(self, date: str, comment: str = '') -> None:
         id_ = valid_date(date)
@@ -40,12 +41,9 @@ class Day(BaseModel[str]):
         return day
 
     @classmethod
-    def by_id(cls,
-              db_conn: DatabaseConnection, id_: str | None,
-              create: bool = False,
-              ) -> Day:
+    def by_id(cls, db_conn: DatabaseConnection, id_: str | None) -> Day:
         """Extend BaseModel.by_id checking for new/lost .todos."""
-        day = super().by_id(db_conn, id_, create)
+        day = super().by_id(db_conn, id_)
         assert day.id_ is not None
         if day.id_ in Todo.days_to_update:
             Todo.days_to_update.remove(day.id_)
diff --git a/plomtask/db.py b/plomtask/db.py
index 853b4c6..f6ef1cb 100644
--- a/plomtask/db.py
+++ b/plomtask/db.py
@@ -238,6 +238,7 @@ class BaseModel(Generic[BaseModelId]):
     id_: None | BaseModelId
     cache_: dict[BaseModelId, Self]
     to_search: list[str] = []
+    can_create_by_id = False
     _exists = True
 
     def __init__(self, id_: BaseModelId | None) -> None:
@@ -388,15 +389,12 @@ class BaseModel(Generic[BaseModelId]):
 
     @classmethod
     def by_id(cls, db_conn: DatabaseConnection,
-              id_: BaseModelId | None,
-              # pylint: disable=unused-argument
-              create: bool = False) -> Self:
+              id_: BaseModelId | None
+              ) -> Self:
         """Retrieve by id_, on failure throw NotFoundException.
 
         First try to get from cls.cache_, only then check DB; if found,
         put into cache.
-
-        If create=True, make anew (but do not cache yet).
         """
         obj = None
         if id_ is not None:
@@ -407,10 +405,20 @@ class BaseModel(Generic[BaseModelId]):
                     break
         if obj:
             return obj
-        if create:
+        raise NotFoundException(f'found no object of ID {id_}')
+
+    @classmethod
+    def by_id_or_create(cls, db_conn: DatabaseConnection,
+                        id_: BaseModelId | None
+                        ) -> Self:
+        """Wrapper around .by_id, creating (not caching/saving) if not find."""
+        if not cls.can_create_by_id:
+            raise HandledException('Class cannot .by_id_or_create.')
+        try:
+            return cls.by_id(db_conn, id_)
+        except NotFoundException:
             obj = cls(id_)
             return obj
-        raise NotFoundException(f'found no object of ID {id_}')
 
     @classmethod
     def all(cls: type[BaseModelInstance],
diff --git a/plomtask/http.py b/plomtask/http.py
index a5e4613..be79159 100644
--- a/plomtask/http.py
+++ b/plomtask/http.py
@@ -244,7 +244,7 @@ class TaskHandler(BaseHTTPRequestHandler):
     def do_GET_day(self) -> dict[str, object]:
         """Show single Day of ?date=."""
         date = self._params.get_str('date', date_in_n_days(0))
-        day = Day.by_id(self.conn, date, create=True)
+        day = Day.by_id_or_create(self.conn, date)
         make_type = self._params.get_str('make_type')
         conditions_present = []
         enablers_for = {}
@@ -400,7 +400,7 @@ class TaskHandler(BaseHTTPRequestHandler):
     def do_GET_condition(self) -> dict[str, object]:
         """Show Condition of ?id=."""
         id_ = self._params.get_int_or_none('id')
-        c = Condition.by_id(self.conn, id_, create=True)
+        c = Condition.by_id_or_create(self.conn, id_)
         ps = Process.all(self.conn)
         return {'condition': c, 'is_new': c.id_ is None,
                 'enabled_processes': [p for p in ps if c in p.conditions],
@@ -423,7 +423,7 @@ class TaskHandler(BaseHTTPRequestHandler):
     def do_GET_process(self) -> dict[str, object]:
         """Show Process of ?id=."""
         id_ = self._params.get_int_or_none('id')
-        process = Process.by_id(self.conn, id_, create=True)
+        process = Process.by_id_or_create(self.conn, id_)
         title_64 = self._params.get_str('title_b64')
         if title_64:
             title = b64decode(title_64.encode()).decode()
@@ -501,7 +501,7 @@ class TaskHandler(BaseHTTPRequestHandler):
     def do_POST_day(self) -> str:
         """Update or insert Day of date and Todos mapped to it."""
         date = self._params.get_str('date')
-        day = Day.by_id(self.conn, date, create=True)
+        day = Day.by_id_or_create(self.conn, date)
         day.comment = self._form_data.get_str('day_comment')
         day.save(self.conn)
         make_type = self._form_data.get_str('make_type')
@@ -600,7 +600,7 @@ class TaskHandler(BaseHTTPRequestHandler):
             process = Process.by_id(self.conn, id_)
             process.remove(self.conn)
             return '/processes'
-        process = Process.by_id(self.conn, id_, create=True)
+        process = Process.by_id_or_create(self.conn, id_)
         process.title.set(self._form_data.get_str('title'))
         process.description.set(self._form_data.get_str('description'))
         process.effort.set(self._form_data.get_float('effort'))
@@ -676,7 +676,7 @@ class TaskHandler(BaseHTTPRequestHandler):
             condition = Condition.by_id(self.conn, id_)
             condition.remove(self.conn)
             return '/conditions'
-        condition = Condition.by_id(self.conn, id_, create=True)
+        condition = Condition.by_id_or_create(self.conn, id_)
         condition.is_active = self._form_data.get_str('is_active') == 'True'
         condition.title.set(self._form_data.get_str('title'))
         condition.description.set(self._form_data.get_str('description'))
diff --git a/plomtask/processes.py b/plomtask/processes.py
index 4ff90ef..d007d0f 100644
--- a/plomtask/processes.py
+++ b/plomtask/processes.py
@@ -34,6 +34,7 @@ class Process(BaseModel[int], ConditionsRelations):
                          ('process_step_suppressions', 'process',
                           'suppressed_steps', 0)]
     to_search = ['title.newest', 'description.newest']
+    can_create_by_id = True
 
     def __init__(self, id_: int | None, calendarize: bool = False) -> None:
         BaseModel.__init__(self, id_)
diff --git a/tests/conditions.py b/tests/conditions.py
index 562dcd9..969942b 100644
--- a/tests/conditions.py
+++ b/tests/conditions.py
@@ -25,10 +25,6 @@ class TestsWithDB(TestCaseWithDB):
         self.check_versioned_from_table_row('title', str)
         self.check_versioned_from_table_row('description', str)
 
-    def test_Condition_by_id(self) -> None:
-        """Test .by_id(), including creation."""
-        self.check_by_id()
-
     def test_Condition_versioned_attributes_singularity(self) -> None:
         """Test behavior of VersionedAttributes on saving (with .title)."""
         self.check_versioned_singularity()
diff --git a/tests/days.py b/tests/days.py
index 1972dbd..02b6c22 100644
--- a/tests/days.py
+++ b/tests/days.py
@@ -53,10 +53,6 @@ class TestsWithDB(TestCaseWithDB):
         kwargs = {'date': self.default_ids[0], 'comment': 'foo'}
         self.check_saving_and_caching(**kwargs)
 
-    def test_Day_by_id(self) -> None:
-        """Test .by_id()."""
-        self.check_by_id()
-
     def test_Day_by_date_range_filled(self) -> None:
         """Test Day.by_date_range_filled."""
         date1, date2, date3 = self.default_ids
@@ -87,7 +83,7 @@ class TestsWithDB(TestCaseWithDB):
         self.assertEqual(Day.by_date_range_filled(self.db_conn,
                                                   day5.date, day7.date),
                          [day5, day6, day7])
-        self.check_storage([day1, day2, day3, day6])
+        self.check_identity_with_cache_and_db([day1, day2, day3, day6])
         # check 'today' is interpreted as today's date
         today = Day(date_in_n_days(0))
         today.save(self.db_conn)
diff --git a/tests/misc.py b/tests/misc.py
index b0fb872..a27f0d0 100644
--- a/tests/misc.py
+++ b/tests/misc.py
@@ -151,7 +151,7 @@ class TestsWithServer(TestCaseWithServer):
     """Tests against our HTTP server/handler (and database)."""
 
     def test_do_GET(self) -> None:
-        """Test / redirect, and unknown targets failing."""
+        """Test GET / redirect, and unknown targets failing."""
         self.conn.request('GET', '/')
         self.check_redirect('/day')
         self.check_get('/foo', 404)
diff --git a/tests/processes.py b/tests/processes.py
index f495fd5..0f43a4d 100644
--- a/tests/processes.py
+++ b/tests/processes.py
@@ -182,10 +182,6 @@ class TestsWithDB(TestCaseWithDB):
             method(self.db_conn, [c1.id_, c2.id_])
             self.assertEqual(getattr(p, target), [c1, c2])
 
-    def test_Process_by_id(self) -> None:
-        """Test .by_id(), including creation"""
-        self.check_by_id()
-
     def test_Process_versioned_attributes_singularity(self) -> None:
         """Test behavior of VersionedAttributes on saving (with .title)."""
         self.check_versioned_singularity()
@@ -249,7 +245,7 @@ class TestsWithDBForProcessStep(TestCaseWithDB):
         p1.set_steps(self.db_conn, [step])
         step.remove(self.db_conn)
         self.assertEqual(p1.explicit_steps, [])
-        self.check_storage([])
+        self.check_identity_with_cache_and_db([])
 
 
 class TestsWithServer(TestCaseWithServer):
diff --git a/tests/utils.py b/tests/utils.py
index 9d3d11d..55c948a 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -87,8 +87,8 @@ class TestCaseWithDB(TestCase):
             attr = getattr(retrieved, attr_name)
             self.assertEqual(sorted(attr.history.values()), vals)
 
-    def check_storage(self, content: list[Any]) -> None:
-        """Test cache and DB equal content."""
+    def check_identity_with_cache_and_db(self, content: list[Any]) -> None:
+        """Test both cache and DB equal content."""
         expected_cache = {}
         for item in content:
             expected_cache[item.id_] = item
@@ -108,30 +108,47 @@ class TestCaseWithDB(TestCase):
         """Test instance.save in its core without relations."""
         obj = self.checked_class(**kwargs)  # pylint: disable=not-callable
         # check object init itself doesn't store anything yet
-        self.check_storage([])
+        self.check_identity_with_cache_and_db([])
         # check saving sets core attributes properly
         obj.save(self.db_conn)
         for key, value in kwargs.items():
             self.assertEqual(getattr(obj, key), value)
         # check saving stored properly in cache and DB
-        self.check_storage([obj])
+        self.check_identity_with_cache_and_db([obj])
 
-    def check_by_id(self) -> None:
-        """Test .by_id(), including creation."""
+    @_within_checked_class
+    def test_by_id(self) -> None:
+        """Test .by_id()."""
+        id1, id2, _ = self.default_ids
         # check failure if not yet saved
-        id1, id2 = self.default_ids[0], self.default_ids[1]
-        obj = self.checked_class(id1)  # pylint: disable=not-callable
+        obj1 = self.checked_class(id1, **self.default_init_kwargs)
         with self.assertRaises(NotFoundException):
             self.checked_class.by_id(self.db_conn, id1)
+        # check identity of cached and retrieved
+        obj1.cache()
+        self.assertEqual(obj1, self.checked_class.by_id(self.db_conn, id1))
         # check identity of saved and retrieved
-        obj.save(self.db_conn)
-        self.assertEqual(obj, self.checked_class.by_id(self.db_conn, id1))
-        # check create=True acts like normal instantiation (sans saving)
-        by_id_created = self.checked_class.by_id(self.db_conn, id2,
-                                                 create=True)
-        # pylint: disable=not-callable
-        self.assertEqual(self.checked_class(id2), by_id_created)
-        self.check_storage([obj])
+        obj2 = self.checked_class(id2, **self.default_init_kwargs)
+        obj2.save(self.db_conn)
+        self.assertEqual(obj2, self.checked_class.by_id(self.db_conn, id2))
+        # obj1.save(self.db_conn)
+        # self.check_identity_with_cache_and_db([obj1, obj2])
+
+    @_within_checked_class
+    def test_by_id_or_create(self) -> None:
+        """Test .by_id_or_create."""
+        # check .by_id_or_create acts like normal instantiation (sans saving)
+        id_ = self.default_ids[0]
+        if not self.checked_class.can_create_by_id:
+            with self.assertRaises(HandledException):
+                self.checked_class.by_id_or_create(self.db_conn, id_)
+        # check .by_id_or_create fails if wrong class
+        else:
+            by_id_created = self.checked_class.by_id_or_create(self.db_conn,
+                                                               id_)
+            with self.assertRaises(NotFoundException):
+                self.checked_class.by_id(self.db_conn, id_)
+            self.assertEqual(self.checked_class(id_), by_id_created)
 
     @_within_checked_class
     def test_from_table_row(self) -> None:
@@ -230,7 +247,7 @@ class TestCaseWithDB(TestCase):
             obj.remove(self.db_conn)
         obj.save(self.db_conn)
         obj.remove(self.db_conn)
-        self.check_storage([])
+        self.check_identity_with_cache_and_db([])
 
 
 class TestCaseWithServer(TestCaseWithDB):