Skip to content

Commit 8749fc6

Browse files
committed
refactor 2-cfg conditioning + better img_cond defaults
1 parent 6103d86 commit 8749fc6

File tree

1 file changed

+50
-44
lines changed

1 file changed

+50
-44
lines changed

stable-diffusion.cpp

Lines changed: 50 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1116,7 +1116,7 @@ class StableDiffusionGGML {
11161116
ggml_tensor* noise,
11171117
SDCondition cond,
11181118
SDCondition uncond,
1119-
SDCondition img_cond,
1119+
SDCondition img_uncond,
11201120
ggml_tensor* control_hint,
11211121
float control_strength,
11221122
sd_guidance_params_t guidance,
@@ -1143,7 +1143,7 @@ class StableDiffusionGGML {
11431143

11441144
if (img_cfg_scale != cfg_scale && !sd_version_is_inpaint_or_unet_edit(version)) {
11451145
LOG_WARN("2-conditioning CFG is not supported with this model, disabling it for better performance...");
1146-
img_cfg_scale = cfg_scale;
1146+
img_cfg_scale = 1.0f;
11471147
}
11481148

11491149
size_t steps = sigmas.size() - 1;
@@ -1155,10 +1155,11 @@ class StableDiffusionGGML {
11551155
}
11561156

11571157
struct ggml_tensor* noised_input = ggml_dup_tensor(work_ctx, x);
1158-
1159-
bool has_unconditioned = img_cfg_scale != 1.0 && uncond.c_crossattn != nullptr;
1160-
bool has_img_cond = cfg_scale != img_cfg_scale && img_cond.c_crossattn != nullptr;
1158+
11611159
bool has_skiplayer = slg_scale != 0.0 && skip_layers.size() > 0;
1160+
bool has_conditionned = (has_skiplayer || cfg_scale != 0.0) && cond.c_crossattn != nullptr;
1161+
bool has_unconditioned = cfg_scale != img_cfg_scale && uncond.c_crossattn != nullptr;
1162+
bool has_img_uncond = img_cfg_scale != 1.0 && img_uncond.c_crossattn != nullptr;
11621163

11631164
// denoise wrapper
11641165
struct ggml_tensor* out_cond = ggml_dup_tensor(work_ctx, x);
@@ -1177,7 +1178,7 @@ class StableDiffusionGGML {
11771178
LOG_WARN("SLG is incompatible with %s models", model_version_to_str[version]);
11781179
}
11791180
}
1180-
if (has_img_cond) {
1181+
if (has_img_uncond) {
11811182
out_img_cond = ggml_dup_tensor(work_ctx, x);
11821183
}
11831184
struct ggml_tensor* denoised = ggml_dup_tensor(work_ctx, x);
@@ -1240,21 +1241,23 @@ class StableDiffusionGGML {
12401241
diffusion_params.vace_context = vace_context;
12411242
diffusion_params.vace_strength = vace_strength;
12421243

1243-
if (start_merge_step == -1 || step <= start_merge_step) {
1244-
// cond
1245-
diffusion_params.context = cond.c_crossattn;
1246-
diffusion_params.c_concat = cond.c_concat;
1247-
diffusion_params.y = cond.c_vector;
1248-
work_diffusion_model->compute(n_threads,
1249-
diffusion_params,
1250-
&out_cond);
1251-
} else {
1252-
diffusion_params.context = id_cond.c_crossattn;
1253-
diffusion_params.c_concat = cond.c_concat;
1254-
diffusion_params.y = id_cond.c_vector;
1255-
work_diffusion_model->compute(n_threads,
1256-
diffusion_params,
1257-
&out_cond);
1244+
if (has_conditionned) {
1245+
if (start_merge_step == -1 || step <= start_merge_step) {
1246+
// cond
1247+
diffusion_params.context = cond.c_crossattn;
1248+
diffusion_params.c_concat = cond.c_concat;
1249+
diffusion_params.y = cond.c_vector;
1250+
work_diffusion_model->compute(n_threads,
1251+
diffusion_params,
1252+
&out_cond);
1253+
} else {
1254+
diffusion_params.context = id_cond.c_crossattn;
1255+
diffusion_params.c_concat = cond.c_concat;
1256+
diffusion_params.y = id_cond.c_vector;
1257+
work_diffusion_model->compute(n_threads,
1258+
diffusion_params,
1259+
&out_cond);
1260+
}
12581261
}
12591262

12601263
float* negative_data = nullptr;
@@ -1275,10 +1278,10 @@ class StableDiffusionGGML {
12751278
}
12761279

12771280
float* img_cond_data = nullptr;
1278-
if (has_img_cond) {
1279-
diffusion_params.context = img_cond.c_crossattn;
1280-
diffusion_params.c_concat = img_cond.c_concat;
1281-
diffusion_params.y = img_cond.c_vector;
1281+
if (has_img_uncond) {
1282+
diffusion_params.context = img_uncond.c_crossattn;
1283+
diffusion_params.c_concat = img_uncond.c_concat;
1284+
diffusion_params.y = img_uncond.c_vector;
12821285
work_diffusion_model->compute(n_threads,
12831286
diffusion_params,
12841287
&out_img_cond);
@@ -1321,19 +1324,19 @@ class StableDiffusionGGML {
13211324
float latent_result = positive_data[i];
13221325
if (has_unconditioned) {
13231326
// out_uncond + cfg_scale * (out_cond - out_uncond)
1324-
if (has_img_cond) {
1327+
if (has_img_uncond) {
13251328
// out_uncond + text_cfg_scale * (out_cond - out_img_cond) + image_cfg_scale * (out_img_cond - out_uncond)
13261329
latent_result = negative_data[i] + img_cfg_scale * (img_cond_data[i] - negative_data[i]) + cfg_scale * (positive_data[i] - img_cond_data[i]);
13271330
} else {
13281331
// img_cfg_scale == cfg_scale
13291332
latent_result = negative_data[i] + cfg_scale * (positive_data[i] - negative_data[i]);
13301333
}
1331-
} else if (has_img_cond) {
1334+
} else if (has_img_uncond) {
13321335
// img_cfg_scale == 1
13331336
latent_result = img_cond_data[i] + cfg_scale * (positive_data[i] - img_cond_data[i]);
13341337
}
13351338
if (is_skiplayer_step) {
1336-
latent_result = latent_result + (positive_data[i] - skip_layer_data[i]) * slg_scale;
1339+
latent_result = latent_result + slg_scale * (positive_data[i] - skip_layer_data[i]);
13371340
}
13381341
// v = latent_result, eps = latent_result
13391342
// denoised = (v * c_out + input * c_skip) or (input + eps * c_out)
@@ -1977,7 +1980,7 @@ char* sd_sample_params_to_str(const sd_sample_params_t* sample_params) {
19771980
sample_params->guidance.txt_cfg,
19781981
std::isfinite(sample_params->guidance.img_cfg)
19791982
? sample_params->guidance.img_cfg
1980-
: sample_params->guidance.txt_cfg,
1983+
: 1.0f,
19811984
sample_params->guidance.distilled_guidance,
19821985
sample_params->guidance.slg.layer_count,
19831986
sample_params->guidance.slg.layer_start,
@@ -2142,7 +2145,8 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
21422145
}
21432146

21442147
if (!std::isfinite(guidance.img_cfg)) {
2145-
guidance.img_cfg = guidance.txt_cfg;
2148+
// default to 1
2149+
guidance.img_cfg = 1.0f;
21462150
}
21472151

21482152
// for (auto v : sigmas) {
@@ -2250,7 +2254,7 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
22502254

22512255
SDCondition uncond;
22522256
if (guidance.txt_cfg != 1.0 ||
2253-
(sd_version_is_inpaint_or_unet_edit(sd_ctx->sd->version) && guidance.txt_cfg != guidance.img_cfg)) {
2257+
(sd_version_is_inpaint_or_unet_edit(sd_ctx->sd->version) && guidance.txt_cfg != 1.0f)) {
22542258
bool zero_out_masked = false;
22552259
if (sd_version_is_sdxl(sd_ctx->sd->version) && negative_prompt.size() == 0 && !sd_ctx->sd->is_using_edm_v_parameterization) {
22562260
zero_out_masked = true;
@@ -2288,14 +2292,16 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
22882292
ggml_ext_tensor_scale_inplace(control_latent, control_strength);
22892293
}
22902294

2295+
struct ggml_tensor* empty_latent;
2296+
22912297
if (sd_version_is_inpaint(sd_ctx->sd->version)) {
22922298
int64_t mask_channels = 1;
22932299
if (sd_ctx->sd->version == VERSION_FLUX_FILL) {
22942300
mask_channels = 8 * 8; // flatten the whole mask
22952301
} else if (sd_ctx->sd->version == VERSION_FLEX_2) {
22962302
mask_channels = 1 + init_latent->ne[2];
22972303
}
2298-
auto empty_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, init_latent->ne[0], init_latent->ne[1], mask_channels + init_latent->ne[2], 1);
2304+
empty_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, init_latent->ne[0], init_latent->ne[1], mask_channels + init_latent->ne[2], 1);
22992305
// no mask, set the whole image as masked
23002306
for (int64_t x = 0; x < empty_latent->ne[0]; x++) {
23012307
for (int64_t y = 0; y < empty_latent->ne[1]; y++) {
@@ -2345,31 +2351,31 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
23452351
concat_latent = empty_latent;
23462352
}
23472353
cond.c_concat = concat_latent;
2348-
uncond.c_concat = empty_latent;
2354+
uncond.c_concat = concat_latent;
23492355
denoise_mask = nullptr;
23502356
} else if (sd_version_is_unet_edit(sd_ctx->sd->version)) {
2351-
auto empty_latent = ggml_dup_tensor(work_ctx, init_latent);
2357+
empty_latent = ggml_dup_tensor(work_ctx, init_latent);
23522358
ggml_set_f32(empty_latent, 0);
2353-
uncond.c_concat = empty_latent;
2354-
cond.c_concat = ref_latents[0];
2359+
cond.c_concat = ref_latents[0];
23552360
if (cond.c_concat == nullptr) {
23562361
cond.c_concat = empty_latent;
23572362
}
2363+
uncond.c_concat = cond.c_concat;
23582364
} else if (sd_version_is_control(sd_ctx->sd->version)) {
2359-
auto empty_latent = ggml_dup_tensor(work_ctx, init_latent);
2365+
empty_latent = ggml_dup_tensor(work_ctx, init_latent);
23602366
ggml_set_f32(empty_latent, 0);
2361-
uncond.c_concat = empty_latent;
23622367
if (sd_ctx->sd->control_net == nullptr) {
23632368
cond.c_concat = control_latent;
23642369
}
23652370
if (cond.c_concat == nullptr) {
23662371
cond.c_concat = empty_latent;
23672372
}
2373+
uncond.c_concat = cond.c_concat;
23682374
}
2369-
SDCondition img_cond;
2375+
SDCondition img_uncond = uncond;
23702376
if (uncond.c_crossattn != nullptr &&
2371-
(sd_version_is_inpaint_or_unet_edit(sd_ctx->sd->version) && guidance.txt_cfg != guidance.img_cfg)) {
2372-
img_cond = SDCondition(uncond.c_crossattn, uncond.c_vector, cond.c_concat);
2377+
(sd_version_is_inpaint_or_unet_edit(sd_ctx->sd->version) && guidance.img_cfg != 1.0)) {
2378+
img_uncond = SDCondition(uncond.c_crossattn, uncond.c_vector, empty_latent);
23732379
}
23742380
for (int b = 0; b < batch_count; b++) {
23752381
int64_t sampling_start = ggml_time_ms();
@@ -2396,7 +2402,7 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
23962402
noise,
23972403
cond,
23982404
uncond,
2399-
img_cond,
2405+
img_uncond,
24002406
image_hint,
24012407
control_strength,
24022408
guidance,
@@ -3027,7 +3033,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
30273033
noise,
30283034
cond,
30293035
uncond,
3030-
{},
3036+
uncond,
30313037
nullptr,
30323038
0,
30333039
sd_vid_gen_params->high_noise_sample_params.guidance,
@@ -3063,7 +3069,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
30633069
noise,
30643070
cond,
30653071
uncond,
3066-
{},
3072+
uncond,
30673073
nullptr,
30683074
0,
30693075
sd_vid_gen_params->sample_params.guidance,

0 commit comments

Comments
 (0)