From 2bd60c9db8a8a82f0b8bbf0380135a1ea7060259 Mon Sep 17 00:00:00 2001 From: Christian Heller Date: Mon, 19 Aug 2024 05:30:08 +0200 Subject: [PATCH] Add scheduler selection. --- stable.py | 49 +++++++++++++++++++++++++++++++++---------------- stable/core.py | 18 +++++++++++++++--- 2 files changed, 48 insertions(+), 19 deletions(-) diff --git a/stable.py b/stable.py index 7a31a85..df335e8 100755 --- a/stable.py +++ b/stable.py @@ -1,4 +1,5 @@ #!/usr/bin/env python3 +from sys import exit as sys_exit from os.path import dirname, basename, splitext, join as path_join, exists from argparse import ArgumentParser from random import randint @@ -6,37 +7,53 @@ from stable.core import ImageMaker def save_path(count: int) -> str: - filename_count = f'_{count:08}' if args.number > 1 else '' + filename_count = f'_{count:08}' if args.quantity > 1 else '' return path_join(dir_path, f'{filename_sans_ext}{filename_count}.{ext}') def parse_args(): parser = ArgumentParser(add_help=False) - parser.add_argument('-p', '--prompt', type=str, required=True) - parser.add_argument('-o', '--output', type=str, required=True) - parser.add_argument('-m', '--model', type=str, required=True) - parser.add_argument('-r', '--randomness_seed', default=1, type=int, - help='default: 1; if set 0, chosen randomnly') + parser.add_argument('-m', '--model', required=True) + parser.add_argument('-o', '--output') + parser.add_argument('-p', '--prompt') + parser.add_argument('-H', '--help', action='help') + parser.add_argument('-S', '--list_schedulers', action='store_true') parser.add_argument('-g', '--guidance', default=7.5, type=float, help='default: 7.5') - parser.add_argument('-s', '--steps', default=15, type=int, - help='default: 15') parser.add_argument('-h', '--height', default=512, type=int, help='default: 512') + parser.add_argument('-n', '--n_steps', default=15, type=int, + help='default: 15') + parser.add_argument('-q', '--quantity', default=1, type=int, + help='default: 1') + parser.add_argument('-r', '--randomness_seed', default=1, type=int, + help='default: 1; if set 0, chosen randomnly') + parser.add_argument('-s', '--scheduler', default='PNDMScheduler', + help='default: PNDMScheduler') parser.add_argument('-w', '--width', default=512, type=int, help='default: 512') - parser.add_argument('-n', '--number', default=1, type=int, - help='default: 1') - parser.add_argument('-H', '--help', action='help') - return parser.parse_args() + parsed_args = parser.parse_args() + if not parsed_args.list_schedulers: + if not parsed_args.output: + raise Exception('Unless -H or -S, need --output set.') + if not parsed_args.prompt: + raise Exception('Unless -H or -S, need --prompt set.') + return parsed_args args = parse_args() +if args.list_schedulers: + maker = ImageMaker(args.model) + print(f'AVAILABLE SCHEDULERS FOR MODEL {args.model}:\n') + for name in maker.compatible_schedulers: + print(name) + sys_exit(0) + dir_path = dirname(args.output) filename = basename(args.output) filename_sans_ext, ext = splitext(filename) ext = ext[1:] if ext else 'png' -for n in range(args.number): +for n in range(args.quantity): path = save_path(n) if exists(path): raise Exception(f'Would overwrite file: {path}') @@ -44,10 +61,10 @@ for n in range(args.number): maker = ImageMaker(args.model) start_seed = args.randomness_seed start_seed = start_seed if start_seed != 0 else randint(-(2**31-1), 2**31) -for n in range(args.number): +for n in range(args.quantity): nth_seed = start_seed + n path = save_path(n) - maker.set_gen_params(args.prompt, nth_seed, args.guidance, - args.height, args.width, args.steps) + maker.set_gen_params(args.prompt, nth_seed, args.guidance, args.height, + args.width, args.n_steps, args.scheduler) print(f'GENERATING: {path}; {maker.gen_params_to_exif_comment}') maker.gen_image_to(path) diff --git a/stable/core.py b/stable/core.py index e85f517..32a9372 100644 --- a/stable/core.py +++ b/stable/core.py @@ -38,8 +38,8 @@ class ImageMaker: FilterOut(SAFETY_CHECKER_WARNING_PATTERN)) captureWarnings(True) logging.disable_progress_bar() - self.pipe = StableDiffusionPipeline.from_single_file(model_path, torch_dtype=float16, - local_files_only=True) + self.pipe = StableDiffusionPipeline.from_single_file( + model_path, torch_dtype=float16, local_files_only=True) self.pipe.to('cuda') self.generator = Generator() print('PIPELINE READY\n') @@ -48,7 +48,14 @@ class ImageMaker: self.seed = seed self.generator.manual_seed(seed) - def set_gen_params(self, prompt, seed, guidance, height, width, n_steps): + def set_gen_params(self, prompt, seed, guidance, height, width, n_steps, + scheduler_name): + scheduler_selection = [s for s in self.pipe.scheduler.compatibles + if s.__name__ == scheduler_name] + if not scheduler_selection: + raise Exception(f'unknown scheduler: {scheduler_name}') + self.pipe.scheduler = scheduler_selection[0].from_config( + self.pipe.scheduler.config) self.seed = seed self.prompt = prompt self.guidance = guidance @@ -56,6 +63,10 @@ class ImageMaker: self.width = width self.n_steps = n_steps + @property + def compatible_schedulers(self): + return [s.__name__ for s in self.pipe.scheduler.compatibles] + @property def gen_params_to_exif_comment(self): return f'SEED: {self.seed}; ' +\ @@ -64,6 +75,7 @@ class ImageMaker: f'HEIGHT: {self.height}; ' +\ f'WIDTH: {self.width}; ' +\ f'MODEL_FILE: {self.model_filename}; ' +\ + f'SCHEDULER: {self.pipe.scheduler.__class__.__name__}; ' +\ f'PROMPT: {self.prompt}' def gen_image_to(self, path): -- 2.30.2