home · contact · privacy
Refactor GenParams usage.
[stable_plom] / stable.py
1 #!/usr/bin/env python3
2 from sys import argv, exit as sys_exit, stdin
3 from os.path import dirname, basename, splitext, join as path_join, exists
4 from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
5 from random import randint
6 from stable.gen_params import GenParams, GEN_PARAMS
7
8
9 DEFAULT_SCHEDULER = 'EulerAncestralDiscreteScheduler'
10 SEED_MIN = -(2**31-1)
11 SEED_MAX = 2**31
12
13
14 def parse_args():
15     parser = ArgumentParser(add_help=False,
16                             formatter_class=ArgumentDefaultsHelpFormatter)
17     parser.add_argument('-m', '--model',
18                         help='model filename (-P will pre prefixed, but may '
19                         'also be full path on its own)')
20     parser.add_argument('-o', '--output',
21                         help='output filename or path; if -q > 1, will insert '
22                         'incremented counter number; if no image file '
23                         'extension included, defaults to .png')
24     parser.add_argument('-p', '--prompt',
25                         help='textual guidance to image generation')
26     parser.add_argument('-q', '--quantity', default=1, type=int,
27                         help='how many pictures to generate (with seed '
28                         'auto-incrementing)')
29     parser.add_argument('-h', '--height', default=512, type=int,
30                         help='target height in pixels')
31     parser.add_argument('-w', '--width', default=512, type=int,
32                         help='target width in pixels')
33     parser.add_argument('-g', '--guidance', default=7.5, type=float,
34                         help='adherence to text prompt')
35     parser.add_argument('-n', '--n_steps', default=15, type=int,
36                         help='number of denoising steps')
37     parser.add_argument('-s', '--seed', default=1, type=int,
38                         help='randomness seed; set 0 to choose randomly; '
39                         'increments if -q > 1')
40     parser.add_argument('-S', '--scheduler', default=DEFAULT_SCHEDULER,
41                         help='name of denoising scheduler; --list-schedulers '
42                         'prints available choices for chosen model')
43     parser.add_argument('-D', '--defaults_from_stdin', action='store_true',
44                         help='parse stdin for image generation parameters'
45                         '(e.g. from image file EXIF comment)')
46     parser.add_argument('-M', '--models_dir',
47                         help='directory path prefix to MODEL (for where'
48                         'it\'s provided as a mere filename, as with -D)')
49     parser.add_argument('-H', '--help', action='help')
50     parser.add_argument('-L', '--list-schedulers', action='store_true',
51                         help='list options for -S available with chosen model')
52
53     # Prepare validation
54     parsed_args = parser.parse_args()
55     prefix = f'{argv[0]}: error: '
56     if parsed_args.list_schedulers:
57         required = {'model': 'm'}
58         prefix += 'for --list-schedulers '
59     else:
60         required = {'output': 'o', 'prompt': 'p', 'model': 'm'}
61         prefix += 'unless calling with -H/--help or --list-schedulers, '
62     prefix += 'requiring '
63
64     # Re-seed defaults from stdin, if -D.
65     temp_gen_paramses = [None]
66     if parsed_args.defaults_from_stdin:
67         temp_gen_paramses = []
68         for line in stdin.readlines():
69             temp_gen_paramses += [GenParams.from_str(line)]
70     parsed_args.models = []
71     parsed_args.gen_paramses = []
72     for gen_params_ in temp_gen_paramses:
73         if gen_params_:
74             parser.set_defaults(**gen_params_.as_dict)
75         new_parsed_args = parser.parse_args()
76
77         # Validate input.
78         suffixes = []
79         for k, v in required.items():
80             if not getattr(new_parsed_args, k):
81                 suffixes += [f'-{v}/--{k}']
82         if suffixes:
83             parser.print_usage()
84             print(f'{prefix}{", ".join(suffixes)}')
85             sys_exit(1)
86
87         # Prepare generation parameters, model path, etc.
88         if not parsed_args.models_dir:
89             parsed_args.models_dir = dirname(new_parsed_args.model)
90         parsed_args.models += [basename(new_parsed_args.model)]
91         parsed_args.gen_paramses += [GenParams(
92                 **{k: getattr(new_parsed_args, k)
93                    for k in (k.lower() for k in GEN_PARAMS)})]
94     return parsed_args
95
96
97 def run():
98
99     def save_path(count: int) -> str:
100         n_iterations = len(args.gen_paramses) * args.quantity
101         filename_count = f'_{count:08}' if n_iterations > 1 else ''
102         # pylint: disable=possibly-used-before-assignment
103         return path_join(dir_path,
104                          f'{filename_sans_ext}{filename_count}.{ext}')
105
106     # guard against overwriting
107     args = parse_args()
108     if not args.list_schedulers:
109         dir_path = dirname(args.output)
110         filename = basename(args.output)
111         filename_sans_ext, ext = splitext(filename)
112         ext = ext[1:] if ext else 'png'
113         for n in range(args.quantity * len(args.gen_paramses)):
114             path = save_path(n)
115             if exists(path):
116                 raise Exception(f'Would overwrite file: {path}')
117
118     # pylint: disable=wrong-import-position, import-outside-toplevel
119     from stable.core import ImageMaker  # noqa: E402
120     old_model_path = ''
121     maker = None
122     for i, model_name in enumerate(args.models):
123         new_model_path = path_join(args.models_dir, model_name)
124         if new_model_path != old_model_path:
125             maker = ImageMaker(new_model_path)
126         old_model_path = new_model_path
127
128         # only list available schedulers, if -L
129         if args.list_schedulers:
130             print(f'AVAILABLE SCHEDULERS FOR MODEL {new_model_path}:\n')
131             for name in maker.compatible_schedulers:
132                 print(name)
133             continue
134
135         # otherwise generate pictures
136         gen_params = args.gen_paramses[i]
137         start_seed = gen_params.seed
138         start_seed = start_seed if start_seed != 0 else randint(SEED_MIN,
139                                                                 SEED_MAX)
140         seed_corrector = 0
141         for n in range(args.quantity):
142             if 0 == start_seed + n + seed_corrector:
143                 seed_corrector += 1
144             gen_params.seed = start_seed + n + seed_corrector
145             path = save_path(i*args.quantity + n)
146             maker.set_gen_params(gen_params)
147             print(f'GENERATING: {path}; {gen_params.to_str}')
148             maker.gen_image_to(path)
149
150
151 run()