From 8f28c8c685fa91b9cbabb4b424da4091e52058cf Mon Sep 17 00:00:00 2001
From: Christian Heller <c.heller@plomlompom.de>
Date: Tue, 18 Jun 2024 07:02:04 +0200
Subject: [PATCH] Refactor saving and caching tests, treatment of None IDs.

---
 plomtask/days.py                 |  3 +-
 plomtask/db.py                   |  9 ++--
 plomtask/http.py                 | 19 ++++---
 plomtask/versioned_attributes.py |  5 +-
 tests/days.py                    |  9 ----
 tests/processes.py               | 10 ++--
 tests/todos.py                   |  2 +
 tests/utils.py                   | 92 +++++++++++++++++++++++---------
 8 files changed, 95 insertions(+), 54 deletions(-)

diff --git a/plomtask/days.py b/plomtask/days.py
index 267156d..0bd942c 100644
--- a/plomtask/days.py
+++ b/plomtask/days.py
@@ -41,10 +41,9 @@ class Day(BaseModel[str]):
         return day
 
     @classmethod
-    def by_id(cls, db_conn: DatabaseConnection, id_: str | None) -> Day:
+    def by_id(cls, db_conn: DatabaseConnection, id_: str) -> Day:
         """Extend BaseModel.by_id checking for new/lost .todos."""
         day = super().by_id(db_conn, id_)
-        assert day.id_ is not None
         if day.id_ in Todo.days_to_update:
             Todo.days_to_update.remove(day.id_)
             day.todos = Todo.by_date(db_conn, day.id_)
diff --git a/plomtask/db.py b/plomtask/db.py
index f6ef1cb..797b08e 100644
--- a/plomtask/db.py
+++ b/plomtask/db.py
@@ -388,9 +388,7 @@ class BaseModel(Generic[BaseModelId]):
         return obj
 
     @classmethod
