Skip to content

Commit 47ef4b0

Browse files
committed
feat: enable SIGUSR1 to soft-cancel pending batches
This is mostly an example of using sd_cancel to asynchronously cancel the current batch without discarding already generated images.
1 parent 771edfa commit 47ef4b0

1 file changed

Lines changed: 67 additions & 0 deletions

File tree

examples/cli/main.cpp

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -474,6 +474,13 @@ bool save_results(const SDCliParams& cli_params,
474474
return sucessful_reults != 0;
475475
}
476476

477+
#if defined(__unix__) || defined(__APPLE__) || defined(_POSIX_VERSION)
478+
#define SD_ENABLE_SIGNAL_HANDLER
479+
static void set_signal_cancel_handler(sd_ctx_t* sd_ctx);
480+
#else
481+
#define set_signal_cancel_handler(SD_CTX) ((void)SD_CTX)
482+
#endif
483+
477484
int main(int argc, const char* argv[]) {
478485
if (argc > 1 && std::string(argv[1]) == "--version") {
479486
std::cout << version_string() << "\n";
@@ -711,6 +718,8 @@ int main(int argc, const char* argv[]) {
711718
return 1;
712719
}
713720

721+
set_signal_cancel_handler(sd_ctx);
722+
714723
if (gen_params.sample_params.sample_method == SAMPLE_METHOD_COUNT) {
715724
gen_params.sample_params.sample_method = sd_get_default_sample_method(sd_ctx);
716725
}
@@ -783,6 +792,8 @@ int main(int argc, const char* argv[]) {
783792
results = generate_video(sd_ctx, &vid_gen_params, &num_results);
784793
}
785794

795+
set_signal_cancel_handler(nullptr);
796+
786797
if (results == nullptr) {
787798
LOG_ERROR("generate failed");
788799
free_sd_ctx(sd_ctx);
@@ -836,3 +847,59 @@ int main(int argc, const char* argv[]) {
836847

837848
return 0;
838849
}
850+
851+
#ifdef SD_ENABLE_SIGNAL_HANDLER
852+
853+
#include <atomic>
854+
#include <csignal>
855+
#include <thread>
856+
#include <unistd.h>
857+
858+
// this lock is needed to avoid a race condition between
859+
// free_sd_ctx and a pending sd_cancel_generation call
860+
std::atomic_flag signal_lock = ATOMIC_FLAG_INIT;
861+
static int g_sigint_cnt;
862+
static sd_ctx_t* g_sd_ctx;
863+
864+
static void sig_cancel_handler(int /* signum */)
865+
{
866+
if (!signal_lock.test_and_set(std::memory_order_acquire)) {
867+
if (g_sd_ctx != nullptr) {
868+
if (g_sigint_cnt == 1) {
869+
char msg[] = "\ngot cancel signal, cancelling new generations\n";
870+
write(2, msg, sizeof(msg)-1);
871+
/* first signal cancels only the remaining latents on a batch */
872+
sd_cancel_generation(g_sd_ctx, SD_CANCEL_NEW_LATENTS);
873+
++g_sigint_cnt;
874+
} else {
875+
char msg[] = "\ngot cancel signal, cancelling everything\n";
876+
write(2, msg, sizeof(msg)-1);
877+
/* cancels everything */
878+
sd_cancel_generation(g_sd_ctx, SD_CANCEL_ALL);
879+
}
880+
}
881+
signal_lock.clear(std::memory_order_release);
882+
}
883+
}
884+
885+
static void set_signal_cancel_handler(sd_ctx_t* sd_ctx)
886+
{
887+
if (g_sigint_cnt == 0) {
888+
g_sigint_cnt++;
889+
struct sigaction sa{};
890+
sa.sa_handler = sig_cancel_handler;
891+
sa.sa_flags = SA_RESTART;
892+
sigaction(SIGUSR1, &sa, nullptr);
893+
}
894+
895+
while (signal_lock.test_and_set(std::memory_order_acquire)) {
896+
std::this_thread::yield();
897+
}
898+
899+
g_sd_ctx = sd_ctx;
900+
901+
signal_lock.clear(std::memory_order_release);
902+
}
903+
904+
#endif
905+

0 commit comments

Comments
 (0)