From: Christian Heller Date: Thu, 22 Aug 2024 13:16:51 +0000 (+0200) Subject: Minor code style fixes. X-Git-Url: https://plomlompom.com/repos/te"st.html?a=commitdiff_plain;h=c8edb703912f92d3a3a12748c2d4d2b9a8cb8138;p=stable_plom Minor code style fixes. --- diff --git a/browser.py b/browser.py index cf59764..c2ab5b8 100755 --- a/browser.py +++ b/browser.py @@ -2,11 +2,11 @@ from os import scandir from os.path import splitext from exiftool import ExifToolHelper # type: ignore -from stable.gen_params import GenParams -import gi +import gi # type: ignore gi.require_version('Gtk', '4.0') # pylint: disable=wrong-import-position -from gi.repository import Gtk # noqa: E402 +from gi.repository import Gtk # type: ignore # noqa: E402 +from stable.gen_params import GenParams # noqa: E402 class Window(Gtk.ApplicationWindow): @@ -44,7 +44,7 @@ class Window(Gtk.ApplicationWindow): self.current_index = -1 self.newest(None) - def _reload(self, set_max = False): + def _reload(self, set_max=False): self.entries = [e for e in scandir('.') if e.is_file() and splitext(e)[1] in {'.png', '.jpg', '.jpeg'}] diff --git a/stable.py b/stable.py index cf9cbcc..71320d2 100755 --- a/stable.py +++ b/stable.py @@ -7,12 +7,8 @@ from stable.gen_params import GenParams DEFAULT_SCHEDULER = 'EulerAncestralDiscreteScheduler' - - -def save_path(count: int) -> str: - n_iterations = len(args.gen_paramses) * args.quantity - filename_count = f'_{count:08}' if n_iterations > 1 else '' - return path_join(dir_path, f'{filename_sans_ext}{filename_count}.{ext}') +SEED_MIN = -(2**31-1) +SEED_MAX = 2**31 def parse_args(): @@ -65,7 +61,7 @@ def parse_args(): prefix += 'unless calling with -H/--help or --list-schedulers, ' prefix += 'requiring ' - # Re-seed defaults from stdin, if -D. + # Re-seed defaults from stdin, if -D. temp_gen_paramses = [None] if parsed_args.defaults_from_stdin: temp_gen_paramses = [] @@ -73,9 +69,9 @@ def parse_args(): temp_gen_paramses += [GenParams.from_str(line)] parsed_args.models = [] parsed_args.gen_paramses = [] - for gen_params in temp_gen_paramses: - if gen_params: - parser.set_defaults(**gen_params.as_dict) + for gen_params_ in temp_gen_paramses: + if gen_params_: + parser.set_defaults(**gen_params_.as_dict) new_parsed_args = parser.parse_args() # Validate input. @@ -100,45 +96,58 @@ def parse_args(): return parsed_args -# guard against overwriting -args = parse_args() -if not args.list_schedulers: - dir_path = dirname(args.output) - filename = basename(args.output) - filename_sans_ext, ext = splitext(filename) - ext = ext[1:] if ext else 'png' - for n in range(args.quantity * len(args.gen_paramses)): - path = save_path(n) - if exists(path): - raise Exception(f'Would overwrite file: {path}') - -# pylint: disable=wrong-import-position -from stable.core import ImageMaker # noqa: E402 -old_model_path = '' -maker = None -for i, model_name in enumerate(args.models): - new_model_path = path_join(args.models_dir, model_name) - if new_model_path != old_model_path: - maker = ImageMaker(new_model_path) - old_model_path = new_model_path - - # only list available schedulers, if -L - if args.list_schedulers: - print(f'AVAILABLE SCHEDULERS FOR MODEL {new_model_path}:\n') - for name in maker.compatible_schedulers: - print(name) - continue - - # otherwise generate pictures - gen_params = args.gen_paramses[i] - start_seed = gen_params.seed - start_seed = start_seed if start_seed != 0 else randint(-(2**31-1), 2**31) - seed_corrector = 0 - for n in range(args.quantity): - if 0 == start_seed + n + seed_corrector: - seed_corrector += 1 - gen_params.seed = start_seed + n + seed_corrector - path = save_path(i*args.quantity + n) - maker.set_gen_params(gen_params) - print(f'GENERATING: {path}; {gen_params.to_str}') - maker.gen_image_to(path) +def run(): + + def save_path(count: int) -> str: + n_iterations = len(args.gen_paramses) + filename_count = f'_{count:08}' if n_iterations > 1 else '' + # pylint: disable=possibly-used-before-assignment + return path_join(dir_path, + f'{filename_sans_ext}{filename_count}.{ext}') + + # guard against overwriting + args = parse_args() + if not args.list_schedulers: + dir_path = dirname(args.output) + filename = basename(args.output) + filename_sans_ext, ext = splitext(filename) + ext = ext[1:] if ext else 'png' + for n in range(args.quantity * len(args.gen_paramses)): + path = save_path(n) + if exists(path): + raise Exception(f'Would overwrite file: {path}') + + # pylint: disable=wrong-import-position, import-outside-toplevel + from stable.core import ImageMaker # noqa: E402 + old_model_path = '' + maker = None + for i, model_name in enumerate(args.models): + new_model_path = path_join(args.models_dir, model_name) + if new_model_path != old_model_path: + maker = ImageMaker(new_model_path) + old_model_path = new_model_path + + # only list available schedulers, if -L + if args.list_schedulers: + print(f'AVAILABLE SCHEDULERS FOR MODEL {new_model_path}:\n') + for name in maker.compatible_schedulers: + print(name) + continue + + # otherwise generate pictures + gen_params = args.gen_paramses[i] + start_seed = gen_params.seed + start_seed = start_seed if start_seed != 0 else randint(SEED_MIN, + SEED_MAX) + seed_corrector = 0 + for n in range(args.quantity): + if 0 == start_seed + n + seed_corrector: + seed_corrector += 1 + gen_params.seed = start_seed + n + seed_corrector + path = save_path(i*args.quantity + n) + maker.set_gen_params(gen_params) + print(f'GENERATING: {path}; {gen_params.to_str}') + maker.gen_image_to(path) + + +run() diff --git a/stable/core.py b/stable/core.py index f2792fc..5c506c7 100644 --- a/stable/core.py +++ b/stable/core.py @@ -1,11 +1,9 @@ from logging import (Formatter as LogFormatter, captureWarnings, Filter as LogFilter) -from os.path import basename from diffusers import StableDiffusionPipeline from diffusers.utils import logging from torch import Generator, float16 from exiftool import ExifToolHelper # type: ignore -from stable.gen_params import GenParams SAFETY_CHECKER_WARNING_PATTERN = 'You have disabled the safety checker' @@ -43,7 +41,7 @@ class ImageMaker: scheduler_selection = [s for s in self.pipe.scheduler.compatibles if s.__name__ == gen_params.scheduler_name] if not scheduler_selection: - raise Exception(f'unknown scheduler: {scheduler_name}') + raise Exception(f'unknown scheduler: {gen_params.scheduler_name}') self.pipe.scheduler = scheduler_selection[0].from_config( self.pipe.scheduler.config) self.gen_params = gen_params diff --git a/stable/gen_params.py b/stable/gen_params.py index e3ae86f..fc1eba1 100644 --- a/stable/gen_params.py +++ b/stable/gen_params.py @@ -1,6 +1,7 @@ class GenParams: - def __init__(self, seed, guidance, n_steps, height, width, scheduler, model, prompt): + def __init__(self, seed, guidance, n_steps, height, width, scheduler, + model, prompt): if '; ' in model: raise Exception('illegal model filename (must not contain "; ")') self.seed = seed @@ -9,7 +10,7 @@ class GenParams: self.height, self.width = height, width self.scheduler_name = scheduler self.model_name = model - self.prompt = prompt + self.prompt = prompt @property def to_str(self): @@ -27,7 +28,7 @@ class GenParams: d = {} first_split = string.split('; PROMPT: ', maxsplit=1) if 2 == len(first_split): - d['prompt'] = first_split[1] + d['prompt'] = first_split[1] for section in first_split[0].split('; '): key, val = section.split(': ', maxsplit=1) key = key.lower()