-    def by_id(cls, db_conn: DatabaseConnection,
-              id_: BaseModelId | None
-              ) -> Self:
+    def by_id(cls, db_conn: DatabaseConnection, id_: BaseModelId) -> Self:
         """Retrieve by id_, on failure throw NotFoundException.
 
         First try to get from cls.cache_, only then check DB; if found,
@@ -414,11 +412,12 @@ class BaseModel(Generic[BaseModelId]):
         """Wrapper around .by_id, creating (not caching/saving) if not find."""
         if not cls.can_create_by_id:
             raise HandledException('Class cannot .by_id_or_create.')
+        if id_ is None:
+            return cls(None)
         try:
             return cls.by_id(db_conn, id_)
         except NotFoundException:
-            obj = cls(id_)
-            return obj
+            return cls(id_)
 
     @classmethod
     def all(cls: type[BaseModelInstance],
diff --git a/plomtask/http.py b/plomtask/http.py
index be79159..7c7fbd4 100644
--- a/plomtask/http.py
+++ b/plomtask/http.py
@@ -335,7 +335,8 @@ class TaskHandler(BaseHTTPRequestHandler):
         adoptables: dict[int, list[Todo]] = {}
         any_adoptables = [Todo.by_id(self.conn, t.id_)
                           for t in Todo.by_date(self.conn, todo.date)
-                          if t != todo]
+                          if t.id_ is not None
+                          and t != todo]
         for id_ in collect_adoptables_keys(steps_todo_to_process):
             adoptables[id_] = [t for t in any_adoptables
                                if t.process.id_ == id_]
@@ -410,13 +411,13 @@ class TaskHandler(BaseHTTPRequestHandler):
 
     def do_GET_condition_titles(self) -> dict[str, object]:
         """Show title history of Condition of ?id=."""
-        id_ = self._params.get_int_or_none('id')
+        id_ = self._params.get_int('id')
         condition = Condition.by_id(self.conn, id_)
         return {'condition': condition}
 
     def do_GET_condition_descriptions(self) -> dict[str, object]:
         """Show description historys of Condition of ?id=."""
-        id_ = self._params.get_int_or_none('id')
+        id_ = self._params.get_int('id')
         condition = Condition.by_id(self.conn, id_)
         return {'condition': condition}
 
@@ -443,19 +444,19 @@ class TaskHandler(BaseHTTPRequestHandler):
 
     def do_GET_process_titles(self) -> dict[str, object]:
         """Show title history of Process of ?id=."""
-        id_ = self._params.get_int_or_none('id')
+        id_ = self._params.get_int('id')
         process = Process.by_id(self.conn, id_)
         return {'process': process}
 
     def do_GET_process_descriptions(self) -> dict[str, object]:
         """Show description historys of Process of ?id=."""
-        id_ = self._params.get_int_or_none('id')
+        id_ = self._params.get_int('id')
         process = Process.by_id(self.conn, id_)
         return {'process': process}
 
     def do_GET_process_efforts(self) -> dict[str, object]:
         """Show default effort history of Process of ?id=."""
-        id_ = self._params.get_int_or_none('id')
+        id_ = self._params.get_int('id')
         process = Process.by_id(self.conn, id_)
         return {'process': process}
 
@@ -597,6 +598,8 @@ class TaskHandler(BaseHTTPRequestHandler):
         # pylint: disable=too-many-branches
         id_ = self._params.get_int_or_none('id')
         for _ in self._form_data.get_all_str('delete'):
+            if id_ is None:
+                raise NotFoundException('trying to delete non-saved Process')
             process = Process.by_id(self.conn, id_)
             process.remove(self.conn)
             return '/processes'
@@ -673,7 +676,9 @@ class TaskHandler(BaseHTTPRequestHandler):
         """Update/insert Condition of ?id= and fields defined in postvars."""
         id_ = self._params.get_int_or_none('id')
         for _ in self._form_data.get_all_str('delete'):
-            condition = Condition.by_id(self.conn, id_)
+            if id_ is None:
+                raise NotFoundException('trying to delete non-saved Condition')
+            condition = Condition.by_id_or_create(self.conn, id_)
             condition.remove(self.conn)
             return '/conditions'
         condition = Condition.by_id_or_create(self.conn, id_)
diff --git a/plomtask/versioned_attributes.py b/plomtask/versioned_attributes.py
index cbd1c8e..8861c98 100644
--- a/plomtask/versioned_attributes.py
+++ b/plomtask/versioned_attributes.py
@@ -4,7 +4,8 @@ from typing import Any
 from sqlite3 import Row
 from time import sleep
 from plomtask.db import DatabaseConnection
-from plomtask.exceptions import HandledException, BadFormatException
+from plomtask.exceptions import (HandledException, BadFormatException,
+                                 NotFoundException)
 
 TIMESTAMP_FMT = '%Y-%m-%d %H:%M:%S.%f'
 
@@ -98,6 +99,8 @@ class VersionedAttribute:
 
     def save(self, db_conn: DatabaseConnection) -> None:
         """Save as self.history entries, but first wipe old ones."""
+        if self.parent.id_ is None:
+            raise NotFoundException('cannot save attribute to parent if no ID')
         db_conn.rewrite_relations(self.table_name, 'parent', self.parent.id_,
                                   [[item[0], item[1]]
                                    for item in self.history.items()])
diff --git a/tests/days.py b/tests/days.py
index 02b6c22..9fb12ad 100644
--- a/tests/days.py
+++ b/tests/days.py
@@ -44,15 +44,6 @@ class TestsWithDB(TestCaseWithDB):
     checked_class = Day
     default_ids = ('2024-01-01', '2024-01-02', '2024-01-03')
 
-    def test_saving_and_caching(self) -> None:
-        """Test storage of instances.
-
-        We don't use the parent class's method here because the checked class
-        has too different a handling of IDs.
-        """
-        kwargs = {'date': self.default_ids[0], 'comment': 'foo'}
-        self.check_saving_and_caching(**kwargs)
-
     def test_Day_by_date_range_filled(self) -> None:
         """Test Day.by_date_range_filled."""
         date1, date2, date3 = self.default_ids
diff --git a/tests/processes.py b/tests/processes.py
index 0f43a4d..d33aa8f 100644
--- a/tests/processes.py
+++ b/tests/processes.py
@@ -58,6 +58,7 @@ class TestsWithDB(TestCaseWithDB):
     def test_Process_conditions_saving(self) -> None:
         """Test .save/.save_core."""
         p, set1, set2, set3 = self.p_of_conditions()
+        assert p.id_ is not None
         r = Process.by_id(self.db_conn, p.id_)
         self.assertEqual(sorted(r.conditions), sorted(set1))
         self.assertEqual(sorted(r.enables), sorted(set2))
@@ -200,13 +201,15 @@ class TestsWithDB(TestCaseWithDB):
             p1.remove(self.db_conn)
         p2.set_steps(self.db_conn, [])
         with self.assertRaises(NotFoundException):
+            assert step_id is not None
             ProcessStep.by_id(self.db_conn, step_id)
         p1.remove(self.db_conn)
         step = ProcessStep(None, p2.id_, p3.id_, None)
-        step_id = step.id_
         p2.set_steps(self.db_conn, [step])
+        step_id = step.id_
         p2.remove(self.db_conn)
         with self.assertRaises(NotFoundException):
+            assert step_id is not None
             ProcessStep.by_id(self.db_conn, step_id)
         todo = Todo(None, p3, False, '2024-01-01')
         todo.save(self.db_conn)
@@ -229,10 +232,6 @@ class TestsWithDBForProcessStep(TestCaseWithDB):
         p = Process(2)
         p.save(self.db_conn)
 
-    def test_saving_and_caching(self) -> None:
-        """Test storage and initialization of instances and attributes."""
-        self.check_saving_and_caching(id_=1, **self.default_init_kwargs)
-
     def test_ProcessStep_remove(self) -> None:
         """Test .remove and unsetting of owner's .explicit_steps entry."""
         p1 = Process(None)
