From: Christian Heller Date: Sat, 24 Aug 2024 03:05:23 +0000 (+0200) Subject: Refactor GenParams usage. X-Git-Url: https://plomlompom.com/repos/%27%29;%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20chunks.push%28escapeHTML%28span%5B2%5D%29%29;%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20%20chunks.push%28%27?a=commitdiff_plain;h=7dcaf2bf4082113add30cccdb3cae0107f6e8ba4;p=stable_plom Refactor GenParams usage. --- diff --git a/browser.py b/browser.py index 6f13d7e..26b67a4 100755 --- a/browser.py +++ b/browser.py @@ -1,16 +1,19 @@ #!/usr/bin/env python3 -from exiftool import ExifToolHelper # type: ignore from json import dump as json_dump, load as json_load from os.path import exists as path_exists, join as path_join, abspath +from exiftool import ExifToolHelper # type: ignore import gi # type: ignore gi.require_version('Gtk', '4.0') gi.require_version('Gio', '2.0') # pylint: disable=wrong-import-position from gi.repository import Gtk, Gio, GObject # type: ignore # noqa: E402 -from stable.gen_params import GenParams # noqa: E402 +# pylint: disable=no-name-in-module +from stable.gen_params import (GenParams, # noqa: E402 + GEN_PARAMS, GEN_PARAMS_STR) # noqa: E402 -IMG_DIR='.' +IMG_DIR = '.' +CACHE_PATH = 'cache.json' class FileItem(GObject.GObject): @@ -20,14 +23,11 @@ class FileItem(GObject.GObject): self.name = info.get_name() self.last_mod_time = info.get_modification_date_time().format_iso8601() self.full_path = path_join(path, self.name) - self.seed = '' - self.guidance = 0.0 - self.n_steps = 0 - self.height = 0 - self.width = 0 - self.scheduler = '' - self.model = '' - self.prompt = '' + for param_name in GEN_PARAMS: + if param_name in GEN_PARAMS_STR: + setattr(self, param_name.lower(), '') + else: + setattr(self, param_name.lower(), 0) if self.full_path in cache: if self.last_mod_time in cache[self.full_path]: cached = cache[self.full_path][self.last_mod_time] @@ -35,17 +35,14 @@ class FileItem(GObject.GObject): setattr(self, k, cached[k]) def set_metadata(self, et, cache): - self.metadata = 'no SD comment' for d in et.get_tags([self.name], ['Comment']): for k, v in d.items(): if k.endswith('Comment'): - self.metadata = '' gen_params = GenParams.from_str(v) for k, v_ in gen_params.as_dict.items(): setattr(self, k, v_) cached = {} - for k in ('seed', 'guidance', 'n_steps', 'height', - 'width', 'scheduler', 'model', 'prompt'): + for k in (k.lower() for k in GEN_PARAMS): cached[k] = getattr(self, k) cache[self.full_path] = {self.last_mod_time: cached} @@ -94,10 +91,10 @@ class Window(Gtk.ApplicationWindow): box_outer.append(self.viewer) self.props.child = box_outer - if not path_exists('cache.json'): - with open('cache.json', 'w') as f: + if not path_exists(CACHE_PATH): + with open(CACHE_PATH, 'w', encoding='utf8') as f: json_dump({}, f) - with open('cache.json', 'r') as f: + with open(CACHE_PATH, 'r', encoding='utf8') as f: cache = json_load(f) self.max_index = 0 self.item = None @@ -115,16 +112,17 @@ class Window(Gtk.ApplicationWindow): item = FileItem(img_dir_absolute, info, cache) self.unsorted += [item] with ExifToolHelper() as et: - for item in [item for item in self.unsorted if item.seed == '']: + for item in [item for item in self.unsorted if '' == item.model]: item.set_metadata(et, cache) self.max_index = len(self.unsorted) - 1 self.sort('last_mod_time') - with open('cache.json', 'w') as f: + with open(CACHE_PATH, 'w', encoding='utf8') as f: json_dump(cache, f) def sort(self, attr_name): self.list_store.remove_all() - for file_item in sorted(self.unsorted, key=lambda i: getattr(i, attr_name)): + for file_item in sorted(self.unsorted, + key=lambda i: getattr(i, attr_name)): self.list_store.append(file_item) self.update_selected() @@ -137,12 +135,10 @@ class Window(Gtk.ApplicationWindow): def reload(self): self.viewer.remove(self.viewer.get_last_child()) if self.item: - metadata = f'{self.item.full_path}: PROMPT: {self.item.prompt}\n' +\ - f'SEED: {self.item.seed} / MODEL: {self.item.model} / ' +\ - f'SCHEDULER: {self.item.scheduler}\nGUIDANCE: {self.item.guidance}\n' +\ - f'N_STEPS: {self.item.n_steps}\nHEIGHT: {self.item.height} / ' +\ - f'WIDTH: {self.item.width}' - self.metadata.props.label = metadata + params_strs = [f'{k}: ' + str(getattr(self.item, k.lower())) + for k in GEN_PARAMS] + self.metadata.props.label = '\n'.join([self.item.full_path] + + params_strs) pic = Gtk.Picture.new_for_filename(self.item.name) self.viewer.append(pic) else: diff --git a/stable.py b/stable.py index 241783f..64ee704 100755 --- a/stable.py +++ b/stable.py @@ -3,7 +3,7 @@ from sys import argv, exit as sys_exit, stdin from os.path import dirname, basename, splitext, join as path_join, exists from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter from random import randint -from stable.gen_params import GenParams +from stable.gen_params import GenParams, GEN_PARAMS DEFAULT_SCHEDULER = 'EulerAncestralDiscreteScheduler' @@ -89,10 +89,8 @@ def parse_args(): parsed_args.models_dir = dirname(new_parsed_args.model) parsed_args.models += [basename(new_parsed_args.model)] parsed_args.gen_paramses += [GenParams( - new_parsed_args.seed, new_parsed_args.guidance, - new_parsed_args.n_steps, new_parsed_args.height, - new_parsed_args.width, new_parsed_args.scheduler, - new_parsed_args.model, new_parsed_args.prompt)] + **{k: getattr(new_parsed_args, k) + for k in (k.lower() for k in GEN_PARAMS)})] return parsed_args diff --git a/stable/core.py b/stable/core.py index 5c506c7..02a44ed 100644 --- a/stable/core.py +++ b/stable/core.py @@ -39,9 +39,9 @@ class ImageMaker: def set_gen_params(self, gen_params): scheduler_selection = [s for s in self.pipe.scheduler.compatibles - if s.__name__ == gen_params.scheduler_name] + if s.__name__ == gen_params.scheduler] if not scheduler_selection: - raise Exception(f'unknown scheduler: {gen_params.scheduler_name}') + raise Exception(f'unknown scheduler: {gen_params.scheduler}') self.pipe.scheduler = scheduler_selection[0].from_config( self.pipe.scheduler.config) self.gen_params = gen_params diff --git a/stable/gen_params.py b/stable/gen_params.py index fc1eba1..160735c 100644 --- a/stable/gen_params.py +++ b/stable/gen_params.py @@ -1,27 +1,26 @@ +GEN_PARAMS_STR = ('MODEL', 'SCHEDULER', 'PROMPT') +GEN_PARAMS_INT = ('SEED', 'N_STEPS', 'HEIGHT', 'WIDTH') +GEN_PARAMS_FLOAT = ('GUIDANCE',) +GEN_PARAMS = GEN_PARAMS_STR + GEN_PARAMS_INT + GEN_PARAMS_FLOAT + + class GenParams: - def __init__(self, seed, guidance, n_steps, height, width, scheduler, - model, prompt): - if '; ' in model: + def __init__(self, **kwargs): + if '; ' in kwargs['model']: raise Exception('illegal model filename (must not contain "; ")') - self.seed = seed - self.guidance = guidance - self.n_steps = n_steps - self.height, self.width = height, width - self.scheduler_name = scheduler - self.model_name = model - self.prompt = prompt + for param_name in (n.lower() for n in GEN_PARAMS): + setattr(self, param_name, kwargs[param_name]) @property def to_str(self): - return f'SEED: {self.seed}; ' +\ - f'GUIDANCE: {self.guidance}; ' +\ - f'N_STEPS: {self.n_steps}; ' +\ - f'HEIGHT: {self.height}; ' +\ - f'WIDTH: {self.width}; ' +\ - f'SCHEDULER: {self.scheduler_name}; ' +\ - f'MODEL: {self.model_name}; ' +\ - f'PROMPT: {self.prompt}' + s = '' + for param_name in GEN_PARAMS: + if 'PROMPT' == param_name: + continue # can create fewer parsing problems at the end + value = getattr(self, param_name.lower()) + s += f'{param_name}: {value}; ' + return f'{s}PROMPT: {self.prompt}' # pylint: disable=no-member @classmethod def from_str(cls, string): @@ -32,20 +31,13 @@ class GenParams: for section in first_split[0].split('; '): key, val = section.split(': ', maxsplit=1) key = key.lower() - if key in {'seed', 'height', 'width', 'n_steps'}: + if key.upper() in GEN_PARAMS_INT: val = int(val) - elif key in {'guidance'}: + elif key.upper() in GEN_PARAMS_FLOAT: val = float(val) d[key] = val return cls(**d) @property def as_dict(self): - d = {} - just_names = {'scheduler', 'model'} - for k in ['seed', 'guidance', 'n_steps', 'height', 'width', 'prompt']\ - + list(just_names): - v = getattr(self, f'{k}_name' if k in just_names else k) - if v is not None: - d[k] = v - return d + return {k: getattr(self, k) for k in (k.lower() for k in GEN_PARAMS)}