From: Christian Heller Date: Mon, 19 Aug 2024 02:55:18 +0000 (+0200) Subject: Switch to faster torch_dtype=float16 for Stable Diffusion pipelines. X-Git-Url: https://plomlompom.com/repos/%7B%7Bprefix%7D%7D/static/gitweb.css?a=commitdiff_plain;h=3a540563d44ded8e18a65c85a1041b172374e9a0;p=stable_plom Switch to faster torch_dtype=float16 for Stable Diffusion pipelines. --- diff --git a/stable/core.py b/stable/core.py index 5612ab2..e85f517 100644 --- a/stable/core.py +++ b/stable/core.py @@ -3,7 +3,7 @@ from logging import (Formatter as LogFormatter, captureWarnings, from os.path import basename from diffusers import StableDiffusionPipeline from diffusers.utils import logging -from torch import Generator +from torch import Generator, float16 from exiftool import ExifToolHelper # type: ignore SAFETY_CHECKER_WARNING_PATTERN = 'You have disabled the safety checker' @@ -38,8 +38,8 @@ class ImageMaker: FilterOut(SAFETY_CHECKER_WARNING_PATTERN)) captureWarnings(True) logging.disable_progress_bar() - self.pipe = StableDiffusionPipeline.from_single_file( - model_path, local_files_only=True) + self.pipe = StableDiffusionPipeline.from_single_file(model_path, torch_dtype=float16, + local_files_only=True) self.pipe.to('cuda') self.generator = Generator() print('PIPELINE READY\n')