From 7dcaf2bf4082113add30cccdb3cae0107f6e8ba4 Mon Sep 17 00:00:00 2001
From: Christian Heller <c.heller@plomlompom.de>
Date: Sat, 24 Aug 2024 05:05:23 +0200
Subject: [PATCH] Refactor GenParams usage.

---
 browser.py           | 50 ++++++++++++++++++++------------------------
 stable.py            |  8 +++----
 stable/core.py       |  4 ++--
 stable/gen_params.py | 48 ++++++++++++++++++------------------------
 4 files changed, 48 insertions(+), 62 deletions(-)

diff --git a/browser.py b/browser.py
index 6f13d7e..26b67a4 100755
--- a/browser.py
+++ b/browser.py
@@ -1,16 +1,19 @@
 #!/usr/bin/env python3
-from exiftool import ExifToolHelper  # type: ignore
 from json import dump as json_dump, load as json_load
 from os.path import exists as path_exists, join as path_join, abspath
+from exiftool import ExifToolHelper  # type: ignore
 import gi  # type: ignore
 gi.require_version('Gtk', '4.0')
 gi.require_version('Gio', '2.0')
 # pylint: disable=wrong-import-position
 from gi.repository import Gtk, Gio, GObject  # type: ignore  # noqa: E402
-from stable.gen_params import GenParams  # noqa: E402
+# pylint: disable=no-name-in-module
+from stable.gen_params import (GenParams,  # noqa: E402
+                               GEN_PARAMS, GEN_PARAMS_STR)  # noqa: E402
 
 
-IMG_DIR='.'
+IMG_DIR = '.'
+CACHE_PATH = 'cache.json'
 
 
 class FileItem(GObject.GObject):
@@ -20,14 +23,11 @@ class FileItem(GObject.GObject):
         self.name = info.get_name()
         self.last_mod_time = info.get_modification_date_time().format_iso8601()
         self.full_path = path_join(path, self.name)
-        self.seed = ''
-        self.guidance = 0.0
-        self.n_steps = 0
-        self.height = 0
-        self.width = 0
-        self.scheduler = ''
-        self.model = ''
-        self.prompt = ''
+        for param_name in GEN_PARAMS:
+            if param_name in GEN_PARAMS_STR:
+                setattr(self, param_name.lower(), '')
+            else:
+                setattr(self, param_name.lower(), 0)
         if self.full_path in cache:
             if self.last_mod_time in cache[self.full_path]:
                 cached = cache[self.full_path][self.last_mod_time]
@@ -35,17 +35,14 @@ class FileItem(GObject.GObject):
                     setattr(self, k, cached[k])
 
     def set_metadata(self, et, cache):
-        self.metadata = 'no SD comment'
         for d in et.get_tags([self.name], ['Comment']):
             for k, v in d.items():
                 if k.endswith('Comment'):
-                    self.metadata = ''
                     gen_params = GenParams.from_str(v)
                     for k, v_ in gen_params.as_dict.items():
                         setattr(self, k, v_)
         cached = {}
-        for k in ('seed', 'guidance', 'n_steps', 'height',
-                  'width', 'scheduler', 'model', 'prompt'):
+        for k in (k.lower() for k in GEN_PARAMS):
             cached[k] = getattr(self, k)
         cache[self.full_path] = {self.last_mod_time: cached}
 
@@ -94,10 +91,10 @@ class Window(Gtk.ApplicationWindow):
         box_outer.append(self.viewer)
         self.props.child = box_outer
 
-        if not path_exists('cache.json'):
-            with open('cache.json', 'w') as f:
+        if not path_exists(CACHE_PATH):
+            with open(CACHE_PATH, 'w', encoding='utf8') as f:
                 json_dump({}, f)
-        with open('cache.json', 'r') as f:
+        with open(CACHE_PATH, 'r', encoding='utf8') as f:
             cache = json_load(f)
         self.max_index = 0
         self.item = None
@@ -115,16 +112,17 @@ class Window(Gtk.ApplicationWindow):
             item = FileItem(img_dir_absolute, info, cache)
             self.unsorted += [item]
         with ExifToolHelper() as et:
-            for item in [item for item in self.unsorted if item.seed == '']:
+            for item in [item for item in self.unsorted if '' == item.model]:
                 item.set_metadata(et, cache)
         self.max_index = len(self.unsorted) - 1
         self.sort('last_mod_time')
-        with open('cache.json', 'w') as f:
+        with open(CACHE_PATH, 'w', encoding='utf8') as f:
             json_dump(cache, f)
 
     def sort(self, attr_name):
         self.list_store.remove_all()
-        for file_item in sorted(self.unsorted, key=lambda i: getattr(i, attr_name)):
+        for file_item in sorted(self.unsorted,
+                                key=lambda i: getattr(i, attr_name)):
             self.list_store.append(file_item)
         self.update_selected()
 
@@ -137,12 +135,10 @@ class Window(Gtk.ApplicationWindow):
     def reload(self):
         self.viewer.remove(self.viewer.get_last_child())
         if self.item:
-            metadata = f'{self.item.full_path}: PROMPT: {self.item.prompt}\n' +\
-                    f'SEED: {self.item.seed} / MODEL: {self.item.model} / ' +\
-                    f'SCHEDULER: {self.item.scheduler}\nGUIDANCE: {self.item.guidance}\n' +\
-                    f'N_STEPS: {self.item.n_steps}\nHEIGHT: {self.item.height} / ' +\
-                    f'WIDTH: {self.item.width}'
-            self.metadata.props.label = metadata
+            params_strs = [f'{k}: ' + str(getattr(self.item, k.lower()))
+                           for k in GEN_PARAMS]
+            self.metadata.props.label = '\n'.join([self.item.full_path]
+                                                  + params_strs)
             pic = Gtk.Picture.new_for_filename(self.item.name)
             self.viewer.append(pic)
         else:
