home · contact · privacy
Split BaseModel.by_id into .by_id and by_id_or_create, refactor tests.
[plomtask] / plomtask / http.py
index fc0059c530e20c5e347a264703905b3fcadcb3df..be791599ff21868fc82ed7080dfefe6b1130e3fd 100644 (file)
@@ -6,6 +6,7 @@ from base64 import b64encode, b64decode
 from http.server import BaseHTTPRequestHandler
 from http.server import HTTPServer
 from urllib.parse import urlparse, parse_qs
 from http.server import BaseHTTPRequestHandler
 from http.server import HTTPServer
 from urllib.parse import urlparse, parse_qs
+from json import dumps as json_dumps
 from os.path import split as path_split
 from jinja2 import Environment as JinjaEnv, FileSystemLoader as JinjaFSLoader
 from plomtask.dating import date_in_n_days
 from os.path import split as path_split
 from jinja2 import Environment as JinjaEnv, FileSystemLoader as JinjaFSLoader
 from plomtask.dating import date_in_n_days
@@ -16,6 +17,7 @@ from plomtask.db import DatabaseConnection, DatabaseFile
 from plomtask.processes import Process, ProcessStep, ProcessStepsNode
 from plomtask.conditions import Condition
 from plomtask.todos import Todo
 from plomtask.processes import Process, ProcessStep, ProcessStepsNode
 from plomtask.conditions import Condition
 from plomtask.todos import Todo
+from plomtask.db import BaseModel
 
 TEMPLATES_DIR = 'templates'
 
 
 TEMPLATES_DIR = 'templates'
 
@@ -27,7 +29,37 @@ class TaskServer(HTTPServer):
                  *args: Any, **kwargs: Any) -> None:
         super().__init__(*args, **kwargs)
         self.db = db_file
                  *args: Any, **kwargs: Any) -> None:
         super().__init__(*args, **kwargs)
         self.db = db_file
-        self.jinja = JinjaEnv(loader=JinjaFSLoader(TEMPLATES_DIR))
+        self.headers: list[tuple[str, str]] = []
+        self._render_mode = 'html'
+        self._jinja = JinjaEnv(loader=JinjaFSLoader(TEMPLATES_DIR))
+
+    def set_json_mode(self) -> None:
+        """Make server send JSON instead of HTML responses."""
+        self._render_mode = 'json'
+        self.headers += [('Content-Type', 'application/json')]
+
+    @staticmethod
+    def ctx_to_json(ctx: dict[str, object]) -> str:
+        """Render ctx into JSON string."""
+        def walk_ctx(node: object) -> Any:
+            if isinstance(node, BaseModel):
+                return node.as_dict
+            if isinstance(node, (list, tuple)):
+                return [walk_ctx(x) for x in node]
+            if isinstance(node, HandledException):
+                return str(node)
+            return node
+        for k, v in ctx.items():
+            ctx[k] = walk_ctx(v)
+        return json_dumps(ctx)
+
+    def render(self, ctx: dict[str, object], tmpl_name: str = '') -> str:
+        """Render ctx according to self._render_mode.."""
+        tmpl_name = f'{tmpl_name}.{self._render_mode}'
+        if 'html' == self._render_mode:
+            template = self._jinja.get_template(tmpl_name)
+            return template.render(ctx)
+        return self.__class__.ctx_to_json(ctx)
 
 
 class InputsParser:
 
 
 class InputsParser:
@@ -106,11 +138,18 @@ class TaskHandler(BaseHTTPRequestHandler):
     _form_data: InputsParser
     _params: InputsParser
 
     _form_data: InputsParser
     _params: InputsParser
 
-    def _send_html(self, html: str, code: int = 200) -> None:
+    def _send_page(self,
+                   ctx: dict[str, Any],
+                   tmpl_name: str,
+                   code: int = 200
+                   ) -> None:
         """Send HTML as proper HTTP response."""
         """Send HTML as proper HTTP response."""
+        body = self.server.render(ctx, tmpl_name)
         self.send_response(code)
         self.send_response(code)
+        for header_tuple in self.server.headers:
+            self.send_header(*header_tuple)
         self.end_headers()
         self.end_headers()
