@@ -507,7 +507,7 @@ struct SDContextParams {
507507 std::string lora_model_dir;
508508
509509 std::map<std::string, std::string> embedding_map;
510- std::vector<sd_embedding_t > embedding_array ;
510+ std::vector<sd_embedding_t > embedding_vec ;
511511
512512 rng_type_t rng_type = CUDA_RNG;
513513 rng_type_t sampler_rng_type = RNG_TYPE_COUNT;
@@ -952,13 +952,13 @@ struct SDContextParams {
952952 }
953953
954954 sd_ctx_params_t to_sd_ctx_params_t (bool vae_decode_only, bool free_params_immediately, bool taesd_preview) {
955- embedding_array .clear ();
956- embedding_array .reserve (embedding_map.size ());
955+ embedding_vec .clear ();
956+ embedding_vec .reserve (embedding_map.size ());
957957 for (const auto & kv : embedding_map) {
958958 sd_embedding_t item;
959959 item.name = kv.first .c_str ();
960960 item.path = kv.second .c_str ();
961- embedding_array .emplace_back (item);
961+ embedding_vec .emplace_back (item);
962962 }
963963
964964 sd_ctx_params_t sd_ctx_params = {
@@ -975,8 +975,8 @@ struct SDContextParams {
975975 taesd_path.c_str (),
976976 control_net_path.c_str (),
977977 lora_model_dir.c_str (),
978- embedding_array .data (),
979- static_cast <uint32_t >(embedding_array .size ()),
978+ embedding_vec .data (),
979+ static_cast <uint32_t >(embedding_vec .size ()),
980980 photo_maker_path.c_str (),
981981 tensor_type_rules.c_str (),
982982 vae_decode_only,
@@ -1030,6 +1030,15 @@ static std::string vec_str_to_string(const std::vector<std::string>& v) {
10301030 return oss.str ();
10311031}
10321032
1033+ static bool is_absolute_path (const std::string& p) {
1034+ #ifdef _WIN32
1035+ // Windows: C:/path or C:\path
1036+ return p.size () > 1 && std::isalpha (static_cast <unsigned char >(p[0 ])) && p[1 ] == ' :' ;
1037+ #else
1038+ return !p.empty () && p[0 ] == ' /' ;
1039+ #endif
1040+ }
1041+
10331042struct SDGenerationParams {
10341043 std::string prompt;
10351044 std::string negative_prompt;
@@ -1072,6 +1081,10 @@ struct SDGenerationParams {
10721081
10731082 int upscale_repeats = 1 ;
10741083
1084+ std::map<std::string, float > lora_map;
1085+ std::map<std::string, float > high_noise_lora_map;
1086+ std::vector<sd_lora_t > lora_vec;
1087+
10751088 SDGenerationParams () {
10761089 sd_sample_params_init (&sample_params);
10771090 sd_sample_params_init (&high_noise_sample_params);
@@ -1442,7 +1455,88 @@ struct SDGenerationParams {
14421455 return options;
14431456 }
14441457
1445- bool process_and_check (SDMode mode) {
1458+ void extract_and_remove_lora (const std::string& lora_model_dir) {
1459+ static const std::regex re (R"( <lora:([^:>]+):([^>]+)>)" );
1460+ static const std::vector<std::string> valid_ext = {" .pt" , " .safetensors" , " .gguf" };
1461+ std::smatch m;
1462+
1463+ std::string tmp = prompt;
1464+
1465+ while (std::regex_search (tmp, m, re)) {
1466+ std::string raw_path = m[1 ].str ();
1467+ const std::string raw_mul = m[2 ].str ();
1468+
1469+ float mul = 0 .f ;
1470+ try {
1471+ mul = std::stof (raw_mul);
1472+ } catch (...) {
1473+ tmp = m.suffix ().str ();
1474+ prompt = std::regex_replace (prompt, re, " " , std::regex_constants::format_first_only);
1475+ continue ;
1476+ }
1477+
1478+ bool is_high_noise = false ;
1479+ static const std::string prefix = " |high_noise|" ;
1480+ if (raw_path.rfind (prefix, 0 ) == 0 ) {
1481+ raw_path.erase (0 , prefix.size ());
1482+ is_high_noise = true ;
1483+ }
1484+
1485+ fs::path final_path;
1486+ if (is_absolute_path (raw_path)) {
1487+ final_path = raw_path;
1488+ } else {
1489+ final_path = fs::path (lora_model_dir) / raw_path;
1490+ }
1491+ if (!fs::exists (final_path)) {
1492+ bool found = false ;
1493+ for (const auto & ext : valid_ext) {
1494+ fs::path try_path = final_path;
1495+ try_path += ext;
1496+ if (fs::exists (try_path)) {
1497+ final_path = try_path;
1498+ found = true ;
1499+ break ;
1500+ }
1501+ }
1502+ if (!found) {
1503+ printf (" can not found lora %s\n " , final_path.lexically_normal ().string ().c_str ());
1504+ tmp = m.suffix ().str ();
1505+ prompt = std::regex_replace (prompt, re, " " , std::regex_constants::format_first_only);
1506+ continue ;
1507+ }
1508+ }
1509+
1510+ const std::string key = final_path.lexically_normal ().string ();
1511+
1512+ if (is_high_noise)
1513+ high_noise_lora_map[key] += mul;
1514+ else
1515+ lora_map[key] += mul;
1516+
1517+ prompt = std::regex_replace (prompt, re, " " , std::regex_constants::format_first_only);
1518+
1519+ tmp = m.suffix ().str ();
1520+ }
1521+
1522+ for (const auto & kv : lora_map) {
1523+ sd_lora_t item;
1524+ item.is_high_noise = false ;
1525+ item.path = kv.first .c_str ();
1526+ item.multiplier = kv.second ;
1527+ lora_vec.emplace_back (item);
1528+ }
1529+
1530+ for (const auto & kv : high_noise_lora_map) {
1531+ sd_lora_t item;
1532+ item.is_high_noise = true ;
1533+ item.path = kv.first .c_str ();
1534+ item.multiplier = kv.second ;
1535+ lora_vec.emplace_back (item);
1536+ }
1537+ }
1538+
1539+ bool process_and_check (SDMode mode, const std::string& lora_model_dir) {
14461540 if (width <= 0 ) {
14471541 fprintf (stderr, " error: the width must be greater than 0\n " );
14481542 return false ;
@@ -1553,14 +1647,44 @@ struct SDGenerationParams {
15531647 seed = rand ();
15541648 }
15551649
1650+ extract_and_remove_lora (lora_model_dir);
1651+
15561652 return true ;
15571653 }
15581654
15591655 std::string to_string () const {
15601656 char * sample_params_str = sd_sample_params_to_str (&sample_params);
15611657 char * high_noise_sample_params_str = sd_sample_params_to_str (&high_noise_sample_params);
1658+
1659+ std::ostringstream lora_ss;
1660+ lora_ss << " {\n " ;
1661+ for (auto it = lora_map.begin (); it != lora_map.end (); ++it) {
1662+ lora_ss << " \" " << it->first << " \" : \" " << it->second << " \" " ;
1663+ if (std::next (it) != lora_map.end ()) {
1664+ lora_ss << " ," ;
1665+ }
1666+ lora_ss << " \n " ;
1667+ }
1668+ lora_ss << " }" ;
1669+ std::string loras_str = lora_ss.str ();
1670+
1671+ lora_ss = std::ostringstream ();
1672+ ;
1673+ lora_ss << " {\n " ;
1674+ for (auto it = high_noise_lora_map.begin (); it != high_noise_lora_map.end (); ++it) {
1675+ lora_ss << " \" " << it->first << " \" : \" " << it->second << " \" " ;
1676+ if (std::next (it) != high_noise_lora_map.end ()) {
1677+ lora_ss << " ," ;
1678+ }
1679+ lora_ss << " \n " ;
1680+ }
1681+ lora_ss << " }" ;
1682+ std::string high_noise_loras_str = lora_ss.str ();
1683+
15621684 std::ostringstream oss;
15631685 oss << " SDGenerationParams {\n "
1686+ << " loras: \" " << loras_str << " \" ,\n "
1687+ << " high_noise_loras: \" " << high_noise_loras_str << " \" ,\n "
15641688 << " prompt: \" " << prompt << " \" ,\n "
15651689 << " negative_prompt: \" " << negative_prompt << " \" ,\n "
15661690 << " clip_skip: " << clip_skip << " ,\n "
@@ -1626,7 +1750,9 @@ void parse_args(int argc, const char** argv, SDCliParams& cli_params, SDContextP
16261750 exit (cli_params.normal_exit ? 0 : 1 );
16271751 }
16281752
1629- if (!cli_params.process_and_check () || !ctx_params.process_and_check (cli_params.mode ) || !gen_params.process_and_check (cli_params.mode )) {
1753+ if (!cli_params.process_and_check () ||
1754+ !ctx_params.process_and_check (cli_params.mode ) ||
1755+ !gen_params.process_and_check (cli_params.mode , ctx_params.lora_model_dir )) {
16301756 print_usage (argc, argv, options_vec);
16311757 exit (1 );
16321758 }
@@ -2139,6 +2265,8 @@ int main(int argc, const char* argv[]) {
21392265
21402266 if (cli_params.mode == IMG_GEN) {
21412267 sd_img_gen_params_t img_gen_params = {
2268+ gen_params.lora_vec .data (),
2269+ static_cast <uint32_t >(gen_params.lora_vec .size ()),
21422270 gen_params.prompt .c_str (),
21432271 gen_params.negative_prompt .c_str (),
21442272 gen_params.clip_skip ,
@@ -2170,6 +2298,8 @@ int main(int argc, const char* argv[]) {
21702298 num_results = gen_params.batch_count ;
21712299 } else if (cli_params.mode == VID_GEN) {
21722300 sd_vid_gen_params_t vid_gen_params = {
2301+ gen_params.lora_vec .data (),
2302+ static_cast <uint32_t >(gen_params.lora_vec .size ()),
21732303 gen_params.prompt .c_str (),
21742304 gen_params.negative_prompt .c_str (),
21752305 gen_params.clip_skip ,
0 commit comments