Skip to content

Commit 4fb8901

Browse files
committed
Fix lora loading when using multiple clip backends
1 parent 2e07c95 commit 4fb8901

2 files changed

Lines changed: 196 additions & 43 deletions

File tree

src/conditioner.hpp

Lines changed: 162 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ struct ConditionerParams {
3535
};
3636

3737
struct Conditioner {
38+
int model_count = 1;
3839
virtual SDCondition get_learned_condition(ggml_context* work_ctx,
3940
int n_threads,
4041
const ConditionerParams& conditioner_params) = 0;
@@ -53,6 +54,11 @@ struct Conditioner {
5354
const std::string& prompt) {
5455
GGML_ABORT("Not implemented yet!");
5556
}
57+
virtual bool is_cond_stage_model_name_at_index(const std::string& name, int index) {
58+
return true;
59+
}
60+
virtual ggml_backend_t get_params_backend_at_index(int index) = 0;
61+
virtual ggml_backend_t get_runtime_backend_at_index(int index) = 0;
5662
};
5763

5864
// ldm.modules.encoders.modules.FrozenCLIPEmbedder
@@ -95,8 +101,9 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
95101
LOG_INFO("CLIP-H: using %s backend", ggml_backend_name(clip_backend));
96102
text_model = std::make_shared<CLIPTextModelRunner>(clip_backend, offload_params_to_cpu, tensor_storage_map, "cond_stage_model.transformer.text_model", OPEN_CLIP_VIT_H_14, true, force_clip_f32);
97103
} else if (sd_version_is_sdxl(version)) {
104+
model_count = 2;
98105
ggml_backend_t clip_g_backend = clip_backend;
99-
if (backends.size() >= 2){
106+
if (backends.size() >= 2) {
100107
clip_g_backend = backends[1];
101108
if (backends.size() > 2) {
102109
LOG_WARN("More than 2 clip backends provided, but the model only supports 2 text encoders. Ignoring the rest.");
@@ -665,6 +672,42 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
665672
conditioner_params.adm_in_channels,
666673
conditioner_params.zero_out_masked);
667674
}
675+
676+
bool is_cond_stage_model_name_at_index(const std::string& name, int index) override {
677+
if (sd_version_is_sdxl(version)) {
678+
if (index == 0) {
679+
return contains(name, "cond_stage_model.model.transformer");
680+
} else if (index == 1) {
681+
return contains(name, "cond_stage_model.model.1");
682+
} else {
683+
return false;
684+
}
685+
}
686+
return true;
687+
}
688+
689+
ggml_backend_t get_params_backend_at_index(int index){
690+
if (sd_version_is_sdxl(version) && index == 1){
691+
if(text_model2) {
692+
return text_model2->get_params_backend();
693+
}
694+
} else if (text_model) {
695+
return text_model->get_params_backend();
696+
}
697+
return nullptr;
698+
}
699+
700+
ggml_backend_t get_runtime_backend_at_index(int index){
701+
if (sd_version_is_sdxl(version) && index == 1){
702+
if(text_model2) {
703+
return text_model2->get_runtime_backend();
704+
}
705+
} else if (text_model) {
706+
return text_model->get_runtime_backend();
707+
}
708+
return nullptr;
709+
}
710+
668711
};
669712

670713
struct FrozenCLIPVisionEmbedder : public GGMLRunner {
@@ -740,12 +783,14 @@ struct SD3CLIPEmbedder : public Conditioner {
740783
bool use_clip_g = false;
741784
bool use_t5 = false;
742785

786+
model_count = 3;
787+
743788
ggml_backend_t clip_l_backend, clip_g_backend, t5_backend;
744789
if (backends.size() == 1) {
745790
clip_l_backend = clip_g_backend = t5_backend = backends[0];
746791
} else if (backends.size() == 2) {
747792
clip_l_backend = clip_g_backend = backends[0];
748-
t5_backend = backends[1];
793+
t5_backend = backends[1];
749794
} else if (backends.size() >= 3) {
750795
clip_l_backend = backends[0];
751796
clip_g_backend = backends[1];
@@ -1175,6 +1220,42 @@ struct SD3CLIPEmbedder : public Conditioner {
11751220
conditioner_params.clip_skip,
11761221
conditioner_params.zero_out_masked);
11771222
}
1223+
1224+
bool is_cond_stage_model_name_at_index(const std::string& name, int index) override {
1225+
if (index == 0) {
1226+
return contains(name, "text_encoders.clip_l");
1227+
} else if (index == 1) {
1228+
return contains(name, "text_encoders.clip_g");
1229+
} else if (index == 2) {
1230+
return contains(name, "text_encoders.t5xxl");
1231+
} else {
1232+
return false;
1233+
}
1234+
}
1235+
1236+
ggml_backend_t get_params_backend_at_index(int index){
1237+
if (index == 0 && clip_l) {
1238+
return clip_l->get_params_backend();
1239+
} else if (index == 1 && clip_g) {
1240+
return clip_g->get_params_backend();
1241+
} else if (index == 2 && t5) {
1242+
return t5->get_params_backend();
1243+
} else {
1244+
return nullptr;
1245+
}
1246+
}
1247+
1248+
ggml_backend_t get_runtime_backend_at_index(int index){
1249+
if (index == 0 && clip_l) {
1250+
return clip_l->get_runtime_backend();
1251+
} else if (index == 1 && clip_g) {
1252+
return clip_g->get_runtime_backend();
1253+
} else if (index == 2 && t5) {
1254+
return t5->get_runtime_backend();
1255+
} else {
1256+
return nullptr;
1257+
}
1258+
}
11781259
};
11791260

11801261
struct FluxCLIPEmbedder : public Conditioner {
@@ -1190,19 +1271,19 @@ struct FluxCLIPEmbedder : public Conditioner {
11901271
bool use_clip_l = false;
11911272
bool use_t5 = false;
11921273

1274+
model_count = 2;
11931275

11941276
ggml_backend_t clip_l_backend, t5_backend;
11951277
if (backends.size() == 1) {
11961278
clip_l_backend = t5_backend = backends[0];
11971279
} else if (backends.size() >= 2) {
11981280
clip_l_backend = backends[0];
1199-
t5_backend = backends[1];
1281+
t5_backend = backends[1];
12001282
if (backends.size() > 2) {
12011283
LOG_WARN("More than 2 clip backends provided, but the model only supports 2 text encoders. Ignoring the rest.");
12021284
}
12031285
}
12041286

1205-
12061287
for (auto pair : tensor_storage_map) {
12071288
if (pair.first.find("text_encoders.clip_l") != std::string::npos) {
12081289
use_clip_l = true;
@@ -1468,6 +1549,36 @@ struct FluxCLIPEmbedder : public Conditioner {
14681549
conditioner_params.clip_skip,
14691550
conditioner_params.zero_out_masked);
14701551
}
1552+
1553+
bool is_cond_stage_model_name_at_index(const std::string& name, int index) override {
1554+
if (index == 0) {
1555+
return contains(name, "text_encoders.clip_l");
1556+
} else if (index == 1) {
1557+
return contains(name, "text_encoders.t5xxl");
1558+
} else {
1559+
return false;
1560+
}
1561+
}
1562+
1563+
ggml_backend_t get_params_backend_at_index(int index){
1564+
if (index == 0 && clip_l) {
1565+
return clip_l->get_params_backend();
1566+
} else if (index == 1 && t5) {
1567+
return t5->get_params_backend();
1568+
} else {
1569+
return nullptr;
1570+
}
1571+
}
1572+
1573+
ggml_backend_t get_runtime_backend_at_index(int index){
1574+
if (index == 0 && clip_l) {
1575+
return clip_l->get_runtime_backend();
1576+
} else if (index == 1 && t5) {
1577+
return t5->get_runtime_backend();
1578+
} else {
1579+
return nullptr;
1580+
}
1581+
}
14711582
};
14721583

14731584
struct T5CLIPEmbedder : public Conditioner {
@@ -1691,6 +1802,20 @@ struct T5CLIPEmbedder : public Conditioner {
16911802
conditioner_params.clip_skip,
16921803
conditioner_params.zero_out_masked);
16931804
}
1805+
1806+
ggml_backend_t get_params_backend_at_index(int index){
1807+
if (t5){
1808+
return t5->get_params_backend();
1809+
}
1810+
return nullptr;
1811+
}
1812+
1813+
ggml_backend_t get_runtime_backend_at_index(int index){
1814+
if (t5){
1815+
return t5->get_runtime_backend();
1816+
}
1817+
return nullptr;
1818+
}
16941819
};
16951820

16961821
struct AnimaConditioner : public Conditioner {
@@ -1703,11 +1828,11 @@ struct AnimaConditioner : public Conditioner {
17031828
const String2TensorStorage& tensor_storage_map = {}) {
17041829
qwen_tokenizer = std::make_shared<LLM::Qwen2Tokenizer>();
17051830
llm = std::make_shared<LLM::LLMRunner>(LLM::LLMArch::QWEN3,
1706-
backend,
1707-
offload_params_to_cpu,
1708-
tensor_storage_map,
1709-
"text_encoders.llm",
1710-
false);
1831+
backend,
1832+
offload_params_to_cpu,
1833+
tensor_storage_map,
1834+
"text_encoders.llm",
1835+
false);
17111836
}
17121837

17131838
void get_param_tensors(std::map<std::string, ggml_tensor*>& tensors) override {
@@ -1827,6 +1952,20 @@ struct AnimaConditioner : public Conditioner {
18271952

18281953
return {hidden_states, t5_weight_tensor, t5_ids_tensor};
18291954
}
1955+
1956+
ggml_backend_t get_params_backend_at_index(int index){
1957+
if (llm){
1958+
return llm->get_params_backend();
1959+
}
1960+
return nullptr;
1961+
}
1962+
1963+
ggml_backend_t get_runtime_backend_at_index(int index){
1964+
if (llm){
1965+
return llm->get_runtime_backend();
1966+
}
1967+
return nullptr;
1968+
}
18301969
};
18311970

18321971
struct LLMEmbedder : public Conditioner {
@@ -2201,6 +2340,20 @@ struct LLMEmbedder : public Conditioner {
22012340
LOG_DEBUG("computing condition graph completed, taking %" PRId64 " ms", t1 - t0);
22022341
return {hidden_states, nullptr, nullptr, extra_hidden_states_vec};
22032342
}
2343+
2344+
ggml_backend_t get_params_backend_at_index(int index){
2345+
if (llm){
2346+
return llm->get_params_backend();
2347+
}
2348+
return nullptr;
2349+
}
2350+
2351+
ggml_backend_t get_runtime_backend_at_index(int index){
2352+
if (llm){
2353+
return llm->get_runtime_backend();
2354+
}
2355+
return nullptr;
2356+
}
22042357
};
22052358

22062359
#endif

src/stable-diffusion.cpp

Lines changed: 34 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1637,14 +1637,6 @@ class StableDiffusionGGML {
16371637
for (auto& kv : lora_state_diff) {
16381638
bool applied = false;
16391639
int64_t t0 = ggml_time_ms();
1640-
// TODO: Fix that
1641-
bool are_clip_backends_similar = true;
1642-
for (auto backend: clip_backends){
1643-
are_clip_backends_similar = are_clip_backends_similar && (clip_backends[0]==backend || ggml_backend_is_cpu(backend));
1644-
}
1645-
if(!are_clip_backends_similar){
1646-
LOG_WARN("Text encoders are running on different backends. This may cause issues when immediately applying LoRAs.");
1647-
}
16481640
auto lora_tensor_filter_diff = [&](const std::string& tensor_name) {
16491641
if (is_diffusion_model_name(tensor_name)) {
16501642
return true;
@@ -1660,19 +1652,22 @@ class StableDiffusionGGML {
16601652
applied = true;
16611653
}
16621654

1663-
auto lora_tensor_filter_cond = [&](const std::string& tensor_name) {
1664-
if (is_cond_stage_model_name(tensor_name)) {
1665-
return true;
1655+
for (int i = 0; i < cond_stage_model->model_count; i++) {
1656+
auto lora_tensor_filter_cond = [&](const std::string& tensor_name) {
1657+
if (is_cond_stage_model_name(tensor_name)) {
1658+
return cond_stage_model->is_cond_stage_model_name_at_index(tensor_name, i);
1659+
}
1660+
return false;
1661+
};
1662+
// TODO: split by model
1663+
LOG_INFO("applying lora to text encoder (%d)", i);
1664+
auto backend = cond_stage_model->get_params_backend_at_index(i);
1665+
lora = load_lora_model_from_file(kv.first, kv.second, backend, lora_tensor_filter_cond);
1666+
if (lora && !lora->lora_tensors.empty()) {
1667+
lora->apply(tensors, version, n_threads);
1668+
lora->free_params_buffer();
1669+
applied = true;
16661670
}
1667-
return false;
1668-
};
1669-
// TODO: split by model
1670-
LOG_INFO("applying lora to text encoders");
1671-
lora = load_lora_model_from_file(kv.first, kv.second, clip_backends[0], lora_tensor_filter_cond);
1672-
if (lora && !lora->lora_tensors.empty()) {
1673-
lora->apply(tensors, version, n_threads);
1674-
lora->free_params_buffer();
1675-
applied = true;
16761671
}
16771672

16781673
auto lora_tensor_filter_first = [&](const std::string& tensor_name) {
@@ -1734,22 +1729,27 @@ class StableDiffusionGGML {
17341729
}
17351730
}
17361731
cond_stage_lora_models = lora_models;
1737-
auto lora_tensor_filter = [&](const std::string& tensor_name) {
1738-
if (is_cond_stage_model_name(tensor_name)) {
1739-
return true;
1740-
}
1741-
return false;
1742-
};
1743-
for (auto& kv : lora_state_diff) {
1744-
const std::string& lora_id = kv.first;
1745-
float multiplier = kv.second;
1746-
//TODO: split by model
1747-
auto lora = load_lora_model_from_file(lora_id, multiplier, clip_backends[0], lora_tensor_filter);
1748-
if (lora && !lora->lora_tensors.empty()) {
1749-
lora->preprocess_lora_tensors(tensors);
1750-
cond_stage_lora_models.push_back(lora);
1732+
1733+
1734+
for(int i=0;i<cond_stage_model->model_count;i++){
1735+
auto lora_tensor_filter_cond = [&](const std::string& tensor_name) {
1736+
if (is_cond_stage_model_name(tensor_name)) {
1737+
return cond_stage_model->is_cond_stage_model_name_at_index(tensor_name, i);
1738+
}
1739+
return false;
1740+
};
1741+
for (auto& kv : lora_state_diff) {
1742+
const std::string& lora_id = kv.first;
1743+
float multiplier = kv.second;
1744+
auto backend = cond_stage_model->get_runtime_backend_at_index(i);
1745+
auto lora = load_lora_model_from_file(kv.first, kv.second, backend, lora_tensor_filter_cond);
1746+
if (lora && !lora->lora_tensors.empty()) {
1747+
lora->preprocess_lora_tensors(tensors);
1748+
cond_stage_lora_models.push_back(lora);
1749+
}
17511750
}
17521751
}
1752+
17531753
auto multi_lora_adapter = std::make_shared<MultiLoraAdapter>(cond_stage_lora_models);
17541754
cond_stage_model->set_weight_adapter(multi_lora_adapter);
17551755
}

0 commit comments

Comments
 (0)