From e29512355f7ae0e84d5e37bd4ddea7faa594e6ad Mon Sep 17 00:00:00 2001
From: Christian Heller <c.heller@plomlompom.de>
Date: Thu, 22 Aug 2024 08:30:22 +0200
Subject: [PATCH] Refactor GenParams code.

---
 stable.py            | 74 ++++++++++++++++++++++++--------------------
 stable/core.py       | 58 +++++++++-------------------------
 stable/gen_params.py | 50 ++++++++++++++++++++++++++++++
 3 files changed, 105 insertions(+), 77 deletions(-)
 create mode 100644 stable/gen_params.py

diff --git a/stable.py b/stable.py
index 665b3a9..0bea6f9 100755
--- a/stable.py
+++ b/stable.py
@@ -3,6 +3,7 @@ from sys import argv, exit as sys_exit, stdin
 from os.path import dirname, basename, splitext, join as path_join, exists
 from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
 from random import randint
+from stable.gen_params import GenParams 
 
 
 DEFAULT_SCHEDULER = 'EulerAncestralDiscreteScheduler'
@@ -43,31 +44,23 @@ def parse_args():
                         help='name of denoising scheduler; --list-schedulers '
                         'prints available choices for chosen model')
     parser.add_argument('-D', '--defaults_from_stdin', action='store_true',
-                        help='parse stdin for image generation options '
+                        help='parse stdin for image generation parameters'
                         '(e.g. from image file EXIF comment)')
-    parser.add_argument('-P', '--model_path_prefix',
-                        help='optional directory path prefix to MODEL for '
-                        'where it\'s provided as a mere filename (as may '
-                        'happen due to -D)')
+    parser.add_argument('-M', '--models_dir',
+                        help='directory path prefix to MODEL (for where'
+                        'it\'s provided as a mere filename, as with -D)')
     parser.add_argument('-H', '--help', action='help')
     parser.add_argument('-L', '--list-schedulers', action='store_true',
                         help='list options for -S available with chosen model')
+
+    # Re-seed defaults from stdin, if -D.
     parsed_args = parser.parse_args()
-    parser.set_defaults(model_path_prefix='')
     if parsed_args.defaults_from_stdin:
-        defaults_string = stdin.read()
-        first_split = defaults_string.split('; PROMPT: ', maxsplit=1)
-        if 2 == len(first_split):
-            parser.set_defaults(prompt=first_split[1])
-        for section in first_split[0].split('; '):
-            key, val = section.split(': ', maxsplit=1)
-            key = key.lower()
-            if key in {'seed', 'height', 'width', 'n_steps'}:
-                val = int(val)
-            elif key in {'guidance'}:
-                val = float(val)
-            parser.set_defaults(**{key: val})
+        gen_params = GenParams.from_str(stdin.read())
+        parser.set_defaults(**gen_params.as_dict)
     parsed_args = parser.parse_args()
+
+    # Validate input.
     prefix = f'{argv[0]}: error: '
     if parsed_args.list_schedulers:
         required = {'model': 'm'}
@@ -84,20 +77,22 @@ def parse_args():
         parser.print_usage()
         print(f'{prefix}{", ".join(suffixes)}')
         sys_exit(1)
+
+    # Prepare generation parameters, model path, etc.
+    if not parsed_args.models_dir:
+        parsed_args.models_dir = dirname(parsed_args.model)
+        parsed_args.model = basename(parsed_args.model)
+    parsed_args.gen_params = GenParams(
+            parsed_args.seed, parsed_args.guidance, parsed_args.n_steps,
+            parsed_args.height, parsed_args.width, parsed_args.scheduler,
+            parsed_args.model, parsed_args.prompt)
+    parsed_args.model_path = path_join(parsed_args.models_dir,
+                                       parsed_args.gen_params.model_name)
     return parsed_args
 
 
+# guard against overwriting
 args = parse_args()
