Skip to content

Commit 78e15bd

Browse files
wbrunaleejet
andauthored
feat: default to LCM scheduler for LCM sampling (#1109)
* feat: default to LCM scheduler for LCM sampling * fix bug and attempt to get default scheduler for vid_gen when none is set --------- Co-authored-by: leejet <leejet714@gmail.com>
1 parent 97cf2ef commit 78e15bd

File tree

3 files changed

+17
-6
lines changed

3 files changed

+17
-6
lines changed

examples/cli/main.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -579,7 +579,7 @@ int main(int argc, const char* argv[]) {
579579
}
580580

581581
if (gen_params.sample_params.scheduler == SCHEDULER_COUNT) {
582-
gen_params.sample_params.scheduler = sd_get_default_scheduler(sd_ctx);
582+
gen_params.sample_params.scheduler = sd_get_default_scheduler(sd_ctx, gen_params.sample_params.sample_method);
583583
}
584584

585585
if (cli_params.mode == IMG_GEN) {
@@ -752,4 +752,4 @@ int main(int argc, const char* argv[]) {
752752
release_all_resources();
753753

754754
return 0;
755-
}
755+
}

stable-diffusion.cpp

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2777,13 +2777,16 @@ enum sample_method_t sd_get_default_sample_method(const sd_ctx_t* sd_ctx) {
27772777
return EULER_A_SAMPLE_METHOD;
27782778
}
27792779

2780-
enum scheduler_t sd_get_default_scheduler(const sd_ctx_t* sd_ctx) {
2780+
enum scheduler_t sd_get_default_scheduler(const sd_ctx_t* sd_ctx, enum sample_method_t sample_method) {
27812781
if (sd_ctx != nullptr && sd_ctx->sd != nullptr) {
27822782
auto edm_v_denoiser = std::dynamic_pointer_cast<EDMVDenoiser>(sd_ctx->sd->denoiser);
27832783
if (edm_v_denoiser) {
27842784
return EXPONENTIAL_SCHEDULER;
27852785
}
27862786
}
2787+
if (sample_method == LCM_SAMPLE_METHOD) {
2788+
return LCM_SCHEDULER;
2789+
}
27872790
return DISCRETE_SCHEDULER;
27882791
}
27892792

@@ -3218,9 +3221,13 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
32183221
LOG_WARN("sample_steps != custom_sigmas_count - 1, set sample_steps to %d", sample_steps);
32193222
}
32203223
} else {
3224+
scheduler_t scheduler = sd_img_gen_params->sample_params.scheduler;
3225+
if (scheduler == SCHEDULER_COUNT) {
3226+
scheduler = sd_get_default_scheduler(sd_ctx, sample_method);
3227+
}
32213228
sigmas = sd_ctx->sd->denoiser->get_sigmas(sample_steps,
32223229
sd_ctx->sd->get_image_seq_len(height, width),
3223-
sd_img_gen_params->sample_params.scheduler,
3230+
scheduler,
32243231
sd_ctx->sd->version);
32253232
}
32263233

@@ -3503,9 +3510,13 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
35033510
}
35043511
}
35053512
} else {
3513+
scheduler_t scheduler = sd_vid_gen_params->sample_params.scheduler;
3514+
if (scheduler == SCHEDULER_COUNT) {
3515+
scheduler = sd_get_default_scheduler(sd_ctx, sample_method);
3516+
}
35063517
sigmas = sd_ctx->sd->denoiser->get_sigmas(total_steps,
35073518
0,
3508-
sd_vid_gen_params->sample_params.scheduler,
3519+
scheduler,
35093520
sd_ctx->sd->version);
35103521
}
35113522

stable-diffusion.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -335,7 +335,7 @@ SD_API void sd_sample_params_init(sd_sample_params_t* sample_params);
335335
SD_API char* sd_sample_params_to_str(const sd_sample_params_t* sample_params);
336336

337337
SD_API enum sample_method_t sd_get_default_sample_method(const sd_ctx_t* sd_ctx);
338-
SD_API enum scheduler_t sd_get_default_scheduler(const sd_ctx_t* sd_ctx);
338+
SD_API enum scheduler_t sd_get_default_scheduler(const sd_ctx_t* sd_ctx, enum sample_method_t sample_method);
339339

340340
SD_API void sd_img_gen_params_init(sd_img_gen_params_t* sd_img_gen_params);
341341
SD_API char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params);

0 commit comments

Comments
 (0)