From: Christian Heller Date: Fri, 16 Aug 2024 05:54:41 +0000 (+0200) Subject: Protect against overwriting previous generates. X-Git-Url: https://plomlompom.com/repos/%7B%7Bprefix%7D%7D/static/%7B%7Bdb.prefix%7D%7D/test?a=commitdiff_plain;h=0081eaa508bee8a8b6c04d5c5cf3746d01c67b0b;p=stable_plom Protect against overwriting previous generates. --- diff --git a/test.py b/test.py index aeeb476..ef64c44 100755 --- a/test.py +++ b/test.py @@ -5,7 +5,7 @@ from PIL import Image from exiftool import ExifToolHelper from torch import Generator from diffusers import StableDiffusionPipeline -from os.path import dirname, basename, splitext, join as path_join +from os.path import dirname, basename, splitext, join as path_join, exists DEFAULT_MODEL='./v1-5-pruned-emaonly.safetensors' @@ -30,6 +30,15 @@ filename = basename(args.output) filename_sans_ext, ext = splitext(filename) ext = ext[1:] if ext else 'png' +def save_path(n: int) -> str: + filename_count = f'_{n:08}' if args.number > 1 else '' + return path_join(dir_path, f'{filename_sans_ext}{filename_count}.{ext}') + +for n in range(args.number): + path = save_path(n) + if exists(path): + raise Exception(f'Would overwrite file: {path}') + pipe = StableDiffusionPipeline.from_single_file(args.model) pipe.to('cuda') generator = Generator() @@ -38,10 +47,9 @@ for n in range(args.number): seed = start_seed + n generator.manual_seed(seed) image = pipe(args.prompt, generator=generator, guidance_scale=args.guidance, num_inference_steps=args.steps, height=args.height, width=args.width).images[0] - filename_count = f'_{n:08}' if args.number > 1 else '' - filename = path_join(dir_path, f'{filename_sans_ext}{filename_count}.{ext}') - image.save(filename) + path = save_path(n) + image.save(path) metadata = f'seed: {seed}; guidance: {args.guidance}; height: {args.height}; width: {args.width}; model: {args.model}; prompt: {args.prompt}' - print(f'saved {filename} – metadata: {metadata}') + print(f'saved {path} – metadata: {metadata}') with ExifToolHelper() as et: - et.set_tags([filename], tags={'Comment': metadata}, params=['-overwrite_original']) + et.set_tags([path], tags={'Comment': metadata}, params=['-overwrite_original'])