home · contact · privacy
Extend BaseModel.by_id_or_create test.
authorChristian Heller <c.heller@plomlompom.de>
Tue, 18 Jun 2024 08:22:59 +0000 (10:22 +0200)
committerChristian Heller <c.heller@plomlompom.de>
Tue, 18 Jun 2024 08:22:59 +0000 (10:22 +0200)
tests/utils.py

index 25cc9ba1e79d663ec692570f6f4c1fce4eaaf911..0925b2d5b2adc0e415293526a4b01c04fc42b178 100644 (file)
@@ -192,18 +192,25 @@ class TestCaseWithDB(TestCase):
     @_within_checked_class
     def test_by_id_or_create(self) -> None:
         """Test .by_id_or_create."""
-        # check .by_id_or_create acts like normal instantiation (sans saving)
-        id_ = self.default_ids[0]
+        # check .by_id_or_create fails if wrong class
         if not self.checked_class.can_create_by_id:
             with self.assertRaises(HandledException):
-                self.checked_class.by_id_or_create(self.db_conn, id_)
-        # check .by_id_or_create fails if wrong class
-        else:
-            by_id_created = self.checked_class.by_id_or_create(self.db_conn,
-                                                               id_)
-            with self.assertRaises(NotFoundException):
-                self.checked_class.by_id(self.db_conn, id_)
-            self.assertEqual(self.checked_class(id_), by_id_created)
+                self.checked_class.by_id_or_create(self.db_conn, None)
+            return
+        # check ID input of None creates, on saving, ID=1,2,… for int IDs
+        if isinstance(self.default_ids[0], int):
+            for n in range(2):
+                item = self.checked_class.by_id_or_create(self.db_conn, None)
+                self.assertEqual(item.id_, None)
+                item.save(self.db_conn)
+                self.assertEqual(item.id_, n+1)
+        # check .by_id_or_create acts like normal instantiation (sans saving)
+        id_ = self.default_ids[2]
+        item = self.checked_class.by_id_or_create(self.db_conn, id_)
+        self.assertEqual(item.id_, id_)
+        with self.assertRaises(NotFoundException):
+            self.checked_class.by_id(self.db_conn, item.id_)
+        self.assertEqual(self.checked_class(item.id_), item)
 
     @_within_checked_class
     def test_from_table_row(self) -> None: