From 5e87cc0397c0aaf5b4f15eeb7518b25776bcef71 Mon Sep 17 00:00:00 2001
From: Christian Heller <c.heller@plomlompom.de>
Date: Tue, 18 Jun 2024 10:22:59 +0200
Subject: [PATCH] Extend BaseModel.by_id_or_create test.

---
 tests/utils.py | 27 +++++++++++++++++----------
 1 file changed, 17 insertions(+), 10 deletions(-)

diff --git a/tests/utils.py b/tests/utils.py
index 25cc9ba..0925b2d 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -192,18 +192,25 @@ class TestCaseWithDB(TestCase):
     @_within_checked_class
     def test_by_id_or_create(self) -> None:
         """Test .by_id_or_create."""
-        # check .by_id_or_create acts like normal instantiation (sans saving)
-        id_ = self.default_ids[0]
+        # check .by_id_or_create fails if wrong class
         if not self.checked_class.can_create_by_id:
             with self.assertRaises(HandledException):
-                self.checked_class.by_id_or_create(self.db_conn, id_)
-        # check .by_id_or_create fails if wrong class
-        else:
-            by_id_created = self.checked_class.by_id_or_create(self.db_conn,
-                                                               id_)
-            with self.assertRaises(NotFoundException):
-                self.checked_class.by_id(self.db_conn, id_)
-            self.assertEqual(self.checked_class(id_), by_id_created)
+                self.checked_class.by_id_or_create(self.db_conn, None)
+            return
+        # check ID input of None creates, on saving, ID=1,2,… for int IDs
+        if isinstance(self.default_ids[0], int):
+            for n in range(2):
+                item = self.checked_class.by_id_or_create(self.db_conn, None)
+                self.assertEqual(item.id_, None)
+                item.save(self.db_conn)
+                self.assertEqual(item.id_, n+1)
+        # check .by_id_or_create acts like normal instantiation (sans saving)
+        id_ = self.default_ids[2]
+        item = self.checked_class.by_id_or_create(self.db_conn, id_)
+        self.assertEqual(item.id_, id_)
+        with self.assertRaises(NotFoundException):
+            self.checked_class.by_id(self.db_conn, item.id_)
+        self.assertEqual(self.checked_class(item.id_), item)
 
     @_within_checked_class
     def test_from_table_row(self) -> None:
-- 
2.30.2