def parse_args():
parser = ArgumentParser(add_help=False,
formatter_class=ArgumentDefaultsHelpFormatter)
+ parser.add_argument('output',
+ help='output file path; an underscore followed by an '
+ '8-digit counter may be added, before the extension, '
+ 'especially if if -q > 1 or name already exists; if '
+ 'no extension provided, will automatically add .png')
parser.add_argument('-m', '--model',
- help='model filename (-P will pre prefixed, but may '
- 'also be full path on its own)')
- parser.add_argument('-o', '--output',
- help='output filename or path; if -q > 1 or name '
- 'pre-existing, will insert incremented counter '
- 'number; will append .png extension if not provided')
+ help='model filename (may be prefixed via -M with a '
+ 'directory path; may also be full path on its own)')
parser.add_argument('-p', '--prompt',
help='textual guidance to image generation')
parser.add_argument('-q', '--quantity', default=1, type=int,
parser.add_argument('-w', '--width', default=512, type=int,
help='target width in pixels')
parser.add_argument('-g', '--guidance', default=7.5, type=float,
- help='adherence to text prompt')
+ help='adherence to text prompt (from 1.0 to 30.0)')
parser.add_argument('-n', '--n_steps', default=15, type=int,
help='number of denoising steps')
parser.add_argument('-s', '--seed', default=1, type=int,
from diffusers.utils import logging
from torch import Generator, float16
from PIL.PngImagePlugin import PngInfo
+from stable.gen_params import GenParams
SAFETY_CHECKER_WARNING_PATTERN = 'You have disabled the safety checker'
-class ImageMaker:
+class _safetyWarningFilter(LogFilter):
+ """Solely to remove the content safety checker warning."""
- def __init__(self, model_path):
+ def __init__(self):
+ super().__init__()
- class FilterOut(LogFilter):
+ def filter(self, record):
+ return SAFETY_CHECKER_WARNING_PATTERN not in record.getMessage()
- def __init__(self, target):
- super().__init__()
- self.target = target
- def filter(self, record):
- return self.target not in record.getMessage()
+class ImageMaker:
+ """Set up model pipe and generate images from GenParams put in."""
- prefix = 'SETTING UP STABLE DIFFUSION PIPELINE FROM MODEL'
- print(f'{prefix}: {model_path}\n')
+ def __init__(self, model_path):
+ self._init_logging()
+ msg = f'SETTING UP STABLE DIFFUSION PIPELINE FROM MODEL: {model_path}'
+ print(f'{msg}\n')
+ self._pipe = StableDiffusionPipeline.from_single_file(
+ model_path, torch_dtype=float16, local_files_only=True)
+ self._pipe.to('cuda')
+ self._generator = Generator()
+ print('PIPELINE READY\n')
+ self._gen_params = None
+
+ def _init_logging(self):
diffusers_logging_handler = logging.get_logger('diffusers').handlers[0]
- diffusers_logging_handler.setFormatter(
- LogFormatter(fmt='DIFFUSERS WARNING: %(message)s\n'))
- diffusers_logging_handler.addFilter(
- FilterOut(SAFETY_CHECKER_WARNING_PATTERN))
+ formatter = LogFormatter(fmt='DIFFUSERS WARNING: %(message)s\n')
+ diffusers_logging_handler.setFormatter(formatter)
+ diffusers_logging_handler.addFilter(_safetyWarningFilter())
captureWarnings(True)
logging.disable_progress_bar()
- self.pipe = StableDiffusionPipeline.from_single_file(
- model_path, torch_dtype=float16, local_files_only=True)
- self.pipe.to('cuda')
- self.generator = Generator()
- print('PIPELINE READY\n')
- self.gen_params = None
- def set_gen_params(self, gen_params):
- scheduler_selection = [s for s in self.pipe.scheduler.compatibles
+ def set_gen_params(self, gen_params: GenParams):
+ """Read in gen_params, select scheduler based on them."""
+ scheduler_selection = [s for s in self._pipe.scheduler.compatibles
if s.__name__ == gen_params.scheduler]
if not scheduler_selection:
raise Exception(f'unknown scheduler: {gen_params.scheduler}')
- self.pipe.scheduler = scheduler_selection[0].from_config(
- self.pipe.scheduler.config)
- self.gen_params = gen_params
+ self._pipe.scheduler = scheduler_selection[0].from_config(
+ self._pipe.scheduler.config)
+ self._gen_params = gen_params
@property
- def compatible_schedulers(self):
- return [s.__name__ for s in self.pipe.scheduler.compatibles]
+ def compatible_schedulers(self) -> list[str]:
+ """List of schedulers compatible to selected model."""
+ return [s.__name__ for s in self._pipe.scheduler.compatibles]
- def gen_image_to(self, path):
+ def gen_image_to(self, path: str) -> None:
"""Create image and write as file with metadata to path."""
- if None in {self.gen_params.seed, self.gen_params.prompt,
- self.gen_params.guidance, self.gen_params.height,
- self.gen_params.width, self.gen_params.n_steps}:
+ if None in {self._gen_params.seed, self._gen_params.prompt,
+ self._gen_params.guidance, self._gen_params.height,
+ self._gen_params.width, self._gen_params.n_steps}:
raise Exception('Generation parameters not initialized.')
- self.generator.manual_seed(self.gen_params.seed)
- image = self.pipe(generator=self.generator,
- prompt=self.gen_params.prompt,
- guidance_scale=self.gen_params.guidance,
- height=self.gen_params.height,
- width=self.gen_params.width,
- num_inference_steps=self.gen_params.n_steps,
- ).images[0]
+ self._generator.manual_seed(self._gen_params.seed)
+ image = self._pipe(generator=self._generator,
+ prompt=self._gen_params.prompt,
+ guidance_scale=self._gen_params.guidance,
+ height=self._gen_params.height,
+ width=self._gen_params.width,
+ num_inference_steps=self._gen_params.n_steps,
+ ).images[0]
png_info = PngInfo()
- png_info.add_text('generation_parameters', self.gen_params.to_str)
+ png_info.add_text('generation_parameters', self._gen_params.to_str)
image.save(path, pnginfo=png_info)