Skip to content

Commit 97cf2ef

Browse files
feat: add KL Optimal scheduler (#1098)
1 parent bda7fab commit 97cf2ef

File tree

6 files changed

+47
-6
lines changed

6 files changed

+47
-6
lines changed

denoiser.hpp

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,41 @@ struct SmoothStepScheduler : SigmaScheduler {
347347
}
348348
};
349349

350+
// Implementation adapted from https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/15608
351+
struct KLOptimalScheduler : SigmaScheduler {
352+
std::vector<float> get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) override {
353+
std::vector<float> sigmas;
354+
355+
if (n == 0) {
356+
return sigmas;
357+
}
358+
if (n == 1) {
359+
sigmas.push_back(sigma_max);
360+
sigmas.push_back(0.0f);
361+
return sigmas;
362+
}
363+
364+
float alpha_min = std::atan(sigma_min);
365+
float alpha_max = std::atan(sigma_max);
366+
367+
for (uint32_t i = 0; i < n; ++i) {
368+
// t goes from 0.0 to 1.0
369+
float t = static_cast<float>(i) / static_cast<float>(n-1);
370+
371+
// Interpolate in the angle domain
372+
float angle = t * alpha_min + (1.0f - t) * alpha_max;
373+
374+
// Convert back to sigma
375+
sigmas.push_back(std::tan(angle));
376+
}
377+
378+
// Append the final zero to sigma
379+
sigmas.push_back(0.0f);
380+
381+
return sigmas;
382+
}
383+
};
384+
350385
struct Denoiser {
351386
virtual float sigma_min() = 0;
352387
virtual float sigma_max() = 0;
@@ -392,6 +427,10 @@ struct Denoiser {
392427
LOG_INFO("get_sigmas with SmoothStep scheduler");
393428
scheduler = std::make_shared<SmoothStepScheduler>();
394429
break;
430+
case KL_OPTIMAL_SCHEDULER:
431+
LOG_INFO("get_sigmas with KL Optimal scheduler");
432+
scheduler = std::make_shared<KLOptimalScheduler>();
433+
break;
395434
case LCM_SCHEDULER:
396435
LOG_INFO("get_sigmas with LCM scheduler");
397436
scheduler = std::make_shared<LCMScheduler>();

examples/cli/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ Generation Options:
120120
tcd] (default: euler for Flux/SD3/Wan, euler_a otherwise)
121121
--high-noise-sampling-method (high noise) sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm,
122122
ddim_trailing, tcd] default: euler for Flux/SD3/Wan, euler_a otherwise
123-
--scheduler denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple, lcm],
123+
--scheduler denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple, kl_optimal, lcm],
124124
default: discrete
125125
--sigmas custom sigma values for the sampler, comma-separated (e.g., "14.61,7.8,3.5,0.0").
126126
--skip-layers layers to skip for SLG steps (default: [7,8,9])

examples/common/common.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1409,7 +1409,7 @@ struct SDGenerationParams {
14091409
on_high_noise_sample_method_arg},
14101410
{"",
14111411
"--scheduler",
1412-
"denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple, lcm], default: discrete",
1412+
"denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple, kl_optimal, lcm], default: discrete",
14131413
on_scheduler_arg},
14141414
{"",
14151415
"--sigmas",
@@ -1911,4 +1911,4 @@ uint8_t* load_image_from_memory(const char* image_bytes,
19111911
int expected_height = 0,
19121912
int expected_channel = 3) {
19131913
return load_image_common(true, image_bytes, len, width, height, expected_width, expected_height, expected_channel);
1914-
}
1914+
}

examples/server/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,11 +114,11 @@ Default Generation Options:
114114
tcd] (default: euler for Flux/SD3/Wan, euler_a otherwise)
115115
--high-noise-sampling-method (high noise) sampling method, one of [euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm,
116116
ddim_trailing, tcd] default: euler for Flux/SD3/Wan, euler_a otherwise
117-
--scheduler denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple, lcm],
117+
--scheduler denoiser sigma scheduler, one of [discrete, karras, exponential, ays, gits, smoothstep, sgm_uniform, simple, kl_optimal, lcm],
118118
default: discrete
119119
--sigmas custom sigma values for the sampler, comma-separated (e.g., "14.61,7.8,3.5,0.0").
120120
--skip-layers layers to skip for SLG steps (default: [7,8,9])
121121
--high-noise-skip-layers (high noise) layers to skip for SLG steps (default: [7,8,9])
122122
-r, --ref-image reference image for Flux Kontext models (can be used multiple times)
123123
--easycache enable EasyCache for DiT models with optional "threshold,start_percent,end_percent" (default: 0.2,0.15,0.95)
124-
```
124+
```

stable-diffusion.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2412,6 +2412,7 @@ const char* scheduler_to_str[] = {
24122412
"sgm_uniform",
24132413
"simple",
24142414
"smoothstep",
2415+
"kl_optimal",
24152416
"lcm",
24162417
};
24172418

@@ -3888,4 +3889,4 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
38883889
LOG_INFO("generate_video completed in %.2fs", (t5 - t0) * 1.0f / 1000);
38893890

38903891
return result_images;
3891-
}
3892+
}

stable-diffusion.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ enum scheduler_t {
6060
SGM_UNIFORM_SCHEDULER,
6161
SIMPLE_SCHEDULER,
6262
SMOOTHSTEP_SCHEDULER,
63+
KL_OPTIMAL_SCHEDULER,
6364
LCM_SCHEDULER,
6465
SCHEDULER_COUNT
6566
};

0 commit comments

Comments
 (0)