From 7e2e2ada25d0eb10610a91514e9abac4a383ce26 Mon Sep 17 00:00:00 2001
From: Christian Heller <c.heller@plomlompom.de>
Date: Sun, 18 Aug 2024 23:29:04 +0200
Subject: [PATCH] Do some code linting.

---
 requirements.txt |  3 ++
 stable.py        | 79 +++++++++++++++++++++++++++++++++---------------
 2 files changed, 57 insertions(+), 25 deletions(-)

diff --git a/requirements.txt b/requirements.txt
index 7464f22..8cca4b7 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -3,3 +3,6 @@ pillow
 torch
 diffusers
 transformers
+pylint
+mypy
+flake8
diff --git a/stable.py b/stable.py
index 813bf09..3db8d5c 100755
--- a/stable.py
+++ b/stable.py
@@ -1,15 +1,15 @@
 #!/usr/bin/env python3
+from os.path import dirname, basename, splitext, join as path_join, exists
+from logging import (Formatter as LoggingFormatter, captureWarnings,
+                     Filter as LoggingFilter)
 from argparse import ArgumentParser
 from random import randint
-from PIL import Image
-from exiftool import ExifToolHelper
+from exiftool import ExifToolHelper  # type: ignore
 from torch import Generator
 from diffusers import StableDiffusionPipeline
-from os.path import dirname, basename, splitext, join as path_join, exists
-from logging import Formatter as LoggingFormatter, captureWarnings, Filter as LoggingFilter
 from diffusers.utils import logging
 
-DEFAULT_MODEL='./v1-5-pruned-emaonly.safetensors'
+DEFAULT_MODEL = './v1-5-pruned-emaonly.safetensors'
 
 
 class FilterOutString(LoggingFilter):
@@ -26,51 +26,80 @@ def parse_args():
     parser = ArgumentParser(add_help=False)
     parser.add_argument('-p', '--prompt', required=True)
     parser.add_argument('-o', '--output', required=True)
-    parser.add_argument('-r', '--randomness_seed', default=1, type=int, help='default: 1; if set 0, chosen randomnly')
-    parser.add_argument('-g', '--guidance', default=7.5, type=float, help='default: 7.5')
-    parser.add_argument('-s', '--steps', default=15, type=int, help='default: 15')
-    parser.add_argument('-m', '--model', default=DEFAULT_MODEL, type=str, help=f'default: {DEFAULT_MODEL}')
-    parser.add_argument('-h', '--height', default=512, type=int, help='default: 512')
-    parser.add_argument('-w', '--width', default=512, type=int, help='default: 512')
-    parser.add_argument('-n', '--number', default=1, type=int, help='default: 1')
+    parser.add_argument('-r', '--randomness_seed', default=1, type=int,
+                        help='default: 1; if set 0, chosen randomnly')
+    parser.add_argument('-g', '--guidance', default=7.5, type=float,
+                        help='default: 7.5')
+    parser.add_argument('-s', '--steps', default=15, type=int,
+                        help='default: 15')
+    parser.add_argument('-m', '--model', default=DEFAULT_MODEL, type=str,
+                        help=f'default: {DEFAULT_MODEL}')
+    parser.add_argument('-h', '--height', default=512, type=int,
+                        help='default: 512')
+    parser.add_argument('-w', '--width', default=512, type=int,
+                        help='default: 512')
+    parser.add_argument('-n', '--number', default=1, type=int,
+                        help='default: 1')
     parser.add_argument('-H', '--help', action='help')
     return parser.parse_args()
 
-args = parse_args()
 
+args = parse_args()
 dir_path = dirname(args.output)
 filename = basename(args.output)
 filename_sans_ext, ext = splitext(filename)
 ext = ext[1:] if ext else 'png'
 
-def save_path(n: int) -> str:
-    filename_count = f'_{n:08}' if args.number > 1 else ''
+
+def save_path(count: int) -> str:
+    filename_count = f'_{count:08}' if args.number > 1 else ''
     return path_join(dir_path, f'{filename_sans_ext}{filename_count}.{ext}')
 
+
 for n in range(args.number):
     path = save_path(n)
     if exists(path):
         raise Exception(f'Would overwrite file: {path}')
 
+
 print(f'SETTING UP STABLE DIFFUSION PIPELINE FROM MODEL: {args.model}\n')
 diffusers_logging_handler = logging.get_logger('diffusers').handlers[0]
-diffusers_logging_handler.setFormatter(LoggingFormatter(fmt='DIFFUSERS WARNING: %(message)s\n'))
-diffusers_logging_handler.addFilter(FilterOutString('You have disabled the safety checker'))
+LOG_FMT = 'DIFFUSERS WARNING: %(message)s\n'
+diffusers_logging_handler.setFormatter(LoggingFormatter(fmt=LOG_FMT))
+SAFETY_MSG_PATTERN= 'You have disabled the safety checker'
+diffusers_logging_handler.addFilter(FilterOutString(SAFETY_MSG_PATTERN))
 captureWarnings(True)
 logging.disable_progress_bar()
-pipe = StableDiffusionPipeline.from_single_file(args.model, local_files_only=True)
+pipe = StableDiffusionPipeline.from_single_file(args.model,
+                                                local_files_only=True)
 print('PIPELINE READY\n')
 
+
+def make_metadata(seed, guidance, height, width, model, prompt):
+    return f'SEED: {seed}; GUIDANCE: {guidance}; HEIGHT: {height}; ' +\
+            f'WIDTH: {width}; MODEL: {model}; PROMPT: {prompt}'
+
+
 pipe.to('cuda')
 generator = Generator()
-start_seed = randint(-(2**31-1), 2**31) if args.randomness_seed == 0 else args.randomness_seed
+start_seed = args.randomness_seed
+start_seed = start_seed if start_seed != 0 else randint(-(2**31-1), 2**31)
 for n in range(args.number):
-    seed = start_seed + n
+    nth_seed = start_seed + n
     path = save_path(n)
-    metadata = f'SEED: {seed}; GUIDANCE: {args.guidance}; HEIGHT: {args.height}; WIDTH: {args.width}; MODEL: {args.model}; PROMPT: {args.prompt}'
+    metadata = make_metadata(nth_seed, args.guidance, args.height, args.width,
+                             args.model, args.prompt)
     print(f'GENERATING: {path}; {metadata}')
-    generator.manual_seed(seed)
-    image = pipe(args.prompt, generator=generator, guidance_scale=args.guidance, num_inference_steps=args.steps, height=args.height, width=args.width).images[0]
-    image.save(path)
+    generator.manual_seed(nth_seed)
+    images = pipe(args.prompt,
+                  generator=generator,
+                  guidance_scale=args.guidance,
+                  num_inference_steps=args.steps,
+                  height=args.height,
+                  width=args.width
+                  ).images
+    images[0].save(path)
     with ExifToolHelper() as et:
-        et.set_tags([path], tags={'Comment': metadata}, params=['-overwrite_original'])
+        et.set_tags([path],
+                    tags={'Comment': metadata},
+                    params=['-overwrite_original'])
-- 
2.30.2