From: Christian Heller Date: Thu, 22 Aug 2024 06:30:22 +0000 (+0200) Subject: Refactor GenParams code. X-Git-Url: https://plomlompom.com/repos/condition_descriptions?a=commitdiff_plain;h=e29512355f7ae0e84d5e37bd4ddea7faa594e6ad;p=stable_plom Refactor GenParams code. --- diff --git a/stable.py b/stable.py index 665b3a9..0bea6f9 100755 --- a/stable.py +++ b/stable.py @@ -3,6 +3,7 @@ from sys import argv, exit as sys_exit, stdin from os.path import dirname, basename, splitext, join as path_join, exists from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter from random import randint +from stable.gen_params import GenParams DEFAULT_SCHEDULER = 'EulerAncestralDiscreteScheduler' @@ -43,31 +44,23 @@ def parse_args(): help='name of denoising scheduler; --list-schedulers ' 'prints available choices for chosen model') parser.add_argument('-D', '--defaults_from_stdin', action='store_true', - help='parse stdin for image generation options ' + help='parse stdin for image generation parameters' '(e.g. from image file EXIF comment)') - parser.add_argument('-P', '--model_path_prefix', - help='optional directory path prefix to MODEL for ' - 'where it\'s provided as a mere filename (as may ' - 'happen due to -D)') + parser.add_argument('-M', '--models_dir', + help='directory path prefix to MODEL (for where' + 'it\'s provided as a mere filename, as with -D)') parser.add_argument('-H', '--help', action='help') parser.add_argument('-L', '--list-schedulers', action='store_true', help='list options for -S available with chosen model') + + # Re-seed defaults from stdin, if -D. parsed_args = parser.parse_args() - parser.set_defaults(model_path_prefix='') if parsed_args.defaults_from_stdin: - defaults_string = stdin.read() - first_split = defaults_string.split('; PROMPT: ', maxsplit=1) - if 2 == len(first_split): - parser.set_defaults(prompt=first_split[1]) - for section in first_split[0].split('; '): - key, val = section.split(': ', maxsplit=1) - key = key.lower() - if key in {'seed', 'height', 'width', 'n_steps'}: - val = int(val) - elif key in {'guidance'}: - val = float(val) - parser.set_defaults(**{key: val}) + gen_params = GenParams.from_str(stdin.read()) + parser.set_defaults(**gen_params.as_dict) parsed_args = parser.parse_args() + + # Validate input. prefix = f'{argv[0]}: error: ' if parsed_args.list_schedulers: required = {'model': 'm'} @@ -84,20 +77,22 @@ def parse_args(): parser.print_usage() print(f'{prefix}{", ".join(suffixes)}') sys_exit(1) + + # Prepare generation parameters, model path, etc. + if not parsed_args.models_dir: + parsed_args.models_dir = dirname(parsed_args.model) + parsed_args.model = basename(parsed_args.model) + parsed_args.gen_params = GenParams( + parsed_args.seed, parsed_args.guidance, parsed_args.n_steps, + parsed_args.height, parsed_args.width, parsed_args.scheduler, + parsed_args.model, parsed_args.prompt) + parsed_args.model_path = path_join(parsed_args.models_dir, + parsed_args.gen_params.model_name) return parsed_args +# guard against overwriting args = parse_args() -# pylint: disable=wrong-import-position -from stable.core import ImageMaker # noqa: E402 -model_path = f'{args.model_path_prefix}{args.model}' -if args.list_schedulers: - maker = ImageMaker(model_path) - print(f'AVAILABLE SCHEDULERS FOR MODEL {args.model}:\n') - for name in maker.compatible_schedulers: - print(name) - sys_exit(0) - dir_path = dirname(args.output) filename = basename(args.output) filename_sans_ext, ext = splitext(filename) @@ -107,16 +102,27 @@ for n in range(args.quantity): if exists(path): raise Exception(f'Would overwrite file: {path}') -maker = ImageMaker(model_path) -start_seed = args.seed +# pylint: disable=wrong-import-position +from stable.core import ImageMaker # noqa: E402 + +# only list available schedulers, if -L +if args.list_schedulers: + maker = ImageMaker(args.model_path) + print(f'AVAILABLE SCHEDULERS FOR MODEL {args.gen_params.model_name}:\n') + for name in maker.compatible_schedulers: + print(name) + sys_exit(0) + +# otherwise generate pictures +maker = ImageMaker(args.model_path) +start_seed = args.gen_params.seed start_seed = start_seed if start_seed != 0 else randint(-(2**31-1), 2**31) seed_corrector = 0 for n in range(args.quantity): if 0 == start_seed + n + seed_corrector: seed_corrector += 1 - nth_seed = start_seed + n + seed_corrector + args.gen_params.seed = start_seed + n + seed_corrector path = save_path(n) - maker.set_gen_params(args.prompt, nth_seed, args.guidance, args.height, - args.width, args.n_steps, args.scheduler) - print(f'GENERATING: {path}; {maker.gen_params_to_exif_comment}') + maker.set_gen_params(args.gen_params) + print(f'GENERATING: {path}; {args.gen_params.to_str}') maker.gen_image_to(path) diff --git a/stable/core.py b/stable/core.py index a5f7b56..f2792fc 100644 --- a/stable/core.py +++ b/stable/core.py @@ -5,6 +5,7 @@ from diffusers import StableDiffusionPipeline from diffusers.utils import logging from torch import Generator, float16 from exiftool import ExifToolHelper # type: ignore +from stable.gen_params import GenParams SAFETY_CHECKER_WARNING_PATTERN = 'You have disabled the safety checker' @@ -22,15 +23,6 @@ class ImageMaker: def filter(self, record): return self.target not in record.getMessage() - self.model_filename = basename(model_path) - if '; ' in self.model_filename: - raise Exception('illegal filename (must not contain "; ")') - self.seed = None - self.prompt = None - self.guidance = None - self.height = None - self.width = None - self.n_steps = None prefix = 'SETTING UP STABLE DIFFUSION PIPELINE FROM MODEL' print(f'{prefix}: {model_path}\n') diffusers_logging_handler = logging.get_logger('diffusers').handlers[0] @@ -45,55 +37,35 @@ class ImageMaker: self.pipe.to('cuda') self.generator = Generator() print('PIPELINE READY\n') + self.gen_params = None - def set_seed(self, seed): - self.seed = seed - self.generator.manual_seed(seed) - - def set_gen_params(self, prompt, seed, guidance, height, width, n_steps, - scheduler_name): + def set_gen_params(self, gen_params): scheduler_selection = [s for s in self.pipe.scheduler.compatibles - if s.__name__ == scheduler_name] + if s.__name__ == gen_params.scheduler_name] if not scheduler_selection: raise Exception(f'unknown scheduler: {scheduler_name}') self.pipe.scheduler = scheduler_selection[0].from_config( self.pipe.scheduler.config) - self.seed = seed - self.prompt = prompt - self.guidance = guidance - self.height = height - self.width = width - self.n_steps = n_steps + self.gen_params = gen_params @property def compatible_schedulers(self): return [s.__name__ for s in self.pipe.scheduler.compatibles] - @property - def gen_params_to_exif_comment(self): - return f'SEED: {self.seed}; ' +\ - f'GUIDANCE: {self.guidance}; ' +\ - f'N_STEPS: {self.n_steps}; ' +\ - f'HEIGHT: {self.height}; ' +\ - f'WIDTH: {self.width}; ' +\ - f'SCHEDULER: {self.pipe.scheduler.__class__.__name__}; ' +\ - f'MODEL: {self.model_filename}; ' +\ - f'PROMPT: {self.prompt}' - def gen_image_to(self, path): - if None in {self.seed, self.prompt, self.guidance, self.height, - self.width, self.n_steps}: + if None in {self.gen_params.seed, self.gen_params.prompt, + self.gen_params.guidance, self.gen_params.height, + self.gen_params.width, self.gen_params.n_steps}: raise Exception('Generation parameters not initialized.') - self.generator.manual_seed(self.seed) + self.generator.manual_seed(self.gen_params.seed) image = self.pipe(generator=self.generator, - prompt=self.prompt, - guidance=self.guidance, - height=self.height, - width=self.width, - num_inference_steps=self.n_steps, + prompt=self.gen_params.prompt, + guidance=self.gen_params.guidance, + height=self.gen_params.height, + width=self.gen_params.width, + num_inference_steps=self.gen_params.n_steps, ).images[0] image.save(path) with ExifToolHelper() as et: - et.set_tags([path], - tags={'Comment': self.gen_params_to_exif_comment}, + et.set_tags([path], tags={'Comment': self.gen_params.to_str}, params=['-overwrite_original']) diff --git a/stable/gen_params.py b/stable/gen_params.py new file mode 100644 index 0000000..e3ae86f --- /dev/null +++ b/stable/gen_params.py @@ -0,0 +1,50 @@ +class GenParams: + + def __init__(self, seed, guidance, n_steps, height, width, scheduler, model, prompt): + if '; ' in model: + raise Exception('illegal model filename (must not contain "; ")') + self.seed = seed + self.guidance = guidance + self.n_steps = n_steps + self.height, self.width = height, width + self.scheduler_name = scheduler + self.model_name = model + self.prompt = prompt + + @property + def to_str(self): + return f'SEED: {self.seed}; ' +\ + f'GUIDANCE: {self.guidance}; ' +\ + f'N_STEPS: {self.n_steps}; ' +\ + f'HEIGHT: {self.height}; ' +\ + f'WIDTH: {self.width}; ' +\ + f'SCHEDULER: {self.scheduler_name}; ' +\ + f'MODEL: {self.model_name}; ' +\ + f'PROMPT: {self.prompt}' + + @classmethod + def from_str(cls, string): + d = {} + first_split = string.split('; PROMPT: ', maxsplit=1) + if 2 == len(first_split): + d['prompt'] = first_split[1] + for section in first_split[0].split('; '): + key, val = section.split(': ', maxsplit=1) + key = key.lower() + if key in {'seed', 'height', 'width', 'n_steps'}: + val = int(val) + elif key in {'guidance'}: + val = float(val) + d[key] = val + return cls(**d) + + @property + def as_dict(self): + d = {} + just_names = {'scheduler', 'model'} + for k in ['seed', 'guidance', 'n_steps', 'height', 'width', 'prompt']\ + + list(just_names): + v = getattr(self, f'{k}_name' if k in just_names else k) + if v is not None: + d[k] = v + return d