home · contact · privacy
Remove asserts no longer needed.
[plomtask] / plomtask / conditions.py
index 9fab77fc118d81bb650116ec307150fc390532ba..b2ecda14cb7cef0b5bf350d1a253389ea32aabb7 100644 (file)
@@ -1,12 +1,13 @@
 """Non-doable elements of ProcessStep/Todo chains."""
 from __future__ import annotations
 """Non-doable elements of ProcessStep/Todo chains."""
 from __future__ import annotations
+from typing import Any
 from sqlite3 import Row
 from plomtask.db import DatabaseConnection, BaseModel
 from plomtask.misc import VersionedAttribute
 from plomtask.exceptions import NotFoundException
 
 
 from sqlite3 import Row
 from plomtask.db import DatabaseConnection, BaseModel
 from plomtask.misc import VersionedAttribute
 from plomtask.exceptions import NotFoundException
 
 
-class Condition(BaseModel):
+class Condition(BaseModel[int]):
     """Non Process-dependency for ProcessSteps and Todos."""
     table_name = 'conditions'
     to_save = ['is_active']
     """Non Process-dependency for ProcessSteps and Todos."""
     table_name = 'conditions'
     to_save = ['is_active']
@@ -20,30 +21,26 @@ class Condition(BaseModel):
 
     @classmethod
     def from_table_row(cls, db_conn: DatabaseConnection,
 
     @classmethod
     def from_table_row(cls, db_conn: DatabaseConnection,
-                       row: Row) -> Condition:
+                       row: Row | list[Any]) -> Condition:
         """Build condition from row, including VersionedAttributes."""
         condition = super().from_table_row(db_conn, row)
         """Build condition from row, including VersionedAttributes."""
         condition = super().from_table_row(db_conn, row)
-        assert isinstance(condition, Condition)
-        for title_row in db_conn.exec('SELECT * FROM condition_titles '
-                                      'WHERE parent_id = ?', (row[0],)):
-            condition.title.history[title_row[1]]\
-                    = title_row[2]  # pylint: disable=no-member
-        for desc_row in db_conn.exec('SELECT * FROM condition_descriptions '
-                                     'WHERE parent_id = ?', (row[0],)):
-            condition.description.history[desc_row[1]]\
-                    = desc_row[2]  # pylint: disable=no-member
+        for name in ('title', 'description'):
+            table_name = f'condition_{name}s'
+            for row_ in db_conn.row_where(table_name, 'parent', row[0]):
+                getattr(condition, name).history_from_row(row_)
         return condition
 
     @classmethod
     def all(cls, db_conn: DatabaseConnection) -> list[Condition]:
         """Collect all Conditions and their VersionedAttributes."""
         conditions = {}
         return condition
 
     @classmethod
     def all(cls, db_conn: DatabaseConnection) -> list[Condition]:
         """Collect all Conditions and their VersionedAttributes."""
         conditions = {}
-        for id_, condition in db_conn.cached_conditions.items():
+        for id_, condition in cls.cache_.items():
             conditions[id_] = condition
         already_recorded = conditions.keys()
             conditions[id_] = condition
         already_recorded = conditions.keys()
-        for row in db_conn.exec('SELECT id FROM conditions'):
-            if row[0] not in already_recorded:
-                condition = cls.by_id(db_conn, row[0])
+        for id_ in db_conn.column_all('conditions', 'id'):
+            if id_ not in already_recorded:
+                condition = cls.by_id(db_conn, id_)
+                assert isinstance(condition.id_, int)
                 conditions[condition.id_] = condition
         return list(conditions.values())
 
                 conditions[condition.id_] = condition
         return list(conditions.values())
 
@@ -52,17 +49,13 @@ class Condition(BaseModel):
               create: bool = False) -> Condition:
         """Collect (or create) Condition and its VersionedAttributes."""
         condition = None
               create: bool = False) -> Condition:
         """Collect (or create) Condition and its VersionedAttributes."""
         condition = None
-        if id_ in db_conn.cached_conditions.keys():
-            condition = db_conn.cached_conditions[id_]
-        else:
-            for row in db_conn.exec('SELECT * FROM conditions WHERE id = ?',
-                                    (id_,)):
-                condition = cls.from_table_row(db_conn, row)
-                break
+        if id_:
+            condition, _ = super()._by_id(db_conn, id_)
         if not condition:
             if not create:
                 raise NotFoundException(f'Condition not found of id: {id_}')
             condition = cls(id_, False)
         if not condition:
             if not create:
                 raise NotFoundException(f'Condition not found of id: {id_}')
             condition = cls(id_, False)
+            condition.save(db_conn)
         return condition
 
     def save(self, db_conn: DatabaseConnection) -> None:
         return condition
 
     def save(self, db_conn: DatabaseConnection) -> None:
@@ -70,5 +63,26 @@ class Condition(BaseModel):
         self.save_core(db_conn)
         self.title.save(db_conn)
         self.description.save(db_conn)
         self.save_core(db_conn)
         self.title.save(db_conn)
         self.description.save(db_conn)
-        assert isinstance(self.id_, int)
-        db_conn.cached_conditions[self.id_] = self
+
+
+class ConditionsRelations:
+    """Methods for handling relations to Conditions, for Todo and Process."""
+
+    def set_conditions(self, db_conn: DatabaseConnection, ids: list[int],
+                       target: str = 'conditions') -> None:
+        """Set self.[target] to Conditions identified by ids."""
+        target_list = getattr(self, target)
+        while len(target_list) > 0:
+            target_list.pop()
+        for id_ in ids:
+            target_list += [Condition.by_id(db_conn, id_)]
+
+    def set_enables(self, db_conn: DatabaseConnection,
+                    ids: list[int]) -> None:
+        """Set self.enables to Conditions identified by ids."""
+        self.set_conditions(db_conn, ids, 'enables')
+
+    def set_disables(self, db_conn: DatabaseConnection,
+                     ids: list[int]) -> None:
+        """Set self.disables to Conditions identified by ids."""
+        self.set_conditions(db_conn, ids, 'disables')