From 0c5c4da0cb306ca35f49c55909c18268d67fc0ff Mon Sep 17 00:00:00 2001
From: Christian Heller <c.heller@plomlompom.de>
Date: Mon, 19 Aug 2024 01:56:23 +0200
Subject: [PATCH] Reorganize code.

---
 stable.py      | 29 ++++-----------
 stable/core.py | 99 +++++++++++++++++++++++++++++++++++++-------------
 2 files changed, 82 insertions(+), 46 deletions(-)

diff --git a/stable.py b/stable.py
index 641ef48..ea0f598 100755
--- a/stable.py
+++ b/stable.py
@@ -2,9 +2,7 @@
 from os.path import dirname, basename, splitext, join as path_join, exists
 from argparse import ArgumentParser
 from random import randint
-from exiftool import ExifToolHelper  # type: ignore
-from torch import Generator
-from stable.core import init_pipeline, make_metadata
+from stable.core import ImageMaker
 
 
 def save_path(count: int) -> str:
@@ -43,26 +41,15 @@ for n in range(args.number):
     if exists(path):
         raise Exception(f'Would overwrite file: {path}')
 
-pipe = init_pipeline(args.model)
-generator = Generator()
+maker = ImageMaker(args.model)
 start_seed = args.randomness_seed
 start_seed = start_seed if start_seed != 0 else randint(-(2**31-1), 2**31)
 for n in range(args.number):
     nth_seed = start_seed + n
     path = save_path(n)
-    metadata = make_metadata(nth_seed, args.guidance, args.height, args.width,
-                             args.model, args.prompt)
-    print(f'GENERATING: {path}; {metadata}')
-    generator.manual_seed(nth_seed)
-    images = pipe(args.prompt,
-                  generator=generator,
-                  guidance_scale=args.guidance,
-                  num_inference_steps=args.steps,
-                  height=args.height,
-                  width=args.width
-                  ).images
-    images[0].save(path)
-    with ExifToolHelper() as et:
-        et.set_tags([path],
-                    tags={'Comment': metadata},
-                    params=['-overwrite_original'])
+    maker.set_gen_params(args.prompt, nth_seed, args.guidance,
+                         args.height, args.width, args.steps)
+    print(f'GENERATING: {path}; {maker.gen_params_to_exif_comment}')
+    maker.set_gen_params(args.prompt, nth_seed, args.guidance, args.height,
+                         args.width, args.steps)
+    maker.gen_image_to(path)
diff --git a/stable/core.py b/stable/core.py
index d7f9d23..5612ab2 100644
--- a/stable/core.py
+++ b/stable/core.py
@@ -1,36 +1,85 @@
-from logging import (Formatter as LoggingFormatter, captureWarnings,
-                     Filter as LoggingFilter)
+from logging import (Formatter as LogFormatter, captureWarnings,
+                     Filter as LogFilter)
+from os.path import basename
 from diffusers import StableDiffusionPipeline
 from diffusers.utils import logging
+from torch import Generator
+from exiftool import ExifToolHelper  # type: ignore
 
-LOG_FMT = 'DIFFUSERS WARNING: %(message)s\n'
-SAFETY_MSG_PATTERN = 'You have disabled the safety checker'
+SAFETY_CHECKER_WARNING_PATTERN = 'You have disabled the safety checker'
 
 
-class _FilterOutString(LoggingFilter):
+class ImageMaker:
 
-    def __init__(self, target):
-        super().__init__()
-        self.target = target
+    def __init__(self, model_path):
 
-    def filter(self, record):
-        return self.target not in record.getMessage()
+        class FilterOut(LogFilter):
 
+            def __init__(self, target):
+                super().__init__()
+                self.target = target
 
-def init_pipeline(model):
-    print(f'SETTING UP STABLE DIFFUSION PIPELINE FROM MODEL: {model}\n')
-    diffusers_logging_handler = logging.get_logger('diffusers').handlers[0]
-    diffusers_logging_handler.setFormatter(LoggingFormatter(fmt=LOG_FMT))
-    diffusers_logging_handler.addFilter(_FilterOutString(SAFETY_MSG_PATTERN))
-    captureWarnings(True)
-    logging.disable_progress_bar()
-    pipe = StableDiffusionPipeline.from_single_file(model,
-                                                    local_files_only=True)
-    pipe.to('cuda')
-    print('PIPELINE READY\n')
-    return pipe
+            def filter(self, record):
+                return self.target not in record.getMessage()
 
+        self.model_filename = basename(model_path)
+        self.seed = None
+        self.prompt = None
+        self.guidance = None
+        self.height = None
+        self.width = None
+        self.n_steps = None
+        prefix = 'SETTING UP STABLE DIFFUSION PIPELINE FROM MODEL'
+        print(f'{prefix}: {model_path}\n')
+        diffusers_logging_handler = logging.get_logger('diffusers').handlers[0]
+        diffusers_logging_handler.setFormatter(
+                LogFormatter(fmt='DIFFUSERS WARNING: %(message)s\n'))
+        diffusers_logging_handler.addFilter(
+                FilterOut(SAFETY_CHECKER_WARNING_PATTERN))
+        captureWarnings(True)
+        logging.disable_progress_bar()
+        self.pipe = StableDiffusionPipeline.from_single_file(
+                model_path, local_files_only=True)
+        self.pipe.to('cuda')
+        self.generator = Generator()
+        print('PIPELINE READY\n')
 
-def make_metadata(seed, guidance, height, width, model, prompt):
-    return f'SEED: {seed}; GUIDANCE: {guidance}; HEIGHT: {height}; ' +\
-            f'WIDTH: {width}; MODEL: {model}; PROMPT: {prompt}'
+    def set_seed(self, seed):
+        self.seed = seed
+        self.generator.manual_seed(seed)
+
+    def set_gen_params(self, prompt, seed, guidance, height, width, n_steps):
+        self.seed = seed
+        self.prompt = prompt
+        self.guidance = guidance
+        self.height = height
+        self.width = width
+        self.n_steps = n_steps
+
+    @property
+    def gen_params_to_exif_comment(self):
+        return f'SEED: {self.seed}; ' +\
+                f'GUIDANCE: {self.guidance}; ' +\
+                f'NUMBER OF STEPS: {self.n_steps}; ' +\
+                f'HEIGHT: {self.height}; ' +\
+                f'WIDTH: {self.width}; ' +\
+                f'MODEL_FILE: {self.model_filename}; ' +\
+                f'PROMPT: {self.prompt}'
+
+    def gen_image_to(self, path):
+        if None in {self.seed, self.prompt, self.guidance, self.height,
+                    self.width, self.n_steps}:
+            raise Exception('Generation parameters not initialized.')
+        self.generator.manual_seed(self.seed)
+        image = self.pipe(generator=self.generator,
+                          prompt=self.prompt,
+                          guidance=self.guidance,
+                          height=self.height,
+                          width=self.width,
+                          num_inference_steps=self.n_steps,
+                          ).images[0]
+        image.save(path)
+        with ExifToolHelper() as et:
+            et.set_tags([path],
+                        tags={'Comment': self.gen_params_to_exif_comment},
+                        params=['-overwrite_original'])
-- 
2.30.2