@@ -988,10 +988,9 @@ struct SDGenerationParams {
988988 std::vector<int > high_noise_skip_layers = {7 , 8 , 9 };
989989 sd_sample_params_t high_noise_sample_params;
990990
991- std::string easycache_option;
991+ std::string cache_mode;
992+ std::string cache_option;
992993 sd_easycache_params_t easycache_params;
993-
994- std::string ucache_option;
995994 sd_ucache_params_t ucache_params;
996995
997996 float moe_boundary = 0 .875f ;
@@ -1308,68 +1307,24 @@ struct SDGenerationParams {
13081307 return 1 ;
13091308 };
13101309
1311- auto on_easycache_arg = [&](int argc, const char ** argv, int index) {
1312- const std::string default_values = " 0.2,0.15,0.95" ;
1313- auto looks_like_value = [](const std::string& token) {
1314- if (token.empty ()) {
1315- return false ;
1316- }
1317- if (token[0 ] != ' -' ) {
1318- return true ;
1319- }
1320- if (token.size () == 1 ) {
1321- return false ;
1322- }
1323- unsigned char next = static_cast <unsigned char >(token[1 ]);
1324- return std::isdigit (next) || token[1 ] == ' .' ;
1325- };
1326-
1327- std::string option_value;
1328- int consumed = 0 ;
1329- if (index + 1 < argc) {
1330- std::string next_arg = argv[index + 1 ];
1331- if (looks_like_value (next_arg)) {
1332- option_value = argv_to_utf8 (index + 1 , argv);
1333- consumed = 1 ;
1334- }
1310+ auto on_cache_mode_arg = [&](int argc, const char ** argv, int index) {
1311+ if (++index >= argc) {
1312+ return -1 ;
13351313 }
1336- if (option_value.empty ()) {
1337- option_value = default_values;
1314+ cache_mode = argv_to_utf8 (index, argv);
1315+ if (cache_mode != " easycache" && cache_mode != " ucache" ) {
1316+ fprintf (stderr, " error: invalid cache mode '%s', must be 'easycache' or 'ucache'\n " , cache_mode.c_str ());
1317+ return -1 ;
13381318 }
1339- easycache_option = option_value;
1340- return consumed;
1319+ return 1 ;
13411320 };
13421321
1343- auto on_ucache_arg = [&](int argc, const char ** argv, int index) {
1344- const std::string default_values = " 1.0,0.15,0.95" ;
1345- auto looks_like_value = [](const std::string& token) {
1346- if (token.empty ()) {
1347- return false ;
1348- }
1349- if (token[0 ] != ' -' ) {
1350- return true ;
1351- }
1352- if (token.size () == 1 ) {
1353- return false ;
1354- }
1355- unsigned char next = static_cast <unsigned char >(token[1 ]);
1356- return std::isdigit (next) || token[1 ] == ' .' ;
1357- };
1358-
1359- std::string option_value;
1360- int consumed = 0 ;
1361- if (index + 1 < argc) {
1362- std::string next_arg = argv[index + 1 ];
1363- if (looks_like_value (next_arg)) {
1364- option_value = argv_to_utf8 (index + 1 , argv);
1365- consumed = 1 ;
1366- }
1367- }
1368- if (option_value.empty ()) {
1369- option_value = default_values;
1322+ auto on_cache_option_arg = [&](int argc, const char ** argv, int index) {
1323+ if (++index >= argc) {
1324+ return -1 ;
13701325 }
1371- ucache_option = option_value ;
1372- return consumed ;
1326+ cache_option = argv_to_utf8 (index, argv) ;
1327+ return 1 ;
13731328 };
13741329
13751330 options.manual_options = {
@@ -1404,13 +1359,13 @@ struct SDGenerationParams {
14041359 " reference image for Flux Kontext models (can be used multiple times)" ,
14051360 on_ref_image_arg},
14061361 {" " ,
1407- " --easycache " ,
1408- " enable EasyCache for DiT models with optional \" threshold,start_percent,end_percent \" (default: 0.2,0.15,0.95 )" ,
1409- on_easycache_arg },
1362+ " --cache-mode " ,
1363+ " caching method: 'easycache' for DiT models, 'ucache' for UNET models (SD1.x/SD2.x/SDXL )" ,
1364+ on_cache_mode_arg },
14101365 {" " ,
1411- " --ucache " ,
1412- " enable UCache for UNET models (SD1.x/SD2.x/SDXL) with optional \" threshold,start_percent,end_percent\" (default: 1.0,0.15,0.95)" ,
1413- on_ucache_arg },
1366+ " --cache-option " ,
1367+ " cache parameters \" threshold,start_percent,end_percent\" (default: 0.2,0.15,0.95 for easycache, 1.0,0.15,0.95 for ucache )" ,
1368+ on_cache_option_arg },
14141369
14151370 };
14161371
@@ -1442,62 +1397,21 @@ struct SDGenerationParams {
14421397 return false ;
14431398 }
14441399
1445- if (!easycache_option.empty ()) {
1446- float values[3 ] = {0 .0f , 0 .0f , 0 .0f };
1447- std::stringstream ss (easycache_option);
1448- std::string token;
1449- int idx = 0 ;
1450- while (std::getline (ss, token, ' ,' )) {
1451- auto trim = [](std::string& s) {
1452- const char * whitespace = " \t\r\n " ;
1453- auto start = s.find_first_not_of (whitespace);
1454- if (start == std::string::npos) {
1455- s.clear ();
1456- return ;
1457- }
1458- auto end = s.find_last_not_of (whitespace);
1459- s = s.substr (start, end - start + 1 );
1460- };
1461- trim (token);
1462- if (token.empty ()) {
1463- fprintf (stderr, " error: invalid easycache option '%s'\n " , easycache_option.c_str ());
1464- return false ;
1465- }
1466- if (idx >= 3 ) {
1467- fprintf (stderr, " error: easycache expects exactly 3 comma-separated values (threshold,start,end)\n " );
1468- return false ;
1469- }
1470- try {
1471- values[idx] = std::stof (token);
1472- } catch (const std::exception&) {
1473- fprintf (stderr, " error: invalid easycache value '%s'\n " , token.c_str ());
1474- return false ;
1400+ easycache_params.enabled = false ;
1401+ ucache_params.enabled = false ;
1402+
1403+ if (!cache_mode.empty ()) {
1404+ std::string option_str = cache_option;
1405+ if (option_str.empty ()) {
1406+ if (cache_mode == " easycache" ) {
1407+ option_str = " 0.2,0.15,0.95" ;
1408+ } else {
1409+ option_str = " 1.0,0.15,0.95" ;
14751410 }
1476- idx++;
1477- }
1478- if (idx != 3 ) {
1479- fprintf (stderr, " error: easycache expects exactly 3 comma-separated values (threshold,start,end)\n " );
1480- return false ;
14811411 }
1482- if (values[0 ] < 0 .0f ) {
1483- fprintf (stderr, " error: easycache threshold must be non-negative\n " );
1484- return false ;
1485- }
1486- if (values[1 ] < 0 .0f || values[1 ] >= 1 .0f || values[2 ] <= 0 .0f || values[2 ] > 1 .0f || values[1 ] >= values[2 ]) {
1487- fprintf (stderr, " error: easycache start/end percents must satisfy 0.0 <= start < end <= 1.0\n " );
1488- return false ;
1489- }
1490- easycache_params.enabled = true ;
1491- easycache_params.reuse_threshold = values[0 ];
1492- easycache_params.start_percent = values[1 ];
1493- easycache_params.end_percent = values[2 ];
1494- } else {
1495- easycache_params.enabled = false ;
1496- }
14971412
1498- if (!ucache_option.empty ()) {
14991413 float values[3 ] = {0 .0f , 0 .0f , 0 .0f };
1500- std::stringstream ss (ucache_option );
1414+ std::stringstream ss (option_str );
15011415 std::string token;
15021416 int idx = 0 ;
15031417 while (std::getline (ss, token, ' ,' )) {
@@ -1513,39 +1427,45 @@ struct SDGenerationParams {
15131427 };
15141428 trim (token);
15151429 if (token.empty ()) {
1516- fprintf (stderr, " error: invalid ucache option '%s'\n " , ucache_option .c_str ());
1430+ fprintf (stderr, " error: invalid cache option '%s'\n " , option_str .c_str ());
15171431 return false ;
15181432 }
15191433 if (idx >= 3 ) {
1520- fprintf (stderr, " error: ucache expects exactly 3 comma-separated values (threshold,start,end)\n " );
1434+ fprintf (stderr, " error: cache option expects exactly 3 comma-separated values (threshold,start,end)\n " );
15211435 return false ;
15221436 }
15231437 try {
15241438 values[idx] = std::stof (token);
15251439 } catch (const std::exception&) {
1526- fprintf (stderr, " error: invalid ucache value '%s'\n " , token.c_str ());
1440+ fprintf (stderr, " error: invalid cache option value '%s'\n " , token.c_str ());
15271441 return false ;
15281442 }
15291443 idx++;
15301444 }
15311445 if (idx != 3 ) {
1532- fprintf (stderr, " error: ucache expects exactly 3 comma-separated values (threshold,start,end)\n " );
1446+ fprintf (stderr, " error: cache option expects exactly 3 comma-separated values (threshold,start,end)\n " );
15331447 return false ;
15341448 }
15351449 if (values[0 ] < 0 .0f ) {
1536- fprintf (stderr, " error: ucache threshold must be non-negative\n " );
1450+ fprintf (stderr, " error: cache threshold must be non-negative\n " );
15371451 return false ;
15381452 }
15391453 if (values[1 ] < 0 .0f || values[1 ] >= 1 .0f || values[2 ] <= 0 .0f || values[2 ] > 1 .0f || values[1 ] >= values[2 ]) {
1540- fprintf (stderr, " error: ucache start/end percents must satisfy 0.0 <= start < end <= 1.0\n " );
1454+ fprintf (stderr, " error: cache start/end percents must satisfy 0.0 <= start < end <= 1.0\n " );
15411455 return false ;
15421456 }
1543- ucache_params.enabled = true ;
1544- ucache_params.reuse_threshold = values[0 ];
1545- ucache_params.start_percent = values[1 ];
1546- ucache_params.end_percent = values[2 ];
1547- } else {
1548- ucache_params.enabled = false ;
1457+
1458+ if (cache_mode == " easycache" ) {
1459+ easycache_params.enabled = true ;
1460+ easycache_params.reuse_threshold = values[0 ];
1461+ easycache_params.start_percent = values[1 ];
1462+ easycache_params.end_percent = values[2 ];
1463+ } else {
1464+ ucache_params.enabled = true ;
1465+ ucache_params.reuse_threshold = values[0 ];
1466+ ucache_params.start_percent = values[1 ];
1467+ ucache_params.end_percent = values[2 ];
1468+ }
15491469 }
15501470
15511471 sample_params.guidance .slg .layers = skip_layers.data ();
@@ -1610,12 +1530,18 @@ struct SDGenerationParams {
16101530 << " sample_params: " << sample_params_str << " ,\n "
16111531 << " high_noise_skip_layers: " << vec_to_string (high_noise_skip_layers) << " ,\n "
16121532 << " high_noise_sample_params: " << high_noise_sample_params_str << " ,\n "
1613- << " easycache_option: \" " << easycache_option << " \" ,\n "
1533+ << " cache_mode: \" " << cache_mode << " \" ,\n "
1534+ << " cache_option: \" " << cache_option << " \" ,\n "
16141535 << " easycache: "
16151536 << (easycache_params.enabled ? " enabled" : " disabled" )
16161537 << " (threshold=" << easycache_params.reuse_threshold
16171538 << " , start=" << easycache_params.start_percent
16181539 << " , end=" << easycache_params.end_percent << " ),\n "
1540+ << " ucache: "
1541+ << (ucache_params.enabled ? " enabled" : " disabled" )
1542+ << " (threshold=" << ucache_params.reuse_threshold
1543+ << " , start=" << ucache_params.start_percent
1544+ << " , end=" << ucache_params.end_percent << " ),\n "
16191545 << " moe_boundary: " << moe_boundary << " ,\n "
16201546 << " video_frames: " << video_frames << " ,\n "
16211547 << " fps: " << fps << " ,\n "
0 commit comments