From: Christian Heller Date: Sun, 18 Aug 2024 21:42:25 +0000 (+0200) Subject: Re-organize code into multiple files. X-Git-Url: https://plomlompom.com/repos/te"st.html?a=commitdiff_plain;h=b001b3f2d49caadbd3fd5e6f7a572cd1408c1fd3;p=stable_plom Re-organize code into multiple files. --- diff --git a/stable.py b/stable.py index 3db8d5c..49bb1d1 100755 --- a/stable.py +++ b/stable.py @@ -1,25 +1,22 @@ #!/usr/bin/env python3 from os.path import dirname, basename, splitext, join as path_join, exists -from logging import (Formatter as LoggingFormatter, captureWarnings, - Filter as LoggingFilter) from argparse import ArgumentParser from random import randint from exiftool import ExifToolHelper # type: ignore from torch import Generator -from diffusers import StableDiffusionPipeline -from diffusers.utils import logging +from stable.core import init_pipeline DEFAULT_MODEL = './v1-5-pruned-emaonly.safetensors' -class FilterOutString(LoggingFilter): +def save_path(count: int) -> str: + filename_count = f'_{count:08}' if args.number > 1 else '' + return path_join(dir_path, f'{filename_sans_ext}{filename_count}.{ext}') - def __init__(self, target): - super().__init__() - self.target = target - def filter(self, record): - return self.target not in record.getMessage() +def make_metadata(seed, guidance, height, width, model, prompt): + return f'SEED: {seed}; GUIDANCE: {guidance}; HEIGHT: {height}; ' +\ + f'WIDTH: {width}; MODEL: {model}; PROMPT: {prompt}' def parse_args(): @@ -49,38 +46,12 @@ dir_path = dirname(args.output) filename = basename(args.output) filename_sans_ext, ext = splitext(filename) ext = ext[1:] if ext else 'png' - - -def save_path(count: int) -> str: - filename_count = f'_{count:08}' if args.number > 1 else '' - return path_join(dir_path, f'{filename_sans_ext}{filename_count}.{ext}') - - for n in range(args.number): path = save_path(n) if exists(path): raise Exception(f'Would overwrite file: {path}') - -print(f'SETTING UP STABLE DIFFUSION PIPELINE FROM MODEL: {args.model}\n') -diffusers_logging_handler = logging.get_logger('diffusers').handlers[0] -LOG_FMT = 'DIFFUSERS WARNING: %(message)s\n' -diffusers_logging_handler.setFormatter(LoggingFormatter(fmt=LOG_FMT)) -SAFETY_MSG_PATTERN= 'You have disabled the safety checker' -diffusers_logging_handler.addFilter(FilterOutString(SAFETY_MSG_PATTERN)) -captureWarnings(True) -logging.disable_progress_bar() -pipe = StableDiffusionPipeline.from_single_file(args.model, - local_files_only=True) -print('PIPELINE READY\n') - - -def make_metadata(seed, guidance, height, width, model, prompt): - return f'SEED: {seed}; GUIDANCE: {guidance}; HEIGHT: {height}; ' +\ - f'WIDTH: {width}; MODEL: {model}; PROMPT: {prompt}' - - -pipe.to('cuda') +pipe = init_pipeline(args.model) generator = Generator() start_seed = args.randomness_seed start_seed = start_seed if start_seed != 0 else randint(-(2**31-1), 2**31) diff --git a/stable/__init__.py b/stable/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/stable/core.py b/stable/core.py new file mode 100644 index 0000000..93cc897 --- /dev/null +++ b/stable/core.py @@ -0,0 +1,31 @@ +from logging import (Formatter as LoggingFormatter, captureWarnings, + Filter as LoggingFilter) +from diffusers import StableDiffusionPipeline +from diffusers.utils import logging + +LOG_FMT = 'DIFFUSERS WARNING: %(message)s\n' +SAFETY_MSG_PATTERN = 'You have disabled the safety checker' + + +class _FilterOutString(LoggingFilter): + + def __init__(self, target): + super().__init__() + self.target = target + + def filter(self, record): + return self.target not in record.getMessage() + + +def init_pipeline(model): + print(f'SETTING UP STABLE DIFFUSION PIPELINE FROM MODEL: {model}\n') + diffusers_logging_handler = logging.get_logger('diffusers').handlers[0] + diffusers_logging_handler.setFormatter(LoggingFormatter(fmt=LOG_FMT)) + diffusers_logging_handler.addFilter(_FilterOutString(SAFETY_MSG_PATTERN)) + captureWarnings(True) + logging.disable_progress_bar() + pipe = StableDiffusionPipeline.from_single_file(model, + local_files_only=True) + pipe.to('cuda') + print('PIPELINE READY\n') + return pipe