From: Christian Heller <c.heller@plomlompom.de>
Date: Mon, 19 Aug 2024 02:55:18 +0000 (+0200)
Subject: Switch to faster torch_dtype=float16 for Stable Diffusion pipelines.
X-Git-Url: https://plomlompom.com/repos/%7B%7Bdb.prefix%7D%7D/static/%7B%7Bprefix%7D%7D/booking/condition?a=commitdiff_plain;h=3a540563d44ded8e18a65c85a1041b172374e9a0;p=stable_plom

Switch to faster torch_dtype=float16 for Stable Diffusion pipelines.
---

diff --git a/stable/core.py b/stable/core.py
index 5612ab2..e85f517 100644
--- a/stable/core.py
+++ b/stable/core.py
@@ -3,7 +3,7 @@ from logging import (Formatter as LogFormatter, captureWarnings,
 from os.path import basename
 from diffusers import StableDiffusionPipeline
 from diffusers.utils import logging
-from torch import Generator
+from torch import Generator, float16
 from exiftool import ExifToolHelper  # type: ignore
 
 SAFETY_CHECKER_WARNING_PATTERN = 'You have disabled the safety checker'
@@ -38,8 +38,8 @@ class ImageMaker:
                 FilterOut(SAFETY_CHECKER_WARNING_PATTERN))
         captureWarnings(True)
         logging.disable_progress_bar()
-        self.pipe = StableDiffusionPipeline.from_single_file(
-                model_path, local_files_only=True)
+        self.pipe = StableDiffusionPipeline.from_single_file(model_path, torch_dtype=float16,
+                                                             local_files_only=True)
         self.pipe.to('cuda')
         self.generator = Generator()
         print('PIPELINE READY\n')