-# pylint: disable=wrong-import-position
-from stable.core import ImageMaker  # noqa: E402
-model_path = f'{args.model_path_prefix}{args.model}'
-if args.list_schedulers:
-    maker = ImageMaker(model_path)
-    print(f'AVAILABLE SCHEDULERS FOR MODEL {args.model}:\n')
-    for name in maker.compatible_schedulers:
-        print(name)
-    sys_exit(0)
-
 dir_path = dirname(args.output)
 filename = basename(args.output)
 filename_sans_ext, ext = splitext(filename)
@@ -107,16 +102,27 @@ for n in range(args.quantity):
     if exists(path):
         raise Exception(f'Would overwrite file: {path}')
 
-maker = ImageMaker(model_path)
-start_seed = args.seed
+# pylint: disable=wrong-import-position
+from stable.core import ImageMaker  # noqa: E402
+
+# only list available schedulers, if -L
+if args.list_schedulers:
+    maker = ImageMaker(args.model_path)
+    print(f'AVAILABLE SCHEDULERS FOR MODEL {args.gen_params.model_name}:\n')
+    for name in maker.compatible_schedulers:
+        print(name)
+    sys_exit(0)
+
+# otherwise generate pictures
+maker = ImageMaker(args.model_path)
+start_seed = args.gen_params.seed
 start_seed = start_seed if start_seed != 0 else randint(-(2**31-1), 2**31)
 seed_corrector = 0
 for n in range(args.quantity):
     if 0 == start_seed + n + seed_corrector:
         seed_corrector += 1
-    nth_seed = start_seed + n + seed_corrector
+    args.gen_params.seed = start_seed + n + seed_corrector
     path = save_path(n)
-    maker.set_gen_params(args.prompt, nth_seed, args.guidance, args.height,
-                         args.width, args.n_steps, args.scheduler)
-    print(f'GENERATING: {path}; {maker.gen_params_to_exif_comment}')
+    maker.set_gen_params(args.gen_params)
+    print(f'GENERATING: {path}; {args.gen_params.to_str}')
     maker.gen_image_to(path)
diff --git a/stable/core.py b/stable/core.py
index a5f7b56..f2792fc 100644
--- a/stable/core.py
+++ b/stable/core.py
@@ -5,6 +5,7 @@ from diffusers import StableDiffusionPipeline
 from diffusers.utils import logging
 from torch import Generator, float16
 from exiftool import ExifToolHelper  # type: ignore
+from stable.gen_params import GenParams
 
 SAFETY_CHECKER_WARNING_PATTERN = 'You have disabled the safety checker'
 
@@ -22,15 +23,6 @@ class ImageMaker:
             def filter(self, record):
                 return self.target not in record.getMessage()
 
-        self.model_filename = basename(model_path)
-        if '; ' in self.model_filename:
-            raise Exception('illegal filename (must not contain "; ")')
-        self.seed = None
-        self.prompt = None
-        self.guidance = None
-        self.height = None
-        self.width = None
-        self.n_steps = None
         prefix = 'SETTING UP STABLE DIFFUSION PIPELINE FROM MODEL'
         print(f'{prefix}: {model_path}\n')
         diffusers_logging_handler = logging.get_logger('diffusers').handlers[0]
@@ -45,55 +37,35 @@ class ImageMaker:
         self.pipe.to('cuda')
         self.generator = Generator()
         print('PIPELINE READY\n')
+        self.gen_params = None
 
-    def set_seed(self, seed):
-        self.seed = seed
-        self.generator.manual_seed(seed)
-
-    def set_gen_params(self, prompt, seed, guidance, height, width, n_steps,
-                       scheduler_name):
+    def set_gen_params(self, gen_params):
         scheduler_selection = [s for s in self.pipe.scheduler.compatibles
-                               if s.__name__ == scheduler_name]
+                               if s.__name__ == gen_params.scheduler_name]
         if not scheduler_selection:
             raise Exception(f'unknown scheduler: {scheduler_name}')
         self.pipe.scheduler = scheduler_selection[0].from_config(
                 self.pipe.scheduler.config)
