Skip to content

Commit 4ceff8d

Browse files
committed
add cache-mode and cache-option
1 parent 9de8cbd commit 4ceff8d

File tree

2 files changed

+59
-133
lines changed

2 files changed

+59
-133
lines changed

examples/cli/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,6 @@ Generation Options:
124124
--skip-layers layers to skip for SLG steps (default: [7,8,9])
125125
--high-noise-skip-layers (high noise) layers to skip for SLG steps (default: [7,8,9])
126126
-r, --ref-image reference image for Flux Kontext models (can be used multiple times)
127-
--easycache enable EasyCache for DiT models with optional "threshold,start_percent,end_percent" (default: 0.2,0.15,0.95)
128-
--ucache enable UCache for UNET models with optional "threshold,start_percent,end_percent" (default: 1,0.15,0.95)
127+
--cache-mode caching method: 'easycache' for DiT models, 'ucache' for UNET models (SD1.x/SD2.x/SDXL)
128+
--cache-option cache parameters "threshold,start_percent,end_percent" (default: 0.2,0.15,0.95 for easycache, 1.0,0.15,0.95 for ucache)
129129
```

examples/cli/main.cpp

Lines changed: 57 additions & 131 deletions
Original file line numberDiff line numberDiff line change
@@ -988,10 +988,9 @@ struct SDGenerationParams {
988988
std::vector<int> high_noise_skip_layers = {7, 8, 9};
989989
sd_sample_params_t high_noise_sample_params;
990990

991-
std::string easycache_option;
991+
std::string cache_mode;
992+
std::string cache_option;
992993
sd_easycache_params_t easycache_params;
993-
994-
std::string ucache_option;
995994
sd_ucache_params_t ucache_params;
996995

997996
float moe_boundary = 0.875f;
@@ -1308,68 +1307,24 @@ struct SDGenerationParams {
13081307
return 1;
13091308
};
13101309

1311-
auto on_easycache_arg = [&](int argc, const char** argv, int index) {
1312-
const std::string default_values = "0.2,0.15,0.95";
1313-
auto looks_like_value = [](const std::string& token) {
1314-
if (token.empty()) {
1315-
return false;
1316-
}
1317-
if (token[0] != '-') {
1318-
return true;
1319-
}
1320-
if (token.size() == 1) {
1321-
return false;
1322-
}
1323-
unsigned char next = static_cast<unsigned char>(token[1]);
1324-
return std::isdigit(next) || token[1] == '.';
1325-
};
1326-
1327-
std::string option_value;
1328-
int consumed = 0;
1329-
if (index + 1 < argc) {
1330-
std::string next_arg = argv[index + 1];
1331-
if (looks_like_value(next_arg)) {
1332-
option_value = argv_to_utf8(index + 1, argv);
1333-
consumed = 1;
1334-
}
1310+
auto on_cache_mode_arg = [&](int argc, const char** argv, int index) {
1311+
if (++index >= argc) {
1312+
return -1;
13351313
}
1336-
if (option_value.empty()) {
1337-
option_value = default_values;
1314+
cache_mode = argv_to_utf8(index, argv);
1315+
if (cache_mode != "easycache" && cache_mode != "ucache") {
1316+
fprintf(stderr, "error: invalid cache mode '%s', must be 'easycache' or 'ucache'\n", cache_mode.c_str());
1317+
return -1;
13381318
}
1339-
easycache_option = option_value;
1340-
return consumed;
1319+
return 1;
13411320
};
13421321

1343-
auto on_ucache_arg = [&](int argc, const char** argv, int index) {
1344-
const std::string default_values = "1.0,0.15,0.95";
1345-
auto looks_like_value = [](const std::string& token) {
1346-
if (token.empty()) {
1347-
return false;
1348-
}
1349-
if (token[0] != '-') {
1350-
return true;
1351-
}
1352-
if (token.size() == 1) {
1353-
return false;
1354-
}
1355-
unsigned char next = static_cast<unsigned char>(token[1]);
1356-
return std::isdigit(next) || token[1] == '.';
1357-
};
1358-
1359-
std::string option_value;
1360-
int consumed = 0;
1361-
if (index + 1 < argc) {
1362-
std::string next_arg = argv[index + 1];
1363-
if (looks_like_value(next_arg)) {
1364-
option_value = argv_to_utf8(index + 1, argv);
1365-
consumed = 1;
1366-
}
1367-
}
1368-
if (option_value.empty()) {
1369-
option_value = default_values;
1322+
auto on_cache_option_arg = [&](int argc, const char** argv, int index) {
1323+
if (++index >= argc) {
1324+
return -1;
13701325
}
1371-
ucache_option = option_value;
1372-
return consumed;
1326+
cache_option = argv_to_utf8(index, argv);
1327+
return 1;
13731328
};
13741329

13751330
options.manual_options = {
@@ -1404,13 +1359,13 @@ struct SDGenerationParams {
14041359
"reference image for Flux Kontext models (can be used multiple times)",
14051360
on_ref_image_arg},
14061361
{"",
1407-
"--easycache",
1408-
"enable EasyCache for DiT models with optional \"threshold,start_percent,end_percent\" (default: 0.2,0.15,0.95)",
1409-
on_easycache_arg},
1362+
"--cache-mode",
1363+
"caching method: 'easycache' for DiT models, 'ucache' for UNET models (SD1.x/SD2.x/SDXL)",
1364+
on_cache_mode_arg},
14101365
{"",
1411-
"--ucache",
1412-
"enable UCache for UNET models (SD1.x/SD2.x/SDXL) with optional \"threshold,start_percent,end_percent\" (default: 1.0,0.15,0.95)",
1413-
on_ucache_arg},
1366+
"--cache-option",
1367+
"cache parameters \"threshold,start_percent,end_percent\" (default: 0.2,0.15,0.95 for easycache, 1.0,0.15,0.95 for ucache)",
1368+
on_cache_option_arg},
14141369

14151370
};
14161371

@@ -1442,62 +1397,21 @@ struct SDGenerationParams {
14421397
return false;
14431398
}
14441399

1445-
if (!easycache_option.empty()) {
1446-
float values[3] = {0.0f, 0.0f, 0.0f};
1447-
std::stringstream ss(easycache_option);
1448-
std::string token;
1449-
int idx = 0;
1450-
while (std::getline(ss, token, ',')) {
1451-
auto trim = [](std::string& s) {
1452-
const char* whitespace = " \t\r\n";
1453-
auto start = s.find_first_not_of(whitespace);
1454-
if (start == std::string::npos) {
1455-
s.clear();
1456-
return;
1457-
}
1458-
auto end = s.find_last_not_of(whitespace);
1459-
s = s.substr(start, end - start + 1);
1460-
};
1461-
trim(token);
1462-
if (token.empty()) {
1463-
fprintf(stderr, "error: invalid easycache option '%s'\n", easycache_option.c_str());
1464-
return false;
1465-
}
1466-
if (idx >= 3) {
1467-
fprintf(stderr, "error: easycache expects exactly 3 comma-separated values (threshold,start,end)\n");
1468-
return false;
1469-
}
1470-
try {
1471-
values[idx] = std::stof(token);
1472-
} catch (const std::exception&) {
1473-
fprintf(stderr, "error: invalid easycache value '%s'\n", token.c_str());
1474-
return false;
1400+
easycache_params.enabled = false;
1401+
ucache_params.enabled = false;
1402+
1403+
if (!cache_mode.empty()) {
1404+
std::string option_str = cache_option;
1405+
if (option_str.empty()) {
1406+
if (cache_mode == "easycache") {
1407+
option_str = "0.2,0.15,0.95";
1408+
} else {
1409+
option_str = "1.0,0.15,0.95";
14751410
}
1476-
idx++;
1477-
}
1478-
if (idx != 3) {
1479-
fprintf(stderr, "error: easycache expects exactly 3 comma-separated values (threshold,start,end)\n");
1480-
return false;
14811411
}
1482-
if (values[0] < 0.0f) {
1483-
fprintf(stderr, "error: easycache threshold must be non-negative\n");
1484-
return false;
1485-
}
1486-
if (values[1] < 0.0f || values[1] >= 1.0f || values[2] <= 0.0f || values[2] > 1.0f || values[1] >= values[2]) {
1487-
fprintf(stderr, "error: easycache start/end percents must satisfy 0.0 <= start < end <= 1.0\n");
1488-
return false;
1489-
}
1490-
easycache_params.enabled = true;
1491-
easycache_params.reuse_threshold = values[0];
1492-
easycache_params.start_percent = values[1];
1493-
easycache_params.end_percent = values[2];
1494-
} else {
1495-
easycache_params.enabled = false;
1496-
}
14971412

1498-
if (!ucache_option.empty()) {
14991413
float values[3] = {0.0f, 0.0f, 0.0f};
1500-
std::stringstream ss(ucache_option);
1414+
std::stringstream ss(option_str);
15011415
std::string token;
15021416
int idx = 0;
15031417
while (std::getline(ss, token, ',')) {
@@ -1513,39 +1427,45 @@ struct SDGenerationParams {
15131427
};
15141428
trim(token);
15151429
if (token.empty()) {
1516-
fprintf(stderr, "error: invalid ucache option '%s'\n", ucache_option.c_str());
1430+
fprintf(stderr, "error: invalid cache option '%s'\n", option_str.c_str());
15171431
return false;
15181432
}
15191433
if (idx >= 3) {
1520-
fprintf(stderr, "error: ucache expects exactly 3 comma-separated values (threshold,start,end)\n");
1434+
fprintf(stderr, "error: cache option expects exactly 3 comma-separated values (threshold,start,end)\n");
15211435
return false;
15221436
}
15231437
try {
15241438
values[idx] = std::stof(token);
15251439
} catch (const std::exception&) {
1526-
fprintf(stderr, "error: invalid ucache value '%s'\n", token.c_str());
1440+
fprintf(stderr, "error: invalid cache option value '%s'\n", token.c_str());
15271441
return false;
15281442
}
15291443
idx++;
15301444
}
15311445
if (idx != 3) {
1532-
fprintf(stderr, "error: ucache expects exactly 3 comma-separated values (threshold,start,end)\n");
1446+
fprintf(stderr, "error: cache option expects exactly 3 comma-separated values (threshold,start,end)\n");
15331447
return false;
15341448
}
15351449
if (values[0] < 0.0f) {
1536-
fprintf(stderr, "error: ucache threshold must be non-negative\n");
1450+
fprintf(stderr, "error: cache threshold must be non-negative\n");
15371451
return false;
15381452
}
15391453
if (values[1] < 0.0f || values[1] >= 1.0f || values[2] <= 0.0f || values[2] > 1.0f || values[1] >= values[2]) {
1540-
fprintf(stderr, "error: ucache start/end percents must satisfy 0.0 <= start < end <= 1.0\n");
1454+
fprintf(stderr, "error: cache start/end percents must satisfy 0.0 <= start < end <= 1.0\n");
15411455
return false;
15421456
}
1543-
ucache_params.enabled = true;
1544-
ucache_params.reuse_threshold = values[0];
1545-
ucache_params.start_percent = values[1];
1546-
ucache_params.end_percent = values[2];
1547-
} else {
1548-
ucache_params.enabled = false;
1457+
1458+
if (cache_mode == "easycache") {
1459+
easycache_params.enabled = true;
1460+
easycache_params.reuse_threshold = values[0];
1461+
easycache_params.start_percent = values[1];
1462+
easycache_params.end_percent = values[2];
1463+
} else {
1464+
ucache_params.enabled = true;
1465+
ucache_params.reuse_threshold = values[0];
1466+
ucache_params.start_percent = values[1];
1467+
ucache_params.end_percent = values[2];
1468+
}
15491469
}
15501470

15511471
sample_params.guidance.slg.layers = skip_layers.data();
@@ -1610,12 +1530,18 @@ struct SDGenerationParams {
16101530
<< " sample_params: " << sample_params_str << ",\n"
16111531
<< " high_noise_skip_layers: " << vec_to_string(high_noise_skip_layers) << ",\n"
16121532
<< " high_noise_sample_params: " << high_noise_sample_params_str << ",\n"
1613-
<< " easycache_option: \"" << easycache_option << "\",\n"
1533+
<< " cache_mode: \"" << cache_mode << "\",\n"
1534+
<< " cache_option: \"" << cache_option << "\",\n"
16141535
<< " easycache: "
16151536
<< (easycache_params.enabled ? "enabled" : "disabled")
16161537
<< " (threshold=" << easycache_params.reuse_threshold
16171538
<< ", start=" << easycache_params.start_percent
16181539
<< ", end=" << easycache_params.end_percent << "),\n"
1540+
<< " ucache: "
1541+
<< (ucache_params.enabled ? "enabled" : "disabled")
1542+
<< " (threshold=" << ucache_params.reuse_threshold
1543+
<< ", start=" << ucache_params.start_percent
1544+
<< ", end=" << ucache_params.end_percent << "),\n"
16191545
<< " moe_boundary: " << moe_boundary << ",\n"
16201546
<< " video_frames: " << video_frames << ",\n"
16211547
<< " fps: " << fps << ",\n"

0 commit comments

Comments
 (0)