@@ -300,6 +299,7 @@ class TestsWithServer(TestCaseWithServer):
         self.post_process(1, form_data_1)
         retrieved_process = Process.by_id(self.db_conn, 1)
         self.assertEqual(retrieved_process.explicit_steps, [])
+        assert retrieved_step_id is not None
         with self.assertRaises(NotFoundException):
             ProcessStep.by_id(self.db_conn, retrieved_step_id)
         # post new first (top_level) step of process 3 to process 1
diff --git a/tests/todos.py b/tests/todos.py
index 56aaf48..7632f39 100644
--- a/tests/todos.py
+++ b/tests/todos.py
@@ -206,6 +206,7 @@ class TestsWithDB(TestCaseWithDB, TestCaseSansDB):
         """Test removal."""
         todo_1 = Todo(None, self.proc, False, self.date1)
         todo_1.save(self.db_conn)
+        assert todo_1.id_ is not None
         todo_0 = Todo(None, self.proc, False, self.date1)
         todo_0.save(self.db_conn)
         todo_0.add_child(todo_1)
@@ -233,6 +234,7 @@ class TestsWithDB(TestCaseWithDB, TestCaseSansDB):
         todo_1.comment = 'foo'
         todo_1.effort = -0.1
         todo_1.save(self.db_conn)
+        assert todo_1.id_ is not None
         Todo.by_id(self.db_conn, todo_1.id_)
         todo_1.comment = ''
         todo_1_id = todo_1.id_
diff --git a/tests/utils.py b/tests/utils.py
index 55c948a..6015710 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -69,23 +69,49 @@ class TestCaseWithDB(TestCase):
                 f(self)
         return wrapper
 
+    def _load_from_db(self, id_: int | str) -> list[object]:
+        db_found: list[object] = []
+        for row in self.db_conn.row_where(self.checked_class.table_name,
+                                          'id', id_):
+            db_found += [self.checked_class.from_table_row(self.db_conn,
+                                                           row)]
+        return db_found
+
     @_within_checked_class
-    def test_saving_and_caching(self) -> None:
-        """Test storage and initialization of instances and attributes."""
-        self.check_saving_and_caching(id_=1, **self.default_init_kwargs)
-        obj = self.checked_class(None, **self.default_init_kwargs)
-        obj.save(self.db_conn)
-        self.assertEqual(obj.id_, 2)
+    def test_saving_versioned(self) -> None:
+        """Test storage and initialization of versioned attributes."""
+        def retrieve_attr_vals() -> list[object]:
+            attr_vals_saved: list[object] = []
+            assert hasattr(retrieved, 'id_')
+            for row in self.db_conn.row_where(attr.table_name, 'parent',
+                                              retrieved.id_):
+                attr_vals_saved += [row[2]]
+            return attr_vals_saved
         for attr_name, type_ in self.test_versioneds.items():
-            owner = self.checked_class(None)
+            # fail saving attributes on non-saved owner
+            owner = self.checked_class(None, **self.default_init_kwargs)
             vals: list[Any] = ['t1', 't2'] if type_ == str else [0.9, 1.1]
             attr = getattr(owner, attr_name)
             attr.set(vals[0])
             attr.set(vals[1])
+            with self.assertRaises(NotFoundException):
+                attr.save(self.db_conn)
             owner.save(self.db_conn)
