Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 107 additions & 54 deletions examples/server/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -404,8 +404,8 @@ int main(int argc, const char** argv) {
std::string size = j.value("size", "");
std::string output_format = j.value("output_format", "png");
int output_compression = j.value("output_compression", 100);
int width = 512;
int height = 512;
int width = default_gen_params.width > 0 ? default_gen_params.width : 512;
int height = default_gen_params.width > 0 ? default_gen_params.height : 512;
if (!size.empty()) {
auto pos = size.find('x');
if (pos != std::string::npos) {
Expand Down Expand Up @@ -593,7 +593,7 @@ int main(int argc, const char** argv) {
n = std::clamp(n, 1, 8);

std::string size = req.form.get_field("size");
int width = 512, height = 512;
int width = -1, height = -1;
if (!size.empty()) {
auto pos = size.find('x');
if (pos != std::string::npos) {
Expand Down Expand Up @@ -650,15 +650,31 @@ int main(int argc, const char** argv) {

LOG_DEBUG("%s\n", gen_params.to_string().c_str());

sd_image_t init_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr};
sd_image_t control_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr};
sd_image_t init_image = {0, 0, 3, nullptr};
sd_image_t control_image = {0, 0, 3, nullptr};
std::vector<sd_image_t> pmid_images;

auto get_resolved_width = [&gen_params, &default_gen_params]() -> int {
if (gen_params.width > 0)
return gen_params.width;
if (default_gen_params.width > 0)
return default_gen_params.width;
return 512;
};
auto get_resolved_height = [&gen_params, &default_gen_params]() -> int {
if (gen_params.height > 0)
return gen_params.height;
if (default_gen_params.height > 0)
return default_gen_params.height;
return 512;
};

std::vector<sd_image_t> ref_images;
ref_images.reserve(images_bytes.size());
for (auto& bytes : images_bytes) {
int img_w = width;
int img_h = height;
int img_w;
int img_h;

uint8_t* raw_pixels = load_image_from_memory(
reinterpret_cast<const char*>(bytes.data()),
static_cast<int>(bytes.size()),
Expand All @@ -670,22 +686,31 @@ int main(int argc, const char** argv) {
}

sd_image_t img{(uint32_t)img_w, (uint32_t)img_h, 3, raw_pixels};
gen_params.set_width_and_height_if_unset(img.width, img.height);
ref_images.push_back(img);
}

sd_image_t mask_image = {0};
if (!mask_bytes.empty()) {
int mask_w = width;
int mask_h = height;
int expected_width = 0;
int expected_height = 0;
if (gen_params.width_and_height_are_set()) {
expected_width = gen_params.width;
expected_height = gen_params.height;
}
int mask_w;
int mask_h;

uint8_t* mask_raw = load_image_from_memory(
reinterpret_cast<const char*>(mask_bytes.data()),
static_cast<int>(mask_bytes.size()),
mask_w, mask_h,
width, height, 1);
expected_width, expected_height, 1);
mask_image = {(uint32_t)mask_w, (uint32_t)mask_h, 1, mask_raw};
gen_params.set_width_and_height_if_unset(mask_image.width, mask_image.height);
} else {
mask_image.width = width;
mask_image.height = height;
mask_image.width = get_resolved_width();
mask_image.height = get_resolved_height();
mask_image.channel = 1;
mask_image.data = nullptr;
}
Expand All @@ -702,8 +727,8 @@ int main(int argc, const char** argv) {
gen_params.auto_resize_ref_image,
gen_params.increase_ref_index,
mask_image,
gen_params.width,
gen_params.height,
get_resolved_width(),
get_resolved_height(),
gen_params.sample_params,
gen_params.strength,
gen_params.seed,
Expand Down Expand Up @@ -886,8 +911,6 @@ int main(int argc, const char** argv) {
SDGenerationParams gen_params = default_gen_params;
gen_params.prompt = prompt;
gen_params.negative_prompt = negative_prompt;
gen_params.width = width;
gen_params.height = height;
gen_params.seed = seed;
gen_params.sample_params.sample_steps = steps;
gen_params.batch_count = batch_size;
Expand All @@ -905,38 +928,66 @@ int main(int argc, const char** argv) {
gen_params.sample_params.scheduler = scheduler;
}

// re-read to avoid applying 512 as default before the provided
// images and/or server command-line
gen_params.width = j.value("width", -1);
gen_params.height = j.value("height", -1);

LOG_DEBUG("%s\n", gen_params.to_string().c_str());

sd_image_t init_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr};
sd_image_t control_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr};
sd_image_t mask_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 1, nullptr};
sd_image_t init_image = {0, 0, 3, nullptr};
sd_image_t control_image = {0, 0, 3, nullptr};
sd_image_t mask_image = {0, 0, 1, nullptr};
std::vector<uint8_t> mask_data;
std::vector<sd_image_t> pmid_images;
std::vector<sd_image_t> ref_images;

if (img2img) {
auto decode_image = [](sd_image_t& image, std::string encoded) -> bool {
// remove data URI prefix if present ("data:image/png;base64,")
auto comma_pos = encoded.find(',');
if (comma_pos != std::string::npos) {
encoded = encoded.substr(comma_pos + 1);
auto get_resolved_width = [&gen_params, &default_gen_params]() -> int {
if (gen_params.width > 0)
return gen_params.width;
if (default_gen_params.width > 0)
return default_gen_params.width;
return 512;
};
auto get_resolved_height = [&gen_params, &default_gen_params]() -> int {
if (gen_params.height > 0)
return gen_params.height;
if (default_gen_params.height > 0)
return default_gen_params.height;
return 512;
};

auto decode_image = [&gen_params](sd_image_t& image, std::string encoded) -> bool {
// remove data URI prefix if present ("data:image/png;base64,")
auto comma_pos = encoded.find(',');
if (comma_pos != std::string::npos) {
encoded = encoded.substr(comma_pos + 1);
}
std::vector<uint8_t> img_data = base64_decode(encoded);
if (!img_data.empty()) {
int expected_width = 0;
int expected_height = 0;
if (gen_params.width_and_height_are_set()) {
expected_width = gen_params.width;
expected_height = gen_params.height;
}
std::vector<uint8_t> img_data = base64_decode(encoded);
if (!img_data.empty()) {
int img_w = image.width;
int img_h = image.height;
uint8_t* raw_data = load_image_from_memory(
(const char*)img_data.data(), (int)img_data.size(),
img_w, img_h,
image.width, image.height, image.channel);
if (raw_data) {
image = {(uint32_t)img_w, (uint32_t)img_h, image.channel, raw_data};
return true;
}
int img_w;
int img_h;

uint8_t* raw_data = load_image_from_memory(
(const char*)img_data.data(), (int)img_data.size(),
img_w, img_h,
expected_width, expected_height, image.channel);
if (raw_data) {
image = {(uint32_t)img_w, (uint32_t)img_h, image.channel, raw_data};
gen_params.set_width_and_height_if_unset(image.width, image.height);
return true;
}
return false;
};
}
return false;
};

if (img2img) {
if (j.contains("init_images") && j["init_images"].is_array() && !j["init_images"].empty()) {
std::string encoded = j["init_images"][0].get<std::string>();
decode_image(init_image, encoded);
Expand All @@ -952,30 +1003,32 @@ int main(int argc, const char** argv) {
}
}
} else {
mask_data = std::vector<uint8_t>(width * height, 255);
mask_image.width = width;
mask_image.height = height;
int m_width = get_resolved_width();
int m_height = get_resolved_height();
mask_data = std::vector<uint8_t>(m_width * m_height, 255);
mask_image.width = m_width;
mask_image.height = m_height;
mask_image.channel = 1;
mask_image.data = mask_data.data();
}

if (j.contains("extra_images") && j["extra_images"].is_array()) {
for (auto extra_image : j["extra_images"]) {
std::string encoded = extra_image.get<std::string>();
sd_image_t tmp_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr};
if (decode_image(tmp_image, encoded)) {
ref_images.push_back(tmp_image);
}
}
}

float denoising_strength = j.value("denoising_strength", -1.f);
if (denoising_strength >= 0.f) {
denoising_strength = std::min(denoising_strength, 1.0f);
gen_params.strength = denoising_strength;
}
}

if (j.contains("extra_images") && j["extra_images"].is_array()) {
for (auto extra_image : j["extra_images"]) {
std::string encoded = extra_image.get<std::string>();
sd_image_t tmp_image = {(uint32_t)gen_params.width, (uint32_t)gen_params.height, 3, nullptr};
if (decode_image(tmp_image, encoded)) {
ref_images.push_back(tmp_image);
}
}
}

sd_img_gen_params_t img_gen_params = {
sd_loras.data(),
static_cast<uint32_t>(sd_loras.size()),
Expand All @@ -988,8 +1041,8 @@ int main(int argc, const char** argv) {
gen_params.auto_resize_ref_image,
gen_params.increase_ref_index,
mask_image,
gen_params.width,
gen_params.height,
get_resolved_width(),
get_resolved_height(),
gen_params.sample_params,
gen_params.strength,
gen_params.seed,
Expand Down
Loading