Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions include/stable-diffusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,14 @@ SD_API void sd_img_gen_params_init(sd_img_gen_params_t* sd_img_gen_params);
SD_API char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params);
SD_API sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_gen_params);

enum sd_cancel_mode_t {
SD_CANCEL_ALL,
SD_CANCEL_NEW_LATENTS,
SD_CANCEL_RESET
};

SD_API void sd_cancel_generation(sd_ctx_t* sd_ctx, enum sd_cancel_mode_t mode);

SD_API void sd_vid_gen_params_init(sd_vid_gen_params_t* sd_vid_gen_params);
SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* sd_vid_gen_params, int* num_frames_out);

Expand Down
49 changes: 49 additions & 0 deletions src/stable-diffusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
#include "latent-preview.h"
#include "name_conversion.h"

#include <atomic>

const char* model_version_to_str[] = {
"SD 1.x",
"SD 1.x Inpaint",
Expand Down Expand Up @@ -106,6 +108,9 @@ static float get_cache_reuse_threshold(const sd_cache_params_t& params) {

/*=============================================== StableDiffusionGGML ================================================*/

static_assert(std::atomic<sd_cancel_mode_t>::is_always_lock_free,
"sd_cancel_mode_t must be lock-free");

class StableDiffusionGGML {
public:
ggml_backend_t backend = nullptr; // general backend
Expand Down Expand Up @@ -171,6 +176,20 @@ class StableDiffusionGGML {
ggml_backend_free(backend);
}

std::atomic<sd_cancel_mode_t> cancellation_flag;

void set_cancel_flag(enum sd_cancel_mode_t flag) {
cancellation_flag.store(flag, std::memory_order_release);
}

void reset_cancel_flag() {
set_cancel_flag(SD_CANCEL_RESET);
}

enum sd_cancel_mode_t get_cancel_flag() {
return cancellation_flag.load(std::memory_order_acquire);
}

void init_backend() {
#ifdef SD_USE_CUDA
LOG_DEBUG("Using CUDA backend");
Expand Down Expand Up @@ -1646,6 +1665,12 @@ class StableDiffusionGGML {
SamplePreviewContext preview = prepare_sample_preview_context();

auto denoise = [&](const sd::Tensor<float>& x, float sigma, int step) -> sd::Tensor<float> {
enum sd_cancel_mode_t cancel_flag = get_cancel_flag();
if (cancel_flag != SD_CANCEL_RESET) {
LOG_DEBUG("cancelling generation");
return {};
}

if (step == 1 || step == -1) {
pretty_progress(0, (int)steps, 0);
}
Expand Down Expand Up @@ -2480,6 +2505,15 @@ void free_sd_ctx(sd_ctx_t* sd_ctx) {
free(sd_ctx);
}

SD_API void sd_cancel_generation(sd_ctx_t* sd_ctx, enum sd_cancel_mode_t mode) {
if (sd_ctx && sd_ctx->sd) {
if (mode < SD_CANCEL_ALL || mode > SD_CANCEL_RESET) {
mode = SD_CANCEL_ALL;
}
sd_ctx->sd->set_cancel_flag(mode);
}
}

SD_API bool sd_ctx_supports_image_generation(const sd_ctx_t* sd_ctx) {
if (sd_ctx == nullptr || sd_ctx->sd == nullptr) {
return false;
Expand Down Expand Up @@ -3222,6 +3256,10 @@ static sd_image_t* decode_image_outputs(sd_ctx_t* sd_ctx,
int64_t t0 = ggml_time_ms();

for (size_t i = 0; i < final_latents.size(); i++) {
if (sd_ctx->sd->get_cancel_flag() == SD_CANCEL_ALL) {
LOG_ERROR("cancelling latent decodings");
break;
}
int64_t t1 = ggml_time_ms();
sd::Tensor<float> image = sd_ctx->sd->decode_first_stage(final_latents[i]);
if (image.empty()) {
Expand Down Expand Up @@ -3389,6 +3427,8 @@ SD_API sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* s
return nullptr;
}

sd_ctx->sd->reset_cancel_flag();

int64_t t0 = ggml_time_ms();
sd_ctx->sd->vae_tiling_params = sd_img_gen_params->vae_tiling_params;
GenerationRequest request(sd_ctx, sd_img_gen_params);
Expand Down Expand Up @@ -3424,6 +3464,12 @@ SD_API sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* s
std::vector<sd::Tensor<float>> final_latents;
int64_t denoise_start = ggml_time_ms();
for (int b = 0; b < request.batch_count; b++) {
sd_cancel_mode_t cancel = sd_ctx->sd->get_cancel_flag();
if (cancel == SD_CANCEL_NEW_LATENTS || cancel == SD_CANCEL_ALL) {
LOG_ERROR("cancelling generation");
break;
}

int64_t sampling_start = ggml_time_ms();
int64_t cur_seed = request.seed + b;
LOG_INFO("generating image: %i/%i - seed %" PRId64, b + 1, request.batch_count, cur_seed);
Expand Down Expand Up @@ -3876,6 +3922,9 @@ SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* s
if (sd_ctx == nullptr || sd_vid_gen_params == nullptr) {
return nullptr;
}

sd_ctx->sd->reset_cancel_flag();

if (num_frames_out != nullptr) {
*num_frames_out = 0;
}
Expand Down
Loading