-        self.seed = seed
-        self.prompt = prompt
-        self.guidance = guidance
-        self.height = height
-        self.width = width
-        self.n_steps = n_steps
+        self.gen_params = gen_params
 
     @property
     def compatible_schedulers(self):
         return [s.__name__ for s in self.pipe.scheduler.compatibles]
 
-    @property
-    def gen_params_to_exif_comment(self):
-        return f'SEED: {self.seed}; ' +\
-                f'GUIDANCE: {self.guidance}; ' +\
-                f'N_STEPS: {self.n_steps}; ' +\
-                f'HEIGHT: {self.height}; ' +\
-                f'WIDTH: {self.width}; ' +\
-                f'SCHEDULER: {self.pipe.scheduler.__class__.__name__}; ' +\
-                f'MODEL: {self.model_filename}; ' +\
-                f'PROMPT: {self.prompt}'
-
     def gen_image_to(self, path):
-        if None in {self.seed, self.prompt, self.guidance, self.height,
-                    self.width, self.n_steps}:
+        if None in {self.gen_params.seed, self.gen_params.prompt,
+                    self.gen_params.guidance, self.gen_params.height,
+                    self.gen_params.width, self.gen_params.n_steps}:
             raise Exception('Generation parameters not initialized.')
-        self.generator.manual_seed(self.seed)
+        self.generator.manual_seed(self.gen_params.seed)
         image = self.pipe(generator=self.generator,
-                          prompt=self.prompt,
-                          guidance=self.guidance,
-                          height=self.height,
-                          width=self.width,
-                          num_inference_steps=self.n_steps,
+                          prompt=self.gen_params.prompt,
+                          guidance=self.gen_params.guidance,
+                          height=self.gen_params.height,
+                          width=self.gen_params.width,
+                          num_inference_steps=self.gen_params.n_steps,
                           ).images[0]
         image.save(path)
         with ExifToolHelper() as et:
-            et.set_tags([path],
-                        tags={'Comment': self.gen_params_to_exif_comment},
+            et.set_tags([path], tags={'Comment': self.gen_params.to_str},
                         params=['-overwrite_original'])
diff --git a/stable/gen_params.py b/stable/gen_params.py
new file mode 100644
index 0000000..e3ae86f
--- /dev/null
+++ b/stable/gen_params.py
@@ -0,0 +1,50 @@
+class GenParams:
+
+    def __init__(self, seed, guidance, n_steps, height, width, scheduler, model, prompt):
+        if '; ' in model:
+            raise Exception('illegal model filename (must not contain "; ")')
+        self.seed = seed
+        self.guidance = guidance
+        self.n_steps = n_steps
+        self.height, self.width = height, width
+        self.scheduler_name = scheduler
+        self.model_name = model
+        self.prompt = prompt 
+
+    @property
+    def to_str(self):
+        return f'SEED: {self.seed}; ' +\
+                f'GUIDANCE: {self.guidance}; ' +\
+                f'N_STEPS: {self.n_steps}; ' +\
+                f'HEIGHT: {self.height}; ' +\
+                f'WIDTH: {self.width}; ' +\
+                f'SCHEDULER: {self.scheduler_name}; ' +\
+                f'MODEL: {self.model_name}; ' +\
+                f'PROMPT: {self.prompt}'
+
+    @classmethod
+    def from_str(cls, string):
+        d = {}
+        first_split = string.split('; PROMPT: ', maxsplit=1)
+        if 2 == len(first_split):
+            d['prompt'] = first_split[1] 
+        for section in first_split[0].split('; '):
+            key, val = section.split(': ', maxsplit=1)
+            key = key.lower()
+            if key in {'seed', 'height', 'width', 'n_steps'}:
+                val = int(val)
+            elif key in {'guidance'}:
+                val = float(val)
+            d[key] = val
+        return cls(**d)
+
+    @property
+    def as_dict(self):
+        d = {}
+        just_names = {'scheduler', 'model'}
+        for k in ['seed', 'guidance', 'n_steps', 'height', 'width', 'prompt']\
+                + list(just_names):
+            v = getattr(self, f'{k}_name' if k in just_names else k)
+            if v is not None:
+                d[k] = v
+        return d
-- 
2.30.2