Skip to content

Commit d939f6e

Browse files
authored
refactor: optimize the handling of LoRA models (#1070)
1 parent e72aea7 commit d939f6e

File tree

6 files changed

+169
-98
lines changed

6 files changed

+169
-98
lines changed

clip.hpp

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -7,31 +7,6 @@
77

88
/*================================================== CLIPTokenizer ===================================================*/
99

10-
__STATIC_INLINE__ std::pair<std::unordered_map<std::string, float>, std::string> extract_and_remove_lora(std::string text) {
11-
std::regex re("<lora:([^:]+):([^>]+)>");
12-
std::smatch matches;
13-
std::unordered_map<std::string, float> filename2multiplier;
14-
15-
while (std::regex_search(text, matches, re)) {
16-
std::string filename = matches[1].str();
17-
float multiplier = std::stof(matches[2].str());
18-
19-
text = std::regex_replace(text, re, "", std::regex_constants::format_first_only);
20-
21-
if (multiplier == 0.f) {
22-
continue;
23-
}
24-
25-
if (filename2multiplier.find(filename) == filename2multiplier.end()) {
26-
filename2multiplier[filename] = multiplier;
27-
} else {
28-
filename2multiplier[filename] += multiplier;
29-
}
30-
}
31-
32-
return std::make_pair(filename2multiplier, text);
33-
}
34-
3510
__STATIC_INLINE__ std::vector<std::pair<int, std::u32string>> bytes_to_unicode() {
3611
std::vector<std::pair<int, std::u32string>> byte_unicode_pairs;
3712
std::set<int> byte_set;

examples/cli/main.cpp

Lines changed: 138 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -507,7 +507,7 @@ struct SDContextParams {
507507
std::string lora_model_dir;
508508

509509
std::map<std::string, std::string> embedding_map;
510-
std::vector<sd_embedding_t> embedding_array;
510+
std::vector<sd_embedding_t> embedding_vec;
511511

512512
rng_type_t rng_type = CUDA_RNG;
513513
rng_type_t sampler_rng_type = RNG_TYPE_COUNT;
@@ -952,13 +952,13 @@ struct SDContextParams {
952952
}
953953

954954
sd_ctx_params_t to_sd_ctx_params_t(bool vae_decode_only, bool free_params_immediately, bool taesd_preview) {
955-
embedding_array.clear();
956-
embedding_array.reserve(embedding_map.size());
955+
embedding_vec.clear();
956+
embedding_vec.reserve(embedding_map.size());
957957
for (const auto& kv : embedding_map) {
958958
sd_embedding_t item;
959959
item.name = kv.first.c_str();
960960
item.path = kv.second.c_str();
961-
embedding_array.emplace_back(item);
961+
embedding_vec.emplace_back(item);
962962
}
963963

964964
sd_ctx_params_t sd_ctx_params = {
@@ -975,8 +975,8 @@ struct SDContextParams {
975975
taesd_path.c_str(),
976976
control_net_path.c_str(),
977977
lora_model_dir.c_str(),
978-
embedding_array.data(),
979-
static_cast<uint32_t>(embedding_array.size()),
978+
embedding_vec.data(),
979+
static_cast<uint32_t>(embedding_vec.size()),
980980
photo_maker_path.c_str(),
981981
tensor_type_rules.c_str(),
982982
vae_decode_only,
@@ -1030,6 +1030,15 @@ static std::string vec_str_to_string(const std::vector<std::string>& v) {
10301030
return oss.str();
10311031
}
10321032

1033+
static bool is_absolute_path(const std::string& p) {
1034+
#ifdef _WIN32
1035+
// Windows: C:/path or C:\path
1036+
return p.size() > 1 && std::isalpha(static_cast<unsigned char>(p[0])) && p[1] == ':';
1037+
#else
1038+
return !p.empty() && p[0] == '/';
1039+
#endif
1040+
}
1041+
10331042
struct SDGenerationParams {
10341043
std::string prompt;
10351044
std::string negative_prompt;
@@ -1072,6 +1081,10 @@ struct SDGenerationParams {
10721081

10731082
int upscale_repeats = 1;
10741083

1084+
std::map<std::string, float> lora_map;
1085+
std::map<std::string, float> high_noise_lora_map;
1086+
std::vector<sd_lora_t> lora_vec;
1087+
10751088
SDGenerationParams() {
10761089
sd_sample_params_init(&sample_params);
10771090
sd_sample_params_init(&high_noise_sample_params);
@@ -1442,7 +1455,88 @@ struct SDGenerationParams {
14421455
return options;
14431456
}
14441457

1445-
bool process_and_check(SDMode mode) {
1458+
void extract_and_remove_lora(const std::string& lora_model_dir) {
1459+
static const std::regex re(R"(<lora:([^:>]+):([^>]+)>)");
1460+
static const std::vector<std::string> valid_ext = {".pt", ".safetensors", ".gguf"};
1461+
std::smatch m;
1462+
1463+
std::string tmp = prompt;
1464+
1465+
while (std::regex_search(tmp, m, re)) {
1466+
std::string raw_path = m[1].str();
1467+
const std::string raw_mul = m[2].str();
1468+
1469+
float mul = 0.f;
1470+
try {
1471+
mul = std::stof(raw_mul);
1472+
} catch (...) {
1473+
tmp = m.suffix().str();
1474+
prompt = std::regex_replace(prompt, re, "", std::regex_constants::format_first_only);
1475+
continue;
1476+
}
1477+
1478+
bool is_high_noise = false;
1479+
static const std::string prefix = "|high_noise|";
1480+
if (raw_path.rfind(prefix, 0) == 0) {
1481+
raw_path.erase(0, prefix.size());
1482+
is_high_noise = true;
1483+
}
1484+
1485+
fs::path final_path;
1486+
if (is_absolute_path(raw_path)) {
1487+
final_path = raw_path;
1488+
} else {
1489+
final_path = fs::path(lora_model_dir) / raw_path;
1490+
}
1491+
if (!fs::exists(final_path)) {
1492+
bool found = false;
1493+
for (const auto& ext : valid_ext) {
1494+
fs::path try_path = final_path;
1495+
try_path += ext;
1496+
if (fs::exists(try_path)) {
1497+
final_path = try_path;
1498+
found = true;
1499+
break;
1500+
}
1501+
}
1502+
if (!found) {
1503+
printf("can not found lora %s\n", final_path.lexically_normal().string().c_str());
1504+
tmp = m.suffix().str();
1505+
prompt = std::regex_replace(prompt, re, "", std::regex_constants::format_first_only);
1506+
continue;
1507+
}
1508+
}
1509+
1510+
const std::string key = final_path.lexically_normal().string();
1511+
1512+
if (is_high_noise)
1513+
high_noise_lora_map[key] += mul;
1514+
else
1515+
lora_map[key] += mul;
1516+
1517+
prompt = std::regex_replace(prompt, re, "", std::regex_constants::format_first_only);
1518+
1519+
tmp = m.suffix().str();
1520+
}
1521+
1522+
for (const auto& kv : lora_map) {
1523+
sd_lora_t item;
1524+
item.is_high_noise = false;
1525+
item.path = kv.first.c_str();
1526+
item.multiplier = kv.second;
1527+
lora_vec.emplace_back(item);
1528+
}
1529+
1530+
for (const auto& kv : high_noise_lora_map) {
1531+
sd_lora_t item;
1532+
item.is_high_noise = true;
1533+
item.path = kv.first.c_str();
1534+
item.multiplier = kv.second;
1535+
lora_vec.emplace_back(item);
1536+
}
1537+
}
1538+
1539+
bool process_and_check(SDMode mode, const std::string& lora_model_dir) {
14461540
if (width <= 0) {
14471541
fprintf(stderr, "error: the width must be greater than 0\n");
14481542
return false;
@@ -1553,14 +1647,44 @@ struct SDGenerationParams {
15531647
seed = rand();
15541648
}
15551649

1650+
extract_and_remove_lora(lora_model_dir);
1651+
15561652
return true;
15571653
}
15581654

15591655
std::string to_string() const {
15601656
char* sample_params_str = sd_sample_params_to_str(&sample_params);
15611657
char* high_noise_sample_params_str = sd_sample_params_to_str(&high_noise_sample_params);
1658+
1659+
std::ostringstream lora_ss;
1660+
lora_ss << "{\n";
1661+
for (auto it = lora_map.begin(); it != lora_map.end(); ++it) {
1662+
lora_ss << " \"" << it->first << "\": \"" << it->second << "\"";
1663+
if (std::next(it) != lora_map.end()) {
1664+
lora_ss << ",";
1665+
}
1666+
lora_ss << "\n";
1667+
}
1668+
lora_ss << " }";
1669+
std::string loras_str = lora_ss.str();
1670+
1671+
lora_ss = std::ostringstream();
1672+
;
1673+
lora_ss << "{\n";
1674+
for (auto it = high_noise_lora_map.begin(); it != high_noise_lora_map.end(); ++it) {
1675+
lora_ss << " \"" << it->first << "\": \"" << it->second << "\"";
1676+
if (std::next(it) != high_noise_lora_map.end()) {
1677+
lora_ss << ",";
1678+
}
1679+
lora_ss << "\n";
1680+
}
1681+
lora_ss << " }";
1682+
std::string high_noise_loras_str = lora_ss.str();
1683+
15621684
std::ostringstream oss;
15631685
oss << "SDGenerationParams {\n"
1686+
<< " loras: \"" << loras_str << "\",\n"
1687+
<< " high_noise_loras: \"" << high_noise_loras_str << "\",\n"
15641688
<< " prompt: \"" << prompt << "\",\n"
15651689
<< " negative_prompt: \"" << negative_prompt << "\",\n"
15661690
<< " clip_skip: " << clip_skip << ",\n"
@@ -1626,7 +1750,9 @@ void parse_args(int argc, const char** argv, SDCliParams& cli_params, SDContextP
16261750
exit(cli_params.normal_exit ? 0 : 1);
16271751
}
16281752

1629-
if (!cli_params.process_and_check() || !ctx_params.process_and_check(cli_params.mode) || !gen_params.process_and_check(cli_params.mode)) {
1753+
if (!cli_params.process_and_check() ||
1754+
!ctx_params.process_and_check(cli_params.mode) ||
1755+
!gen_params.process_and_check(cli_params.mode, ctx_params.lora_model_dir)) {
16301756
print_usage(argc, argv, options_vec);
16311757
exit(1);
16321758
}
@@ -2139,6 +2265,8 @@ int main(int argc, const char* argv[]) {
21392265

21402266
if (cli_params.mode == IMG_GEN) {
21412267
sd_img_gen_params_t img_gen_params = {
2268+
gen_params.lora_vec.data(),
2269+
static_cast<uint32_t>(gen_params.lora_vec.size()),
21422270
gen_params.prompt.c_str(),
21432271
gen_params.negative_prompt.c_str(),
21442272
gen_params.clip_skip,
@@ -2170,6 +2298,8 @@ int main(int argc, const char* argv[]) {
21702298
num_results = gen_params.batch_count;
21712299
} else if (cli_params.mode == VID_GEN) {
21722300
sd_vid_gen_params_t vid_gen_params = {
2301+
gen_params.lora_vec.data(),
2302+
static_cast<uint32_t>(gen_params.lora_vec.size()),
21732303
gen_params.prompt.c_str(),
21742304
gen_params.negative_prompt.c_str(),
21752305
gen_params.clip_skip,

stable-diffusion.cpp

Lines changed: 21 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -937,28 +937,17 @@ class StableDiffusionGGML {
937937
float multiplier,
938938
ggml_backend_t backend,
939939
LoraModel::filter_t lora_tensor_filter = nullptr) {
940-
std::string lora_name = lora_id;
941-
std::string high_noise_tag = "|high_noise|";
942-
bool is_high_noise = false;
943-
if (starts_with(lora_name, high_noise_tag)) {
944-
lora_name = lora_name.substr(high_noise_tag.size());
940+
std::string lora_path = lora_id;
941+
static std::string high_noise_tag = "|high_noise|";
942+
bool is_high_noise = false;
943+
if (starts_with(lora_path, high_noise_tag)) {
944+
lora_path = lora_path.substr(high_noise_tag.size());
945945
is_high_noise = true;
946-
LOG_DEBUG("high noise lora: %s", lora_name.c_str());
947-
}
948-
std::string st_file_path = path_join(lora_model_dir, lora_name + ".safetensors");
949-
std::string ckpt_file_path = path_join(lora_model_dir, lora_name + ".ckpt");
950-
std::string file_path;
951-
if (file_exists(st_file_path)) {
952-
file_path = st_file_path;
953-
} else if (file_exists(ckpt_file_path)) {
954-
file_path = ckpt_file_path;
955-
} else {
956-
LOG_WARN("can not find %s or %s for lora %s", st_file_path.c_str(), ckpt_file_path.c_str(), lora_name.c_str());
957-
return nullptr;
946+
LOG_DEBUG("high noise lora: %s", lora_path.c_str());
958947
}
959-
auto lora = std::make_shared<LoraModel>(lora_id, backend, file_path, is_high_noise ? "model.high_noise_" : "", version);
948+
auto lora = std::make_shared<LoraModel>(lora_id, backend, lora_path, is_high_noise ? "model.high_noise_" : "", version);
960949
if (!lora->load_from_file(n_threads, lora_tensor_filter)) {
961-
LOG_WARN("load lora tensors from %s failed", file_path.c_str());
950+
LOG_WARN("load lora tensors from %s failed", lora_path.c_str());
962951
return nullptr;
963952
}
964953

@@ -1143,12 +1132,15 @@ class StableDiffusionGGML {
11431132
}
11441133
}
11451134

1146-
std::string apply_loras_from_prompt(const std::string& prompt) {
1147-
auto result_pair = extract_and_remove_lora(prompt);
1148-
std::unordered_map<std::string, float> lora_f2m = result_pair.first; // lora_name -> multiplier
1149-
1150-
for (auto& kv : lora_f2m) {
1151-
LOG_DEBUG("lora %s:%.2f", kv.first.c_str(), kv.second);
1135+
void apply_loras(const sd_lora_t* loras, uint32_t lora_count) {
1136+
std::unordered_map<std::string, float> lora_f2m;
1137+
for (int i = 0; i < lora_count; i++) {
1138+
std::string lora_id = SAFE_STR(loras[i].path);
1139+
if (loras[i].is_high_noise) {
1140+
lora_id = "|high_noise|" + lora_id;
1141+
}
1142+
lora_f2m[lora_id] = loras[i].multiplier;
1143+
LOG_DEBUG("lora %s:%.2f", lora_id.c_str(), loras[i].multiplier);
11521144
}
11531145
int64_t t0 = ggml_time_ms();
11541146
if (apply_lora_immediately) {
@@ -1159,9 +1151,7 @@ class StableDiffusionGGML {
11591151
int64_t t1 = ggml_time_ms();
11601152
if (!lora_f2m.empty()) {
11611153
LOG_INFO("apply_loras completed, taking %.2fs", (t1 - t0) * 1.0f / 1000);
1162-
LOG_DEBUG("prompt after extract and remove lora: \"%s\"", result_pair.second.c_str());
11631154
}
1164-
return result_pair.second;
11651155
}
11661156

11671157
ggml_tensor* id_encoder(ggml_context* work_ctx,
@@ -2815,8 +2805,6 @@ sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx,
28152805
int sample_steps = sigmas.size() - 1;
28162806

28172807
int64_t t0 = ggml_time_ms();
2818-
// Apply lora
2819-
prompt = sd_ctx->sd->apply_loras_from_prompt(prompt);
28202808

28212809
// Photo Maker
28222810
std::string prompt_text_only;
@@ -3188,6 +3176,9 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_g
31883176

31893177
size_t t0 = ggml_time_ms();
31903178

3179+
// Apply lora
3180+
sd_ctx->sd->apply_loras(sd_img_gen_params->loras, sd_img_gen_params->lora_count);
3181+
31913182
enum sample_method_t sample_method = sd_img_gen_params->sample_params.sample_method;
31923183
if (sample_method == SAMPLE_METHOD_COUNT) {
31933184
sample_method = sd_get_default_sample_method(sd_ctx);
@@ -3487,7 +3478,7 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
34873478
int64_t t0 = ggml_time_ms();
34883479

34893480
// Apply lora
3490-
prompt = sd_ctx->sd->apply_loras_from_prompt(prompt);
3481+
sd_ctx->sd->apply_loras(sd_vid_gen_params->loras, sd_vid_gen_params->lora_count);
34913482

34923483
ggml_tensor* init_latent = nullptr;
34933484
ggml_tensor* clip_vision_output = nullptr;

stable-diffusion.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,14 @@ typedef struct {
242242
} sd_easycache_params_t;
243243

244244
typedef struct {
245+
bool is_high_noise;
246+
float multiplier;
247+
const char* path;
248+
} sd_lora_t;
249+
250+
typedef struct {
251+
const sd_lora_t* loras;
252+
uint32_t lora_count;
245253
const char* prompt;
246254
const char* negative_prompt;
247255
int clip_skip;
@@ -265,6 +273,8 @@ typedef struct {
265273
} sd_img_gen_params_t;
266274

267275
typedef struct {
276+
const sd_lora_t* loras;
277+
uint32_t lora_count;
268278
const char* prompt;
269279
const char* negative_prompt;
270280
int clip_skip;

0 commit comments

Comments
 (0)