-        self.wfile.write(bytes(html, 'utf-8'))
+        self.wfile.write(bytes(body, 'utf-8'))
 
     @staticmethod
     def _request_wrapper(http_method: str, not_found_msg: str
 
     @staticmethod
     def _request_wrapper(http_method: str, not_found_msg: str
@@ -142,9 +181,8 @@ class TaskHandler(BaseHTTPRequestHandler):
                     for cls in (Day, Todo, Condition, Process, ProcessStep):
                         assert hasattr(cls, 'empty_cache')
                         cls.empty_cache()
                     for cls in (Day, Todo, Condition, Process, ProcessStep):
                         assert hasattr(cls, 'empty_cache')
                         cls.empty_cache()
-                    tmpl = self.server.jinja.get_template('msg.html')
-                    html = tmpl.render(msg=error)
-                    self._send_html(html, error.http_code)
+                    ctx = {'msg': error}
+                    self._send_page(ctx, 'msg', error.http_code)
                 finally:
                     self.conn.close()
             return wrapper
                 finally:
                     self.conn.close()
             return wrapper
@@ -154,13 +192,11 @@ class TaskHandler(BaseHTTPRequestHandler):
     def do_GET(self, handler: Callable[[], str | dict[str, object]]
                ) -> str | None:
         """Render page with result of handler, or redirect if result is str."""
     def do_GET(self, handler: Callable[[], str | dict[str, object]]
                ) -> str | None:
         """Render page with result of handler, or redirect if result is str."""
-        template = f'{self._site}.html'
-        ctx_or_redir = handler()
-        if str == type(ctx_or_redir):
-            return ctx_or_redir
-        assert isinstance(ctx_or_redir, dict)
-        html = self.server.jinja.get_template(template).render(**ctx_or_redir)
-        self._send_html(html)
+        tmpl_name = f'{self._site}'
+        ctx_or_redir_target = handler()
+        if isinstance(ctx_or_redir_target, str):
+            return ctx_or_redir_target
+        self._send_page(ctx_or_redir_target, tmpl_name)
         return None
 
     @_request_wrapper('POST', 'Unknown POST target')
         return None
 
     @_request_wrapper('POST', 'Unknown POST target')
@@ -208,7 +244,7 @@ class TaskHandler(BaseHTTPRequestHandler):
     def do_GET_day(self) -> dict[str, object]:
         """Show single Day of ?date=."""
         date = self._params.get_str('date', date_in_n_days(0))
     def do_GET_day(self) -> dict[str, object]:
         """Show single Day of ?date=."""
         date = self._params.get_str('date', date_in_n_days(0))
-        day = Day.by_id(self.conn, date, create=True)
+        day = Day.by_id_or_create(self.conn, date)
         make_type = self._params.get_str('make_type')
         conditions_present = []
         enablers_for = {}
         make_type = self._params.get_str('make_type')
         conditions_present = []
         enablers_for = {}
@@ -338,6 +374,7 @@ class TaskHandler(BaseHTTPRequestHandler):
             todos.sort(key=lambda t: t.date, reverse=True)
         else:
             todos.sort(key=lambda t: t.date)
             todos.sort(key=lambda t: t.date, reverse=True)
         else:
             todos.sort(key=lambda t: t.date)
+            sort_by = 'title'
         return {'start': start, 'end': end, 'process_id': process_id,
                 'comment_pattern': comment_pattern, 'todos': todos,
                 'all_processes': Process.all(self.conn), 'sort_by': sort_by}
         return {'start': start, 'end': end, 'process_id': process_id,
                 'comment_pattern': comment_pattern, 'todos': todos,
                 'all_processes': Process.all(self.conn), 'sort_by': sort_by}
@@ -355,6 +392,7 @@ class TaskHandler(BaseHTTPRequestHandler):
             conditions.sort(key=lambda c: c.title.newest, reverse=True)
         else:
             conditions.sort(key=lambda c: c.title.newest)
             conditions.sort(key=lambda c: c.title.newest, reverse=True)
         else:
             conditions.sort(key=lambda c: c.title.newest)
+            sort_by = 'title'
         return {'conditions': conditions,
                 'sort_by': sort_by,
                 'pattern': pattern}
         return {'conditions': conditions,
                 'sort_by': sort_by,
                 'pattern': pattern}
@@ -362,7 +400,7 @@ class TaskHandler(BaseHTTPRequestHandler):
     def do_GET_condition(self) -> dict[str, object]:
         """Show Condition of ?id=."""
         id_ = self._params.get_int_or_none('id')
     def do_GET_condition(self) -> dict[str, object]:
         """Show Condition of ?id=."""
         id_ = self._params.get_int_or_none('id')
-        c = Condition.by_id(self.conn, id_, create=True)
+        c = Condition.by_id_or_create(self.conn, id_)
         ps = Process.all(self.conn)
         return {'condition': c, 'is_new': c.id_ is None,
                 'enabled_processes': [p for p in ps if c in p.conditions],
         ps = Process.all(self.conn)
         return {'condition': c, 'is_new': c.id_ is None,
                 'enabled_processes': [p for p in ps if c in p.conditions],
@@ -385,7 +423,7 @@ class TaskHandler(BaseHTTPRequestHandler):
     def do_GET_process(self) -> dict[str, object]:
         """Show Process of ?id=."""
         id_ = self._params.get_int_or_none('id')
     def do_GET_process(self) -> dict[str, object]:
         """Show Process of ?id=."""
         id_ = self._params.get_int_or_none('id')
-        process = Process.by_id(self.conn, id_, create=True)
+        process = Process.by_id_or_create(self.conn, id_)
         title_64 = self._params.get_str('title_b64')
         if title_64:
             title = b64decode(title_64.encode()).decode()
         title_64 = self._params.get_str('title_b64')
         if title_64:
             title = b64decode(title_64.encode()).decode()
@@ -442,6 +480,7 @@ class TaskHandler(BaseHTTPRequestHandler):
             processes.sort(key=lambda p: p.title.newest, reverse=True)
         else:
             processes.sort(key=lambda p: p.title.newest)
             processes.sort(key=lambda p: p.title.newest, reverse=True)
         else:
             processes.sort(key=lambda p: p.title.newest)
+            sort_by = 'title'
         return {'processes': processes, 'sort_by': sort_by, 'pattern': pattern}
 
     # POST handlers
         return {'processes': processes, 'sort_by': sort_by, 'pattern': pattern}
 
     # POST handlers
@@ -462,7 +501,7 @@ class TaskHandler(BaseHTTPRequestHandler):
     def do_POST_day(self) -> str:
         """Update or insert Day of date and Todos mapped to it."""
         date = self._params.get_str('date')
     def do_POST_day(self) -> str:
         """Update or insert Day of date and Todos mapped to it."""
         date = self._params.get_str('date')
-        day = Day.by_id(self.conn, date, create=True)
+        day = Day.by_id_or_create(self.conn, date)
         day.comment = self._form_data.get_str('day_comment')
         day.save(self.conn)
         make_type = self._form_data.get_str('make_type')
         day.comment = self._form_data.get_str('day_comment')
         day.save(self.conn)
         make_type = self._form_data.get_str('make_type')
@@ -561,7 +600,7 @@ class TaskHandler(BaseHTTPRequestHandler):
             process = Process.by_id(self.conn, id_)
             process.remove(self.conn)
             return '/processes'
             process = Process.by_id(self.conn, id_)
             process.remove(self.conn)
             return '/processes'
-        process = Process.by_id(self.conn, id_, create=True)
+        process = Process.by_id_or_create(self.conn, id_)
         process.title.set(self._form_data.get_str('title'))
         process.description.set(self._form_data.get_str('description'))
         process.effort.set(self._form_data.get_float('effort'))
         process.title.set(self._form_data.get_str('title'))
         process.description.set(self._form_data.get_str('description'))
         process.effort.set(self._form_data.get_float('effort'))
@@ -637,8 +676,8 @@ class TaskHandler(BaseHTTPRequestHandler):
             condition = Condition.by_id(self.conn, id_)
             condition.remove(self.conn)
             return '/conditions'
             condition = Condition.by_id(self.conn, id_)
             condition.remove(self.conn)
             return '/conditions'
-        condition = Condition.by_id(self.conn, id_, create=True)
-        condition.is_active = self._form_data.get_all_str('is_active') != []
+        condition = Condition.by_id_or_create(self.conn, id_)
+        condition.is_active = self._form_data.get_str('is_active') == 'True'
         condition.title.set(self._form_data.get_str('title'))
         condition.description.set(self._form_data.get_str('description'))
         condition.save(self.conn)
         condition.title.set(self._form_data.get_str('title'))
         condition.description.set(self._form_data.get_str('description'))
         condition.save(self.conn)