@@ -474,6 +474,13 @@ bool save_results(const SDCliParams& cli_params,
474474 return sucessful_reults != 0 ;
475475}
476476
477+ #if defined(__unix__) || defined(__APPLE__) || defined(_POSIX_VERSION)
478+ #define SD_ENABLE_SIGNAL_HANDLER
479+ static void set_signal_cancel_handler (sd_ctx_t * sd_ctx);
480+ #else
481+ #define set_signal_cancel_handler (SD_CTX ) ((void )SD_CTX)
482+ #endif
483+
477484int main (int argc, const char * argv[]) {
478485 if (argc > 1 && std::string (argv[1 ]) == " --version" ) {
479486 std::cout << version_string () << " \n " ;
@@ -711,6 +718,8 @@ int main(int argc, const char* argv[]) {
711718 return 1 ;
712719 }
713720
721+ set_signal_cancel_handler (sd_ctx);
722+
714723 if (gen_params.sample_params .sample_method == SAMPLE_METHOD_COUNT) {
715724 gen_params.sample_params .sample_method = sd_get_default_sample_method (sd_ctx);
716725 }
@@ -783,6 +792,8 @@ int main(int argc, const char* argv[]) {
783792 results = generate_video (sd_ctx, &vid_gen_params, &num_results);
784793 }
785794
795+ set_signal_cancel_handler (nullptr );
796+
786797 if (results == nullptr ) {
787798 LOG_ERROR (" generate failed" );
788799 free_sd_ctx (sd_ctx);
@@ -836,3 +847,59 @@ int main(int argc, const char* argv[]) {
836847
837848 return 0 ;
838849}
850+
851+ #ifdef SD_ENABLE_SIGNAL_HANDLER
852+
853+ #include < atomic>
854+ #include < csignal>
855+ #include < thread>
856+ #include < unistd.h>
857+
858+ // this lock is needed to avoid a race condition between
859+ // free_sd_ctx and a pending sd_cancel_generation call
860+ std::atomic_flag signal_lock = ATOMIC_FLAG_INIT;
861+ static int g_sigint_cnt;
862+ static sd_ctx_t * g_sd_ctx;
863+
864+ static void sig_cancel_handler (int /* signum */ )
865+ {
866+ if (!signal_lock.test_and_set (std::memory_order_acquire)) {
867+ if (g_sd_ctx != nullptr ) {
868+ if (g_sigint_cnt == 1 ) {
869+ char msg[] = " \n got cancel signal, cancelling new generations\n " ;
870+ write (2 , msg, sizeof (msg)-1 );
871+ /* first signal cancels only the remaining latents on a batch */
872+ sd_cancel_generation (g_sd_ctx, SD_CANCEL_NEW_LATENTS);
873+ ++g_sigint_cnt;
874+ } else {
875+ char msg[] = " \n got cancel signal, cancelling everything\n " ;
876+ write (2 , msg, sizeof (msg)-1 );
877+ /* cancels everything */
878+ sd_cancel_generation (g_sd_ctx, SD_CANCEL_ALL);
879+ }
880+ }
881+ signal_lock.clear (std::memory_order_release);
882+ }
883+ }
884+
885+ static void set_signal_cancel_handler (sd_ctx_t * sd_ctx)
886+ {
887+ if (g_sigint_cnt == 0 ) {
888+ g_sigint_cnt++;
889+ struct sigaction sa{};
890+ sa.sa_handler = sig_cancel_handler;
891+ sa.sa_flags = SA_RESTART;
892+ sigaction (SIGUSR1, &sa, nullptr );
893+ }
894+
895+ while (signal_lock.test_and_set (std::memory_order_acquire)) {
896+ std::this_thread::yield ();
897+ }
898+
899+ g_sd_ctx = sd_ctx;
900+
901+ signal_lock.clear (std::memory_order_release);
902+ }
903+
904+ #endif
905+
0 commit comments