From 1be102192d2a102e3199ea29c5f4ee592d35578c Mon Sep 17 00:00:00 2001 From: Christian Heller Date: Thu, 22 Aug 2024 09:34:35 +0200 Subject: [PATCH] Allow multi-line input via -D/stdin to process any number of GenParams. --- stable.py | 126 ++++++++++++++++++++++++++++++------------------------ 1 file changed, 71 insertions(+), 55 deletions(-) diff --git a/stable.py b/stable.py index 0bea6f9..cf9cbcc 100755 --- a/stable.py +++ b/stable.py @@ -3,14 +3,15 @@ from sys import argv, exit as sys_exit, stdin from os.path import dirname, basename, splitext, join as path_join, exists from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter from random import randint -from stable.gen_params import GenParams +from stable.gen_params import GenParams DEFAULT_SCHEDULER = 'EulerAncestralDiscreteScheduler' def save_path(count: int) -> str: - filename_count = f'_{count:08}' if args.quantity > 1 else '' + n_iterations = len(args.gen_paramses) * args.quantity + filename_count = f'_{count:08}' if n_iterations > 1 else '' return path_join(dir_path, f'{filename_sans_ext}{filename_count}.{ext}') @@ -53,14 +54,8 @@ def parse_args(): parser.add_argument('-L', '--list-schedulers', action='store_true', help='list options for -S available with chosen model') - # Re-seed defaults from stdin, if -D. + # Prepare validation parsed_args = parser.parse_args() - if parsed_args.defaults_from_stdin: - gen_params = GenParams.from_str(stdin.read()) - parser.set_defaults(**gen_params.as_dict) - parsed_args = parser.parse_args() - - # Validate input. prefix = f'{argv[0]}: error: ' if parsed_args.list_schedulers: required = {'model': 'm'} @@ -69,60 +64,81 @@ def parse_args(): required = {'output': 'o', 'prompt': 'p', 'model': 'm'} prefix += 'unless calling with -H/--help or --list-schedulers, ' prefix += 'requiring ' - suffixes = [] - for k, v in required.items(): - if not getattr(parsed_args, k): - suffixes += [f'-{v}/--{k}'] - if suffixes: - parser.print_usage() - print(f'{prefix}{", ".join(suffixes)}') - sys_exit(1) - # Prepare generation parameters, model path, etc. - if not parsed_args.models_dir: - parsed_args.models_dir = dirname(parsed_args.model) - parsed_args.model = basename(parsed_args.model) - parsed_args.gen_params = GenParams( - parsed_args.seed, parsed_args.guidance, parsed_args.n_steps, - parsed_args.height, parsed_args.width, parsed_args.scheduler, - parsed_args.model, parsed_args.prompt) - parsed_args.model_path = path_join(parsed_args.models_dir, - parsed_args.gen_params.model_name) + # Re-seed defaults from stdin, if -D. + temp_gen_paramses = [None] + if parsed_args.defaults_from_stdin: + temp_gen_paramses = [] + for line in stdin.readlines(): + temp_gen_paramses += [GenParams.from_str(line)] + parsed_args.models = [] + parsed_args.gen_paramses = [] + for gen_params in temp_gen_paramses: + if gen_params: + parser.set_defaults(**gen_params.as_dict) + new_parsed_args = parser.parse_args() + + # Validate input. + suffixes = [] + for k, v in required.items(): + if not getattr(new_parsed_args, k): + suffixes += [f'-{v}/--{k}'] + if suffixes: + parser.print_usage() + print(f'{prefix}{", ".join(suffixes)}') + sys_exit(1) + + # Prepare generation parameters, model path, etc. + if not parsed_args.models_dir: + parsed_args.models_dir = dirname(new_parsed_args.model) + parsed_args.models += [basename(new_parsed_args.model)] + parsed_args.gen_paramses += [GenParams( + new_parsed_args.seed, new_parsed_args.guidance, + new_parsed_args.n_steps, new_parsed_args.height, + new_parsed_args.width, new_parsed_args.scheduler, + new_parsed_args.model, new_parsed_args.prompt)] return parsed_args # guard against overwriting args = parse_args() -dir_path = dirname(args.output) -filename = basename(args.output) -filename_sans_ext, ext = splitext(filename) -ext = ext[1:] if ext else 'png' -for n in range(args.quantity): - path = save_path(n) - if exists(path): - raise Exception(f'Would overwrite file: {path}') +if not args.list_schedulers: + dir_path = dirname(args.output) + filename = basename(args.output) + filename_sans_ext, ext = splitext(filename) + ext = ext[1:] if ext else 'png' + for n in range(args.quantity * len(args.gen_paramses)): + path = save_path(n) + if exists(path): + raise Exception(f'Would overwrite file: {path}') # pylint: disable=wrong-import-position from stable.core import ImageMaker # noqa: E402 +old_model_path = '' +maker = None +for i, model_name in enumerate(args.models): + new_model_path = path_join(args.models_dir, model_name) + if new_model_path != old_model_path: + maker = ImageMaker(new_model_path) + old_model_path = new_model_path -# only list available schedulers, if -L -if args.list_schedulers: - maker = ImageMaker(args.model_path) - print(f'AVAILABLE SCHEDULERS FOR MODEL {args.gen_params.model_name}:\n') - for name in maker.compatible_schedulers: - print(name) - sys_exit(0) + # only list available schedulers, if -L + if args.list_schedulers: + print(f'AVAILABLE SCHEDULERS FOR MODEL {new_model_path}:\n') + for name in maker.compatible_schedulers: + print(name) + continue -# otherwise generate pictures -maker = ImageMaker(args.model_path) -start_seed = args.gen_params.seed -start_seed = start_seed if start_seed != 0 else randint(-(2**31-1), 2**31) -seed_corrector = 0 -for n in range(args.quantity): - if 0 == start_seed + n + seed_corrector: - seed_corrector += 1 - args.gen_params.seed = start_seed + n + seed_corrector - path = save_path(n) - maker.set_gen_params(args.gen_params) - print(f'GENERATING: {path}; {args.gen_params.to_str}') - maker.gen_image_to(path) + # otherwise generate pictures + gen_params = args.gen_paramses[i] + start_seed = gen_params.seed + start_seed = start_seed if start_seed != 0 else randint(-(2**31-1), 2**31) + seed_corrector = 0 + for n in range(args.quantity): + if 0 == start_seed + n + seed_corrector: + seed_corrector += 1 + gen_params.seed = start_seed + n + seed_corrector + path = save_path(i*args.quantity + n) + maker.set_gen_params(gen_params) + print(f'GENERATING: {path}; {gen_params.to_str}') + maker.gen_image_to(path) -- 2.30.2