home · contact · privacy
Protect against overwriting previous generates.
authorChristian Heller <c.heller@plomlompom.de>
Fri, 16 Aug 2024 05:54:41 +0000 (07:54 +0200)
committerChristian Heller <c.heller@plomlompom.de>
Fri, 16 Aug 2024 05:54:41 +0000 (07:54 +0200)
test.py

diff --git a/test.py b/test.py
index aeeb4766a5c457bd5819cf3f1d916b826e360600..ef64c446a98ff161c0e69a21f41b04f206b59b36 100755 (executable)
--- a/test.py
+++ b/test.py
@@ -5,7 +5,7 @@ from PIL import Image
 from exiftool import ExifToolHelper
 from torch import Generator
 from diffusers import StableDiffusionPipeline
-from os.path import dirname, basename, splitext, join as path_join
+from os.path import dirname, basename, splitext, join as path_join, exists
 
 DEFAULT_MODEL='./v1-5-pruned-emaonly.safetensors'
 
@@ -30,6 +30,15 @@ filename = basename(args.output)
 filename_sans_ext, ext = splitext(filename)
 ext = ext[1:] if ext else 'png'
 
+def save_path(n: int) -> str:
+    filename_count = f'_{n:08}' if args.number > 1 else ''
+    return path_join(dir_path, f'{filename_sans_ext}{filename_count}.{ext}')
+
+for n in range(args.number):
+    path = save_path(n)
+    if exists(path):
+        raise Exception(f'Would overwrite file: {path}')
+
 pipe = StableDiffusionPipeline.from_single_file(args.model)
 pipe.to('cuda')
 generator = Generator()
@@ -38,10 +47,9 @@ for n in range(args.number):
     seed = start_seed + n
     generator.manual_seed(seed)
     image = pipe(args.prompt, generator=generator, guidance_scale=args.guidance, num_inference_steps=args.steps, height=args.height, width=args.width).images[0]
-    filename_count = f'_{n:08}' if args.number > 1 else ''
-    filename = path_join(dir_path, f'{filename_sans_ext}{filename_count}.{ext}')
-    image.save(filename)
+    path = save_path(n)
+    image.save(path)
     metadata = f'seed: {seed}; guidance: {args.guidance}; height: {args.height}; width: {args.width}; model: {args.model}; prompt: {args.prompt}'
-    print(f'saved {filename} – metadata: {metadata}')
+    print(f'saved {path} – metadata: {metadata}')
     with ExifToolHelper() as et:
-        et.set_tags([filename], tags={'Comment': metadata}, params=['-overwrite_original'])
+        et.set_tags([path], tags={'Comment': metadata}, params=['-overwrite_original'])