from os.path import dirname, basename, splitext, join as path_join, exists
from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
from random import randint
-from stable.gen_params import GenParams
+from stable.gen_params import GenParams
DEFAULT_SCHEDULER = 'EulerAncestralDiscreteScheduler'
def save_path(count: int) -> str:
- filename_count = f'_{count:08}' if args.quantity > 1 else ''
+ n_iterations = len(args.gen_paramses) * args.quantity
+ filename_count = f'_{count:08}' if n_iterations > 1 else ''
return path_join(dir_path, f'{filename_sans_ext}{filename_count}.{ext}')
parser.add_argument('-L', '--list-schedulers', action='store_true',
help='list options for -S available with chosen model')
- # Re-seed defaults from stdin, if -D.
+ # Prepare validation
parsed_args = parser.parse_args()
- if parsed_args.defaults_from_stdin:
- gen_params = GenParams.from_str(stdin.read())
- parser.set_defaults(**gen_params.as_dict)
- parsed_args = parser.parse_args()
-
- # Validate input.
prefix = f'{argv[0]}: error: '
if parsed_args.list_schedulers:
required = {'model': 'm'}
required = {'output': 'o', 'prompt': 'p', 'model': 'm'}
prefix += 'unless calling with -H/--help or --list-schedulers, '
prefix += 'requiring '
- suffixes = []
- for k, v in required.items():
- if not getattr(parsed_args, k):
- suffixes += [f'-{v}/--{k}']
- if suffixes:
- parser.print_usage()
- print(f'{prefix}{", ".join(suffixes)}')
- sys_exit(1)
- # Prepare generation parameters, model path, etc.
- if not parsed_args.models_dir:
- parsed_args.models_dir = dirname(parsed_args.model)
- parsed_args.model = basename(parsed_args.model)
- parsed_args.gen_params = GenParams(
- parsed_args.seed, parsed_args.guidance, parsed_args.n_steps,
- parsed_args.height, parsed_args.width, parsed_args.scheduler,
- parsed_args.model, parsed_args.prompt)
- parsed_args.model_path = path_join(parsed_args.models_dir,
- parsed_args.gen_params.model_name)
+ # Re-seed defaults from stdin, if -D.
+ temp_gen_paramses = [None]
+ if parsed_args.defaults_from_stdin:
+ temp_gen_paramses = []
+ for line in stdin.readlines():
+ temp_gen_paramses += [GenParams.from_str(line)]
+ parsed_args.models = []
+ parsed_args.gen_paramses = []
+ for gen_params in temp_gen_paramses:
+ if gen_params:
+ parser.set_defaults(**gen_params.as_dict)
+ new_parsed_args = parser.parse_args()
+
+ # Validate input.
+ suffixes = []
+ for k, v in required.items():
+ if not getattr(new_parsed_args, k):
+ suffixes += [f'-{v}/--{k}']
+ if suffixes:
+ parser.print_usage()
+ print(f'{prefix}{", ".join(suffixes)}')
+ sys_exit(1)
+
+ # Prepare generation parameters, model path, etc.
+ if not parsed_args.models_dir:
+ parsed_args.models_dir = dirname(new_parsed_args.model)
+ parsed_args.models += [basename(new_parsed_args.model)]
+ parsed_args.gen_paramses += [GenParams(
+ new_parsed_args.seed, new_parsed_args.guidance,
+ new_parsed_args.n_steps, new_parsed_args.height,
+ new_parsed_args.width, new_parsed_args.scheduler,
+ new_parsed_args.model, new_parsed_args.prompt)]
return parsed_args
# guard against overwriting
args = parse_args()
-dir_path = dirname(args.output)
-filename = basename(args.output)
-filename_sans_ext, ext = splitext(filename)
-ext = ext[1:] if ext else 'png'
-for n in range(args.quantity):
- path = save_path(n)
- if exists(path):
- raise Exception(f'Would overwrite file: {path}')
+if not args.list_schedulers:
+ dir_path = dirname(args.output)
+ filename = basename(args.output)
+ filename_sans_ext, ext = splitext(filename)
+ ext = ext[1:] if ext else 'png'
+ for n in range(args.quantity * len(args.gen_paramses)):
+ path = save_path(n)
+ if exists(path):
+ raise Exception(f'Would overwrite file: {path}')
# pylint: disable=wrong-import-position
from stable.core import ImageMaker # noqa: E402
+old_model_path = ''
+maker = None
+for i, model_name in enumerate(args.models):
+ new_model_path = path_join(args.models_dir, model_name)
+ if new_model_path != old_model_path:
+ maker = ImageMaker(new_model_path)
+ old_model_path = new_model_path
-# only list available schedulers, if -L
-if args.list_schedulers:
- maker = ImageMaker(args.model_path)
- print(f'AVAILABLE SCHEDULERS FOR MODEL {args.gen_params.model_name}:\n')
- for name in maker.compatible_schedulers:
- print(name)
- sys_exit(0)
+ # only list available schedulers, if -L
+ if args.list_schedulers:
+ print(f'AVAILABLE SCHEDULERS FOR MODEL {new_model_path}:\n')
+ for name in maker.compatible_schedulers:
+ print(name)
+ continue
-# otherwise generate pictures
-maker = ImageMaker(args.model_path)
-start_seed = args.gen_params.seed
-start_seed = start_seed if start_seed != 0 else randint(-(2**31-1), 2**31)
-seed_corrector = 0
-for n in range(args.quantity):
- if 0 == start_seed + n + seed_corrector:
- seed_corrector += 1
- args.gen_params.seed = start_seed + n + seed_corrector
- path = save_path(n)
- maker.set_gen_params(args.gen_params)
- print(f'GENERATING: {path}; {args.gen_params.to_str}')
- maker.gen_image_to(path)
+ # otherwise generate pictures
+ gen_params = args.gen_paramses[i]
+ start_seed = gen_params.seed
+ start_seed = start_seed if start_seed != 0 else randint(-(2**31-1), 2**31)
+ seed_corrector = 0
+ for n in range(args.quantity):
+ if 0 == start_seed + n + seed_corrector:
+ seed_corrector += 1
+ gen_params.seed = start_seed + n + seed_corrector
+ path = save_path(i*args.quantity + n)
+ maker.set_gen_params(gen_params)
+ print(f'GENERATING: {path}; {gen_params.to_str}')
+ maker.gen_image_to(path)