@@ -35,6 +35,7 @@ struct ConditionerParams {
3535};
3636
3737struct Conditioner {
38+ int model_count = 1 ;
3839 virtual SDCondition get_learned_condition (ggml_context* work_ctx,
3940 int n_threads,
4041 const ConditionerParams& conditioner_params) = 0;
@@ -53,6 +54,11 @@ struct Conditioner {
5354 const std::string& prompt) {
5455 GGML_ABORT (" Not implemented yet!" );
5556 }
57+ virtual bool is_cond_stage_model_name_at_index (const std::string& name, int index) {
58+ return true ;
59+ }
60+ virtual ggml_backend_t get_params_backend_at_index (int index) = 0;
61+ virtual ggml_backend_t get_runtime_backend_at_index (int index) = 0;
5662};
5763
5864// ldm.modules.encoders.modules.FrozenCLIPEmbedder
@@ -95,8 +101,9 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
95101 LOG_INFO (" CLIP-H: using %s backend" , ggml_backend_name (clip_backend));
96102 text_model = std::make_shared<CLIPTextModelRunner>(clip_backend, offload_params_to_cpu, tensor_storage_map, " cond_stage_model.transformer.text_model" , OPEN_CLIP_VIT_H_14, true , force_clip_f32);
97103 } else if (sd_version_is_sdxl (version)) {
104+ model_count = 2 ;
98105 ggml_backend_t clip_g_backend = clip_backend;
99- if (backends.size () >= 2 ){
106+ if (backends.size () >= 2 ) {
100107 clip_g_backend = backends[1 ];
101108 if (backends.size () > 2 ) {
102109 LOG_WARN (" More than 2 clip backends provided, but the model only supports 2 text encoders. Ignoring the rest." );
@@ -665,6 +672,42 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner {
665672 conditioner_params.adm_in_channels ,
666673 conditioner_params.zero_out_masked );
667674 }
675+
676+ bool is_cond_stage_model_name_at_index (const std::string& name, int index) override {
677+ if (sd_version_is_sdxl (version)) {
678+ if (index == 0 ) {
679+ return contains (name, " cond_stage_model.model.transformer" );
680+ } else if (index == 1 ) {
681+ return contains (name, " cond_stage_model.model.1" );
682+ } else {
683+ return false ;
684+ }
685+ }
686+ return true ;
687+ }
688+
689+ ggml_backend_t get_params_backend_at_index (int index){
690+ if (sd_version_is_sdxl (version) && index == 1 ){
691+ if (text_model2) {
692+ return text_model2->get_params_backend ();
693+ }
694+ } else if (text_model) {
695+ return text_model->get_params_backend ();
696+ }
697+ return nullptr ;
698+ }
699+
700+ ggml_backend_t get_runtime_backend_at_index (int index){
701+ if (sd_version_is_sdxl (version) && index == 1 ){
702+ if (text_model2) {
703+ return text_model2->get_runtime_backend ();
704+ }
705+ } else if (text_model) {
706+ return text_model->get_runtime_backend ();
707+ }
708+ return nullptr ;
709+ }
710+
668711};
669712
670713struct FrozenCLIPVisionEmbedder : public GGMLRunner {
@@ -740,12 +783,14 @@ struct SD3CLIPEmbedder : public Conditioner {
740783 bool use_clip_g = false ;
741784 bool use_t5 = false ;
742785
786+ model_count = 3 ;
787+
743788 ggml_backend_t clip_l_backend, clip_g_backend, t5_backend;
744789 if (backends.size () == 1 ) {
745790 clip_l_backend = clip_g_backend = t5_backend = backends[0 ];
746791 } else if (backends.size () == 2 ) {
747792 clip_l_backend = clip_g_backend = backends[0 ];
748- t5_backend = backends[1 ];
793+ t5_backend = backends[1 ];
749794 } else if (backends.size () >= 3 ) {
750795 clip_l_backend = backends[0 ];
751796 clip_g_backend = backends[1 ];
@@ -1175,6 +1220,42 @@ struct SD3CLIPEmbedder : public Conditioner {
11751220 conditioner_params.clip_skip ,
11761221 conditioner_params.zero_out_masked );
11771222 }
1223+
1224+ bool is_cond_stage_model_name_at_index (const std::string& name, int index) override {
1225+ if (index == 0 ) {
1226+ return contains (name, " text_encoders.clip_l" );
1227+ } else if (index == 1 ) {
1228+ return contains (name, " text_encoders.clip_g" );
1229+ } else if (index == 2 ) {
1230+ return contains (name, " text_encoders.t5xxl" );
1231+ } else {
1232+ return false ;
1233+ }
1234+ }
1235+
1236+ ggml_backend_t get_params_backend_at_index (int index){
1237+ if (index == 0 && clip_l) {
1238+ return clip_l->get_params_backend ();
1239+ } else if (index == 1 && clip_g) {
1240+ return clip_g->get_params_backend ();
1241+ } else if (index == 2 && t5) {
1242+ return t5->get_params_backend ();
1243+ } else {
1244+ return nullptr ;
1245+ }
1246+ }
1247+
1248+ ggml_backend_t get_runtime_backend_at_index (int index){
1249+ if (index == 0 && clip_l) {
1250+ return clip_l->get_runtime_backend ();
1251+ } else if (index == 1 && clip_g) {
1252+ return clip_g->get_runtime_backend ();
1253+ } else if (index == 2 && t5) {
1254+ return t5->get_runtime_backend ();
1255+ } else {
1256+ return nullptr ;
1257+ }
1258+ }
11781259};
11791260
11801261struct FluxCLIPEmbedder : public Conditioner {
@@ -1190,19 +1271,19 @@ struct FluxCLIPEmbedder : public Conditioner {
11901271 bool use_clip_l = false ;
11911272 bool use_t5 = false ;
11921273
1274+ model_count = 2 ;
11931275
11941276 ggml_backend_t clip_l_backend, t5_backend;
11951277 if (backends.size () == 1 ) {
11961278 clip_l_backend = t5_backend = backends[0 ];
11971279 } else if (backends.size () >= 2 ) {
11981280 clip_l_backend = backends[0 ];
1199- t5_backend = backends[1 ];
1281+ t5_backend = backends[1 ];
12001282 if (backends.size () > 2 ) {
12011283 LOG_WARN (" More than 2 clip backends provided, but the model only supports 2 text encoders. Ignoring the rest." );
12021284 }
12031285 }
12041286
1205-
12061287 for (auto pair : tensor_storage_map) {
12071288 if (pair.first .find (" text_encoders.clip_l" ) != std::string::npos) {
12081289 use_clip_l = true ;
@@ -1468,6 +1549,36 @@ struct FluxCLIPEmbedder : public Conditioner {
14681549 conditioner_params.clip_skip ,
14691550 conditioner_params.zero_out_masked );
14701551 }
1552+
1553+ bool is_cond_stage_model_name_at_index (const std::string& name, int index) override {
1554+ if (index == 0 ) {
1555+ return contains (name, " text_encoders.clip_l" );
1556+ } else if (index == 1 ) {
1557+ return contains (name, " text_encoders.t5xxl" );
1558+ } else {
1559+ return false ;
1560+ }
1561+ }
1562+
1563+ ggml_backend_t get_params_backend_at_index (int index){
1564+ if (index == 0 && clip_l) {
1565+ return clip_l->get_params_backend ();
1566+ } else if (index == 1 && t5) {
1567+ return t5->get_params_backend ();
1568+ } else {
1569+ return nullptr ;
1570+ }
1571+ }
1572+
1573+ ggml_backend_t get_runtime_backend_at_index (int index){
1574+ if (index == 0 && clip_l) {
1575+ return clip_l->get_runtime_backend ();
1576+ } else if (index == 1 && t5) {
1577+ return t5->get_runtime_backend ();
1578+ } else {
1579+ return nullptr ;
1580+ }
1581+ }
14711582};
14721583
14731584struct T5CLIPEmbedder : public Conditioner {
@@ -1691,6 +1802,20 @@ struct T5CLIPEmbedder : public Conditioner {
16911802 conditioner_params.clip_skip ,
16921803 conditioner_params.zero_out_masked );
16931804 }
1805+
1806+ ggml_backend_t get_params_backend_at_index (int index){
1807+ if (t5){
1808+ return t5->get_params_backend ();
1809+ }
1810+ return nullptr ;
1811+ }
1812+
1813+ ggml_backend_t get_runtime_backend_at_index (int index){
1814+ if (t5){
1815+ return t5->get_runtime_backend ();
1816+ }
1817+ return nullptr ;
1818+ }
16941819};
16951820
16961821struct AnimaConditioner : public Conditioner {
@@ -1703,11 +1828,11 @@ struct AnimaConditioner : public Conditioner {
17031828 const String2TensorStorage& tensor_storage_map = {}) {
17041829 qwen_tokenizer = std::make_shared<LLM::Qwen2Tokenizer>();
17051830 llm = std::make_shared<LLM::LLMRunner>(LLM::LLMArch::QWEN3,
1706- backend,
1707- offload_params_to_cpu,
1708- tensor_storage_map,
1709- " text_encoders.llm" ,
1710- false );
1831+ backend,
1832+ offload_params_to_cpu,
1833+ tensor_storage_map,
1834+ " text_encoders.llm" ,
1835+ false );
17111836 }
17121837
17131838 void get_param_tensors (std::map<std::string, ggml_tensor*>& tensors) override {
@@ -1827,6 +1952,20 @@ struct AnimaConditioner : public Conditioner {
18271952
18281953 return {hidden_states, t5_weight_tensor, t5_ids_tensor};
18291954 }
1955+
1956+ ggml_backend_t get_params_backend_at_index (int index){
1957+ if (llm){
1958+ return llm->get_params_backend ();
1959+ }
1960+ return nullptr ;
1961+ }
1962+
1963+ ggml_backend_t get_runtime_backend_at_index (int index){
1964+ if (llm){
1965+ return llm->get_runtime_backend ();
1966+ }
1967+ return nullptr ;
1968+ }
18301969};
18311970
18321971struct LLMEmbedder : public Conditioner {
@@ -2201,6 +2340,20 @@ struct LLMEmbedder : public Conditioner {
22012340 LOG_DEBUG (" computing condition graph completed, taking %" PRId64 " ms" , t1 - t0);
22022341 return {hidden_states, nullptr , nullptr , extra_hidden_states_vec};
22032342 }
2343+
2344+ ggml_backend_t get_params_backend_at_index (int index){
2345+ if (llm){
2346+ return llm->get_params_backend ();
2347+ }
2348+ return nullptr ;
2349+ }
2350+
2351+ ggml_backend_t get_runtime_backend_at_index (int index){
2352+ if (llm){
2353+ return llm->get_runtime_backend ();
2354+ }
2355+ return nullptr ;
2356+ }
22042357};
22052358
22062359#endif
0 commit comments