From cd498b3b557535b9a9f533f9d0baf5139111a9e7 Mon Sep 17 00:00:00 2001
From: Christian Heller <c.heller@plomlompom.de>
Date: Sun, 18 Aug 2024 23:46:15 +0200
Subject: [PATCH] Get rid of default assumption for model, and some minor
 refactoring.

---
 stable.py      | 16 ++++------------
 stable/core.py |  5 +++++
 2 files changed, 9 insertions(+), 12 deletions(-)

diff --git a/stable.py b/stable.py
index 49bb1d1..641ef48 100755
--- a/stable.py
+++ b/stable.py
@@ -4,9 +4,7 @@ from argparse import ArgumentParser
 from random import randint
 from exiftool import ExifToolHelper  # type: ignore
 from torch import Generator
-from stable.core import init_pipeline
-
-DEFAULT_MODEL = './v1-5-pruned-emaonly.safetensors'
+from stable.core import init_pipeline, make_metadata
 
 
 def save_path(count: int) -> str:
@@ -14,23 +12,17 @@ def save_path(count: int) -> str:
     return path_join(dir_path, f'{filename_sans_ext}{filename_count}.{ext}')
 
 
-def make_metadata(seed, guidance, height, width, model, prompt):
-    return f'SEED: {seed}; GUIDANCE: {guidance}; HEIGHT: {height}; ' +\
-            f'WIDTH: {width}; MODEL: {model}; PROMPT: {prompt}'
-
-
 def parse_args():
     parser = ArgumentParser(add_help=False)
-    parser.add_argument('-p', '--prompt', required=True)
-    parser.add_argument('-o', '--output', required=True)
+    parser.add_argument('-p', '--prompt', type=str, required=True)
+    parser.add_argument('-o', '--output', type=str, required=True)
+    parser.add_argument('-m', '--model', type=str, required=True)
     parser.add_argument('-r', '--randomness_seed', default=1, type=int,
                         help='default: 1; if set 0, chosen randomnly')
     parser.add_argument('-g', '--guidance', default=7.5, type=float,
                         help='default: 7.5')
     parser.add_argument('-s', '--steps', default=15, type=int,
                         help='default: 15')
-    parser.add_argument('-m', '--model', default=DEFAULT_MODEL, type=str,
-                        help=f'default: {DEFAULT_MODEL}')
     parser.add_argument('-h', '--height', default=512, type=int,
                         help='default: 512')
     parser.add_argument('-w', '--width', default=512, type=int,
diff --git a/stable/core.py b/stable/core.py
index 93cc897..d7f9d23 100644
--- a/stable/core.py
+++ b/stable/core.py
@@ -29,3 +29,8 @@ def init_pipeline(model):
     pipe.to('cuda')
     print('PIPELINE READY\n')
     return pipe
+
+
+def make_metadata(seed, guidance, height, width, model, prompt):
+    return f'SEED: {seed}; GUIDANCE: {guidance}; HEIGHT: {height}; ' +\
+            f'WIDTH: {width}; MODEL: {model}; PROMPT: {prompt}'
-- 
2.30.2