home · contact · privacy
Various minor code improvements. master
authorChristian Heller <c.heller@plomlompom.de>
Fri, 13 Sep 2024 04:44:39 +0000 (06:44 +0200)
committerChristian Heller <c.heller@plomlompom.de>
Fri, 13 Sep 2024 04:44:39 +0000 (06:44 +0200)
stable.py
stable/core.py

index 351add1b860d941c8863afaa274ed4513fb7ff9f..44833f623b318c3d7d9e45dc96807f92e7dc0ce4 100755 (executable)
--- a/stable.py
+++ b/stable.py
@@ -16,13 +16,14 @@ SEED_MAX = 2**31
 def parse_args():
     parser = ArgumentParser(add_help=False,
                             formatter_class=ArgumentDefaultsHelpFormatter)
+    parser.add_argument('output',
+                        help='output file path; an underscore followed by an '
+                        '8-digit counter may be added, before the extension, '
+                        'especially if if -q > 1 or name already exists; if '
+                        'no extension provided, will automatically add .png')
     parser.add_argument('-m', '--model',
-                        help='model filename (-P will pre prefixed, but may '
-                        'also be full path on its own)')
-    parser.add_argument('-o', '--output',
-                        help='output filename or path; if -q > 1 or name '
-                        'pre-existing, will insert incremented counter '
-                        'number; will append .png extension if not provided')
+                        help='model filename (may be prefixed via -M with a '
+                        'directory path; may also be full path on its own)')
     parser.add_argument('-p', '--prompt',
                         help='textual guidance to image generation')
     parser.add_argument('-q', '--quantity', default=1, type=int,
@@ -33,7 +34,7 @@ def parse_args():
     parser.add_argument('-w', '--width', default=512, type=int,
                         help='target width in pixels')
     parser.add_argument('-g', '--guidance', default=7.5, type=float,
-                        help='adherence to text prompt')
+                        help='adherence to text prompt (from 1.0 to 30.0)')
     parser.add_argument('-n', '--n_steps', default=15, type=int,
                         help='number of denoising steps')
     parser.add_argument('-s', '--seed', default=1, type=int,
index 4d0b05a808d6b779f1eaa6879a7add079f9eef22..dc627a50bc580f1a2af301da784f89d5dad2fafd 100644 (file)
@@ -4,66 +4,72 @@ from diffusers import StableDiffusionPipeline
 from diffusers.utils import logging
 from torch import Generator, float16
 from PIL.PngImagePlugin import PngInfo
+from stable.gen_params import GenParams
 
 SAFETY_CHECKER_WARNING_PATTERN = 'You have disabled the safety checker'
 
 
-class ImageMaker:
+class _safetyWarningFilter(LogFilter):
+    """Solely to remove the content safety checker warning."""
 
-    def __init__(self, model_path):
+    def __init__(self):
+        super().__init__()
 
-        class FilterOut(LogFilter):
+    def filter(self, record):
+        return SAFETY_CHECKER_WARNING_PATTERN not in record.getMessage()
 
-            def __init__(self, target):
-                super().__init__()
-                self.target = target
 
-            def filter(self, record):
-                return self.target not in record.getMessage()
+class ImageMaker:
+    """Set up model pipe and generate images from GenParams put in."""
 
-        prefix = 'SETTING UP STABLE DIFFUSION PIPELINE FROM MODEL'
-        print(f'{prefix}: {model_path}\n')
+    def __init__(self, model_path):
+        self._init_logging()
+        msg = f'SETTING UP STABLE DIFFUSION PIPELINE FROM MODEL: {model_path}'
+        print(f'{msg}\n')
+        self._pipe = StableDiffusionPipeline.from_single_file(
+                model_path, torch_dtype=float16, local_files_only=True)
+        self._pipe.to('cuda')
+        self._generator = Generator()
+        print('PIPELINE READY\n')
+        self._gen_params = None
+
+    def _init_logging(self):
         diffusers_logging_handler = logging.get_logger('diffusers').handlers[0]
-        diffusers_logging_handler.setFormatter(
-                LogFormatter(fmt='DIFFUSERS WARNING: %(message)s\n'))
-        diffusers_logging_handler.addFilter(
-                FilterOut(SAFETY_CHECKER_WARNING_PATTERN))
+        formatter = LogFormatter(fmt='DIFFUSERS WARNING: %(message)s\n')
+        diffusers_logging_handler.setFormatter(formatter)
+        diffusers_logging_handler.addFilter(_safetyWarningFilter())
         captureWarnings(True)
         logging.disable_progress_bar()
-        self.pipe = StableDiffusionPipeline.from_single_file(
-                model_path, torch_dtype=float16, local_files_only=True)
-        self.pipe.to('cuda')
-        self.generator = Generator()
-        print('PIPELINE READY\n')
-        self.gen_params = None
 
-    def set_gen_params(self, gen_params):
-        scheduler_selection = [s for s in self.pipe.scheduler.compatibles
+    def set_gen_params(self, gen_params: GenParams):
+        """Read in gen_params, select scheduler based on them."""
+        scheduler_selection = [s for s in self._pipe.scheduler.compatibles
                                if s.__name__ == gen_params.scheduler]
         if not scheduler_selection:
             raise Exception(f'unknown scheduler: {gen_params.scheduler}')
-        self.pipe.scheduler = scheduler_selection[0].from_config(
-                self.pipe.scheduler.config)
-        self.gen_params = gen_params
+        self._pipe.scheduler = scheduler_selection[0].from_config(
+                self._pipe.scheduler.config)
+        self._gen_params = gen_params
 
     @property
-    def compatible_schedulers(self):
-        return [s.__name__ for s in self.pipe.scheduler.compatibles]
+    def compatible_schedulers(self) -> list[str]:
+        """List of schedulers compatible to selected model."""
+        return [s.__name__ for s in self._pipe.scheduler.compatibles]
 
-    def gen_image_to(self, path):
+    def gen_image_to(self, path: str) -> None:
         """Create image and write as file with metadata to path."""
-        if None in {self.gen_params.seed, self.gen_params.prompt,
-                    self.gen_params.guidance, self.gen_params.height,
-                    self.gen_params.width, self.gen_params.n_steps}:
+        if None in {self._gen_params.seed, self._gen_params.prompt,
+                    self._gen_params.guidance, self._gen_params.height,
+                    self._gen_params.width, self._gen_params.n_steps}:
             raise Exception('Generation parameters not initialized.')
-        self.generator.manual_seed(self.gen_params.seed)
-        image = self.pipe(generator=self.generator,
-                          prompt=self.gen_params.prompt,
-                          guidance_scale=self.gen_params.guidance,
-                          height=self.gen_params.height,
-                          width=self.gen_params.width,
-                          num_inference_steps=self.gen_params.n_steps,
-                          ).images[0]
+        self._generator.manual_seed(self._gen_params.seed)
+        image = self._pipe(generator=self._generator,
+                           prompt=self._gen_params.prompt,
+                           guidance_scale=self._gen_params.guidance,
+                           height=self._gen_params.height,
+                           width=self._gen_params.width,
+                           num_inference_steps=self._gen_params.n_steps,
+                           ).images[0]
         png_info = PngInfo()
-        png_info.add_text('generation_parameters', self.gen_params.to_str)
+        png_info.add_text('generation_parameters', self._gen_params.to_str)
         image.save(path, pnginfo=png_info)