Skip to content

Commit 707ffb1

Browse files
wbrunadonington
andcommitted
feat(server): cancel current generation on client disconnect
Co-authored-by: donington <jandastroy@gmail.com>
1 parent 3e937be commit 707ffb1

1 file changed

Lines changed: 25 additions & 6 deletions

File tree

examples/server/main.cpp

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#include <chrono>
33
#include <filesystem>
44
#include <fstream>
5+
#include <future>
56
#include <iomanip>
67
#include <iostream>
78
#include <mutex>
@@ -368,6 +369,18 @@ int main(int argc, const char** argv) {
368369
return httplib::Server::HandlerResponse::Unhandled;
369370
});
370371

372+
auto wait_for_generation = [](std::future<void>& ft, sd_ctx_t* sd_ctx, const httplib::Request& req) {
373+
std::future_status ft_status;
374+
do {
375+
if (!ft.valid())
376+
break;
377+
ft_status = ft.wait_for(std::chrono::milliseconds(1000));
378+
if (req.is_connection_closed()) {
379+
sd_cancel_generation(sd_ctx, SD_CANCEL_ALL);
380+
}
381+
} while (ft_status != std::future_status::ready);
382+
};
383+
371384
// root
372385
svr.Get("/", [&](const httplib::Request&, httplib::Response& res) {
373386
if (!svr_params.serve_html_path.empty()) {
@@ -510,11 +523,13 @@ int main(int argc, const char** argv) {
510523
sd_image_t* results = nullptr;
511524
int num_results = 0;
512525

513-
{
526+
std::future<void> ft = std::async(std::launch::async, [&]() {
514527
std::lock_guard<std::mutex> lock(sd_ctx_mutex);
515528
results = generate_image(sd_ctx, &img_gen_params);
516529
num_results = gen_params.batch_count;
517-
}
530+
});
531+
532+
wait_for_generation(ft, sd_ctx, req);
518533

519534
for (int i = 0; i < num_results; i++) {
520535
if (results[i].data == nullptr) {
@@ -756,11 +771,13 @@ int main(int argc, const char** argv) {
756771
sd_image_t* results = nullptr;
757772
int num_results = 0;
758773

759-
{
774+
std::future<void> ft = std::async(std::launch::async, [&]() {
760775
std::lock_guard<std::mutex> lock(sd_ctx_mutex);
761776
results = generate_image(sd_ctx, &img_gen_params);
762777
num_results = gen_params.batch_count;
763-
}
778+
});
779+
780+
wait_for_generation(ft, sd_ctx, req);
764781

765782
json out;
766783
out["created"] = static_cast<long long>(std::time(nullptr));
@@ -1071,11 +1088,13 @@ int main(int argc, const char** argv) {
10711088
sd_image_t* results = nullptr;
10721089
int num_results = 0;
10731090

1074-
{
1091+
std::future<void> ft = std::async(std::launch::async, [&]() {
10751092
std::lock_guard<std::mutex> lock(sd_ctx_mutex);
10761093
results = generate_image(sd_ctx, &img_gen_params);
10771094
num_results = gen_params.batch_count;
1078-
}
1095+
});
1096+
1097+
wait_for_generation(ft, sd_ctx, req);
10791098

10801099
json out;
10811100
out["images"] = json::array();

0 commit comments

Comments
 (0)