From: Christian Heller Date: Sun, 18 Aug 2024 21:46:15 +0000 (+0200) Subject: Get rid of default assumption for model, and some minor refactoring. X-Git-Url: https://plomlompom.com/repos/%7B%7Bdb.prefix%7D%7D/%7B%7B%20web_path%20%7D%7D/static/git-favicon.png?a=commitdiff_plain;h=cd498b3b557535b9a9f533f9d0baf5139111a9e7;p=stable_plom Get rid of default assumption for model, and some minor refactoring. --- diff --git a/stable.py b/stable.py index 49bb1d1..641ef48 100755 --- a/stable.py +++ b/stable.py @@ -4,9 +4,7 @@ from argparse import ArgumentParser from random import randint from exiftool import ExifToolHelper # type: ignore from torch import Generator -from stable.core import init_pipeline - -DEFAULT_MODEL = './v1-5-pruned-emaonly.safetensors' +from stable.core import init_pipeline, make_metadata def save_path(count: int) -> str: @@ -14,23 +12,17 @@ def save_path(count: int) -> str: return path_join(dir_path, f'{filename_sans_ext}{filename_count}.{ext}') -def make_metadata(seed, guidance, height, width, model, prompt): - return f'SEED: {seed}; GUIDANCE: {guidance}; HEIGHT: {height}; ' +\ - f'WIDTH: {width}; MODEL: {model}; PROMPT: {prompt}' - - def parse_args(): parser = ArgumentParser(add_help=False) - parser.add_argument('-p', '--prompt', required=True) - parser.add_argument('-o', '--output', required=True) + parser.add_argument('-p', '--prompt', type=str, required=True) + parser.add_argument('-o', '--output', type=str, required=True) + parser.add_argument('-m', '--model', type=str, required=True) parser.add_argument('-r', '--randomness_seed', default=1, type=int, help='default: 1; if set 0, chosen randomnly') parser.add_argument('-g', '--guidance', default=7.5, type=float, help='default: 7.5') parser.add_argument('-s', '--steps', default=15, type=int, help='default: 15') - parser.add_argument('-m', '--model', default=DEFAULT_MODEL, type=str, - help=f'default: {DEFAULT_MODEL}') parser.add_argument('-h', '--height', default=512, type=int, help='default: 512') parser.add_argument('-w', '--width', default=512, type=int, diff --git a/stable/core.py b/stable/core.py index 93cc897..d7f9d23 100644 --- a/stable/core.py +++ b/stable/core.py @@ -29,3 +29,8 @@ def init_pipeline(model): pipe.to('cuda') print('PIPELINE READY\n') return pipe + + +def make_metadata(seed, guidance, height, width, model, prompt): + return f'SEED: {seed}; GUIDANCE: {guidance}; HEIGHT: {height}; ' +\ + f'WIDTH: {width}; MODEL: {model}; PROMPT: {prompt}'