home · contact · privacy
Allow multi-line input via -D/stdin to process any number of GenParams.
authorChristian Heller <c.heller@plomlompom.de>
Thu, 22 Aug 2024 07:34:35 +0000 (09:34 +0200)
committerChristian Heller <c.heller@plomlompom.de>
Thu, 22 Aug 2024 07:34:35 +0000 (09:34 +0200)
stable.py

index 0bea6f90f9eec28de8b5992b729e8ba68bbe03d2..cf9cbcc50f7dc267526cc5c156dacb3fc6998cbd 100755 (executable)
--- a/stable.py
+++ b/stable.py
@@ -3,14 +3,15 @@ from sys import argv, exit as sys_exit, stdin
 from os.path import dirname, basename, splitext, join as path_join, exists
 from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
 from random import randint
-from stable.gen_params import GenParams 
+from stable.gen_params import GenParams
 
 
 DEFAULT_SCHEDULER = 'EulerAncestralDiscreteScheduler'
 
 
 def save_path(count: int) -> str:
-    filename_count = f'_{count:08}' if args.quantity > 1 else ''
+    n_iterations = len(args.gen_paramses) * args.quantity
+    filename_count = f'_{count:08}' if n_iterations > 1 else ''
     return path_join(dir_path, f'{filename_sans_ext}{filename_count}.{ext}')
 
 
@@ -53,14 +54,8 @@ def parse_args():
     parser.add_argument('-L', '--list-schedulers', action='store_true',
                         help='list options for -S available with chosen model')
 
-    # Re-seed defaults from stdin, if -D.
+    # Prepare validation
     parsed_args = parser.parse_args()
-    if parsed_args.defaults_from_stdin:
-        gen_params = GenParams.from_str(stdin.read())
-        parser.set_defaults(**gen_params.as_dict)
-    parsed_args = parser.parse_args()
-
-    # Validate input.
     prefix = f'{argv[0]}: error: '
     if parsed_args.list_schedulers:
         required = {'model': 'm'}
@@ -69,60 +64,81 @@ def parse_args():
         required = {'output': 'o', 'prompt': 'p', 'model': 'm'}
         prefix += 'unless calling with -H/--help or --list-schedulers, '
     prefix += 'requiring '
-    suffixes = []
-    for k, v in required.items():
-        if not getattr(parsed_args, k):
-            suffixes += [f'-{v}/--{k}']
-    if suffixes:
-        parser.print_usage()
-        print(f'{prefix}{", ".join(suffixes)}')
-        sys_exit(1)
 
-    # Prepare generation parameters, model path, etc.
-    if not parsed_args.models_dir:
-        parsed_args.models_dir = dirname(parsed_args.model)
-        parsed_args.model = basename(parsed_args.model)
-    parsed_args.gen_params = GenParams(
-            parsed_args.seed, parsed_args.guidance, parsed_args.n_steps,
-            parsed_args.height, parsed_args.width, parsed_args.scheduler,
-            parsed_args.model, parsed_args.prompt)
-    parsed_args.model_path = path_join(parsed_args.models_dir,
-                                       parsed_args.gen_params.model_name)
+     # Re-seed defaults from stdin, if -D.
+    temp_gen_paramses = [None]
+    if parsed_args.defaults_from_stdin:
+        temp_gen_paramses = []
+        for line in stdin.readlines():
+            temp_gen_paramses += [GenParams.from_str(line)]
+    parsed_args.models = []
+    parsed_args.gen_paramses = []
+    for gen_params in temp_gen_paramses:
+        if gen_params:
+            parser.set_defaults(**gen_params.as_dict)
+        new_parsed_args = parser.parse_args()
+
+        # Validate input.
+        suffixes = []
+        for k, v in required.items():
+            if not getattr(new_parsed_args, k):
+                suffixes += [f'-{v}/--{k}']
+        if suffixes:
+            parser.print_usage()
+            print(f'{prefix}{", ".join(suffixes)}')
+            sys_exit(1)
+
+        # Prepare generation parameters, model path, etc.
+        if not parsed_args.models_dir:
+            parsed_args.models_dir = dirname(new_parsed_args.model)
+        parsed_args.models += [basename(new_parsed_args.model)]
+        parsed_args.gen_paramses += [GenParams(
+                new_parsed_args.seed, new_parsed_args.guidance,
+                new_parsed_args.n_steps, new_parsed_args.height,
+                new_parsed_args.width, new_parsed_args.scheduler,
+                new_parsed_args.model, new_parsed_args.prompt)]
     return parsed_args
 
 
 # guard against overwriting
 args = parse_args()
-dir_path = dirname(args.output)
-filename = basename(args.output)
-filename_sans_ext, ext = splitext(filename)
-ext = ext[1:] if ext else 'png'
-for n in range(args.quantity):
-    path = save_path(n)
-    if exists(path):
-        raise Exception(f'Would overwrite file: {path}')
+if not args.list_schedulers:
+    dir_path = dirname(args.output)
+    filename = basename(args.output)
+    filename_sans_ext, ext = splitext(filename)
+    ext = ext[1:] if ext else 'png'
+    for n in range(args.quantity * len(args.gen_paramses)):
+        path = save_path(n)
+        if exists(path):
+            raise Exception(f'Would overwrite file: {path}')
 
 # pylint: disable=wrong-import-position
 from stable.core import ImageMaker  # noqa: E402
+old_model_path = ''
+maker = None
+for i, model_name in enumerate(args.models):
+    new_model_path = path_join(args.models_dir, model_name)
+    if new_model_path != old_model_path:
+        maker = ImageMaker(new_model_path)
+    old_model_path = new_model_path
 
-# only list available schedulers, if -L
-if args.list_schedulers:
-    maker = ImageMaker(args.model_path)
-    print(f'AVAILABLE SCHEDULERS FOR MODEL {args.gen_params.model_name}:\n')
-    for name in maker.compatible_schedulers:
-        print(name)
-    sys_exit(0)
+    # only list available schedulers, if -L
+    if args.list_schedulers:
+        print(f'AVAILABLE SCHEDULERS FOR MODEL {new_model_path}:\n')
+        for name in maker.compatible_schedulers:
+            print(name)
+        continue
 
-# otherwise generate pictures
-maker = ImageMaker(args.model_path)
-start_seed = args.gen_params.seed
-start_seed = start_seed if start_seed != 0 else randint(-(2**31-1), 2**31)
-seed_corrector = 0
-for n in range(args.quantity):
-    if 0 == start_seed + n + seed_corrector:
-        seed_corrector += 1
-    args.gen_params.seed = start_seed + n + seed_corrector
-    path = save_path(n)
-    maker.set_gen_params(args.gen_params)
-    print(f'GENERATING: {path}; {args.gen_params.to_str}')
-    maker.gen_image_to(path)
+    # otherwise generate pictures
+    gen_params = args.gen_paramses[i]
+    start_seed = gen_params.seed
+    start_seed = start_seed if start_seed != 0 else randint(-(2**31-1), 2**31)
+    seed_corrector = 0
+    for n in range(args.quantity):
+        if 0 == start_seed + n + seed_corrector:
+            seed_corrector += 1
+        gen_params.seed = start_seed + n + seed_corrector
+        path = save_path(i*args.quantity + n)
+        maker.set_gen_params(gen_params)
+        print(f'GENERATING: {path}; {gen_params.to_str}')
+        maker.gen_image_to(path)