Skip to content

Commit 9de8cbd

Browse files
committed
add ucache
1 parent 0392273 commit 9de8cbd

File tree

5 files changed

+530
-15
lines changed

5 files changed

+530
-15
lines changed

examples/cli/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,4 +125,5 @@ Generation Options:
125125
--high-noise-skip-layers (high noise) layers to skip for SLG steps (default: [7,8,9])
126126
-r, --ref-image reference image for Flux Kontext models (can be used multiple times)
127127
--easycache enable EasyCache for DiT models with optional "threshold,start_percent,end_percent" (default: 0.2,0.15,0.95)
128+
--ucache enable UCache for UNET models with optional "threshold,start_percent,end_percent" (default: 1,0.15,0.95)
128129
```

examples/cli/main.cpp

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -991,6 +991,9 @@ struct SDGenerationParams {
991991
std::string easycache_option;
992992
sd_easycache_params_t easycache_params;
993993

994+
std::string ucache_option;
995+
sd_ucache_params_t ucache_params;
996+
994997
float moe_boundary = 0.875f;
995998
int video_frames = 1;
996999
int fps = 16;
@@ -1337,6 +1340,38 @@ struct SDGenerationParams {
13371340
return consumed;
13381341
};
13391342

1343+
auto on_ucache_arg = [&](int argc, const char** argv, int index) {
1344+
const std::string default_values = "1.0,0.15,0.95";
1345+
auto looks_like_value = [](const std::string& token) {
1346+
if (token.empty()) {
1347+
return false;
1348+
}
1349+
if (token[0] != '-') {
1350+
return true;
1351+
}
1352+
if (token.size() == 1) {
1353+
return false;
1354+
}
1355+
unsigned char next = static_cast<unsigned char>(token[1]);
1356+
return std::isdigit(next) || token[1] == '.';
1357+
};
1358+
1359+
std::string option_value;
1360+
int consumed = 0;
1361+
if (index + 1 < argc) {
1362+
std::string next_arg = argv[index + 1];
1363+
if (looks_like_value(next_arg)) {
1364+
option_value = argv_to_utf8(index + 1, argv);
1365+
consumed = 1;
1366+
}
1367+
}
1368+
if (option_value.empty()) {
1369+
option_value = default_values;
1370+
}
1371+
ucache_option = option_value;
1372+
return consumed;
1373+
};
1374+
13401375
options.manual_options = {
13411376
{"-s",
13421377
"--seed",
@@ -1372,6 +1407,10 @@ struct SDGenerationParams {
13721407
"--easycache",
13731408
"enable EasyCache for DiT models with optional \"threshold,start_percent,end_percent\" (default: 0.2,0.15,0.95)",
13741409
on_easycache_arg},
1410+
{"",
1411+
"--ucache",
1412+
"enable UCache for UNET models (SD1.x/SD2.x/SDXL) with optional \"threshold,start_percent,end_percent\" (default: 1.0,0.15,0.95)",
1413+
on_ucache_arg},
13751414

13761415
};
13771416

@@ -1456,6 +1495,59 @@ struct SDGenerationParams {
14561495
easycache_params.enabled = false;
14571496
}
14581497

1498+
if (!ucache_option.empty()) {
1499+
float values[3] = {0.0f, 0.0f, 0.0f};
1500+
std::stringstream ss(ucache_option);
1501+
std::string token;
1502+
int idx = 0;
1503+
while (std::getline(ss, token, ',')) {
1504+
auto trim = [](std::string& s) {
1505+
const char* whitespace = " \t\r\n";
1506+
auto start = s.find_first_not_of(whitespace);
1507+
if (start == std::string::npos) {
1508+
s.clear();
1509+
return;
1510+
}
1511+
auto end = s.find_last_not_of(whitespace);
1512+
s = s.substr(start, end - start + 1);
1513+
};
1514+
trim(token);
1515+
if (token.empty()) {
1516+
fprintf(stderr, "error: invalid ucache option '%s'\n", ucache_option.c_str());
1517+
return false;
1518+
}
1519+
if (idx >= 3) {
1520+
fprintf(stderr, "error: ucache expects exactly 3 comma-separated values (threshold,start,end)\n");
1521+
return false;
1522+
}
1523+
try {
1524+
values[idx] = std::stof(token);
1525+
} catch (const std::exception&) {
1526+
fprintf(stderr, "error: invalid ucache value '%s'\n", token.c_str());
1527+
return false;
1528+
}
1529+
idx++;
1530+
}
1531+
if (idx != 3) {
1532+
fprintf(stderr, "error: ucache expects exactly 3 comma-separated values (threshold,start,end)\n");
1533+
return false;
1534+
}
1535+
if (values[0] < 0.0f) {
1536+
fprintf(stderr, "error: ucache threshold must be non-negative\n");
1537+
return false;
1538+
}
1539+
if (values[1] < 0.0f || values[1] >= 1.0f || values[2] <= 0.0f || values[2] > 1.0f || values[1] >= values[2]) {
1540+
fprintf(stderr, "error: ucache start/end percents must satisfy 0.0 <= start < end <= 1.0\n");
1541+
return false;
1542+
}
1543+
ucache_params.enabled = true;
1544+
ucache_params.reuse_threshold = values[0];
1545+
ucache_params.start_percent = values[1];
1546+
ucache_params.end_percent = values[2];
1547+
} else {
1548+
ucache_params.enabled = false;
1549+
}
1550+
14591551
sample_params.guidance.slg.layers = skip_layers.data();
14601552
sample_params.guidance.slg.layer_count = skip_layers.size();
14611553
high_noise_sample_params.guidance.slg.layers = high_noise_skip_layers.data();
@@ -2087,6 +2179,7 @@ int main(int argc, const char* argv[]) {
20872179
}, // pm_params
20882180
ctx_params.vae_tiling_params,
20892181
gen_params.easycache_params,
2182+
gen_params.ucache_params,
20902183
};
20912184

20922185
results = generate_image(sd_ctx, &img_gen_params);
@@ -2110,6 +2203,7 @@ int main(int argc, const char* argv[]) {
21102203
gen_params.video_frames,
21112204
gen_params.vace_strength,
21122205
gen_params.easycache_params,
2206+
gen_params.ucache_params,
21132207
};
21142208

21152209
results = generate_video(sd_ctx, &vid_gen_params, &num_results);

0 commit comments

Comments
 (0)