-            retrieved = owner.__class__.by_id(self.db_conn, owner.id_)
+            # check stored attribute is as expected
+            retrieved = self._load_from_db(owner.id_)[0]
             attr = getattr(retrieved, attr_name)
             self.assertEqual(sorted(attr.history.values()), vals)
+            # check owner.save() created entries in attr table
+            attr_vals_saved = retrieve_attr_vals()
+            self.assertEqual(vals, attr_vals_saved)
+            # check setting new val to attr inconsequential to DB without save
+            attr.set(vals[0])
+            attr_vals_saved = retrieve_attr_vals()
+            self.assertEqual(vals, attr_vals_saved)
+            # check save finally adds new val
+            attr.save(self.db_conn)
+            attr_vals_saved = retrieve_attr_vals()
+            self.assertEqual(vals + [vals[0]], attr_vals_saved)
 
     def check_identity_with_cache_and_db(self, content: list[Any]) -> None:
         """Test both cache and DB equal content."""
@@ -97,24 +123,42 @@ class TestCaseWithDB(TestCase):
         db_found: list[Any] = []
         for item in content:
             assert isinstance(item.id_, type(self.default_ids[0]))
-            for row in self.db_conn.row_where(self.checked_class.table_name,
-                                              'id', item.id_):
-                db_found += [self.checked_class.from_table_row(self.db_conn,
-                                                               row)]
+            db_found += self._load_from_db(item.id_)
         hashes_db_found = [hash(x) for x in db_found]
         self.assertEqual(sorted(hashes_content), sorted(hashes_db_found))
 
-    def check_saving_and_caching(self, **kwargs: Any) -> None:
-        """Test instance.save in its core without relations."""
-        obj = self.checked_class(**kwargs)  # pylint: disable=not-callable
-        # check object init itself doesn't store anything yet
-        self.check_identity_with_cache_and_db([])
-        # check saving sets core attributes properly
-        obj.save(self.db_conn)
-        for key, value in kwargs.items():
-            self.assertEqual(getattr(obj, key), value)
-        # check saving stored properly in cache and DB
-        self.check_identity_with_cache_and_db([obj])
+    @_within_checked_class
+    def test_saving_and_caching(self) -> None:
+        """Test effects of .cache() and .save()."""
+        id1 = self.default_ids[0]
+        # check failure to cache without ID (if None-ID input possible)
+        if isinstance(id1, int):
+            obj0 = self.checked_class(None, **self.default_init_kwargs)
+            with self.assertRaises(HandledException):
+                obj0.cache()
+        # check mere object init itself doesn't even store in cache
+        obj1 = self.checked_class(id1, **self.default_init_kwargs)
+        self.assertEqual(self.checked_class.get_cache(), {})
+        # check .cache() fills cache, but not DB
+        obj1.cache()
+        self.assertEqual(self.checked_class.get_cache(), {id1: obj1})
+        db_found = self._load_from_db(id1)
+        self.assertEqual(db_found, [])
+        # check .save() sets ID (for int IDs), updates cache, and fills DB
+        # (expect ID to be set to id1, despite obj1 already having that as ID:
+        # it's generated by cursor.lastrowid on the DB table, and with obj1
+        # not written there, obj2 should get it first!)
+        id_input = None if isinstance(id1, int) else id1
+        obj2 = self.checked_class(id_input, **self.default_init_kwargs)
+        obj2.save(self.db_conn)
+        obj2_hash = hash(obj2)
+        self.assertEqual(self.checked_class.get_cache(), {id1: obj2})
+        db_found += self._load_from_db(id1)
+        self.assertEqual([hash(o) for o in db_found], [obj2_hash])
+        # check we cannot overwrite obj2 with obj1 despite its same ID,
+        # since it has disappeared now
+        with self.assertRaises(HandledException):
+            obj1.save(self.db_conn)
 
     @_within_checked_class
     def test_by_id(self) -> None:
@@ -131,8 +175,6 @@ class TestCaseWithDB(TestCase):
         obj2 = self.checked_class(id2, **self.default_init_kwargs)
         obj2.save(self.db_conn)
         self.assertEqual(obj2, self.checked_class.by_id(self.db_conn, id2))
-        # obj1.save(self.db_conn)
-        # self.check_identity_with_cache_and_db([obj1, obj2])
 
     @_within_checked_class
     def test_by_id_or_create(self) -> None:
-- 
2.30.2