home · contact · privacy
Add scheduler selection.
authorChristian Heller <c.heller@plomlompom.de>
Mon, 19 Aug 2024 03:30:08 +0000 (05:30 +0200)
committerChristian Heller <c.heller@plomlompom.de>
Mon, 19 Aug 2024 03:30:08 +0000 (05:30 +0200)
stable.py
stable/core.py

index 7a31a8541ed69d3f7962def6d6a3f436d23244aa..df335e8c2510a79fa423f7d171c46859dd50f389 100755 (executable)
--- a/stable.py
+++ b/stable.py
@@ -1,4 +1,5 @@
 #!/usr/bin/env python3
+from sys import exit as sys_exit
 from os.path import dirname, basename, splitext, join as path_join, exists
 from argparse import ArgumentParser
 from random import randint
@@ -6,37 +7,53 @@ from stable.core import ImageMaker
 
 
 def save_path(count: int) -> str:
-    filename_count = f'_{count:08}' if args.number > 1 else ''
+    filename_count = f'_{count:08}' if args.quantity > 1 else ''
     return path_join(dir_path, f'{filename_sans_ext}{filename_count}.{ext}')
 
 
 def parse_args():
     parser = ArgumentParser(add_help=False)
-    parser.add_argument('-p', '--prompt', type=str, required=True)
-    parser.add_argument('-o', '--output', type=str, required=True)
-    parser.add_argument('-m', '--model', type=str, required=True)
-    parser.add_argument('-r', '--randomness_seed', default=1, type=int,
-                        help='default: 1; if set 0, chosen randomnly')
+    parser.add_argument('-m', '--model', required=True)
+    parser.add_argument('-o', '--output')
+    parser.add_argument('-p', '--prompt')
+    parser.add_argument('-H', '--help', action='help')
+    parser.add_argument('-S', '--list_schedulers', action='store_true')
     parser.add_argument('-g', '--guidance', default=7.5, type=float,
                         help='default: 7.5')
-    parser.add_argument('-s', '--steps', default=15, type=int,
-                        help='default: 15')
     parser.add_argument('-h', '--height', default=512, type=int,
                         help='default: 512')
+    parser.add_argument('-n', '--n_steps', default=15, type=int,
+                        help='default: 15')
+    parser.add_argument('-q', '--quantity', default=1, type=int,
+                        help='default: 1')
+    parser.add_argument('-r', '--randomness_seed', default=1, type=int,
+                        help='default: 1; if set 0, chosen randomnly')
+    parser.add_argument('-s', '--scheduler', default='PNDMScheduler',
+                        help='default: PNDMScheduler')
     parser.add_argument('-w', '--width', default=512, type=int,
                         help='default: 512')
-    parser.add_argument('-n', '--number', default=1, type=int,
-                        help='default: 1')
-    parser.add_argument('-H', '--help', action='help')
-    return parser.parse_args()
+    parsed_args = parser.parse_args()
+    if not parsed_args.list_schedulers:
+        if not parsed_args.output:
+            raise Exception('Unless -H or -S, need --output set.')
+        if not parsed_args.prompt:
+            raise Exception('Unless -H or -S, need --prompt set.')
+    return parsed_args
 
 
 args = parse_args()
+if args.list_schedulers:
+    maker = ImageMaker(args.model)
+    print(f'AVAILABLE SCHEDULERS FOR MODEL {args.model}:\n')
+    for name in maker.compatible_schedulers:
+        print(name)
+    sys_exit(0)
+
 dir_path = dirname(args.output)
 filename = basename(args.output)
 filename_sans_ext, ext = splitext(filename)
 ext = ext[1:] if ext else 'png'
-for n in range(args.number):
+for n in range(args.quantity):
     path = save_path(n)
     if exists(path):
         raise Exception(f'Would overwrite file: {path}')
@@ -44,10 +61,10 @@ for n in range(args.number):
 maker = ImageMaker(args.model)
 start_seed = args.randomness_seed
 start_seed = start_seed if start_seed != 0 else randint(-(2**31-1), 2**31)
-for n in range(args.number):
+for n in range(args.quantity):
     nth_seed = start_seed + n
     path = save_path(n)
-    maker.set_gen_params(args.prompt, nth_seed, args.guidance,
-                         args.height, args.width, args.steps)
+    maker.set_gen_params(args.prompt, nth_seed, args.guidance, args.height,
+                         args.width, args.n_steps, args.scheduler)
     print(f'GENERATING: {path}; {maker.gen_params_to_exif_comment}')
     maker.gen_image_to(path)
index e85f5171c1ade8389a458e28f5268bfff422f6cb..32a93726f448b80c6ad441dc2944610db2284c24 100644 (file)
@@ -38,8 +38,8 @@ class ImageMaker:
                 FilterOut(SAFETY_CHECKER_WARNING_PATTERN))
         captureWarnings(True)
         logging.disable_progress_bar()
-        self.pipe = StableDiffusionPipeline.from_single_file(model_path, torch_dtype=float16,
-                                                             local_files_only=True)
+        self.pipe = StableDiffusionPipeline.from_single_file(
+                model_path, torch_dtype=float16, local_files_only=True)
         self.pipe.to('cuda')
         self.generator = Generator()
         print('PIPELINE READY\n')
@@ -48,7 +48,14 @@ class ImageMaker:
         self.seed = seed
         self.generator.manual_seed(seed)
 
-    def set_gen_params(self, prompt, seed, guidance, height, width, n_steps):
+    def set_gen_params(self, prompt, seed, guidance, height, width, n_steps,
+                       scheduler_name):
+        scheduler_selection = [s for s in self.pipe.scheduler.compatibles
+                               if s.__name__ == scheduler_name]
+        if not scheduler_selection:
+            raise Exception(f'unknown scheduler: {scheduler_name}')
+        self.pipe.scheduler = scheduler_selection[0].from_config(
+                self.pipe.scheduler.config)
         self.seed = seed
         self.prompt = prompt
         self.guidance = guidance
@@ -56,6 +63,10 @@ class ImageMaker:
         self.width = width
         self.n_steps = n_steps
 
+    @property
+    def compatible_schedulers(self):
+        return [s.__name__ for s in self.pipe.scheduler.compatibles]
+
     @property
     def gen_params_to_exif_comment(self):
         return f'SEED: {self.seed}; ' +\
@@ -64,6 +75,7 @@ class ImageMaker:
                 f'HEIGHT: {self.height}; ' +\
                 f'WIDTH: {self.width}; ' +\
                 f'MODEL_FILE: {self.model_filename}; ' +\
+                f'SCHEDULER: {self.pipe.scheduler.__class__.__name__}; ' +\
                 f'PROMPT: {self.prompt}'
 
     def gen_image_to(self, path):