From: Christian Heller Date: Sun, 18 Aug 2024 21:29:04 +0000 (+0200) Subject: Do some code linting. X-Git-Url: https://plomlompom.com/repos/%7B%7Bdb.prefix%7D%7D/static/%7B%7B%20web_path%20%7D%7D/blog?a=commitdiff_plain;h=7e2e2ada25d0eb10610a91514e9abac4a383ce26;p=stable_plom Do some code linting. --- diff --git a/requirements.txt b/requirements.txt index 7464f22..8cca4b7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,3 +3,6 @@ pillow torch diffusers transformers +pylint +mypy +flake8 diff --git a/stable.py b/stable.py index 813bf09..3db8d5c 100755 --- a/stable.py +++ b/stable.py @@ -1,15 +1,15 @@ #!/usr/bin/env python3 +from os.path import dirname, basename, splitext, join as path_join, exists +from logging import (Formatter as LoggingFormatter, captureWarnings, + Filter as LoggingFilter) from argparse import ArgumentParser from random import randint -from PIL import Image -from exiftool import ExifToolHelper +from exiftool import ExifToolHelper # type: ignore from torch import Generator from diffusers import StableDiffusionPipeline -from os.path import dirname, basename, splitext, join as path_join, exists -from logging import Formatter as LoggingFormatter, captureWarnings, Filter as LoggingFilter from diffusers.utils import logging -DEFAULT_MODEL='./v1-5-pruned-emaonly.safetensors' +DEFAULT_MODEL = './v1-5-pruned-emaonly.safetensors' class FilterOutString(LoggingFilter): @@ -26,51 +26,80 @@ def parse_args(): parser = ArgumentParser(add_help=False) parser.add_argument('-p', '--prompt', required=True) parser.add_argument('-o', '--output', required=True) - parser.add_argument('-r', '--randomness_seed', default=1, type=int, help='default: 1; if set 0, chosen randomnly') - parser.add_argument('-g', '--guidance', default=7.5, type=float, help='default: 7.5') - parser.add_argument('-s', '--steps', default=15, type=int, help='default: 15') - parser.add_argument('-m', '--model', default=DEFAULT_MODEL, type=str, help=f'default: {DEFAULT_MODEL}') - parser.add_argument('-h', '--height', default=512, type=int, help='default: 512') - parser.add_argument('-w', '--width', default=512, type=int, help='default: 512') - parser.add_argument('-n', '--number', default=1, type=int, help='default: 1') + parser.add_argument('-r', '--randomness_seed', default=1, type=int, + help='default: 1; if set 0, chosen randomnly') + parser.add_argument('-g', '--guidance', default=7.5, type=float, + help='default: 7.5') + parser.add_argument('-s', '--steps', default=15, type=int, + help='default: 15') + parser.add_argument('-m', '--model', default=DEFAULT_MODEL, type=str, + help=f'default: {DEFAULT_MODEL}') + parser.add_argument('-h', '--height', default=512, type=int, + help='default: 512') + parser.add_argument('-w', '--width', default=512, type=int, + help='default: 512') + parser.add_argument('-n', '--number', default=1, type=int, + help='default: 1') parser.add_argument('-H', '--help', action='help') return parser.parse_args() -args = parse_args() +args = parse_args() dir_path = dirname(args.output) filename = basename(args.output) filename_sans_ext, ext = splitext(filename) ext = ext[1:] if ext else 'png' -def save_path(n: int) -> str: - filename_count = f'_{n:08}' if args.number > 1 else '' + +def save_path(count: int) -> str: + filename_count = f'_{count:08}' if args.number > 1 else '' return path_join(dir_path, f'{filename_sans_ext}{filename_count}.{ext}') + for n in range(args.number): path = save_path(n) if exists(path): raise Exception(f'Would overwrite file: {path}') + print(f'SETTING UP STABLE DIFFUSION PIPELINE FROM MODEL: {args.model}\n') diffusers_logging_handler = logging.get_logger('diffusers').handlers[0] -diffusers_logging_handler.setFormatter(LoggingFormatter(fmt='DIFFUSERS WARNING: %(message)s\n')) -diffusers_logging_handler.addFilter(FilterOutString('You have disabled the safety checker')) +LOG_FMT = 'DIFFUSERS WARNING: %(message)s\n' +diffusers_logging_handler.setFormatter(LoggingFormatter(fmt=LOG_FMT)) +SAFETY_MSG_PATTERN= 'You have disabled the safety checker' +diffusers_logging_handler.addFilter(FilterOutString(SAFETY_MSG_PATTERN)) captureWarnings(True) logging.disable_progress_bar() -pipe = StableDiffusionPipeline.from_single_file(args.model, local_files_only=True) +pipe = StableDiffusionPipeline.from_single_file(args.model, + local_files_only=True) print('PIPELINE READY\n') + +def make_metadata(seed, guidance, height, width, model, prompt): + return f'SEED: {seed}; GUIDANCE: {guidance}; HEIGHT: {height}; ' +\ + f'WIDTH: {width}; MODEL: {model}; PROMPT: {prompt}' + + pipe.to('cuda') generator = Generator() -start_seed = randint(-(2**31-1), 2**31) if args.randomness_seed == 0 else args.randomness_seed +start_seed = args.randomness_seed +start_seed = start_seed if start_seed != 0 else randint(-(2**31-1), 2**31) for n in range(args.number): - seed = start_seed + n + nth_seed = start_seed + n path = save_path(n) - metadata = f'SEED: {seed}; GUIDANCE: {args.guidance}; HEIGHT: {args.height}; WIDTH: {args.width}; MODEL: {args.model}; PROMPT: {args.prompt}' + metadata = make_metadata(nth_seed, args.guidance, args.height, args.width, + args.model, args.prompt) print(f'GENERATING: {path}; {metadata}') - generator.manual_seed(seed) - image = pipe(args.prompt, generator=generator, guidance_scale=args.guidance, num_inference_steps=args.steps, height=args.height, width=args.width).images[0] - image.save(path) + generator.manual_seed(nth_seed) + images = pipe(args.prompt, + generator=generator, + guidance_scale=args.guidance, + num_inference_steps=args.steps, + height=args.height, + width=args.width + ).images + images[0].save(path) with ExifToolHelper() as et: - et.set_tags([path], tags={'Comment': metadata}, params=['-overwrite_original']) + et.set_tags([path], + tags={'Comment': metadata}, + params=['-overwrite_original'])