From c8edb703912f92d3a3a12748c2d4d2b9a8cb8138 Mon Sep 17 00:00:00 2001
From: Christian Heller <c.heller@plomlompom.de>
Date: Thu, 22 Aug 2024 15:16:51 +0200
Subject: [PATCH] Minor code style fixes.

---
 browser.py           |   8 +--
 stable.py            | 113 +++++++++++++++++++++++--------------------
 stable/core.py       |   4 +-
 stable/gen_params.py |   7 +--
 4 files changed, 70 insertions(+), 62 deletions(-)

diff --git a/browser.py b/browser.py
index cf59764..c2ab5b8 100755
--- a/browser.py
+++ b/browser.py
@@ -2,11 +2,11 @@
 from os import scandir
 from os.path import splitext
 from exiftool import ExifToolHelper  # type: ignore
-from stable.gen_params import GenParams
-import gi
+import gi  # type: ignore
 gi.require_version('Gtk', '4.0')
 # pylint: disable=wrong-import-position
-from gi.repository import Gtk  # noqa: E402
+from gi.repository import Gtk  # type: ignore  # noqa: E402
+from stable.gen_params import GenParams  # noqa: E402
 
 
 class Window(Gtk.ApplicationWindow):
@@ -44,7 +44,7 @@ class Window(Gtk.ApplicationWindow):
         self.current_index = -1
         self.newest(None)
 
-    def _reload(self, set_max = False):
+    def _reload(self, set_max=False):
         self.entries = [e for e in scandir('.')
                         if e.is_file()
                         and splitext(e)[1] in {'.png', '.jpg', '.jpeg'}]
diff --git a/stable.py b/stable.py
index cf9cbcc..71320d2 100755
--- a/stable.py
+++ b/stable.py
@@ -7,12 +7,8 @@ from stable.gen_params import GenParams
 
 
 DEFAULT_SCHEDULER = 'EulerAncestralDiscreteScheduler'
-
-
-def save_path(count: int) -> str:
-    n_iterations = len(args.gen_paramses) * args.quantity
-    filename_count = f'_{count:08}' if n_iterations > 1 else ''
-    return path_join(dir_path, f'{filename_sans_ext}{filename_count}.{ext}')
+SEED_MIN = -(2**31-1)
+SEED_MAX = 2**31
 
 
 def parse_args():
@@ -65,7 +61,7 @@ def parse_args():
         prefix += 'unless calling with -H/--help or --list-schedulers, '
     prefix += 'requiring '
 
-     # Re-seed defaults from stdin, if -D.
+    # Re-seed defaults from stdin, if -D.
     temp_gen_paramses = [None]
     if parsed_args.defaults_from_stdin:
         temp_gen_paramses = []
@@ -73,9 +69,9 @@ def parse_args():
             temp_gen_paramses += [GenParams.from_str(line)]
     parsed_args.models = []
     parsed_args.gen_paramses = []
-    for gen_params in temp_gen_paramses:
-        if gen_params:
-            parser.set_defaults(**gen_params.as_dict)
+    for gen_params_ in temp_gen_paramses:
+        if gen_params_:
+            parser.set_defaults(**gen_params_.as_dict)
         new_parsed_args = parser.parse_args()
 
         # Validate input.
@@ -100,45 +96,58 @@ def parse_args():
     return parsed_args
 
 
