88
99#include " stable-diffusion.h"
1010
11+ #define STB_IMAGE_IMPLEMENTATION
12+ #include " stb_image.h"
13+
1114#define STB_IMAGE_WRITE_IMPLEMENTATION
1215#define STB_IMAGE_WRITE_STATIC
1316#include " stb_image_write.h"
1417
1518#if defined(__APPLE__) && defined(__MACH__)
16- #include < sys/types.h>
1719#include < sys/sysctl.h>
20+ #include < sys/types.h>
1821#endif
1922
2023#if !defined(_WIN32)
2124#include < sys/ioctl.h>
2225#include < unistd.h>
2326#endif
2427
28+ #define TXT2IMG " txt2img"
29+ #define IMG2IMG " img2img"
30+
2531// get_num_physical_cores is copy from
2632// https://github.com/ggerganov/llama.cpp/blob/master/examples/common.cpp
2733// LICENSE: https://github.com/ggerganov/llama.cpp/blob/master/LICENSE
@@ -63,30 +69,36 @@ int32_t get_num_physical_cores() {
6369
6470struct Option {
6571 int n_threads = -1 ;
72+ std::string mode = TXT2IMG;
6673 std::string model_path;
6774 std::string output_path = " output.png" ;
75+ std::string init_img;
6876 std::string prompt;
6977 std::string negative_prompt;
7078 float cfg_scale = 7 .0f ;
7179 int w = 512 ;
7280 int h = 512 ;
7381 SampleMethod sample_method = EULAR_A;
7482 int sample_steps = 20 ;
83+ float strength = 0 .75f ;
7584 int seed = 42 ;
7685 bool verbose = false ;
7786
7887 void print () {
7988 printf (" Option: \n " );
8089 printf (" n_threads: %d\n " , n_threads);
90+ printf (" mode: %s\n " , mode.c_str ());
8191 printf (" model_path: %s\n " , model_path.c_str ());
8292 printf (" output_path: %s\n " , output_path.c_str ());
93+ printf (" init_img: %s\n " , init_img.c_str ());
8394 printf (" prompt: %s\n " , prompt.c_str ());
8495 printf (" negative_prompt: %s\n " , negative_prompt.c_str ());
8596 printf (" cfg_scale: %.2f\n " , cfg_scale);
8697 printf (" width: %d\n " , w);
8798 printf (" height: %d\n " , h);
8899 printf (" sample_method: %s\n " , " eular a" );
89100 printf (" sample_steps: %d\n " , sample_steps);
101+ printf (" strength: %.2f\n " , strength);
90102 printf (" seed: %d\n " , seed);
91103 }
92104};
@@ -96,13 +108,17 @@ void print_usage(int argc, const char* argv[]) {
96108 printf (" \n " );
97109 printf (" arguments:\n " );
98110 printf (" -h, --help show this help message and exit\n " );
111+ printf (" -M, --mode [txt2img or img2img] generation mode (default: txt2img)\n " );
99112 printf (" -t, --threads N number of threads to use during computation (default: -1).\n " );
100113 printf (" If threads <= 0, then threads will be set to the number of CPU physical cores\n " );
101114 printf (" -m, --model [MODEL] path to model\n " );
115+ printf (" -i, --init-img [IMAGE] path to the input image, required by img2img\n " );
102116 printf (" -o, --output OUTPUT path to write result image to (default: .\\ output.png)\n " );
103117 printf (" -p, --prompt [PROMPT] the prompt to render\n " );
104118 printf (" -n, --negative-prompt PROMPT the negative prompt (default: \"\" )\n " );
105119 printf (" --cfg-scale SCALE unconditional guidance scale: (default: 7.0)\n " );
120+ printf (" --strength STRENGTH strength for noising/unnoising (default: 0.75)\n " );
121+ printf (" 1.0 corresponds to full destruction of information in init image\n " );
106122 printf (" -H, --height H image height, in pixel space (default: 512)\n " );
107123 printf (" -W, --width W image width, in pixel space (default: 512)\n " );
108124 printf (" --sample-method SAMPLE_METHOD sample method (default: \" eular a\" )\n " );
@@ -123,12 +139,25 @@ void parse_args(int argc, const char* argv[], Option* opt) {
123139 break ;
124140 }
125141 opt->n_threads = std::stoi (argv[i]);
142+ } else if (arg == " -M" || arg == " --mode" ) {
143+ if (++i >= argc) {
144+ invalid_arg = true ;
145+ break ;
146+ }
147+ opt->mode = argv[i];
148+
126149 } else if (arg == " -m" || arg == " --model" ) {
127150 if (++i >= argc) {
128151 invalid_arg = true ;
129152 break ;
130153 }
131154 opt->model_path = argv[i];
155+ } else if (arg == " -i" || arg == " --init-img" ) {
156+ if (++i >= argc) {
157+ invalid_arg = true ;
158+ break ;
159+ }
160+ opt->init_img = argv[i];
132161 } else if (arg == " -o" || arg == " --output" ) {
133162 if (++i >= argc) {
134163 invalid_arg = true ;
@@ -153,6 +182,12 @@ void parse_args(int argc, const char* argv[], Option* opt) {
153182 break ;
154183 }
155184 opt->cfg_scale = std::stof (argv[i]);
185+ } else if (arg == " --strength" ) {
186+ if (++i >= argc) {
187+ invalid_arg = true ;
188+ break ;
189+ }
190+ opt->strength = std::stof (argv[i]);
156191 } else if (arg == " -H" || arg == " --height" ) {
157192 if (++i >= argc) {
158193 invalid_arg = true ;
@@ -198,6 +233,12 @@ void parse_args(int argc, const char* argv[], Option* opt) {
198233 opt->n_threads = get_num_physical_cores ();
199234 }
200235
236+ if (opt->mode != TXT2IMG && opt->mode != IMG2IMG) {
237+ fprintf (stderr, " error: invalid mode %s, must be one of ['%s', '%s']\n " ,
238+ opt->mode .c_str (), TXT2IMG, IMG2IMG);
239+ exit (1 );
240+ }
241+
201242 if (opt->prompt .length () == 0 ) {
202243 fprintf (stderr, " error: the following arguments are required: prompt\n " );
203244 print_usage (argc, argv);
@@ -210,6 +251,12 @@ void parse_args(int argc, const char* argv[], Option* opt) {
210251 exit (1 );
211252 }
212253
254+ if (opt->mode == IMG2IMG && opt->init_img .length () == 0 ) {
255+ fprintf (stderr, " error: when using the img2img mode, the following arguments are required: init-img\n " );
256+ print_usage (argc, argv);
257+ exit (1 );
258+ }
259+
213260 if (opt->output_path .length () == 0 ) {
214261 fprintf (stderr, " error: the following arguments are required: output_path\n " );
215262 print_usage (argc, argv);
@@ -230,6 +277,11 @@ void parse_args(int argc, const char* argv[], Option* opt) {
230277 fprintf (stderr, " error: the sample_steps must be greater than 0\n " );
231278 exit (1 );
232279 }
280+
281+ if (opt->strength < 0 .f || opt->strength > 1 .f ) {
282+ fprintf (stderr, " error: can only work with strength in [0.0, 1.0]\n " );
283+ exit (1 );
284+ }
233285}
234286
235287int main (int argc, const char * argv[]) {
@@ -242,19 +294,66 @@ int main(int argc, const char* argv[]) {
242294 set_sd_log_level (SDLogLevel::DEBUG);
243295 }
244296
245- StableDiffusion sd (opt.n_threads );
297+ bool vae_decode_only = true ;
298+ std::vector<uint8_t > init_img;
299+ if (opt.mode == IMG2IMG) {
300+ vae_decode_only = false ;
301+
302+ int c = 0 ;
303+ unsigned char * img_data = stbi_load (opt.init_img .c_str (), &opt.w , &opt.h , &c, 3 );
304+ if (img_data == NULL ) {
305+ fprintf (stderr, " load image from '%s' failed\n " , opt.init_img .c_str ());
306+ return 1 ;
307+ }
308+ if (c != 3 ) {
309+ fprintf (stderr, " input image must be a 3 channels RGB image, but got %d channels\n " , c);
310+ free (img_data);
311+ return 1 ;
312+ }
313+ if (opt.w <= 0 || opt.w % 32 != 0 ) {
314+ fprintf (stderr, " error: the width of image must be a multiple of 32\n " );
315+ free (img_data);
316+ return 1 ;
317+ }
318+ if (opt.h <= 0 || opt.h % 32 != 0 ) {
319+ fprintf (stderr, " error: the height of image must be a multiple of 32\n " );
320+ free (img_data);
321+ return 1 ;
322+ }
323+ init_img.assign (img_data, img_data + (opt.w * opt.h * c));
324+ }
325+ StableDiffusion sd (opt.n_threads , vae_decode_only);
246326 if (!sd.load_from_file (opt.model_path )) {
247327 return 1 ;
248328 }
249329
250- std::vector<uint8_t > img = sd.txt2img (opt.prompt ,
251- opt.negative_prompt ,
252- opt.cfg_scale ,
253- opt.w ,
254- opt.h ,
255- opt.sample_method ,
256- opt.sample_steps ,
257- opt.seed );
330+ std::vector<uint8_t > img;
331+ if (opt.mode == TXT2IMG) {
332+ img = sd.txt2img (opt.prompt ,
333+ opt.negative_prompt ,
334+ opt.cfg_scale ,
335+ opt.w ,
336+ opt.h ,
337+ opt.sample_method ,
338+ opt.sample_steps ,
339+ opt.seed );
340+ } else {
341+ img = sd.img2img (init_img,
342+ opt.prompt ,
343+ opt.negative_prompt ,
344+ opt.cfg_scale ,
345+ opt.w ,
346+ opt.h ,
347+ opt.sample_method ,
348+ opt.sample_steps ,
349+ opt.strength ,
350+ opt.seed );
351+ }
352+
353+ if (img.size () == 0 ) {
354+ fprintf (stderr, " generate failed\n " );
355+ return 1 ;
356+ }
258357
259358 stbi_write_png (opt.output_path .c_str (), opt.w , opt.h , 3 , img.data (), 0 );
260359 printf (" save result image to '%s'\n " , opt.output_path .c_str ());
0 commit comments