home · contact · privacy
Re-organize code into multiple files.
authorChristian Heller <c.heller@plomlompom.de>
Sun, 18 Aug 2024 21:42:25 +0000 (23:42 +0200)
committerChristian Heller <c.heller@plomlompom.de>
Sun, 18 Aug 2024 21:42:25 +0000 (23:42 +0200)
stable.py
stable/__init__.py [new file with mode: 0644]
stable/core.py [new file with mode: 0644]

index 3db8d5cfa372674c6a6cebd0d525c5f0dd0d6fad..49bb1d17af6dbb3b0c77ea9415b2285affdf612f 100755 (executable)
--- a/stable.py
+++ b/stable.py
@@ -1,25 +1,22 @@
 #!/usr/bin/env python3
 from os.path import dirname, basename, splitext, join as path_join, exists
-from logging import (Formatter as LoggingFormatter, captureWarnings,
-                     Filter as LoggingFilter)
 from argparse import ArgumentParser
 from random import randint
 from exiftool import ExifToolHelper  # type: ignore
 from torch import Generator
-from diffusers import StableDiffusionPipeline
-from diffusers.utils import logging
+from stable.core import init_pipeline
 
 DEFAULT_MODEL = './v1-5-pruned-emaonly.safetensors'
 
 
-class FilterOutString(LoggingFilter):
+def save_path(count: int) -> str:
+    filename_count = f'_{count:08}' if args.number > 1 else ''
+    return path_join(dir_path, f'{filename_sans_ext}{filename_count}.{ext}')
 
-    def __init__(self, target):
-        super().__init__()
-        self.target = target
 
-    def filter(self, record):
-        return self.target not in record.getMessage()
+def make_metadata(seed, guidance, height, width, model, prompt):
+    return f'SEED: {seed}; GUIDANCE: {guidance}; HEIGHT: {height}; ' +\
+            f'WIDTH: {width}; MODEL: {model}; PROMPT: {prompt}'
 
 
 def parse_args():
@@ -49,38 +46,12 @@ dir_path = dirname(args.output)
 filename = basename(args.output)
 filename_sans_ext, ext = splitext(filename)
 ext = ext[1:] if ext else 'png'
-
-
-def save_path(count: int) -> str:
-    filename_count = f'_{count:08}' if args.number > 1 else ''
-    return path_join(dir_path, f'{filename_sans_ext}{filename_count}.{ext}')
-
-
 for n in range(args.number):
     path = save_path(n)
     if exists(path):
         raise Exception(f'Would overwrite file: {path}')
 
-
-print(f'SETTING UP STABLE DIFFUSION PIPELINE FROM MODEL: {args.model}\n')
-diffusers_logging_handler = logging.get_logger('diffusers').handlers[0]
-LOG_FMT = 'DIFFUSERS WARNING: %(message)s\n'
-diffusers_logging_handler.setFormatter(LoggingFormatter(fmt=LOG_FMT))
-SAFETY_MSG_PATTERN= 'You have disabled the safety checker'
-diffusers_logging_handler.addFilter(FilterOutString(SAFETY_MSG_PATTERN))
-captureWarnings(True)
-logging.disable_progress_bar()
-pipe = StableDiffusionPipeline.from_single_file(args.model,
-                                                local_files_only=True)
-print('PIPELINE READY\n')
-
-
-def make_metadata(seed, guidance, height, width, model, prompt):
-    return f'SEED: {seed}; GUIDANCE: {guidance}; HEIGHT: {height}; ' +\
-            f'WIDTH: {width}; MODEL: {model}; PROMPT: {prompt}'
-
-
-pipe.to('cuda')
+pipe = init_pipeline(args.model)
 generator = Generator()
 start_seed = args.randomness_seed
 start_seed = start_seed if start_seed != 0 else randint(-(2**31-1), 2**31)
diff --git a/stable/__init__.py b/stable/__init__.py
new file mode 100644 (file)
index 0000000..e69de29
diff --git a/stable/core.py b/stable/core.py
new file mode 100644 (file)
index 0000000..93cc897
--- /dev/null
@@ -0,0 +1,31 @@
+from logging import (Formatter as LoggingFormatter, captureWarnings,
+                     Filter as LoggingFilter)
+from diffusers import StableDiffusionPipeline
+from diffusers.utils import logging
+
+LOG_FMT = 'DIFFUSERS WARNING: %(message)s\n'
+SAFETY_MSG_PATTERN = 'You have disabled the safety checker'
+
+
+class _FilterOutString(LoggingFilter):
+
+    def __init__(self, target):
+        super().__init__()
+        self.target = target
+
+    def filter(self, record):
+        return self.target not in record.getMessage()
+
+
+def init_pipeline(model):
+    print(f'SETTING UP STABLE DIFFUSION PIPELINE FROM MODEL: {model}\n')
+    diffusers_logging_handler = logging.get_logger('diffusers').handlers[0]
+    diffusers_logging_handler.setFormatter(LoggingFormatter(fmt=LOG_FMT))
+    diffusers_logging_handler.addFilter(_FilterOutString(SAFETY_MSG_PATTERN))
+    captureWarnings(True)
+    logging.disable_progress_bar()
+    pipe = StableDiffusionPipeline.from_single_file(model,
+                                                    local_files_only=True)
+    pipe.to('cuda')
+    print('PIPELINE READY\n')
+    return pipe