-# guard against overwriting
-args = parse_args()
-if not args.list_schedulers:
-    dir_path = dirname(args.output)
-    filename = basename(args.output)
-    filename_sans_ext, ext = splitext(filename)
-    ext = ext[1:] if ext else 'png'
-    for n in range(args.quantity * len(args.gen_paramses)):
-        path = save_path(n)
-        if exists(path):
-            raise Exception(f'Would overwrite file: {path}')
-
-# pylint: disable=wrong-import-position
-from stable.core import ImageMaker  # noqa: E402
-old_model_path = ''
-maker = None
-for i, model_name in enumerate(args.models):
-    new_model_path = path_join(args.models_dir, model_name)
-    if new_model_path != old_model_path:
-        maker = ImageMaker(new_model_path)
-    old_model_path = new_model_path
-
-    # only list available schedulers, if -L
-    if args.list_schedulers:
-        print(f'AVAILABLE SCHEDULERS FOR MODEL {new_model_path}:\n')
-        for name in maker.compatible_schedulers:
-            print(name)
-        continue
-
-    # otherwise generate pictures
-    gen_params = args.gen_paramses[i]
-    start_seed = gen_params.seed
-    start_seed = start_seed if start_seed != 0 else randint(-(2**31-1), 2**31)
-    seed_corrector = 0
-    for n in range(args.quantity):
-        if 0 == start_seed + n + seed_corrector:
-            seed_corrector += 1
-        gen_params.seed = start_seed + n + seed_corrector
-        path = save_path(i*args.quantity + n)
-        maker.set_gen_params(gen_params)
-        print(f'GENERATING: {path}; {gen_params.to_str}')
-        maker.gen_image_to(path)
+def run():
+
+    def save_path(count: int) -> str:
+        n_iterations = len(args.gen_paramses)
+        filename_count = f'_{count:08}' if n_iterations > 1 else ''
+        # pylint: disable=possibly-used-before-assignment
+        return path_join(dir_path,
+                         f'{filename_sans_ext}{filename_count}.{ext}')
+
+    # guard against overwriting
+    args = parse_args()
+    if not args.list_schedulers:
+        dir_path = dirname(args.output)
+        filename = basename(args.output)
+        filename_sans_ext, ext = splitext(filename)
+        ext = ext[1:] if ext else 'png'
+        for n in range(args.quantity * len(args.gen_paramses)):
+            path = save_path(n)
+            if exists(path):
+                raise Exception(f'Would overwrite file: {path}')
+
+    # pylint: disable=wrong-import-position, import-outside-toplevel
+    from stable.core import ImageMaker  # noqa: E402
+    old_model_path = ''
+    maker = None
+    for i, model_name in enumerate(args.models):
+        new_model_path = path_join(args.models_dir, model_name)
+        if new_model_path != old_model_path:
+            maker = ImageMaker(new_model_path)
+        old_model_path = new_model_path
+
+        # only list available schedulers, if -L
+        if args.list_schedulers:
+            print(f'AVAILABLE SCHEDULERS FOR MODEL {new_model_path}:\n')
+            for name in maker.compatible_schedulers:
+                print(name)
+            continue
+
+        # otherwise generate pictures
+        gen_params = args.gen_paramses[i]
+        start_seed = gen_params.seed
+        start_seed = start_seed if start_seed != 0 else randint(SEED_MIN,
+                                                                SEED_MAX)
+        seed_corrector = 0
+        for n in range(args.quantity):
+            if 0 == start_seed + n + seed_corrector:
+                seed_corrector += 1
+            gen_params.seed = start_seed + n + seed_corrector
+            path = save_path(i*args.quantity + n)
+            maker.set_gen_params(gen_params)
+            print(f'GENERATING: {path}; {gen_params.to_str}')
+            maker.gen_image_to(path)
+
+
+run()
diff --git a/stable/core.py b/stable/core.py
index f2792fc..5c506c7 100644
--- a/stable/core.py
+++ b/stable/core.py
@@ -1,11 +1,9 @@
 from logging import (Formatter as LogFormatter, captureWarnings,
                      Filter as LogFilter)
-from os.path import basename
 from diffusers import StableDiffusionPipeline
 from diffusers.utils import logging
 from torch import Generator, float16
 from exiftool import ExifToolHelper  # type: ignore
-from stable.gen_params import GenParams
 
 SAFETY_CHECKER_WARNING_PATTERN = 'You have disabled the safety checker'
 
@@ -43,7 +41,7 @@ class ImageMaker:
         scheduler_selection = [s for s in self.pipe.scheduler.compatibles
                                if s.__name__ == gen_params.scheduler_name]
         if not scheduler_selection:
-            raise Exception(f'unknown scheduler: {scheduler_name}')
+            raise Exception(f'unknown scheduler: {gen_params.scheduler_name}')
         self.pipe.scheduler = scheduler_selection[0].from_config(
                 self.pipe.scheduler.config)
         self.gen_params = gen_params
diff --git a/stable/gen_params.py b/stable/gen_params.py
index e3ae86f..fc1eba1 100644
--- a/stable/gen_params.py
+++ b/stable/gen_params.py
@@ -1,6 +1,7 @@
 class GenParams:
 
-    def __init__(self, seed, guidance, n_steps, height, width, scheduler, model, prompt):
+    def __init__(self, seed, guidance, n_steps, height, width, scheduler,
+                 model, prompt):
         if '; ' in model:
             raise Exception('illegal model filename (must not contain "; ")')
         self.seed = seed
@@ -9,7 +10,7 @@ class GenParams:
         self.height, self.width = height, width
         self.scheduler_name = scheduler
         self.model_name = model
-        self.prompt = prompt 
+        self.prompt = prompt
 
     @property
     def to_str(self):
@@ -27,7 +28,7 @@ class GenParams:
         d = {}
         first_split = string.split('; PROMPT: ', maxsplit=1)
         if 2 == len(first_split):
-            d['prompt'] = first_split[1] 
+            d['prompt'] = first_split[1]
         for section in first_split[0].split('; '):
             key, val = section.split(': ', maxsplit=1)
             key = key.lower()
-- 
2.30.2