diff --git a/stable.py b/stable.py
index 241783f..64ee704 100755
--- a/stable.py
+++ b/stable.py
@@ -3,7 +3,7 @@ from sys import argv, exit as sys_exit, stdin
 from os.path import dirname, basename, splitext, join as path_join, exists
 from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
 from random import randint
-from stable.gen_params import GenParams
+from stable.gen_params import GenParams, GEN_PARAMS
 
 
 DEFAULT_SCHEDULER = 'EulerAncestralDiscreteScheduler'
@@ -89,10 +89,8 @@ def parse_args():
             parsed_args.models_dir = dirname(new_parsed_args.model)
         parsed_args.models += [basename(new_parsed_args.model)]
         parsed_args.gen_paramses += [GenParams(
-                new_parsed_args.seed, new_parsed_args.guidance,
-                new_parsed_args.n_steps, new_parsed_args.height,
-                new_parsed_args.width, new_parsed_args.scheduler,
-                new_parsed_args.model, new_parsed_args.prompt)]
+                **{k: getattr(new_parsed_args, k)
+                   for k in (k.lower() for k in GEN_PARAMS)})]
     return parsed_args
 
 
diff --git a/stable/core.py b/stable/core.py
index 5c506c7..02a44ed 100644
--- a/stable/core.py
+++ b/stable/core.py
@@ -39,9 +39,9 @@ class ImageMaker:
 
     def set_gen_params(self, gen_params):
         scheduler_selection = [s for s in self.pipe.scheduler.compatibles
-                               if s.__name__ == gen_params.scheduler_name]
+                               if s.__name__ == gen_params.scheduler]
         if not scheduler_selection:
-            raise Exception(f'unknown scheduler: {gen_params.scheduler_name}')
+            raise Exception(f'unknown scheduler: {gen_params.scheduler}')
         self.pipe.scheduler = scheduler_selection[0].from_config(
                 self.pipe.scheduler.config)
         self.gen_params = gen_params
diff --git a/stable/gen_params.py b/stable/gen_params.py
index fc1eba1..160735c 100644
--- a/stable/gen_params.py
+++ b/stable/gen_params.py
@@ -1,27 +1,26 @@
+GEN_PARAMS_STR = ('MODEL', 'SCHEDULER', 'PROMPT')
+GEN_PARAMS_INT = ('SEED', 'N_STEPS', 'HEIGHT', 'WIDTH')
+GEN_PARAMS_FLOAT = ('GUIDANCE',)
+GEN_PARAMS = GEN_PARAMS_STR + GEN_PARAMS_INT + GEN_PARAMS_FLOAT
+
+
 class GenParams:
 
-    def __init__(self, seed, guidance, n_steps, height, width, scheduler,
-                 model, prompt):
-        if '; ' in model:
+    def __init__(self, **kwargs):
+        if '; ' in kwargs['model']:
             raise Exception('illegal model filename (must not contain "; ")')
-        self.seed = seed
-        self.guidance = guidance
-        self.n_steps = n_steps
-        self.height, self.width = height, width
-        self.scheduler_name = scheduler
-        self.model_name = model
-        self.prompt = prompt
+        for param_name in (n.lower() for n in GEN_PARAMS):
+            setattr(self, param_name, kwargs[param_name])
 
     @property
     def to_str(self):
-        return f'SEED: {self.seed}; ' +\
-                f'GUIDANCE: {self.guidance}; ' +\
-                f'N_STEPS: {self.n_steps}; ' +\
-                f'HEIGHT: {self.height}; ' +\
-                f'WIDTH: {self.width}; ' +\
-                f'SCHEDULER: {self.scheduler_name}; ' +\
-                f'MODEL: {self.model_name}; ' +\
-                f'PROMPT: {self.prompt}'
+        s = ''
+        for param_name in GEN_PARAMS:
+            if 'PROMPT' == param_name:
+                continue  # can create fewer parsing problems at the end
+            value = getattr(self, param_name.lower())
+            s += f'{param_name}: {value}; '
+        return f'{s}PROMPT: {self.prompt}'  # pylint: disable=no-member
 
     @classmethod
     def from_str(cls, string):
@@ -32,20 +31,13 @@ class GenParams:
         for section in first_split[0].split('; '):
             key, val = section.split(': ', maxsplit=1)
             key = key.lower()
-            if key in {'seed', 'height', 'width', 'n_steps'}:
+            if key.upper() in GEN_PARAMS_INT:
                 val = int(val)
-            elif key in {'guidance'}:
+            elif key.upper() in GEN_PARAMS_FLOAT:
                 val = float(val)
             d[key] = val
         return cls(**d)
 
     @property
     def as_dict(self):
-        d = {}
-        just_names = {'scheduler', 'model'}
-        for k in ['seed', 'guidance', 'n_steps', 'height', 'width', 'prompt']\
-                + list(just_names):
-            v = getattr(self, f'{k}_name' if k in just_names else k)
-            if v is not None:
-                d[k] = v
-        return d
+        return {k: getattr(self, k) for k in (k.lower() for k in GEN_PARAMS)}
-- 
2.30.2