Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
247 changes: 136 additions & 111 deletions otherarch/sdcpp/stable-diffusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,10 @@ class StableDiffusionGGML {
}
}

bool init(const sd_ctx_params_t* sd_ctx_params) {
bool init(const sd_ctx_params_t* sd_ctx_params_kcpp) {
// kcpp make sd_ctx_params mutable
sd_ctx_params_t sd_ctx_params_local = *sd_ctx_params_kcpp;
sd_ctx_params_t *sd_ctx_params = &sd_ctx_params_local;
n_threads = sd_ctx_params->n_threads;
vae_decode_only = sd_ctx_params->vae_decode_only;
free_params_immediately = sd_ctx_params->free_params_immediately;
Expand All @@ -304,10 +307,12 @@ class StableDiffusionGGML {

init_backend();

std::string taesd_path_fixed = taesd_path;
std::string t5_path_fixed = SAFE_STR(sd_ctx_params->t5xxl_path);
std::string clipl_path_fixed = SAFE_STR(sd_ctx_params->clip_l_path);
std::string clipg_path_fixed = SAFE_STR(sd_ctx_params->clip_g_path);
std::string clip_vision_fixed = SAFE_STR(sd_ctx_params->clip_vision_path);
std::string clipg_path_fixed = SAFE_STR(sd_ctx_params->clip_g_path);
std::string clipl_path_fixed = SAFE_STR(sd_ctx_params->clip_l_path);
std::string llm_path_fixed = SAFE_STR(sd_ctx_params->llm_path);
std::string t5_path_fixed = SAFE_STR(sd_ctx_params->t5xxl_path);
std::string taesd_path_fixed = taesd_path;

ModelLoader model_loader;

Expand All @@ -333,7 +338,9 @@ class StableDiffusionGGML {
}

bool is_unet = sd_version_is_unet(model_loader.get_sd_version());
int tempver = model_loader.get_sd_version();

// begin kcpp replacements
SDVersion tempver = model_loader.get_sd_version();

// kcpp fallback to separate diffusion model passed as model
if (tempver == VERSION_COUNT &&
Expand Down Expand Up @@ -365,73 +372,143 @@ class StableDiffusionGGML {
tempver = model_loader.get_sd_version();
}

bool iswan = (tempver==VERSION_WAN2 || tempver==VERSION_WAN2_2_I2V || tempver==VERSION_WAN2_2_TI2V);
bool isqwenimg = (tempver==VERSION_QWEN_IMAGE);
bool iszimg = (tempver==VERSION_Z_IMAGE);
bool isflux2 = (tempver==VERSION_FLUX2);
bool isflux2k = (tempver==VERSION_FLUX2_KLEIN);
bool is_wan = sd_version_is_wan(tempver);
bool is_wan21 = sd_version_is_wan(tempver) && tempver != VERSION_WAN2_2_TI2V;
bool is_qwenimg = sd_version_is_qwen_image(tempver);
bool is_ovis = (tempver==VERSION_OVIS_IMAGE);
bool is_anima = (tempver==VERSION_ANIMA);
bool conditioner_is_llm = (isqwenimg||iszimg||isflux2||isflux2k||is_ovis||is_anima);
bool has_t5 = sd_version_is_sd3(tempver) || (sd_version_is_flux(tempver) && !is_ovis) || is_wan;
bool has_clip_or_t5 = sd_version_is_unet(tempver) || has_t5;

auto swap_to = [](const char* name, std::string& dst, std::string& src) {
LOG_INFO("swap %s from '%s'", name, src.c_str());
dst.swap(src);
};

if (has_clip_or_t5) {

if(is_wan) {
if (clip_vision_fixed == "" && clipl_path_fixed != "") {
swap_to("clip_vision", clip_vision_fixed, clipl_path_fixed);
}
if (clipg_path_fixed != "") {
if (t5_path_fixed == "") {
swap_to("umt5_xxl", t5_path_fixed, clipg_path_fixed);
} else if (clip_vision_fixed == "") {
swap_to("clip_vision", clip_vision_fixed, clipg_path_fixed);
} else {
LOG_WARN("unused model '%s'", clipg_path_fixed.c_str());
clipg_path_fixed = "";
}
}
}

} else {

std::string * conditioners[3] = {&clipl_path_fixed, &clipg_path_fixed, &t5_path_fixed};

if(is_qwenimg) {
for (auto conditioner: conditioners) {
if (*conditioner != "") {
// assume the llm comes first, unless we see "mmproj" in the filename
if (clip_vision_fixed == "" && toLowerCase(*conditioner).find("mmproj") != std::string::npos) {
swap_to("clip_vision", clip_vision_fixed, *conditioner);
} else if (llm_path_fixed == "") {
swap_to("llm", llm_path_fixed, *conditioner);
} else if (clip_vision_fixed == "") {
swap_to("clip_vision", clip_vision_fixed, *conditioner);
}
}
}
}

else {
// assume it's a model with a single llm conditioner (z-image, flux2, ...)
for (auto conditioner: conditioners) {
if (llm_path_fixed == "" && *conditioner != "") {
swap_to("llm", llm_path_fixed, *conditioner);
break;
}
}
}

for (auto conditioner: conditioners) {
if (*conditioner != "") {
LOG_WARN("unused model '%s'", conditioner->c_str());
*conditioner = "";
}
}
}

//kcpp qol fallback: if qwen image, and they loaded the qwen2vl llm as t5 by mistake
if(conditioner_is_llm && t5_path_fixed!="")
if(taesd_path_fixed != "")
{
if(clipl_path_fixed=="" && clipg_path_fixed=="")
std::string to_search = "taesd.embd";
std::string to_replace = "";
if(sd_version_is_sd1(tempver) || sd_version_is_sd2(tempver))
{
clipl_path_fixed = t5_path_fixed;
t5_path_fixed = "";
to_replace = "taesd.embd";
}
else if(clipl_path_fixed=="" && clipg_path_fixed!="")
else if(sd_version_is_sdxl(tempver))
{
clipl_path_fixed = t5_path_fixed;
t5_path_fixed = "";
to_replace = "taesd_xl.embd";
}
else if(clipl_path_fixed!="" && clipg_path_fixed=="")
else if(sd_version_is_flux(tempver)||sd_version_is_z_image(tempver)||tempver == VERSION_OVIS_IMAGE)
{
//very tricky case. see if we can tell if clipl is an mmproj, if so move to right place
if(toLowerCase(clipl_path_fixed).find("mmproj") != std::string::npos)
{
clipg_path_fixed = clipl_path_fixed;
clipl_path_fixed = t5_path_fixed;
t5_path_fixed = "";
}
to_replace = "taesd_f.embd";
}
}

if (clipl_path_fixed!="") {
LOG_INFO("loading clip_l from '%s'", clipl_path_fixed.c_str());
std::string prefix = is_unet ? "cond_stage_model.transformer." : "text_encoders.clip_l.transformer.";
if(iswan)
else if(sd_version_is_sd3(tempver))
{
prefix = "cond_stage_model.transformer.";
LOG_INFO("swap clip_vision from '%s'", clipl_path_fixed.c_str());
to_replace = "taesd_3.embd";
}
if(conditioner_is_llm)
else if(sd_version_is_flux2(tempver))
{
prefix = "text_encoders.llm.";
LOG_INFO("swap llm from '%s'", clipl_path_fixed.c_str());
to_replace = "taesd_f2.embd";
}
if (!model_loader.init_from_file(clipl_path_fixed.c_str(), prefix)) {
LOG_WARN("loading clip_l from '%s' failed", clipl_path_fixed.c_str());
else if(is_wan21||is_qwenimg||sd_version_is_anima(tempver))
{
to_replace = "taesd_w21.embd";
}
}

if (clipg_path_fixed!="") {
LOG_INFO("loading clip_g from '%s'", clipg_path_fixed.c_str());
std::string prefix = is_unet ? "cond_stage_model.1.transformer." : "text_encoders.clip_g.transformer.";
if(iswan)
if(to_replace!="")
{
size_t pos = taesd_path_fixed.find(to_search);
if (pos != std::string::npos) {
taesd_path_fixed.replace(pos, to_search.length(), to_replace);
}
}
else
{
prefix = "cond_stage_model.transformer.";
LOG_INFO("swap clip_vision from '%s'", clipg_path_fixed.c_str());
printf("\nCannot use TAESD: Unknown tempver %d. TAESD Disabled!\n",tempver);
taesd_path_fixed = "";
}
if(isqwenimg)
if (taesd_path_fixed != "" && !file_exists(taesd_path_fixed))
{
prefix = "text_encoders.llm.visual.";
LOG_INFO("swap llm mmproj from '%s'", clipg_path_fixed.c_str());
printf("\nCannot use TAESD: \"%s\" not found. TAESD Disabled!\n", taesd_path_fixed.c_str());
taesd_path_fixed = "";
}
if (!model_loader.init_from_file(clipg_path_fixed.c_str(), prefix)) {
LOG_WARN("loading clip_g from '%s' failed", clipg_path_fixed.c_str());
}

sd_ctx_params->clip_g_path = clipg_path_fixed.c_str();
sd_ctx_params->clip_l_path = clipl_path_fixed.c_str();
sd_ctx_params->clip_vision_path = clip_vision_fixed.c_str();
sd_ctx_params->llm_path = llm_path_fixed.c_str();
sd_ctx_params->t5xxl_path = t5_path_fixed.c_str();
taesd_path = taesd_path_fixed;
use_tiny_autoencoder = (taesd_path != "");
// end kcpp replacements

if (strlen(SAFE_STR(sd_ctx_params->clip_l_path)) > 0) {
LOG_INFO("loading clip_l from '%s'", sd_ctx_params->clip_l_path);
std::string prefix = is_unet ? "cond_stage_model.transformer." : "text_encoders.clip_l.transformer.";
if (!model_loader.init_from_file(sd_ctx_params->clip_l_path, prefix)) {
LOG_WARN("loading clip_l from '%s' failed", sd_ctx_params->clip_l_path);
}
}

if (strlen(SAFE_STR(sd_ctx_params->clip_g_path)) > 0) {
LOG_INFO("loading clip_g from '%s'", sd_ctx_params->clip_g_path);
std::string prefix = is_unet ? "cond_stage_model.1.transformer." : "text_encoders.clip_g.transformer.";
if (!model_loader.init_from_file(sd_ctx_params->clip_g_path, prefix)) {
LOG_WARN("loading clip_g from '%s' failed", sd_ctx_params->clip_g_path);
}
}

Expand All @@ -443,10 +520,10 @@ class StableDiffusionGGML {
}
}

if (t5_path_fixed!="") {
LOG_INFO("loading t5xxl from '%s'", t5_path_fixed.c_str());
if (!model_loader.init_from_file(t5_path_fixed.c_str(), "text_encoders.t5xxl.transformer.")) {
LOG_WARN("loading t5xxl from '%s' failed", t5_path_fixed.c_str());
if (strlen(SAFE_STR(sd_ctx_params->t5xxl_path)) > 0) {
LOG_INFO("loading t5xxl from '%s'", sd_ctx_params->t5xxl_path);
if (!model_loader.init_from_file(sd_ctx_params->t5xxl_path, "text_encoders.t5xxl.transformer.")) {
LOG_WARN("loading t5xxl from '%s' failed", sd_ctx_params->t5xxl_path);
}
}

Expand Down Expand Up @@ -475,7 +552,6 @@ class StableDiffusionGGML {
model_loader.convert_tensors_name();

version = model_loader.get_sd_version();

if (version == VERSION_COUNT) {
LOG_ERROR("get sd version from file failed: '%s'", SAFE_STR(sd_ctx_params->model_path));
return false;
Expand All @@ -484,57 +560,6 @@ class StableDiffusionGGML {
auto& tensor_storage_map = model_loader.get_tensor_storage_map();

LOG_INFO("Version: %s ", model_version_to_str[version]);

if(use_tiny_autoencoder) // kcpp
{
std::string to_search = "taesd.embd";
std::string to_replace = "";
if(sd_version_is_sd1(version) || sd_version_is_sd2(version))
{
to_replace = "taesd.embd";
}
else if(sd_version_is_sdxl(version))
{
to_replace = "taesd_xl.embd";
}
else if(sd_version_is_flux(version)||sd_version_is_z_image(version)||version == VERSION_OVIS_IMAGE)
{
to_replace = "taesd_f.embd";
}
else if(sd_version_is_sd3(version))
{
to_replace = "taesd_3.embd";
}
else if(sd_version_is_flux2(version))
{
to_replace = "taesd_f2.embd";
}
else if((sd_version_is_wan(version) && version != VERSION_WAN2_2_TI2V)||sd_version_is_qwen_image(version)||sd_version_is_anima(version))
{
to_replace = "taesd_w21.embd";
}

if(to_replace!="")
{
size_t pos = taesd_path_fixed.find(to_search);
if (pos != std::string::npos) {
taesd_path_fixed.replace(pos, to_search.length(), to_replace);
}
}
else
{
printf("\nCannot use TAESD: Unknown version %d. TAESD Disabled!\n",version);
taesd_path_fixed = "";
use_tiny_autoencoder = false;
}
if (use_tiny_autoencoder && !file_exists(taesd_path_fixed))
{
printf("\nCannot use TAESD: \"%s\" not found. TAESD Disabled!\n", taesd_path_fixed.c_str());
taesd_path_fixed = "";
use_tiny_autoencoder = false;
}
}

ggml_type wtype = (int)sd_ctx_params->wtype < std::min<int>(SD_TYPE_COUNT, GGML_TYPE_COUNT)
? (ggml_type)sd_ctx_params->wtype
: GGML_TYPE_COUNT;
Expand Down Expand Up @@ -1025,7 +1050,7 @@ class StableDiffusionGGML {
vae_params_mem_size = first_stage_model->get_params_buffer_size();
}
if (use_tiny_autoencoder || version == VERSION_SDXS) {
if (use_tiny_autoencoder && !tae_first_stage->load_from_file(taesd_path_fixed, n_threads)) {
if (use_tiny_autoencoder && !tae_first_stage->load_from_file(taesd_path, n_threads)) {
return false;
}
use_tiny_autoencoder = true; // now the processing is identical for VERSION_SDXS
Expand Down