From: Christian Heller Date: Fri, 13 Sep 2024 04:44:39 +0000 (+0200) Subject: Various minor code improvements. X-Git-Url: https://plomlompom.com/repos/%7B%7Bprefix%7D%7D/template?a=commitdiff_plain;h=a0010f5a9c958a17e8e5e11296e3b5b8185167d6;p=stable_plom Various minor code improvements. --- diff --git a/stable.py b/stable.py index 351add1..44833f6 100755 --- a/stable.py +++ b/stable.py @@ -16,13 +16,14 @@ SEED_MAX = 2**31 def parse_args(): parser = ArgumentParser(add_help=False, formatter_class=ArgumentDefaultsHelpFormatter) + parser.add_argument('output', + help='output file path; an underscore followed by an ' + '8-digit counter may be added, before the extension, ' + 'especially if if -q > 1 or name already exists; if ' + 'no extension provided, will automatically add .png') parser.add_argument('-m', '--model', - help='model filename (-P will pre prefixed, but may ' - 'also be full path on its own)') - parser.add_argument('-o', '--output', - help='output filename or path; if -q > 1 or name ' - 'pre-existing, will insert incremented counter ' - 'number; will append .png extension if not provided') + help='model filename (may be prefixed via -M with a ' + 'directory path; may also be full path on its own)') parser.add_argument('-p', '--prompt', help='textual guidance to image generation') parser.add_argument('-q', '--quantity', default=1, type=int, @@ -33,7 +34,7 @@ def parse_args(): parser.add_argument('-w', '--width', default=512, type=int, help='target width in pixels') parser.add_argument('-g', '--guidance', default=7.5, type=float, - help='adherence to text prompt') + help='adherence to text prompt (from 1.0 to 30.0)') parser.add_argument('-n', '--n_steps', default=15, type=int, help='number of denoising steps') parser.add_argument('-s', '--seed', default=1, type=int, diff --git a/stable/core.py b/stable/core.py index 4d0b05a..dc627a5 100644 --- a/stable/core.py +++ b/stable/core.py @@ -4,66 +4,72 @@ from diffusers import StableDiffusionPipeline from diffusers.utils import logging from torch import Generator, float16 from PIL.PngImagePlugin import PngInfo +from stable.gen_params import GenParams SAFETY_CHECKER_WARNING_PATTERN = 'You have disabled the safety checker' -class ImageMaker: +class _safetyWarningFilter(LogFilter): + """Solely to remove the content safety checker warning.""" - def __init__(self, model_path): + def __init__(self): + super().__init__() - class FilterOut(LogFilter): + def filter(self, record): + return SAFETY_CHECKER_WARNING_PATTERN not in record.getMessage() - def __init__(self, target): - super().__init__() - self.target = target - def filter(self, record): - return self.target not in record.getMessage() +class ImageMaker: + """Set up model pipe and generate images from GenParams put in.""" - prefix = 'SETTING UP STABLE DIFFUSION PIPELINE FROM MODEL' - print(f'{prefix}: {model_path}\n') + def __init__(self, model_path): + self._init_logging() + msg = f'SETTING UP STABLE DIFFUSION PIPELINE FROM MODEL: {model_path}' + print(f'{msg}\n') + self._pipe = StableDiffusionPipeline.from_single_file( + model_path, torch_dtype=float16, local_files_only=True) + self._pipe.to('cuda') + self._generator = Generator() + print('PIPELINE READY\n') + self._gen_params = None + + def _init_logging(self): diffusers_logging_handler = logging.get_logger('diffusers').handlers[0] - diffusers_logging_handler.setFormatter( - LogFormatter(fmt='DIFFUSERS WARNING: %(message)s\n')) - diffusers_logging_handler.addFilter( - FilterOut(SAFETY_CHECKER_WARNING_PATTERN)) + formatter = LogFormatter(fmt='DIFFUSERS WARNING: %(message)s\n') + diffusers_logging_handler.setFormatter(formatter) + diffusers_logging_handler.addFilter(_safetyWarningFilter()) captureWarnings(True) logging.disable_progress_bar() - self.pipe = StableDiffusionPipeline.from_single_file( - model_path, torch_dtype=float16, local_files_only=True) - self.pipe.to('cuda') - self.generator = Generator() - print('PIPELINE READY\n') - self.gen_params = None - def set_gen_params(self, gen_params): - scheduler_selection = [s for s in self.pipe.scheduler.compatibles + def set_gen_params(self, gen_params: GenParams): + """Read in gen_params, select scheduler based on them.""" + scheduler_selection = [s for s in self._pipe.scheduler.compatibles if s.__name__ == gen_params.scheduler] if not scheduler_selection: raise Exception(f'unknown scheduler: {gen_params.scheduler}') - self.pipe.scheduler = scheduler_selection[0].from_config( - self.pipe.scheduler.config) - self.gen_params = gen_params + self._pipe.scheduler = scheduler_selection[0].from_config( + self._pipe.scheduler.config) + self._gen_params = gen_params @property - def compatible_schedulers(self): - return [s.__name__ for s in self.pipe.scheduler.compatibles] + def compatible_schedulers(self) -> list[str]: + """List of schedulers compatible to selected model.""" + return [s.__name__ for s in self._pipe.scheduler.compatibles] - def gen_image_to(self, path): + def gen_image_to(self, path: str) -> None: """Create image and write as file with metadata to path.""" - if None in {self.gen_params.seed, self.gen_params.prompt, - self.gen_params.guidance, self.gen_params.height, - self.gen_params.width, self.gen_params.n_steps}: + if None in {self._gen_params.seed, self._gen_params.prompt, + self._gen_params.guidance, self._gen_params.height, + self._gen_params.width, self._gen_params.n_steps}: raise Exception('Generation parameters not initialized.') - self.generator.manual_seed(self.gen_params.seed) - image = self.pipe(generator=self.generator, - prompt=self.gen_params.prompt, - guidance_scale=self.gen_params.guidance, - height=self.gen_params.height, - width=self.gen_params.width, - num_inference_steps=self.gen_params.n_steps, - ).images[0] + self._generator.manual_seed(self._gen_params.seed) + image = self._pipe(generator=self._generator, + prompt=self._gen_params.prompt, + guidance_scale=self._gen_params.guidance, + height=self._gen_params.height, + width=self._gen_params.width, + num_inference_steps=self._gen_params.n_steps, + ).images[0] png_info = PngInfo() - png_info.add_text('generation_parameters', self.gen_params.to_str) + png_info.add_text('generation_parameters', self._gen_params.to_str) image.save(path, pnginfo=png_info)