#!/usr/bin/env python3
from os.path import dirname, basename, splitext, join as path_join, exists
-from logging import (Formatter as LoggingFormatter, captureWarnings,
- Filter as LoggingFilter)
from argparse import ArgumentParser
from random import randint
from exiftool import ExifToolHelper # type: ignore
from torch import Generator
-from diffusers import StableDiffusionPipeline
-from diffusers.utils import logging
+from stable.core import init_pipeline
DEFAULT_MODEL = './v1-5-pruned-emaonly.safetensors'
-class FilterOutString(LoggingFilter):
+def save_path(count: int) -> str:
+ filename_count = f'_{count:08}' if args.number > 1 else ''
+ return path_join(dir_path, f'{filename_sans_ext}{filename_count}.{ext}')
- def __init__(self, target):
- super().__init__()
- self.target = target
- def filter(self, record):
- return self.target not in record.getMessage()
+def make_metadata(seed, guidance, height, width, model, prompt):
+ return f'SEED: {seed}; GUIDANCE: {guidance}; HEIGHT: {height}; ' +\
+ f'WIDTH: {width}; MODEL: {model}; PROMPT: {prompt}'
def parse_args():
filename = basename(args.output)
filename_sans_ext, ext = splitext(filename)
ext = ext[1:] if ext else 'png'
-
-
-def save_path(count: int) -> str:
- filename_count = f'_{count:08}' if args.number > 1 else ''
- return path_join(dir_path, f'{filename_sans_ext}{filename_count}.{ext}')
-
-
for n in range(args.number):
path = save_path(n)
if exists(path):
raise Exception(f'Would overwrite file: {path}')
-
-print(f'SETTING UP STABLE DIFFUSION PIPELINE FROM MODEL: {args.model}\n')
-diffusers_logging_handler = logging.get_logger('diffusers').handlers[0]
-LOG_FMT = 'DIFFUSERS WARNING: %(message)s\n'
-diffusers_logging_handler.setFormatter(LoggingFormatter(fmt=LOG_FMT))
-SAFETY_MSG_PATTERN= 'You have disabled the safety checker'
-diffusers_logging_handler.addFilter(FilterOutString(SAFETY_MSG_PATTERN))
-captureWarnings(True)
-logging.disable_progress_bar()
-pipe = StableDiffusionPipeline.from_single_file(args.model,
- local_files_only=True)
-print('PIPELINE READY\n')
-
-
-def make_metadata(seed, guidance, height, width, model, prompt):
- return f'SEED: {seed}; GUIDANCE: {guidance}; HEIGHT: {height}; ' +\
- f'WIDTH: {width}; MODEL: {model}; PROMPT: {prompt}'
-
-
-pipe.to('cuda')
+pipe = init_pipeline(args.model)
generator = Generator()
start_seed = args.randomness_seed
start_seed = start_seed if start_seed != 0 else randint(-(2**31-1), 2**31)
--- /dev/null
+from logging import (Formatter as LoggingFormatter, captureWarnings,
+ Filter as LoggingFilter)
+from diffusers import StableDiffusionPipeline
+from diffusers.utils import logging
+
+LOG_FMT = 'DIFFUSERS WARNING: %(message)s\n'
+SAFETY_MSG_PATTERN = 'You have disabled the safety checker'
+
+
+class _FilterOutString(LoggingFilter):
+
+ def __init__(self, target):
+ super().__init__()
+ self.target = target
+
+ def filter(self, record):
+ return self.target not in record.getMessage()
+
+
+def init_pipeline(model):
+ print(f'SETTING UP STABLE DIFFUSION PIPELINE FROM MODEL: {model}\n')
+ diffusers_logging_handler = logging.get_logger('diffusers').handlers[0]
+ diffusers_logging_handler.setFormatter(LoggingFormatter(fmt=LOG_FMT))
+ diffusers_logging_handler.addFilter(_FilterOutString(SAFETY_MSG_PATTERN))
+ captureWarnings(True)
+ logging.disable_progress_bar()
+ pipe = StableDiffusionPipeline.from_single_file(model,
+ local_files_only=True)
+ pipe.to('cuda')
+ print('PIPELINE READY\n')
+ return pipe